8000 If we re-fakeify a FakeTensor with the same ShapeEnv, preserve symbol… · pytorch/pytorch@b8b3c26 · GitHub
[go: up one dir, main page]

Skip to content

Commit b8b3c26

Browse files
voznesenskympytorchmergebot
authored andcommitted
If we re-fakeify a FakeTensor with the same ShapeEnv, preserve symbols (#113651)
Subsumes half of #113605 We support fakeifying an already fake tensor, which will give you a new fake tensor mirroring the same structure as the original fake tensor, which is what is needed by #113643 . However, when this refakeification happens, we will naively reallocate all new sizes for all of the fake tensor. This is the right thing to do if you are re-fakeifying on a fresh ShapeEnv (because you're reparametrizing the sizes or something), but if you have two fake tensor modes which are sharing a shape environment, you would actually rather just reuse the original sizes/strides/offset from the original fake tensor. This ends up being pretty simple. I recommend viewing with whitespace diff turned off. There's some fuzz around jagged tensor handling; that code is probably not quite right, but I fixed it for this particular case in the most straightforward way. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #113651 Approved by: https://github.com/albanD, https://github.com/eellison, https://github.com/bdhirsh
1 parent cab039f commit b8b3c26

File tree

3 files changed

+87
-29
lines changed

3 files changed

+87
-29
lines changed

test/test_fake_tensor.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
DynamicOutputShapeException,
1616
UnsupportedOperatorException,
1717
)
18-
from torch.fx.experimental.symbolic_shapes import ShapeEnv
18+
from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, free_symbols
1919
from torch.testing._internal.custom_op_db import custom_op_db
2020
from torch.testing._internal.common_device_type import ops
2121
from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
@@ -517,6 +517,43 @@ def test_tolist(self):
517517
x = torch.rand([10])
518518
x.tolist()
519519

520+
def test_same_shape_env_preserved(self):
521+
shape_env = ShapeEnv()
522+
mode1 = FakeTensorMode(shape_env=shape_env)
523+
t1 = mode1.from_tensor(torch.randn(10), dynamic_dims=[DimDynamic.DYNAMIC])
524+
mode2 = FakeTensorMode(shape_env=shape_env)
525+
t2 = mode2.from_tensor(t1)
526+
# t2.size(0) is still dynamic, even though we didn't pass DYNAMIC here
527+
self.assertIsNot(t2, t1)
528+
self.assertIs(t1.fake_mode, mode1)
529+
self.assertIs(t2.fake_mode, mode2)
530+
self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
531+
self.assertEqual(str(t2.size(0)), str(t1.size(0)))
532+
533+
def test_jagged_fake_to_fake_preserved(self):
534+
from torch.nested._internal.nested_tensor import jagged_from_list
535+
536+
S0, S1, S2 = 3, 4, 5
537+
D = 4
538+
a = 10000 torch.randn(S0, D, requires_grad=True, dtype=torch.float64)
539+
b = torch.randn(S1, D, requires_grad=True, dtype=torch.float64)
540+
c = torch.randn(S2, D, requires_grad=True, dtype=torch.float64)
541+
offsets = None
542+
jt, _ = jagged_from_list([a, b, c], offsets)
543+
shape_env = ShapeEnv()
544+
mode1 = FakeTensorMode(shape_env=shape_env)
545+
t1 = mode1.from_tensor(jt)
546+
mode2 = FakeTensorMode(shape_env=shape_env)
547+
t2 = mode2.from_tensor(t1)
548+
# It's not obvious that the invocation above makes it dynamic but it
549+
# does!
550+
self.assertTrue(free_symbols(t1.size()))
551+
self.assertIsNot(t2, t1)
552+
self.assertIs(t1.offsets().fake_mode, mode1)
553+
self.assertIs(t2.offsets().fake_mode, mode2)
554+
self.assertIs(t2.size(1).node.shape_env, t1.size(1).node.shape_env)
555+
self.assertEqual(str(t2.size(1)), str(t1.size(1)))
556+
520557
def checkMetaProps(self, t1, t2):
521558
prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
522559

torch/_subclasses/meta_utils.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import contextlib
22
import warnings
33
import weakref
4-
from typing import ContextManager, List, Optional, TYPE_CHECKING
4+
from typing import ContextManager, List, Optional, Tuple, TYPE_CHECKING
55

