@@ -304,6 +304,24 @@ def _check_reduction_value(reduction: str):
304
304
raise ValueError ("{} is not a valid value for reduction" .format (reduction ))
305
305
306
306
307
+ # This helper function maps depreciated arguments, "size_average" and "reduce"
308
+ # to their corresponding "reduction" string argument
309
+ def _get_string_reduction_arg (
310
+ * , size_average : Optional [bool ], reduce : Optional [bool ]
311
+ ) -> str :
312
+ if size_average is None :
313
+ size_average = True
314
+ if reduce is None :
315
+ reduce = True
316
+ if size_average and reduce :
317
+ ret = "mean"
318
+ elif reduce :
319
+ ret = "sum"
320
+ else :
321
+ ret = "none"
322
+ return ret
323
+
324
+
307
325
@register_decomposition (torch .ops .aten .margin_ranking_loss )
308
326
def margin_ranking_loss (
309
327
input1 : TensorLikeType ,
@@ -350,22 +368,6 @@ def hinge_embedding_loss(
350
368
return _apply_loss_reduction (loss , reduction )
351
369
352
370
353
- def _get_string_reduction_arg (
354
- size_average : Optional [bool ], reduce : Optional [bool ]
355
- ) -> str :
356
- if size_average is None :
357
- size_average = True
358
- if reduce is None :
359
- reduce = True
360
- if size_average and reduce :
361
- ret = "mean"
362
- elif reduce :
363
- ret = "sum"
364
- else :
365
- ret = "none"
366
- return ret
367
-
368
-
369
371
def _nll_loss_nd (
370
372
input : TensorLikeType ,
371
373
target : TensorLikeType ,
@@ -437,6 +439,7 @@ def _nll_loss_nd(
437
439
return torch .sum (loss ) / torch .sum (current_weight )
438
440
439
441
442
+ @register_decomposition (torch .ops .aten .nll_loss )
440
443
def nll_loss (
441
444
input : TensorLikeType ,
442
445
target : TensorLikeType ,
@@ -449,7 +452,7 @@ def nll_loss(
449
452
if size_average is not None or reduce is not None :
450
453
# TODO: raise exception instead of converting value
451
454
# msg = "size_average and reduce args are deprecated, please use reduction argument."
452
- reduction = _get_string_reduction_arg (size_average , reduce )
455
+ reduction = _get_string_reduction_arg (size_average = size_average , reduce = reduce )
453
456
454
457
if input .ndim == 3 or input .ndim > 4 :
455
458
# input ndim is == 3 or > 4
0 commit comments