8000 Update base for Update on "[Cutlass] Fix wrapper code generation brea… · pytorch/pytorch@03e9fb1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 03e9fb1

Browse files
committed
Update base for Update on "[Cutlass] Fix wrapper code generation breakage"
Fixes issues introduced by #159355 The issue got past OSS CI because the H100 tag wasn't added, not sure how to prevent these kinds of issues in the future, perhaps we should run H100 on Inductor PRs? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
2 parents d35b27d + 50eac81 commit 03e9fb1

File tree

19 files changed

+258
-130
lines changed

19 files changed

+258
-130
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
11ec6354315768a85da41032535e3b7b99c5f706
1+
f7888497a1eb9e98d4c07537f0d0bcfe180d1363

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
29ae4c76c026185f417a25e841d2cd5e65f087a3
1+
b6a5b82b9948b610fa4c304d0d869c82b8f17db1

.github/workflows/inductor-periodic.yml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,21 @@ jobs:
8181
sync-tag: rocm-build
8282
test-matrix: |
8383
{ include: [
84-
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
85-
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
86-
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
87-
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
88-
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
89-
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
90-
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
91-
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
92-
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
93-
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
94-
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
95-
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
96-
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
97-
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
98-
6D38 { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
84+
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
85+
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
86+
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
87+
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
88+
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
89+
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
90+
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
91+
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
92+
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
93+
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
94+
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
95+
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
96+
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
97+
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
98+
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
9999
]}
100100
secrets: inherit
101101

aten/src/ATen/native/ComparisonUtils.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,29 @@ static void _assert_match(const O& original, const C& compared, const std::strin
2424
}
2525
}
2626

27+
template<>
28+
void _assert_match<c10::Device, std::optional<c10::Device>>(
29+
const c10::Device& original,
30+
const std::optional<c10::Device>& compared,
31+
const std::string& name) {
32+
if (compared) {
33+
const c10::Device& expected = compared.value();
34+
if (original.type() != expected.type()) {
35+
std::stringstream msg;
36+
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
37+
throw std::runtime_error(msg.str());
38+
}
39+
40+
// If the expected device doesn't have an index (e.g., just "cuda"),
41+
// or if both devices have the same index, consider them equal
42+
if (expected.has_index() && original.has_index() && expected.index() != original.index()) {
43+
std::stringstream msg;
44+
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
45+
throw std::runtime_error(msg.str());
46+
}
47+
}
48+
}
49+
2750
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
2851
_assert_match(tensor.sym_sizes(), sizes, "sizes");
2952
_assert_match(tensor.sym_strides(), strides, "strides");

test/dynamo/test_guard_manager.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ def hook(guard_wrapper, f_locals, builder):
931931

932932
# Check types of foo.x
933933
foo_x_mgr = builder.get_guard_manager_from_source(foo_x_source)
934-
self.assertTrue(foo_x_mgr.is_guarded_value_dict())
934+
self.assertTrue(issubclass(foo_x_mgr.get_type_of_guarded_value(), dict))
935935

936936
# Check types of foo.x["a"]
937937
foo_x_a_source = DictGetItemSource(foo_x_source, "a")
@@ -946,12 +946,14 @@ def hook(guard_wrapper, f_locals, builder):
946946
# Check types of foo.z
947947
foo_z_source = AttrSource(foo_source, "z")
948948
foo_z_mgr = builder.get_guard_manager_from_source(foo_z_source)
949-
self.assertTrue(foo_z_mgr.is_guarded_value_empty_dict())
949+
self.assertTrue(issubclass(foo_z_mgr.get_type_of_guarded_value(), dict))
950950

951951
# Check types of mod
952952
mod_source = LocalSource("mod")
953953
mod_mgr = builder.get_guard_manager_from_source(mod_source)
954-
self.assertTrue(mod_mgr.is_guarded_value_nn_module())
954+
self.assertTrue(
955+
issubclass(mod_mgr.get_type_of_guarded_value(), torch.nn.Module)
956+
)
955957

956958
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
957959
with install_guard_manager_testing_hook(hook):
@@ -1006,6 +1008,12 @@ def hook(guard_wrapper, f_locals, builder):
10061008
from torch._dynamo.source import AttrSource, LocalSource
10071009

