8000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7c9fb75 commit e79d3c6Copy full SHA for e79d3c6
torch/distributed/tensor/parallel/loss.py
@@ -281,6 +281,7 @@ def _nll_loss_forward_handler(
281
282
channel_dim = 1 if x.dim() >= 2 else 0
283
spec = x._spec
284
+ dim = normalize_dim(dim, x.dim())
285
mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
286
287
# Check user input: if target and weight are not DTensors, convert them to DTensors;
@@ -427,6 +428,7 @@ def _nll_loss_backward_handler(
427
428
429
430
431
432
433
434
# if target and weight are not DTensors, convert them to DTensors
0 commit comments