@@ -427,33 +427,35 @@ def _nll_loss_nd(
427
427
reduction : str ,
428
428
ignore_index : int ,
429
429
) -> TensorLikeType :
430
- if input .ndim == 3 or input .ndim > 4 :
431
- msg = "Expected input dimension to be either [1, 2, 4] but recieved {}."
432
- raise ValueError (msg .format (input .ndim ))
430
+ utils .check (
431
+ input .ndim < 4 and input .ndim != 3 ,
432
+ lambda : f"Expected input dimension to be either [1, 2, 4] but recieved { input .ndim } ." ,
433
+ )
433
434
434
- if input .ndim != 1 and input .shape [0 ] != target .shape [0 ]:
435
- msg = "Expected input batch size ({}) to match target batch size ({})."
436
- raise ValueError (msg .format (input .shape [0 ], target .shape [0 ]))
435
+ utils .check (
436
+ (input .ndim == 1 ) or (input .shape [0 ] == target .shape [0 ]),
437
+ lambda : f"Expected input batch size { input .shape [0 ]} to match target batch size { target .shape [0 ]} ." ,
438
+ )
437
439
438
440
_check_reduction_value (reduction )
439
441
440
442
flat_target = torch .reshape (target , [- 1 ])
441
443
ignore_classes_mask = torch .eq (flat_target , ignore_index )
442
444
ignore_class_weight = torch .scalar_tensor (0 , dtype = input .dtype , device = input .device )
443
- default_class_weight = torch .scalar_tensor (
444
- 1 , dtype = input .dtype , device = input .device
445
- )
446
445
447
446
# TODO: This check does not work with FakeTensor inputs
448
447
"""
449
448
num_classes = input.shape[1] if input.ndim > 1 else input.shape[0]
450
449
valid_classes_mask = torch.logical_and((flat_target >= 0), (flat_target < num_classes))
451
450
if not torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)):
452
- print(target, num_classes, ignore_index)
453
451
raise ValueError("Target class is out-of-bounds and not ignore index")
454
452
"""
455
453
454
+ # TODO Add comment for expansion
456
455
if weight is None :
456
+ default_class_weight = torch .scalar_tensor (
457
+ 1 , dtype = input .dtype , device = input .device
458
+ )
457
459
current_weight = torch .where (
458
460
ignore_classes_mask ,
459
461
ignore_class_weight ,
@@ -466,6 +468,7 @@ def _nll_loss_nd(
466
468
weight [flat_target ],
467
469
)
468
470
471
+ # TODO Add comments for each case
469
472
if input .ndim == 1 :
470
473
loss = - input [target ] * current_weight
471
474
elif input .ndim == 2 :
@@ -488,6 +491,7 @@ def _nll_loss_nd(
488
491
elif reduction == "sum" :
489
492
return torch .sum (loss )
490
493
else :
494
+ # TODO Add comments "mean" reduction case
491
495
return torch .sum (loss ) / torch .sum (current_weight )
492
496
493
497
@@ -506,34 +510,39 @@ def nll_loss(
506
510
# msg = "size_average and reduce args are deprecated, please use reduction argument."
507
511
reduction = _get_string_reduction_arg (size_average = size_average , reduce = reduce )
508
512
509
- if input .ndim == 3 or input .ndim > 4 :
510
- # input ndim is == 3 or > 4
511
- batch_size = input .shape [0 ]
512
- num_classes = input .shape [1 ]
513
- out_size = [batch_size ] + list (input .shape [2 :])
514
-
515
- if target .shape [1 :] != input .shape [2 :]:
516
- msg = "Expected target size {} but got {}"
517
- raise ValueError (msg .format (out_size , target .shape ))
518
-
519
- # support empty batches, see #15870
520
- if input .numel () > 0 :
521
- input = torch .reshape (input , [batch_size , num_classes , 1 , - 1 ])
522
- else :
523
- input = torch .reshape (input , [batch_size , num_classes , 0 , 0 ])
524
-
525
- if target .numel () > 0 :
526
- target = torch .reshape (target , [batch_size , 1 , - 1 ])
527
- else :
528
- target = torch .reshape (target , [batch_size , 0 , 0 ])
529
-
530
- if reduction == "none" :
531
- return _nll_loss_nd (input , target , weight , reduction , ignore_index )
532
- else :
533
- result = _nll_loss_nd (input , target , weight , reduction , ignore_index )
534
- return torch .reshape (result , out_size )
513
+ # TODO Can input be zero or one dimension? If so, how do we interpret that?
514
+ # The documentation for suggests that input should have at least two dimensions.
515
+ # Why are inputs with three or four dimensions special?
516
+ if input .ndim < 4 and input .ndim != 3 :
517
+ return _nll_loss_nd (input , target , weight , reduction , ignore_index )
518
+
519
+ # TODO Add comment for this case
520
+ # input.ndim == 3 or input.ndim > 4
521
+ batch_size = input .shape [0 ]
522
+ num_classes = input .shape [1 ]
523
+ out_size = [batch_size ] + list (input .shape [2 :])
524
+
525
+ utils .check (
526
+ target .shape [1 :] != input .shape [2 :],
527
+ lambda : f"Expected target shape
67ED
{ out_size } but got { target .shape } " ,
528
+ )
529
+
530
+ # support empty batches, see #15870
531
+ if input .numel () > 0 :
532
+ input = torch .reshape (input , [batch_size , num_classes , 1 , - 1 ])
533
+ else :
534
+ input = torch .reshape (input , [batch_size , num_classes , 0 , 0 ])
535
+
536
+ if target .numel () > 0 :
537
+ target = torch .reshape (target , [batch_size , 1 , - 1 ])
535
538
else :
539
+ target = torch .reshape (target , [batch_size , 0 , 0 ])
540
+
541
+ if reduction == "none" :
536
542
return _nll_loss_nd (input , target , weight , reduction , ignore_index )
543
+ else :
544
+ result = _nll_loss_nd (input , target , weight , reduction , ignore_index )
545
+ return torch .reshape (result , out_size )
537
546
538
547
539
548
# tanhshrink does not use _make_elementwise_unary_reference because it does not support out
0 commit comments