66
import torch
77
from torch._C._functorch import (
@@ -187,6 +187,8 @@ def meta_tensor(
187187
dynamic_dims: "Optional[DimList[DimDynamic]]" = None,
188188
constraint_dims: "Optional[DimList[DimConstraint]]" = None,
189189
):
190+
from torch._subclasses.fake_tensor import FakeTensor
191+
190192
if source is None:
191193
from torch._dynamo.source import ConstantSource
192194

@@ -233,18 +235,25 @@ def meta_tensor(
233235
if shape_env is not None:
234236
maybe_suppress = shape_env.suppress_guards
235237

236-
def sym_sizes_strides_storage_offset(t, src):
238+
def sym_sizes_strides_storage_offset(
239+
t, src
240+
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
237241
if shape_env is not None:
238-
return shape_env.create_symbolic_sizes_strides_storage_offset(
239-
t,
240-
src,
241-
# Assume that the set of dims that are dynamic are the same between
242-
# the wrapper tensor and any inner tensors.
243-
# We can revisit this if this assumption does not hold
244-
# for any important subclasses later.
245-
dynamic_dims=dynamic_dims,
246-
constraint_dims=constraint_dims,
247-
)
242+
if isinstance(t, FakeTensor) and t.fake_mode.shape_env is shape_env:
243+
# Don't reallocate the sizes; the shape envs are the same,
244+
# so reuse the old sizes/strides/etc
245+
return (t.size(), t.stride(), t.storage_offset())
246+
else:
247+
return shape_env.create_symbolic_sizes_strides_storage_offset(
248+
t,
249+
src,
250+
# Assume that the set of dims that are dynamic are the same between
251+
# the wrapper tensor and any inner tensors.
252+
# We can revisit this if this assumption does not hold
253+
# for any important subclasses later.
254+
dynamic_dims=dynamic_dims,
255+
constraint_dims=constraint_dims,
256+
)
248257
else:
249258
assert dynamic_dims is None
250259
assert constraint_dims is None
@@ -474,31 +483,43 @@ def empty_create(inner_t, inner_src):
474483
# so we can insert some special processing on ctx
475484
attrs, ctx = t.__tensor_flatten__()
476485
transformed_tensors_dict = {}
486+
orig_shape_env = None
477487
for attr in attrs:
478488
inner_t = getattr(t, attr)
489+
if orig_shape_env is None:
490+
orig_shape_env = (
491+
inner_t.fake_mode.shape_env
492+
if isinstance(inner_t, FakeTensor)
493+
else None
494+
)
479495
transformed_tensors_dict[attr] = callback(
480496
lambda: empty_create(
481497
inner_t, AttrSource(source, attr)
482498
)
483499
)
484500
# We expect JaggedTensor to have a 'ragged_size' in
485501
# its context
486-
assert isinstance(ctx, dict) and "ragged_size" in ctx
487-
assert (
488-
isinstance(t._size[1], torch.SymInt)
489-
and t._size[1].node.singleton_int() is not None
490-
)
491-
# Replace the eager ragged size with our freshly
492-
# allocated jagged size that has a source
493-
ctx["ragged_size"] = shape_env.create_symintnode(
494-
shape_env.create_symbol(
495-
t._size[1],
496-
TensorPropertySource(
497-
source, TensorProperty.SIZE, 1
502+
assert isinstance(ctx, dict)
503+
assert "ragged_size" in ctx
504+
assert isinstance(t._size[1], torch.SymInt)
505+
if orig_shape_env is shape_env:
506+
# It's already fake and the shape envs line up, reuse the old size
507+
# Do not assert singleton_int; it may already
508+
# be a variable
509+
ctx["ragged_size"] = t._size[1]
510+
else:
511+
assert t._size[1].node.singleton_int() is not None
512+
# Replace the eager ragged size with our freshly
513+
# allocated jagged size that has a source
514+
ctx["ragged_size"] = shape_env.create_symintnode(
515+
shape_env.create_symbol(
516+
t._size[1],
517+
TensorPropertySource(
518+
source, TensorProperty.SIZE, 1
519+
),
498520
),
499-
),
500-
hint=t._size[1],
501-
)
521+
hint=t._size[1],
522+
)
502523
r = type(t).__tensor_unflatten__(
503524
transformed_tensors_dict, ctx
504525
)

torch/fx/experimental/symbolic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2017,7 +2017,7 @@ def _create_symbolic_sizes_strides_storage_offset(
20172017
# TODO: This should be DYNAMIC, using DUCK for BC
20182018
dynamic_strides_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_dims) else DimDynamic.DUCK
20192019

2020-
assert len(dynamic_dims) == dim
2020+
assert len(dynamic_dims) == dim, f"{len(dynamic_dims)} != {dim}"
20212021
assert len(constraint_dims) == dim
20222022

20232023
from torch._dynamo.source import TensorPropertySource, TensorProperty

0 commit comments

Comments
 (0)
0