File tree Expand file tree Collapse file tree 1 file changed +10
-16
lines changed
torch/_refs/nn/functional Expand file tree Collapse file tree 1 file changed +10
-16
lines changed Original file line number Diff line number Diff line change @@ -451,22 +451,16 @@ def _nll_loss_nd(
451
451
raise ValueError("Target class is out-of-bounds and not ignore index")
452
452
"""
453
453
454
- # TODO Add comment for expansion
455
- if weight is None :
456
- default_class_weight = torch .scalar_tensor (
457
- 1 , dtype = input .dtype , device = input .device
458
- )
459
- current_weight = torch .where (
460
- ignore_classes_mask ,
461
- ignore_class_weight ,
462
- default_class_weight ,
463
- )
464
- else :
465
- current_weight = torch .where (
466
- ignore_classes_mask ,
467
- ignore_class_weight .expand_as (flat_target ),
468
- weight [flat_target ],
469
- )
454
+ class_weight = (
455
+ torch .scalar_tensor (1 , dtype = input .dtype , device = input .device )
456
+ if weight is None
457
+ else weight [flat_target ]
458
+ )
459
+ current_weight = torch .where (
460
+ ignore_classes_mask ,
461
+ ignore_class_weight ,
462
+ class_weight ,
463
+ )
470
464
471
465
# TODO Add comments for each case
472
466
if input .ndim == 1 :
You can’t perform that action at this time.
0 commit comments