8000 refactor helper function · pytorch/pytorch@370bc60 · GitHub
[go: up one dir, main page]

Skip to content

Commit 370bc60

Browse files
committed
refactor helper function
1 parent 96cc303 commit 370bc60

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

torch/_refs/nn/functional/__init__.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,24 @@ def _check_reduction_value(reduction: str):
304304
raise ValueError("{} is not a valid value for reduction".format(reduction))
305305

306306

307+
# This helper function maps depreciated arguments, "size_average" and "reduce"
308+
# to their corresponding "reduction" string argument
309+
def _get_string_reduction_arg(
310+
*, size_average: Optional[bool], reduce: Optional[bool]
311+
) -> str:
312+
if size_average is None:
313+
size_average = True
314+
if reduce is None:
315+
reduce = True
316+
if size_average and reduce:
317+
ret = "mean"
318+
elif reduce:
319+
ret = "sum"
320+
else:
321+
ret = "none"
322+
return ret
323+
324+
307325
@register_decomposition(torch.ops.aten.margin_ranking_loss)
308326
def margin_ranking_loss(
309327
input1: TensorLikeType,
@@ -350,22 +368,6 @@ def hinge_embedding_loss(
350368
return _apply_loss_reduction(loss, reduction)
351369

352370

353-
def _get_string_reduction_arg(
354-
size_average: Optional[bool], reduce: Optional[bool]
355-
) -> str:
356-
if size_average is None:
357-
size_average = True
358-
if reduce is None:
359-
reduce = True
360-
if size_average and reduce:
361-
ret = "mean"
362-
elif reduce:
363-
ret = "sum"
364-
else:
365-
ret = "none"
366-
return ret
367-
368-
369371
def _nll_loss_nd(
370372
input: TensorLikeType,
371373
target: TensorLikeType,
@@ -437,6 +439,7 @@ def _nll_loss_nd(
437439
return torch.sum(loss) / torch.sum(current_weight)
438440

439441

442+
@register_decomposition(torch.ops.aten.nll_loss)
440443
def nll_loss(
441444
input: TensorLikeType,
442445
target: TensorLikeType,
@@ -449,7 +452,7 @@ def nll_loss(
449452
if size_average is not None or reduce is not None:
450453
# TODO: raise exception instead of converting value
451454
# msg = "size_average and reduce args are deprecated, please use reduction argument."
452-
reduction = _get_string_reduction_arg(size_average, reduce)
455+
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
453456

454457
if input.ndim == 3 or input.ndim > 4:
455458
# input ndim is == 3 or > 4

0 commit comments

Comments
 (0)
0