8000 Add ref for relu6, fixes hardshrink and improves testing of related o… · pytorch/pytorch@28776c4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 28776c4

Browse files
soulitzerpytorchmergebot
authored andcommitted
Add ref for relu6, fixes hardshrink and improves testing of related ops (#81142)
This PR: - Adds ref for relu6 and makes its OpInfo a UnaryUfuncInfo - Correct hardshrink ref when lambd < 0 and when inputs are nan - Corrected nan behavior vectorized implementation of hardshrink (fixes #81138) - Make OpInfos for {hard,soft}shrink, hardtanh UnaryUfuncInfos and add error_inputs for softshrink Pull Request resolved: #81142 Approved by: https://github.com/Lezcano, https://github.com/ngimel, https://github.com/mruberry
1 parent 9ee3120 commit 28776c4

File tree

3 files changed

+111
-65
lines changed

3 files changed

+111
-65
lines changed

aten/src/ATen/native/cpu/Activation.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,14 +558,15 @@ void hardsigmoid_backward_kernel(TensorIteratorBase& iter) {
558558
void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
559559
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "hardshrink_cpu", [&] {
560560
auto lambd_val = lambd.to<scalar_t>();
561+
using Vec = Vectorized<scalar_t>;
561562
cpu_kernel_vec(
562563
iter,
563564
[=](scalar_t self_val) {
564565
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0)
565566
: self_val;
566567
},
567-
[=](Vectorized<scalar_t> self_val) {
568-
return ((self_val < -lambd_val) | (self_val > lambd_val)) & self_val;
568+
[=](Vec self_val) {
569+
return Vec::blendv(self_val, Vec(0), (self_val >= -lambd_val) & (self_val <= lambd_val));
569570
});
570571
});
571572
}

torch/_refs/nn/functional/__init__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def hardshrink(a: TensorLikeType, lambd: float = 0.5):
272272
# hardshrink(x) = x if x > lambd
273273
# = x if x < -lambd
274274
# = 0 otherwise
275-
return refs.where(abs(a) > abs(lambd), a, 0)
275+
return refs.where(refs.logical_and(a >= -lambd, a <= lambd), 0, a)
276276

277277

278278
@register_decomposition(torch.ops.aten.softshrink)
@@ -282,6 +282,10 @@ def softshrink(a: TensorLikeType, lambd: float = 0.5):
282282
# softshrink(x) = x - lambd if x > lambd
283283
# = x + lambd if x < -lambd
284284
# = 0 otherwise
285+
check(
286+
lambd >= 0,
287+
lambda: f"lambda must be greater or equal to 0, but found to be {lambd}",
288+
)
285289
ge_mask = a > lambd
286290
le_mask = a < -lambd
287291
zero_mask = torch.logical_not(refs.logical_or(ge_mask, le_mask))
@@ -538,3 +542,16 @@ def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
538542
)
539543

540544
return refs.where(a > 0, a, a * weight)
545+
546+
547+
def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
548+
"""
549+
Reference implementation of torch.nn.functional.relu6
550+
"""
551+
if inplace:
552+
raise NotImplementedError
553+
554+
# See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126
555+
# It may be better to use clamp here, but we use hardtanh to replicate
556+
# the behavior of the existing implementation
557+
return refs.nn.functional.hardtanh(a, 0, 6)

torch/testing/_internal/common_methods_invocations.py

Lines changed: 90 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7287,12 +7287,41 @@ def sample_inputs_linalg_svdvals(op_info, device, dtype, requires_grad=False, **
72877287
for batch, m, n in product(batches, ns, ns):
72887288
yield SampleInput(make_arg(batch + (m, n)))
72897289

7290+
def error_inputs_softshrink(op, device, **kwargs):
7291+
yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"lambd": -0.5}),
7292+
error_regex="lambda must be greater or equal to 0, but found to be -0.5")
72907293

