8000 [Autograd] Improve error for leaf tensors as out argument to fallback… · pytorch/pytorch@34a28f0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 34a28f0

Browse files
peterbell10pytorchmergebot
authored andcommitted
[Autograd] Improve error for leaf tensors as out argument to fallback (#121089)
Closes #120988 Currently operators that hit the autograd fallback call `check_inplace` on all mutated inputs, including out arguments. This leads to a slightly confusing error message: ``` RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. ``` Compared to functions that don't fallback, which raise ``` RuntimeError: add(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad. ``` This changes the error message to make clear the issue is with the out argument, but does not tighten the check to outright ban out arguments that require grad. Instead, I use the same checks from `check_inplace` which allows non-leaf tensors that require grad to pass without error. Pull Request resolved: #121089 Approved by: https://github.com/lezcano, https://github.com/soulitzer ghstack dependencies: #121142
1 parent eae9751 commit 34a28f0

File tree

4 files changed

+95
-14
lines changed

4 files changed

+95
-14
lines changed

test/test_ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,38 @@ def _case_four_transform(t):
999999
with self.assertRaises(RuntimeError, msg=msg_fail):
10001000
op_out(out=out)
10011001

1002+
@ops(
1003+
[op for op in op_db if op.supports_out and (op.supports_autograd or op.is_factory_function)],
1004+
dtypes=OpDTypes.supported,
1005+
allowed_dtypes=[torch.float, torch.cfloat]
1006+
)
1007+
def test_out_requires_grad_error(self, device, dtype, op):
1008+
sample = first_sample(self, op.sample_inputs(device, dtype))
1009+
1010+
# Call op to get prototype for out arguments
1011+
expect = op(sample.input, *sample.args, **sample.kwargs)
1012+
any_requires_grad = False
1013+
1014+
def set_requires_grad(x):
1015+
nonlocal any_requires_grad
1016+
if isinstance(x, torch.Tensor) and (
1017+
x.is_floating_point() or x.is_complex()
1018+
):
1019+
any_requires_grad = True
1020+
x.requires_grad_(True)
1021+
return x
1022+
1023+
out = pytree.tree_map_(set_requires_grad, expect)
1024+
if not any_requires_grad:
1025+
# Skip ops without any floating point outputs, e.g. isnan
1026+
return
1027+
1028+
msg = (
1029+
"functions with out=... arguments don't support automatic "
1030+
"differentiation, but one of the arguments requires grad."
1031+
)
1032+
with self.assertRaises(RuntimeError, msg=msg):
1033+
op(sample.input, *sample.args, **sample.kwargs, out=out)
10021034

10031035
@ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,))
10041036
def test_out_integral_dtype(self, device, dtype, op):

torch/csrc/autograd/VariableTypeUtils.h

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,60 @@
3131