10081010
foo_source = LocalSource("foo")
1011+
foo_mgr = builder.get_guard_manager_from_source(foo_source)
1012+
for accessor in foo_mgr.get_accessors():
1013+
if isinstance(accessor, GetAttrGuardAccessor):
1014+
self.assertTrue(
1015+
accessor.get_attr_name() in ("a", "b", "c", "d", "e")
1016+
)
10091017

10101018
# Check types of foo.a
10111019
foo_a_source = AttrSource(foo_source, "a")

test/export/test_export.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
OutputSpec,
6060
TensorArgument,
6161
)
62+
from torch.export.passes import move_to_device_pass
6263
from torch.fx.experimental.proxy_tensor import make_fx
6364
from torch.fx.experimental.symbolic_shapes import ShapeEnv
6465
from torch.testing import FileCheck
@@ -15914,6 +15915,22 @@ def forward(self, x):
1591415915
len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs)
1591515916
)
1591615917

15918+
@requires_cuda
15919+
def test_assert_tensor_metadata_device_index(self):
15920+
class N(torch.nn.Module):
15921+
def __init__(self):
15922+
super().__init__()
15923+
15924+
def forward(self, x, y):
15925+
x = x.float()
15926+
y = y.float()
15927+
return x + y
15928+
15929+
inp = (torch.randn(3, device="cuda"), torch.randn(3, device="cuda"))
15930+
ep = export(N(), inp)
15931+
ep = move_to_device_pass(ep, {"cuda:0": "cuda"})
15932+
ep.module()(torch.randn(3, device="cuda:0"), torch.randn(3, device="cuda:0"))
15933+
1591715934
def test_input_output_no_stacktrace(self):
1591815935
class M(torch.nn.Module):
1591915936
def forward(self, x):

test/inductor/test_compiled_autograd.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torch._dynamo.testing import normalize_gm
3030
from torch._dynamo.utils import counters
3131
from torch._inductor import config as inductor_config
32+
from torch._inductor.cpp_builder import is_msvc_cl
3233
from torch._inductor.test_case import run_tests, TestCase
3334
from torch.nn.attention.flex_attention import flex_attention
3435
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -40,6 +41,7 @@
4041
from torch.testing._internal.common_utils import (
4142
instantiate_parametrized_tests,
4243
IS_S390X,
44+
IS_WINDOWS,
4345
parametrize,
4446
scoped_load_inline,
4547
skipIfWindows,
@@ -193,6 +195,18 @@ def model(i):
193195
for _ in range(3):
194196
self.run_as_subprocess(script)
195197

198+
def gen_cache_miss_log_prefix(self):
199+
if IS_WINDOWS:
200+
if is_msvc_cl():
201+
return "Cache miss due to new autograd node: struct "
202+
else:
203+
self.fail(
204+
"Compilers other than msvc have not yet been verified on Windows."
205+
)
206+
return ""
207+
else:
208+
return "Cache miss due to new autograd node: "
209+
196210
def test_reset(self):
197211
compiled_autograd.compiled_autograd_enabled = True
198212
torch._C._dynamo.compiled_autograd.set_autograd_compiler(lambda: None, True)
@@ -3146,7 +3160,7 @@ def test_logs(self):
31463160
self.assertEqual(counters["compiled_autograd"]["compiles"], 1)
31473161
assert "torch::autograd::AccumulateGrad (NodeCall" in logs.getvalue()
31483162
assert (
3149-
"Cache miss due to new autograd node: torch::autograd::GraphRoot"
3163+
self.gen_cache_miss_log_prefix() + "torch::autograd::GraphRoot"
31503164
not in logs.getvalue()
31513165
)
31523166

@@ -3353,7 +3367,6 @@ def fn(x, obj):
33533367
sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs)
33543368
)
33553369

3356-
@skipIfWindows(msg="AssertionError: Scalars are not equal!")
33573370
def test_verbose_logs_cpp(self):
33583371
torch._logging.set_logs(compiled_autograd_verbose=True)
33593372

@@ -3381,8 +3394,9 @@ def fn():
33813394
self.check_output_and_recompiles(fn)
33823395

33833396
patterns1 = [
3384-
r".*Cache miss due to new autograd node: torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), "
3385-
r"previous key sizes=\[\]\n",
3397+
r".*"
3398+
+ self.gen_cache_miss_log_prefix()
3399+
+ r"torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), previous key sizes=\[\]\n",
33863400
]
33873401

33883402
all_logs = logs.getvalue()
@@ -3420,7 +3434,8 @@ def test_verbose_logs_dynamic_shapes(self):
34203434

