@@ -381,22 +381,26 @@ def _nll_loss_2d(
381381 msg = "Expected 3 dimensions for input but got {}."
382382 raise ValueError (msg .format (input .ndim ))
383383
384+ if input .shape [0 ] != target .shape [0 ]:
385+ msg = "Expected input batch size ({}) to match target batch size ({})."
386+ raise ValueError (msg .format (input .shape [0 ], target .shape [0 ]))
387+
384388 _check_reduction_value (reduction )
385389
386- current_target = torch .reshape (target , [- 1 ])
387- ignore_classes_mask = torch .eq (current_target , ignore_index )
390+ flat_target = torch .reshape (target , [- 1 ])
391+ ignore_classes_mask = torch .eq (flat_target , ignore_index )
392+ ignore_class_weight = torch .scalar_tensor (0 , dtype = input .dtype , device = input .device )
388393 if weight is None :
389394 current_weight = torch .where (
390395 ignore_classes_mask ,
391- torch . scalar_tensor ( 0 , dtype = input . dtype , device = input . device ) ,
396+ ignore_class_weight ,
392397 torch .scalar_tensor (1 , dtype = input .dtype , device = input .device ),
393398 )
394399 else :
395- ignore_class_weight = torch .scalar_tensor (
396- 0 , dtype = input .dtype , device = input .device
397- ).expand_as (current_target )
398400 current_weight = torch .where (
399- ignore_classes_mask , ignore_class_weight , weight [current_target ]
401+ ignore_classes_mask ,
402+ ignore_class_weight .expand_as (flat_target ),
403+ weight [flat_target ],
400404 )
401405
402406 batch_size = input .shape [0 ]
@@ -409,7 +413,7 @@ def _nll_loss_2d(
409413 hdx = (torch .arange (numel ) - (bdx * extent )) // width
410414 wdx = torch .arange (numel ) % width
411415
412- loss = - input [bdx , current_target , hdx , wdx ] * current_weight
416+ loss = - input [bdx , flat_target , hdx , wdx ] * current_weight
413417 loss = torch .reshape (loss , target .shape )
414418
415419 if reduction == "none" :
@@ -427,8 +431,8 @@ def _nll_loss_1d(
427431 reduction : str ,
428432 ignore_index : int ,
429433) -> TensorLikeType :
430- if input .ndim < 1 :
431- msg = "Expected 1 or more dimension for input but got {}."
434+ if input .ndim < 1 or input . ndim > 2 :
435+ msg = "Expected 1 or 2 dimensions for input but got {}."
432436 raise ValueError (msg .format (input .ndim ))
433437
434438 if input .ndim != 1 and input .shape [0 ] != target .shape [0 ]:
@@ -437,31 +441,24 @@ def _nll_loss_1d(
437441
438442 _check_reduction_value (reduction )
439443
440- if target .ndim <= 1 :
441- current_target = target
442- else :
443- current_target = target [:, 0 ]
444-
445- ignore_classes_mask = torch .eq (current_target , ignore_index )
444+ ignore_classes_mask = torch .eq (target , ignore_index )
445+ ignore_class_weight = torch .scalar_tensor (0 , dtype = input .dtype , device = input .device )
446446 if weight is None :
447447 current_weight = torch .where (
448448 ignore_classes_mask ,
449- torch . scalar_tensor ( 0 , dtype = input . dtype , device = input . device ) ,
449+ ignore_class_weight ,
450450 torch .scalar_tensor (1 , dtype = input .dtype
B1D7
span>, device = input .device ),
451451 )
452452 else :
453- ignore_class_weight = torch .scalar_tensor (
454- 0 , dtype = input .dtype , device = input .device
455- ).expand_as (current_target )
456453 current_weight = torch .where (
457- ignore_classes_mask , ignore_class_weight , weight [current_target ]
454+ ignore_classes_mask , ignore_class_weight . expand_as ( target ) , weight [target ]
458455 )
459456
460- batch_size = input .shape [0 ]
461457 if input .ndim == 1 :
462- loss = - input [current_target ] * current_weight
458+ loss = - input [target ] * current_weight
463459 else :
464- loss = - input [torch .arange (batch_size ), current_target ] * current_weight
460+ batch_size = input .shape [0 ]
461+ loss = - input [torch .arange (batch_size ), target ] * current_weight
465462
466463 if reduction == "none" :
467464 return loss
0 commit comments