8000 Revert "Constant folding for dynamic shape node (#129686)" · pytorch/pytorch@9df4bc6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9df4bc6

Browse files
Revert "Constant folding for dynamic shape node (#129686)"
This reverts commit b7d287f. Reverted #129686 on behalf of https://github.com/atalman due to Failing internally. Test: https://github.com/pytorch/ao/blob/main/test/prototype/mx_formats/test_mx_linear.py ([comment](#129686 (comment)))
1 parent 7cd48df commit 9df4bc6

File tree

9 files changed

+21
-191
lines changed

9 files changed

+21
-191
lines changed

benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ hf_Bert_large,pass,6
9494

9595

9696

97-
hf_BigBird,pass,6
97+
hf_BigBird,fail_to_run,3
9898

9999

100100

test/inductor/test_cpu_repro.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1522,7 +1522,7 @@ def fn(x, y):
15221522
def test_int_div(self):
15231523
def fn(x, y):
15241524
s3 = x.size(1)
1525-
a = torch.ones((1 + s3) // 2)
1525+
a = torch.zeros((1 + s3) // 2)
15261526
a += y
15271527
return a, s3
15281528

test/inductor/test_cudagraph_trees.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1154,7 +1154,7 @@ def foo(x):
11541154
for _ in range(3):
11551155
out = foo(inp)
11561156
node = self.curr_node()
1157-
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
1157+
self.assertEqual(len(list(node.path_live_weakrefs())), 2)
11581158

11591159
@torch.compile(mode="reduce-overhead")
11601160
def foo(x):

test/inductor/test_torchinductor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5153,7 +5153,7 @@ def test_unbacked_floordiv_simplify(self):
51535153
def fn(x, y):
51545154
z = y.item()
51555155
torch._check(z // 2 == 3)
5156-
return x + x.new_ones(z)
5156+
return x + x.new_zeros(z)
51575157

51585158
self.common(
51595159
fn,
@@ -11171,7 +11171,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1117111171
UniformValueConstantFolder(mod).run()
1117211172

1117311173
# there are a couple extra tensors created in `insertable_tensor_check`
11174-
self.assertTrue(max_live_tensors == 3)
11174+
self.assertTrue(max_live_tensors == 4)
1117511175

1117611176
# See https://github.com/pytorch/pytorch/issues/100348
1117711177
def test_inductor_detach_view(self):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,6 @@ def run(*ex, **kwargs):
323323
"test_list_clearing_dynamic_shapes": TestFailure(
324324
("cpu", "cuda", "xpu"), is_skip=True
325325
),
326-
"test_dropout_trivial_1_dynamic_shapes": TestFailure(
327-
("cpu", "cuda", "xpu"), is_skip=True
328-
),
329326
"test_dropout2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True),
330327
"test_dropout3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True),
331328
"test_masked_fill_promotion_dynamic_shapes": TestFailure(

test/inductor/test_torchinductor_dynamic_shapes.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
from torch._inductor.codegen.cpp import CppScheduling
1919
from torch._inductor.codegen.wrapper import WrapperCodeGen
2020
from torch._inductor.test_case import TestCase
21-
from torch._inductor.utils import run_and_get_code
2221
from torch._inductor.virtualized import V
23-
from torch.testing import FileCheck
2422
from torch.testing._internal.common_device_type import (
2523
instantiate_device_type_tests,
2624
onlyCPU,
@@ -148,51 +146,6 @@ def tearDown(self):
148146
TestCase.tearDown(self)
149147
torch._dynamo.reset()
150148

151-
def test_constant_fold_uniform_value_dynamic(self, device):
152-
def full_add_zero(x):
153-
a = torch.full(x.shape, 1, dtype=x.dtype, device=x.device)
154-
b = a - 1
155-
return x + b
156-
157-
def full_mul_one(x):
158-
a = torch.full(x.shape, -1, dtype=x.dtype, device=x.device)
159-
b = 2 + a
160-
return x * b
161-
162-
def full_view_op(x):
163-
a = torch.ones([1], dtype=x.dtype, device=x.device)
164-
a = a[:, None]
165-
return x * a
166-
167-
def full_mul_symint(x):
168-
a = torch.full(x.shape, -1, dtype=x.dtype, device=x.device)
169-
b = 2 + a
170-
return b * x.shape[0]
171-
172-
fns = (full_add_zero, full_mul_one, full_view_op)
173-
174-
x = torch.randn((2, 4), device=device)
175-
y = torch.randn((3, 4), device=device)
176-
177-
for dynamic in [False, True]:
178-
torch._dynamo.reset()
179-
for fn in fns:
180-
ref = fn(x)
181-
fn_c = torch.compile(fn, dynamic=dynamic)
182-
183-
actual, source_codes = run_and_get_code(fn_c, x)
184-
185-
if fn is not full_mul_symint:
186-
# due to constant folding, fn returns x directly.
187-
if device == "cpu":
188-
FileCheck().check_not("cpp_fused").run(source_codes[0])
189-
else:
190-
FileCheck().check_not("triton.jit").run(source_codes[0])
191-
192-
self.assertEqual(ref, actual)
193-
self.assertEqual(fn(x), fn_c(x))
194-
self.assertEqual(fn(y), fn_c(y))
195-
196149
def test_arange_dynamic(self, device):
197150
def fn(a):
198151
batch_size = a.numel()

test/test_torch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4591,10 +4591,9 @@ def test_multinomial_device_constrain(self, device):
45914591
# FIXME: move to test distributions
45924592
@deviceCountAtLeast(2)
45934593
@onlyCUDA
4594-
@skipIfTorchInductor("FIXME: error not thrown")
45954594
def test_multinomial_gpu_device_constrain(self, devices):
45964595
x = torch.empty(3, device=devices[0])
4597-
y = torch.empty(3, device=devices[1], dtype=torch.long)
4596+
y = torch.empty(3, device=devices[1])
45984597
self.assertRaisesRegex(
45994598
RuntimeError, "Expected all tensors to be on the same device",
46004599
lambda: torch.multinomial(x, 2, out=y))

torch/_inductor/constant_folding.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,6 @@ def __init__(
6060
# is the output
6161
self.user_to_last_uses = self.node_to_last_non_output_use()
6262

63-
def _support_dynamic_shape(self):
64-
# ConstantFolder not support dynamic shape now
65-
return False
66-
67-
def _deduce_value(self, node):
68-
return super().run_node(node)
69-
7063
def is_impure(self, node: torch.fx.node.Node):
7164
if (
7265
node.target == torch.ops.prims.convert_element_type.default
@@ -166,9 +159,7 @@ def set_env(arg):
166159
):
167160
return self.unknown_value
168161

169-
out = self._deduce_value(node)
170-
if out == self.unknown_value:
171-
return self.unknown_value
162+
out = super().run_node(node)
172163

173164
if node.op != "get_attr" and isinstance(out, torch.Tensor):
174165
if out.device.type == "meta":
@@ -203,13 +194,10 @@ def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> Non
203194
self.node_replacements[node] = tensor
204195

205196
def run(self):
206-
env: Dict[torch.fx.Node, Any] = {}
207-
self.insert_placerholder_values(env)
208-
return super().run(initial_env=env)
209-
210-
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
197+
env = {}
211198
for n in self.module.graph.find_nodes(op="placeholder"):
212199
env[n] = self.unknown_value
200+
return super().run(initial_env=env)
213201

214202

215203
@torch.utils._python_dispatch._disable_current_modes()

torch/_inductor/fx_passes/joint_graph.py

Lines changed: 12 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
import logging
44
import typing
55
from collections import Counter
6-
from typing import Any, Dict, List, Set, Union
6+
from typing import Dict, List, Set, Union
77

88
import torch
99
import torch._guards
10-
import torch.utils._pytree as pytree
1110
from torch._inductor.constant_folding import ConstantFolder
12-
from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict
1311
from torch.fx.experimental.symbolic_shapes import statically_known_true
1412
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
1513
from torch.multiprocessing.reductions import StorageWeakRef
@@ -203,112 +201,21 @@ def __init__(self, gm, skip_constructors=False):
203201
# see: [constant folding refining of symints]
204202
self.node_replacements_shapes: Dict[torch.fx.Node, List[int]] = {}
205203

206-
# initialize symint -> node mapping so that we can
207-
# use symint nodes in full constructors
208-
self.symint_nodes = _SymHashingDict()
209-
for n in self.module.graph.nodes:
210-
if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
211-
self.symint_nodes[n.meta["val"]] = n
212-
213-
# reference from torch/_funtorch/partitioners.py:get_default_op_list
214-
self.view_op_packets = [
215-
aten.squeeze,
216-
aten.unsqueeze,
217-
aten.alias,
218-
aten.view,
219-
aten.slice,
220-
aten.t,
221-
prims.broadcast_in_dim,
222-
aten.expand,
223-
aten.as_strided,
224-
aten.permute,
225-
]
226-
227-
self.indexing_op_packets = {
228-
aten.slice,
229-
}
230-
231-
def _support_dynamic_shape(self):
232-
return True
233-
234204
def insertable_tensor_check(self, t: torch.Tensor) -> bool:
235-
return True
205+
# TODO - we could also Tensors which get replaced with arange here
206+
return (
207+
t.numel() != 0
208+
and bool((t == t.flatten()[0]).all())
209+
and torch._C._has_storage(t)
210+
and t.layout == torch.strided
211+
)
236212

237213
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
238214
self.node_replacements[node] = tensor.flatten()[0].item()
239-
self.node_replacements_shapes[node] = node.meta["val"].shape
240215
self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage())
241-
242-
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
243-
for n in self.module.graph.find_nodes(op="placeholder"):
244-
if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
245-
env[n] = n.meta["val"]
246-
else:
247-
env[n] = self.unknown_value
248-
249-
def _deduce_value(self, node: torch.fx.Node):
250-
# deduce value for full-like nodes
251-
# 1. for constructors, substitute value is a tensor of size [1]
252-
# 2. for view ops/indexing, substitute value is the same as the input
253-
# 3. for pointwise ops, run node to get the substitute value
254-
# 4. deal with some special ops
255-
# otherwise, stop deduce value and return unknown value
256-
257-
# TODO: cat, more indexing
258-
# TODO - do on cpu to avoid syncs
259-
260-
# single-elem attrs
261-
if node.op == "get_attr" or (
262-
node.op == "call_function"
263-
and node.target == torch.ops.aten.lift_fresh_copy.default
264-
):
265-
out = super(ConstantFolder, self).run_node(node)
266-
if isinstance(out, torch.Tensor) and out.numel() == 1:
267-
return out
268-
269-
# constructors ops
270-
if (
271-
node.op == "call_function"
272-
and node.target == aten.full.default
273-
and len(node.args) == 2
274-
):
275-
args, kwargs = self.fetch_args_kwargs_from_env(node)
276-
new_args = [[1], args[1]]
277-
return aten.full.default(*new_args, **node.kwargs)
278-
279-
# view ops, return input tensor, the first argument
280-
if hasattr(node.target, "overloadpacket") and (
281-
node.target.overloadpacket in self.view_op_packets
282-
or node.target.overloadpacket in self.indexing_op_packets
283-
):
284-
assert isinstance(node.args[0], torch.fx.Node)
285-
return self.env[node.args[0]]
286-
287-
# we don't want to return unknown value for symints so that we can
288-
# still constant fold through their use in constructors or views
289-
# if we see them in a pointwise node (e.g., tensor * symint)
290-
# we will bail
291-
if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt):
292-
return node.meta["val"]
293-
294-
# pointwise ops
295-
if isinstance(node.target, torch._ops.OpOverload) and (
296-
torch.Tag.pointwise in node.target.tags
297-
or node.target is torch.ops.aten.scalar_tensor.default
298-
):
299-
args, kwargs = self.fetch_args_kwargs_from_env(node)
300-
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
301-
302-
if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs):
303-
return self.unknown_value
304-
305-
# we run the ops with dim 1, so remove memory_format to avoid error
306-
kwargs = dict(kwargs)
307-
kwargs.pop("memory_format", None)
308-
309-
return node.target(*args, **kwargs)
310-
311-
return self.unknown_value
216+
shape = list(tensor.shape)
217+
assert all(type(dim) is int for dim in shape)
218+
self.node_replacements_shapes[node] = shape
312219

313220

314221
@torch.utils._python_dispatch._disable_current_modes()
@@ -373,24 +280,10 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule):
373280
):
374281
torch._check(runtime_size == compile_time_size)
375282

376-
# replace SymInt as Node before creating a new full node
377-
# e.g. (1, s0) -> (1, arg0_1)
378-
node_shape = node_replacements_shapes[node]
379-
if not all(
380-
not isinstance(s, torch.SymInt) or s in cf.symint_nodes
381-
for s in node_shape
382-
):
383-
continue
384-
385-
shapes = [
386-
cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s
387-
for s in node_replacements_shapes[node]
388-
]
389-
390283
# zeros and ones just get traced into full, so we insert those
391284
new_node = graph.call_function(
392285
aten.full.default,
393-
args=(shapes, value),
286+
args=(node_replacements_shapes[node], value),
394287
kwargs={
395288
"dtype": fake_tensor.dtype,
396289
"layout": torch.strided,

0 commit comments

Comments
 (0)
0