10000 Relax aten.to restriction (#142420) · pytorch/pytorch@0e1675a · GitHub
[go: up one dir, main page]

Skip to content

Commit 0e1675a

Browse files
yushangdipytorchmergebot
authored andcommitted
Relax aten.to restriction (#142420)
Summary: if we have a.to(b), and b has a different dtype with a, then it must be a copy. In this case, we do not need to freeze the tensor. Instead, we use torch.ops.aten._assert_tensor_metadata.default to ensure that a must not have the same dtype as b. Fixes #139718 Update executorch pin to include pytorch/executorch#7277. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_float_conversion buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_device_to_mutation_float ``` Differential Revision: D66988295 Pull Request resolved: #142420 Approved by: https://github.com/bdhirsh
1 parent 768d73f commit 0e1675a

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
6f638937d64e3396793956d75ee3e14802022745
1+
a29b208a06ab378bb29ab1aa68932e412f8e09f1

test/export/test_export.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5108,6 +5108,29 @@ def forward(self, x):
51085108
for op in ops:
51095109
self.assertIn(op, (torch.ops.aten._to_copy.default,))
51105110

5111+
def test_float_conversion_from_int(self):
5112+
class Module(torch.nn.Module):
5113+
def forward(self, x):
5114+
return x.float()
5115+
5116+
ep = export(Module(), (torch.tensor(1, dtype=torch.int32),)).run_decompositions(
5117+
{}
5118+
)
5119+
ops = []
5120+
for node in ep.graph.nodes:
5121+
if node.op == "call_function":
5122+
ops.append(node.target)
5123+
self.assertGreater(len(ops), 0)
5124+
self.assertIn(torch.ops.aten._to_copy.default, ops)
5125+
self.assertIn(torch.ops.aten._assert_tensor_metadata.default, ops)
5126+
5127+
self.assertEqual(ep.module()(torch.tensor(1, dtype=torch.int32)), 1)
5128+
5129+
# Raises error because the input dtype is not the same as the input
5130+
# tensor when exporting.
5131+
with self.assertRaisesRegex(RuntimeError, "Tensor dtype mismatch!"):
5132+
ep.module()(torch.tensor(1, dtype=torch.float32))
5133+
51115134
def test_device_to_mutation_float(self):
51125135
class Module(torch.nn.Module):
51135136
def forward(self, x):

torch/_subclasses/functional_tensor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,27 @@ def unwrap(x):
535535
torch.ops.aten.dropout.default,
536536
torch.ops.aten._to_copy.default,
537537
):
538-
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
538+
539+
def must_copy():
540+
"""
541+
Return True if the output of the op must be copied, not an alias
542+
"""
543+
# output dtype is different from input
544+
return (
545+
func == torch.ops.aten._to_copy.default
546+
and "dtype" in kwargs
547+
and kwargs["dtype"] != args_unwrapped[0].dtype
548+
)
549+
550+
if must_copy():
551+
# We can further relax to args_unwrapped[0] != kwargs["dtype"], but I don't think
552+
# we have an aten op for that.
553+
torch.ops.aten._assert_tensor_metadata.default(
554+
torch._from_functional_tensor(args_unwrapped[0]),
555+
dtype=args_unwrapped[0].dtype,
556+
)
557+
else:
558+
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
539559
outs_wrapped = pytree.tree_map_only(
540560
torch.Tensor, wrap, outs_unwrapped
541561
)

0 commit comments

Comments
 (0)
0