8000 Propagate CreationMeta when chaining views (#51061) · pytorch/pytorch@0b5303e · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 0b5303e

Browse files
jbschlosserfacebook-github-bot
authored andcommitted
Propagate CreationMeta when chaining views (#51061)
Summary: Fixes #49824 ## Background When creating a view of a view, there was a possibility that the new view would be less restrictive than the previous view, incorrectly sidestepping the error that should be thrown when using in-place operations on the new view. The fix addresses this by propagating `CreationMeta` from the previous view to the new view. Currently, the old view's `creation_meta` is only propagated when the new view's `creation_meta == CreationMeta::DEFAULT`. This ensures that the new view is not less restrictive than the previous view wrt. allowing in-place operations. Pull Request resolved: #51061 Test Plan: ``` python test/test_autograd.py TestAutogradDeviceTypeCPU.test_inplace_view_of_multiple_output_view_cpu python test/test_autograd.py TestAutogradDeviceTypeCUDA.test_inplace_view_of_multiple_output_view_cuda python test/test_autograd.py TestAutogradDeviceTypeCPU.test_inplace_multiple_output_view_of_view_cpu python test/test_autograd.py TestAutogradDeviceTypeCUDA.test_inplace_multiple_output_view_of_view_cuda ``` Reviewed By: heitorschueroff Differential Revision: D26076434 Pulled By: jbschlosser fbshipit-source-id: c47f0ddcef9b8449427b671aff9ad08edca70fcd
1 parent 5ec2e26 commit 0b5303e

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

test/test_autograd.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7377,6 +7377,20 @@ def test_inplace_view_multiple_outputs(self, device):
73777377
with self.assertRaises(RuntimeError):
73787378
v1[0].mul_(2)
73797379

7380+
def test_inplace_view_of_multiple_output_view(self, device):
7381+
a = torch.rand(10, device=device, requires_grad=True).clone()
7382+
b = a.unbind(0)
7383+
c = b[0].view_as(b[0])
7384+
with self.assertRaises(RuntimeError):
7385+
c.mul_(2)
7386+
7387+
def test_inplace_multiple_output_view_of_view(self, device):
7388+
a = torch.rand(10, device=device, requires_grad=True).clone()
7389+
b = a.view_as(a)
7390+
c = b.unbind(0)
7391+
with self.assertRaises(RuntimeError):
7392+
c[0].mul_(2)
7393+
73807394
def test_inplace_view_makes_base_require_grad(self, device):
73817395
# in-place modification to view makes base require grad
73827396
a = torch.randn(4, 4, device=device, requires_grad=False)

torch/csrc/autograd/VariableTypeUtils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_dif
145145
if (base.is_view()) {
146146
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
147147
const auto& base_bw_info = diff_view_meta->get_backward_view();
148+
creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta);
148149
return make_variable_differentiable_view(tensor, base_bw_info.chain(base, tensor, view_func),
149150
c10::nullopt, creation_meta, allow_tensor_metadata_change);
150151
} else {
@@ -188,6 +189,10 @@ inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_dif
188189
}
189190

190191
if (is_fw_differentiable || is_bw_differentiable) {
192+
if (base.is_view()) {
193+
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
194+
creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta);
195+
}
191196
return make_variable_differentiable_view(tensor, std::move(new_bw_info), std::move(new_fw_info),
192197
creation_meta, allow_tensor_metadata_change);
193198
} else {
@@ -234,6 +239,11 @@ inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& ten
234239
}
235240
}
236241

242+
if ((is_fw_differentiable || is_bw_differentiable) && base.is_view()) {
243+
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
244+
creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta);
245+
}
246+
237247
for(Tensor &tensor : tensors) {
238248
if (is_fw_differentiable || is_bw_differentiable) {
239249
tensor = make_variable_differentiable_view(tensor, new_bw_info, new_fw_info, creation_meta);

torch/csrc/autograd/variable.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,15 @@ struct TORCH_API ViewInfo {
502502
enum class CreationMeta: uint8_t { DEFAULT, IN_CUSTOM_FUNCTION, MULTI_OUTPUT_NODE,
503503
NO_GRAD_MODE, MULTI_OUTPUT_SAFE };
504504

505+
/// Handles correctly propagating CreationMeta when a new view is created from a previous view.
506+
/// In general, we don't want the new view to be _less_ restrictive than the previous view
507+
/// (it's okay to be _more_ restrictive). A CreationMeta value of DEFAULT is currently the least
508+
/// restrictive, as the behavior for all other CreationMeta values is to error out for in-place ops.
509+
/// If this changes, the logic here will need to be updated to properly handle the new semantics.
510+
inline CreationMeta propagate_creation_meta(CreationMeta prev_view_creation_meta, CreationMeta new_view_creation_meta) {
511+
return (new_view_creation_meta == CreationMeta::DEFAULT) ? prev_view_creation_meta : new_view_creation_meta;
512+
}
513+
505514
/// Unified function to handle error checking when rebase happens
506515
/// indirect=true means that the caller is not doing the inplace, but the inplace happened
507516
/// somewhere else.

0 commit comments

Comments
 (0)
0