10BC0 fixup · pytorch/pytorch@b1393e5 · GitHub
[go: up one dir, main page]

Skip to content

Commit b1393e5

Browse files
committed
fixup
1 parent 2256d36 commit b1393e5

File tree

1 file changed

+21
-24
lines changed

1 file changed

+21
-24
lines changed

torch/_refs/nn/functional/__init__.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -381,22 +381,26 @@ def _nll_loss_2d(
381381
msg = "Expected 3 dimensions for input but got {}."
382382
raise ValueError(msg.format(input.ndim))
383383

384+
if input.shape[0] != target.shape[0]:
385+
msg = "Expected input batch size ({}) to match target batch size ({})."
386+
raise ValueError(msg.format(input.shape[0], target.shape[0]))
387+
384388
_check_reduction_value(reduction)
385389

386-
current_target = torch.reshape(target, [-1])
387-
ignore_classes_mask = torch.eq(current_target, ignore_index)
390+
flat_target = torch.reshape(target, [-1])
391+
ignore_classes_mask = torch.eq(flat_target, ignore_index)
392+
ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device)
388393
if weight is None:
389394
current_weight = torch.where(
390395
ignore_classes_mask,
391-
torch.scalar_tensor(0, dtype=input.dtype, device=input.device),
396+
ignore_class_weight,
392397
torch.scalar_tensor(1, dtype=input.dtype, device=input.device),
393398
)
394399
else:
395-
ignore_class_weight = torch.scalar_tensor(
396-
0, dtype=input.dtype, device=input.device
397-
).expand_as(current_target)
398400
current_weight = torch.where(
399-
ignore_classes_mask, ignore_class_weight, weight[current_target]
401+
ignore_classes_mask,
402+
ignore_class_weight.expand_as(flat_target),
403+
weight[flat_target],
400404
)
401405

402406
batch_size = input.shape[0]
@@ -409,7 +413,7 @@ def _nll_loss_2d(
409413
hdx = (torch.arange(numel) - (bdx * extent)) // width
410414
wdx = torch.arange(numel) % width
411415

412-
loss = -input[bdx, current_target, hdx, wdx] * current_weight
416+
loss = -input[bdx, flat_target, hdx, wdx] * current_weight
413417
loss = torch.reshape(loss, target.shape)
414418

415419
if reduction == "none":
@@ -427,8 +431,8 @@ def _nll_loss_1d(
427431
reduction: str,
428432
ignore_index: int,
429433
) -> TensorLikeType:
430-
if input.ndim < 1:
431-
msg = "Expected 1 or more dimension for input but got {}."
434+
if input.ndim < 1 or input.ndim > 2:
435+
msg = "Expected 1 or 2 dimensions for input but got {}."
432436
raise ValueError(msg.format(input.ndim))
433437

434438
if input.ndim != 1 and input.shape[0] != target.shape[0]:
@@ -437,31 +441,24 @@ def _nll_loss_1d(
437441

438442
_check_reduction_value(reduction)
439443

440-
if target.ndim <= 1:
441-
current_target = target
442-
else:
443-
current_target = target[:, 0]
444-
445-
ignore_classes_mask = torch.eq(current_target, ignore_index)
444+
ignore_classes_mask = torch.eq(target, ignore_index)
445+
ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device)
446446
if weight is None:
447447
current_weight = torch.where(
448448
ignore_classes_mask,
449-
torch.scalar_tensor(0, dtype=input.dtype, device=input.device),
449+
ignore_class_weight,
450450
torch.scalar_tensor(1, dtype=input.dtype, device=input.device),
451451
)
452452
else:
453-
ignore_class_weight = torch.scalar_tensor(
454-
0, dtype=input.dtype, device=input.device
455-
).expand_as(current_target)
456453
current_weight = torch.where(
457-
ignore_classes_mask, ignore_class_weight, weight[current_target]
454+
ignore_classes_mask, ignore_class_weight.expand_as(target), weight[target]
458455
)
459456

460-
batch_size = input.shape[0]
461457
if input.ndim == 1:
462-
loss = -input[current_target] * current_weight
458+
loss = -input[target] * current_weight
463459
else:
464-
loss = -input[torch.arange(batch_size), current_target] * current_weight
460+
batch_size = input.shape[0]
461+
loss = -input[torch.arange(batch_size), target] * current_weight
465462

466463
if reduction == "none":
467464
return loss

0 commit comments

Comments
 (0)
0