3232
namespace torch {
3333
namespace autograd {
34+
enum class can_mutate_inplace_result {
35+
success,
36+
non_default_backward_view,
37+
view_of_leaf,
38+
is_leaf,
39+
};
3440

3541
// The requires_grad argument is used to know if the inplace operation needs
3642
// gradient to be setup for it.
3743
// In particular, we can have tensor.requires_grad() != requires_grad when
3844
// writing a Tensor that requires gradients inplace into a Tensor that does not
3945
// require gradients: a = torch.rand(2) b = torch.rand(2, requires_grad=True)
4046
// a.copy_(b)
47+
inline can_mutate_inplace_result can_mutate_inplace(
48+
const at::Tensor& tensor,
49+
bool requires_grad) {
50+
if (!requires_grad || !GradMode::is_enabled()) {
51+
return can_mutate_inplace_result::success;
52+
}
53+
auto diff_view_meta = impl::get_view_autograd_meta(tensor);
54+
if (diff_view_meta && diff_view_meta->has_bw_view()) {
55+
if (diff_view_meta->get_creation_meta() != CreationMeta::DEFAULT) {
56+
return can_mutate_inplace_result::non_default_backward_view;
57+
}
58+
if (tensor.requires_grad() && tensor._base().is_leaf()) {
59+
return can_mutate_inplace_result::view_of_leaf;
60+
}
61+
}
62+
if (tensor.requires_grad() && tensor.is_leaf()) {
63+
return can_mutate_inplace_result::is_leaf;
64+
}
65+
return can_mutate_inplace_result::success;
66+
}
67+
4168
inline void check_inplace(const at::Tensor& tensor, bool requires_grad) {
42-
if (requires_grad && GradMode::is_enabled()) {
43-
auto diff_view_meta = impl::get_view_autograd_meta(tensor);
44-
if (diff_view_meta && diff_view_meta->has_bw_view()) {
45-
// This can throw or warn
46-
handle_view_on_rebase(diff_view_meta);
47-
if (tensor.requires_grad() && tensor._base().is_leaf()) {
48-
TORCH_CHECK(
49-
false,
50-
"a view of a leaf Variable that requires grad is being used in an in-place o 8000 peration.");
51-
}
69+
switch (can_mutate_inplace(tensor, requires_grad)) {
70+
case can_mutate_inplace_result::success:
71+
return;
72+
case can_mutate_inplace_result::non_default_backward_view: {
73+
return handle_view_on_rebase(impl::get_view_autograd_meta(tensor));
5274
}
53-
if (tensor.requires_grad() && tensor.is_leaf()) {
75+
case can_mutate_inplace_result::view_of_leaf:
76+
TORCH_CHECK(
77+
false,
78+
"a view of a leaf Variable that requires grad is being used in an in-place operation.");
79+
break;
80+
81+
case can_mutate_inplace_result::is_leaf:
5482
TORCH_CHECK(
5583
false,
5684
"a leaf Variable that requires grad is being used in an in-place operation.");
57-
}
85+
break;
5886
}
87+
TORCH_INTERNAL_ASSERT(false);
5988
}
6089

6190
inline void check_inplace(at::ITensorListRef tensors, bool requires_grad) {

torch/csrc/autograd/autograd_not_implemented_fallback.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,26 @@ static void autogradNotImplementedFallbackImpl(
299299
num_arguments);
300300

301301
const bool any_requires_grad = !tensors_requiring_grad_on_stack.empty();
302+
const bool has_out_arg = std::any_of(
303+
schema.arguments().begin(),
304+
schema.arguments().end(),
305+
[](const c10::Argument& arg) { return arg.is_out(); });
302306

303307
_foreach_tensor(
304308
[&](size_t _, size_t i, const at::Tensor& t) {
305309
if (schema.is_mutable({c10::SchemaArgType::input, i})) {
306-
check_inplace(t, any_requires_grad);
310+
if (has_out_arg) {
311+
// Normally out argument overloads would not support any arguments
312+
// that require grad. However, we loosen this check to maintain
313+
// backward compatibility.
314+
// See https://github.com/pytorch/pytorch/issues/120988
315+
if (can_mutate_inplace(t, any_requires_grad) !=
316+
can_mutate_inplace_result::success) {
317+
throw_error_out_requires_grad(schema.name().c_str());
318+
}
319+
} else {
320+
check_inplace(t, any_requires_grad);
321+
}
307322
}
308323
},
309324
stack,

torch/testing/_internal/common_methods_invocations.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14005,6 +14005,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1400514005
operator_variant=None,
1400614006
inplace_operator_variant=None,
1400714007
check_batched_gradgrad=False,
14008+
supports_out=False,
1400814009
supports_forward_ad=True,
1400914010
supports_fwgrad_bwgrad=True,
1401014011
check_batched_forward_grad=False,
@@ -15567,7 +15568,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1556715568
sample_inputs_func=sample_inputs_split_with_sizes,
1556815569
supports_out=True,
1556915570
supports_forward_ad=True,
15570-
supports_fwgrad_bwgrad=True),
15571+
supports_fwgrad_bwgrad=True,
15572+
skips=(
15573+
# No error raised
15574+
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_requires_grad_error"),
15575+
)),
1557115576
BinaryUfuncInfo('__radd__',
1557215577
op=torch.Tensor.__radd__,
1557315578
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),

0 commit comments

Comments
 (0)
0