7291-
def sample_inputs_softshrink_hardshrink_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs):
7292-
N = 10
7293-
tensors = [SampleInput(make_tensor((N, N), device=device, dtype=dtype,
7294-
requires_grad=requires_grad)) for _ in range(1, N)]
7295-
return tensors
7294+
def sample_inputs_softshrink(op_info, device, dtype, requires_grad=False, **kwargs):
7295+
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
7296+
7297+
# The additional sample is to check additional values of lambd beyond the default
7298+
# value (what is already checked by sample_inputs_elementwise_unary)
7299+
for lbda in (0., 0.5):
7300+
yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda})
7301+
7302+
yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
7303+
7304+
def sample_inputs_hardshrink(op_info, device, dtype, requires_grad=False, **kwargs):
7305+
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
7306+
7307+
# The additional sample is to check additional values of lambd beyond the default
7308+
# value (what is already checked by sample_inputs_elementwise_unary)
7309+
# Note that unlike softshrink, lambd is allowed to be negative for hardshrink
7310+
for lbda in (-0.5, 0., 0.5):
7311+
yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda})
7312+
7313+
yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
7314+
7315+
7316+
def sample_inputs_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs):
7317+
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
7318+
7319+
# The additional sample is to check additional values of min_val and max_val beyond the default
7320+
# value (what is already checked by sample_inputs_elementwise_unary)
7321+
for max_val, min_val in ((-0.5, 0.5), (0.5, -0.5), (0., 0.)):
7322+
yield SampleInput(make_arg(S, S), kwargs={"min_val": min_val, "max_val": max_val})
7323+
7324+
yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
72967325