34213435
actual_logs = logs.getvalue()
34223436
expected_logs = [
3423-
"Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]",
3437+
self.gen_cache_miss_log_prefix()
3438+
+ "torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]",
34243439
]
34253440
for expected in expected_logs:
34263441
self.assertTrue(expected in actual_logs)
@@ -3451,7 +3466,7 @@ def fn():
34513466
fn()
34523467

34533468
unexpected_logs = [
3454-
"Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0)"
3469+
self.gen_cache_miss_log_prefix() + "torch::autograd::GraphRoot (NodeCall 0)"
34553470
]
34563471

34573472
self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0)

test/run_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ def __contains__(self, item):
182182
"dynamo/test_misc",
183183
"inductor/test_cpu_repro",
184184
"inductor/test_cpu_select_algorithm",
185-
"inductor/test_aot_inductor_arrayref",
186185
"inductor/test_torchinductor_codegen_dynamic_shapes",
187186
"lazy/test_meta_kernel",
188187
"onnx/test_utility_funs",
@@ -240,7 +239,6 @@ def __contains__(self, item):
240239
# some false errors
241240
"doctests",
242241
# new failures to investigate and fix
243-
"cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic",
244242
"test_tensorboard",
245243
# onnx + protobuf failure, see
246244
# https://github.com/protocolbuffers/protobuf/issues/22104

torch/_C/_dynamo/guards.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ class GetGenericDictGuardAccessor(GuardAccessor): ...
142142
class TypeDictGuardAccessor(GuardAccessor): ...
143143
class TypeMROGuardAccessor(GuardAccessor): ...
144144

145+
class GetAttrGuardAccessor(GuardAccessor):
146+
def get_attr_name(self) -> str: ...
147+
145148
def install_object_aliasing_guard(
146149
guard_managers: list[GuardManager],
147150
tensor_names: list[str],

torch/_dynamo/guards.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def find_tag_safe_roots(self):
355355
def visit_dict_manager(node):
356356
# Just recurse through the key and value dict managers and check if
357357
# all of them are tag safe nodes.
358-
assert node.is_guarded_value_dict()
358+
assert issubclass(node.get_type_of_guarded_value(), dict)
359359

360360
tag_safe_roots = []
361361
is_subtree_tag_safe = True
@@ -394,12 +394,12 @@ def visit_manager(node):
394394
# If the node guards a tensor, mark it tag safe only if there
395395
# are no accessors. Presence of accessors means presence of
396396
# symbolic shape guards.
397-
if node.is_guarded_value_tensor():
397+
if issubclass(node.get_type_of_guarded_value(), torch.Tensor):
398398
if node.has_no_accessors() and not node.has_object_aliasing_guard():
399399
node.mark_tag_safe()
400400
else:
401401
node.mark_tag_safe()
402-
elif node.is_guarded_value_dict():
402+
elif issubclass(node.get_type_of_guarded_value(), dict):
403403
accessors = node.get_accessors()
404404
child_mgrs = node.get_child_managers()
405405
is_subtree_tag_safe = all(
@@ -408,7 +408,7 @@ def visit_manager(node):
408408
)
409409
if is_subtree_tag_safe:
410410
node.mark_tag_safe()
411-
elif node.is_guarded_value_nn_module():
411+
elif issubclass(node.get_type_of_guarded_value(), torch.nn.Module):
412412
accessors = node.get_accessors()
413413
child_mgrs = node.get_child_managers()
414414
is_subtree_tag_safe = all(
@@ -434,7 +434,7 @@ def visit(node):
434434

435435
tag_safe_roots = visit(self.root)
436436
for node in tag_safe_roots:
437-
if node.is_guarded_value_nn_module():
437+
if issubclass(node.get_type_of_guarded_value(), torch.nn.Module):
438438
node.mark_tag_safe_root()
439439

440440
def populate_diff_guard_manager(self):
@@ -468,7 +468,7 @@ def get_manager_line(self, guard_manager, accessor_str=None):
468468
s = t + ": source=" + source
469469
if accessor_str:
470470
s += ", " + accessor_str
471-
s += f", type={guard_manager.type_of_guarded_value()}"
471+
s += f", type={guard_manager.get_type_of_guarded_value()}"
472472
s += f", tag_safe=({guard_manager.is_tag_safe()}, {guard_manager.is_tag_safe_root()})"
473473
return s
474474

0 commit comments

Comments
 (0)
0