10000 Re-enable FakeTensor caching for SymInts · pytorch/pytorch@bbf0c41 · GitHub
[go: up one dir, main page]

Skip to content

Commit bbf0c41

Browse files
committed
Re-enable FakeTensor caching for SymInts
Summary: This backs out D60320595 which itself turned off FakeTensor caching when a SymInt was present. Tests seem to pass so I'm assuming some dynamic shape work fixed what was breaking previously. Test Plan: Reran the tests listed in T196779132 and they seem to pass. ghstack-source-id: 0d17332 Pull Request resolved: #152662
1 parent e38001e commit bbf0c41

File tree

4 files changed

+9
-17
lines changed

4 files changed

+9
-17
lines changed

aten/src/ATen/EmptyTensor.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
2828
opt_device_type = at::getAccelerator(false);
2929
}
3030
if (opt_device_type.has_value()) {
31-
return at::globalContext().getPinnedMemoryAllocator(
32-
opt_device_type.value());
31+
return at::globalContext().getPinnedMemoryAllocator(opt_device_type);
3332
} else {
3433
TORCH_CHECK(
3534
false, "Need to provide pin_memory allocator to use pin memory.")

test/dynamo/test_subclasses.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2535,9 +2535,9 @@ def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f
25352535
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
25362536
25372537
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
2538-
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
2538+
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
25392539
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
2540-
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
2540+
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
25412541
""", # noqa: B950
25422542
)
25432543

@@ -2591,9 +2591,9 @@ def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f
25912591
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
25922592
25932593
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
2594-
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
2594+
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
25952595
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
2596-
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
2596+
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
25972597
""", # noqa: B950
25982598
)
25992599

torch/_subclasses/fake_tensor.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1624,10 +1624,6 @@ def _prep_args_for_hash(
16241624
raise _BypassDispatchCache("constant attribute")
16251625
if is_sparse_any(arg):
16261626
raise _BypassDispatchCache(f"{arg.layout} tensor")
1627-
# FIXME: For now back out caching when there are symbolic nbytes
1628-
# - this doesn't seem to play nice with set(). See T196779132 for examples.
1629-
if isinstance(arg.untyped_storage().nbytes(), SymInt):
1630-
raise _BypassDispatchCache("symbolic nbytes")
16311627
metadata = extract_tensor_metadata(arg)
16321628
metadata._flatten_into(result, self, state)
16331629
elif isinstance(arg, Tensor):
@@ -1929,11 +1925,7 @@ def _output_from_cache_entry(
19291925
if entry.is_output_tuple:
19301926
outputs = [
19311927
self._get_output_tensor_from_cache_entry(
1932-
state,
1933-
output_info,
1934-
key,
1935-
func,
1936-
args,
1928< D4F7 /code>+
state, output_info, key, func, args
19371929
)
19381930
for output_info in entry.output_infos
19391931
]

torch/fx/experimental/symbolic_shapes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6895,7 +6895,8 @@ def _evaluate_expr(
68956895
):
68966896
return orig_expr
68976897

6898-
# Don't track this one
6898+
# Don't track this one. (Because this cache is inside this function the
6899+
# cache only lasts for the invocation of this function call)
68996900
@functools.lru_cache(None)
69006901
def compute_concrete_val() -> sympy.Basic:
69016902
if hint is None:
@@ -7100,7 +7101,7 @@ def compute_concrete_val() -> sympy.Basic:
71007101
insts, frame.f_lasti, key=lambda x: x.offset
71017102
)
71027103
else:
7103-
# For Pyhton <= 3.10, instructions are always 2 bytes.
7104+
# For Python <= 3.10, instructions are always 2 bytes.
71047105
cur = frame.f_lasti // 2
71057106

71067107
if sys.version_info >= (3, 13):

0 commit comments

Comments
 (0)
0