72977326
def sample_inputs_eig(op_info, device, dtype, requires_grad=False, **kwargs):
72987327
eigvecs = make_tensor((S, S), device=device, dtype=dtype,
@@ -14992,46 +15021,43 @@ def error_inputs_mean(op_info, device, **kwargs):
1499215021
# # TypeError: igammac(): argument 'input' (position 1) must be Tensor, not float
1499315022
# DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'),
1499415023
# )),
14995-
OpInfo('nn.functional.softshrink',
14996-
aten_name="softshrink",
14997-
aten_backward_name='softshrink_backward',
14998-
dtypes=floating_types_and(torch.bfloat16),
14999-
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
15000-
supports_autograd=True,
15001-
supports_forward_ad=True,
15002-
supports_fwgrad_bwgrad=True,
15003-
assert_autodiffed=False,
15004-
sample_inputs_func=sample_inputs_softshrink_hardshrink_hardtanh,
15005-
supports_gradgrad=True,
15006-
),
15007-
OpInfo('nn.functional.hardshrink',
15008-
aten_name="hardshrink",
15009-
aten_backward_name='hardshrink_backward',
15010-
dtypes=floating_types_and(torch.bfloat16,),
15011-
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
15012-
supports_autograd=True,
15013-
assert_autodiffed=True,
15014-
sample_inputs_func=sample_inputs_softshrink_hardshrink_hardtanh,
15015-
supports_gradgrad=True,
15016-
supports_forward_ad=True,
15017-
supports_fwgrad_bwgrad=True,
15018-
autodiff_nonfusible_nodes=["aten::hardshrink"]),
15019-
OpInfo('nn.functional.hardtanh',
15020-
aten_name="hardtanh",
15021-
aten_backward_name='hardtanh_backward',
15022-
dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16),
15023-
backward_dtypes=all_types(),
15024-
dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.bfloat16),
15025-
backward_dtypesIfCUDA=floating_types_and(torch.float16),
15026-
supports_autograd=True,
15027-
assert_autodiffed=True,
15028-
sample_inputs_func=sample_inputs_softshrink_hardshrink_hardtanh,
15029-
supports_gradgrad=True,
15030-
supports_out=False,
15031-
supports_forward_ad=True,
15032-
supports_fwgrad_bwgrad=True,
15033-
autodiff_nonfusible_nodes=["aten::hardtanh"],
15034-
),
15024+
UnaryUfuncInfo('nn.functional.softshrink',
15025+
ref=_NOTHING,
15026+
aten_name="softshrink",
15027+
aten_backward_name='softshrink_backward',
15028+
dtypes=floating_types_and(torch.bfloat16),
15029+
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
15030+
supports_forward_ad=True,
15031+
supports_fwgrad_bwgrad=True,
15032+
assert_autodiffed=False,
15033+
sample_inputs_func=sample_inputs_softshrink,
15034+
error_inputs_func=error_inputs_softshrink),
15035+
UnaryUfuncInfo('nn.functional.hardshrink',
15036+
ref=_NOTHING,
15037+
aten_name="hardshrink",
15038+
aten_backward_name='hardshrink_backward',
15039+
dtypes=floating_types_and(torch.bfloat16,),
15040+
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
15041+
assert_autodiffed=True,
15042+
sample_inputs_func=sample_inputs_hardshrink,
15043+
supports_forward_ad=True,
15044+
supports_fwgrad_bwgrad=True,
15045+
autodiff_nonfusible_nodes=["aten::hardshrink"]),
15046+
UnaryUfuncInfo('nn.functional.hardtanh',
15047+
ref=_NOTHING,
15048+
aten_name="hardtanh",
15049+
aten_backward_name='hardtanh_backward',
15050+
dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16),
15051+
backward_dtypes=all_types(),
15052+
dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.float16,
15053+
torch.bfloat16),
15054+
backward_dtypesIfCUDA=floating_types_and(torch.float16),
15055+
assert_autodiffed=True,
15056+
sample_inputs_func=sample_inputs_hardtanh,
15057+
supports_out=False,
15058+
supports_forward_ad=True,
15059+
supports_fwgrad_bwgrad=True,
15060+
autodiff_nonfusible_nodes=["aten::hardtanh"]),
1503515061
OpInfo('nn.functional.gelu',
1503615062
aten_name="gelu",
1503715063
aten_backward_name='gelu_backward',
@@ -15050,20 +15076,18 @@ def error_inputs_mean(op_info, device, **kwargs):
1505015076
# AssertionError: Tensor-likes are not close!
1505115077
# May not replicate in CI
1505215078
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),)),
15053-
OpInfo('nn.functional.relu6',
15054-
aten_name="relu6",
15055-
dtypes=all_types_and(torch.bfloat16),
15056-
backward_dtypes=floating_types(),
15057-
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
15058-
backward_dtypesIfCUDA=floating_types_and(torch.float16),
15059-
supports_autograd=True,
15060-
assert_autodiffed=True,
15061-
sample_inputs_func=sample_inputs_softshrink_hardshrink_hardtanh,
15062-
supports_gradgrad=True,
15063-
supports_out=False,
15064-
supports_forward_ad=True,
15065-
supports_fwgrad_bwgrad=True,
15066-
autodiff_nonfusible_nodes=["aten::relu6"]),
15079+
UnaryUfuncInfo('nn.functional.relu6',
15080+
ref=_NOTHING,
15081+
aten_name="relu6",
15082+
dtypes=all_types_and(torch.bfloat16),
15083+
backward_dtypes=floating_types(),
15084+
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
15085+
backward_dtypesIfCUDA=floating_types_and(torch.float16),
15086+
assert_autodiffed=True,
15087+
supports_out=False,
15088+
supports_forward_ad=True,
15089+
supports_fwgrad_bwgrad=True,
15090+
autodiff_nonfusible_nodes=["aten::relu6"]),
1506715091
OpInfo('mm',
1506815092
dtypes=all_types_and_complex_and(torch.bfloat16),
1506915093
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
@@ -20512,7 +20536,7 @@ def __init__(
2051220536
"_refs.nn.functional.elu",
2051320537
torch_opinfo_name="nn.functional.elu",
2051420538
),
20515-
PythonRefInfo(
20539+
ElementwiseUnaryPythonRefInfo(
2051620540
"_refs.nn.functional.hardtanh",
2051720541
torch_opinfo_name="nn.functional.hardtanh",
2051820542
supports_nvfuser=False,
@@ -20544,6 +20568,10 @@ def __init__(
2054420568
torch_opinfo_name="nn.functional.relu",
2054520569
supports_nvfuser=False,
2054620570
),
20571+
ElementwiseUnaryPythonRefInfo(
20572+
"_refs.nn.functional.relu6",
20573+
torch_opinfo_name="nn.functional.relu6",
20574+
),
2054720575
ElementwiseUnaryPythonRefInfo(
2054820576
"_refs.nn.functional.mish",
2054920577
torch_opinfo_name="nn.functional.mish",
@@ -20580,12 +20608,12 @@ def __init__(
2058020608
"_refs.nn.functional.tanhshrink",
2058120609
torch_opinfo_name="nn.functional.tanhshrink",
2058220610
),
20583-
PythonRefInfo(
20611+
ElementwiseUnaryPythonRefInfo(
2058420612
"_refs.nn.functional.hardshrink",
2058520613
torch_opinfo_name="nn.functional.hardshrink",
2058620614
supports_nvfuser=False,
2058720615
),
20588-
PythonRefInfo(
20616+
ElementwiseUnaryPythonRefInfo(
2058920617
"_refs.nn.functional.softshrink",
2059020618
torch_opinfo_name="nn.functional.softshrink",
2059120619
supports_nvfuser=False,

0 commit comments

Comments
 (0)
0