8000 add normalization · pytorch/pytorch@e79d3c6 · GitHub
[go: up one dir, main page]

Skip to content

Commit e79d3c6

Browse files
abhilash1910pytorchmergebot
authored andcommitted
add normalization
1 parent 7c9fb75 commit e79d3c6

File tree

1 file changed

+2
-0
lines changed
  • torch/distributed/tensor/parallel

1 file changed

+2
-0
lines changed

torch/distributed/tensor/parallel/loss.py

+2
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def _nll_loss_forward_handler(
281281

282282
channel_dim = 1 if x.dim() >= 2 else 0
283283
spec = x._spec
284+
dim = normalize_dim(dim, x.dim())
284285
mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
285286

286287
# Check user input: if target and weight are not DTensors, convert them to DTensors;
@@ -427,6 +428,7 @@ def _nll_loss_backward_handler(
427428

428429
channel_dim = 1 if x.dim() >= 2 else 0
429430
spec = x._spec
431+
dim = normalize_dim(dim, x.dim())
430432
mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
431433

432434
# if target and weight are not DTensors, convert them to DTensors

0 commit comments

Comments
 (0)
0