8000 Update · pytorch/pytorch@cc3eefe · GitHub
[go: up one dir, main page]

Skip to content

Commit cc3eefe

Browse files
committed
Update
[ghstack-poisoned]
2 parents bbad26b + a41c9d5 commit cc3eefe

36 files changed

+660
-358
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
319c8d7fd3551bac63429334509de2663aa43f57
1+
8148603e3f3a618acef447a73bdeec9b749a95fb

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,12 +1513,8 @@ static void addmm_impl_cpu_(
15131513
// that will call then into Arm® Compute Library (ACL) GEMM kernel and also
15141514
// additionally have support for running kernel with BF16 instructions
15151515
if (transpose_c) {
1516-
bool apply_heur =
1517-
apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
1518-
if (apply_heur && transpose_a && !transpose_b &&
1519-
(result.scalar_type() == at::ScalarType::Float ||
1520-
result.scalar_type() == at::ScalarType::BFloat16 ||
1521-
result.scalar_type() == at::ScalarType::Half)) {
1516+
bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
1517+
if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
15221518
try {
15231519
mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
15241520
// We have dispatched to ACL GEMM for single precision float

aten/src/ATen/native/mkldnn/Matmul.cpp

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -236,27 +236,15 @@ void mkldnn_matmul(
236236
"mkldnn_matmul: unsupported dims for mat and mat2");
237237

238238
#if defined(__aarch64__)
239-
// oneDNN fast-maths mode (enabled by setting the environment variable
240-
// ONEDNN_DEFAULT_FPMATH_MODE=BF16) will dispatch fp32 inputs to bf16 kernels
241-
// where HW permits. So, both fp32 and bf16 inputs are permitted.
242-
TORCH_CHECK(
243-
(mat1.scalar_type() == mat2.scalar_type()) &&
244-
(mat1.scalar_type() == result.scalar_type()) &&
245-
((mat1.scalar_type() == at::kFloat) ||
246-
(mat1.scalar_type() == at::kBFloat16) ||
247-
(mat1.scalar_type() == at::kHalf)),
248-
"mkldnn_matmul: only enabled for fp32, bf16 and fp16 path");
239+
// oneDNN fast-maths mode (enabled by setting the environment variable ONEDNN_DEFAULT_FPMATH_MODE=BF16) will dispatch
240+
// fp32 inputs to bf16 kernels where HW permits. So, both fp32 and bf16 inputs are permitted.
241+
TORCH_CHECK((mat1.scalar_type() == mat2.scalar_type()) && (mat1.scalar_type() == result.scalar_type()) &&
242+
((mat1.scalar_type() == at::kFloat) || (mat1.scalar_type() == at::kBFloat16)),
243+
"mkldnn_matmul: only enabled for fp32 and bf16 path");
249244
// device needs to support bf16 if the inputs are of bf16 type
250245
if (mat1.scalar_type() == at::kBFloat16) {
251-
TORCH_CHECK(
252-
mkldnn_bf16_device_check_arm(),
253-
"mkldnn_matmul: mkldnn_matmul bf16 path needs a cpu with bf16 support");
254-
}
255-
// device needs to support fp16 if the inputs are of fp16 type
256-
if (mat1.scalar_type() == at::kHalf) {
257-
TORCH_CHECK(
258-
mkldnn_fp16_device_check_arm(),
259-
"mkldnn_matmul: mkldnn_matmul fp16 path needs a cpu with fp16 support");
246+
TORCH_CHECK(mkldnn_bf16_device_check_arm(),
247+
"mkldnn_matmul: mkldnn_matmul bf16 path needs a cpu with bf16 support");
260248
}
261249
#else
262250
TORCH_CHECK(

aten/src/ATen/native/mkldnn/Utils.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,6 @@ inline bool mkldnn_bf16_device_check_arm() {
9090
return cpuinfo_initialize() && cpuinfo_has_arm_bf16();
9191
}
9292

93-
inline bool mkldnn_fp16_device_check_arm() {
94-
return cpuinfo_initialize() && cpuinfo_has_arm_neon_fp16();
95-
}
96-
9793
inline bool is_arm_neoverse() {
9894
return (cpuinfo_initialize() && cpuinfo_get_uarchs_count() == 1 &&
9995
(cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v1 ||
@@ -106,10 +102,6 @@ constexpr bool mkldnn_bf16_device_check_arm() {
106102
return false;
107103
}
108104

109-
inline bool mkldnn_fp16_device_check_arm() {
110-
return false;
111-
}
112-
113105
constexpr bool is_arm_neoverse() {
114106
return false;
115107
}
@@ -129,7 +121,7 @@ inline bool mkldnn_fp16_device_check() {
129121
#if defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))
130122
return ideep::has_fp16_type_support();
131123
#else
132-
return mkldnn_fp16_device_check_arm();
124+
return false;
133125
#endif
134126
}
135127

cmake/Dependencies.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -864,9 +864,9 @@ if(NOT Python_Interpreter_FOUND)
864864
message(FATAL_ERROR "Python3 could not be found.")
865865
endif()
866866

867-
if(${Python_VERSION} VERSION_LESS 3.8)
867+
if(${Python_VERSION} VERSION_LESS 3.9)
868868
message(FATAL_ERROR
869-
"Found Python libraries version ${Python_VERSION}. Python < 3.8 is no longer supported by PyTorch.")
869+
"Found Python libraries version ${Python_VERSION}. Python < 3.9 is no longer supported by PyTorch.")
870870
endif()
871871

872872
# ---[ Python + Numpy

test/dynamo/test_dicts.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,69 @@ def fn(x):
838838
d["e"] = 5
839839
self.assertEqual(d["e"], res["e"])
840840

841+
def test_mapping_proxy_existing(self):
842+
d = {"a": 2, "b": 3, "c": 5}
843+
844+
def fn(x, mp):
845+
y = torch.sin(x * mp["a"])
846+
for k, v in mp.items():
847+
y += torch.cos(x * v)
848+
return y
849+
850+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
851+
x = torch.randn(4)
852+
mp = types.MappingProxyType(d)
853+
ref = fn(x, mp)
854+
res = opt_fn(x, mp)
855+
self.assertEqual(ref, res)
856+
857+
d["a"] = 3
858+
ref = fn(x, mp)
859+
res = opt_fn(x, mp)
860+
self.assertEqual(ref, res)
861+
862+
d.pop("b")
863+
ref = fn(x, mp)
864+
res = opt_fn(x, mp)
865+
self.assertEqual(ref, res)
866+
867+
def test_mapping_proxy_existing_mutation(self):
868+
d = {"a": 2, "b": 3, "c": 5}
869+
870+
mp = types.MappingProxyType(d)
871+
872+
def fn(x):
873+
d["d"] = 4
874+
y = torch.sin(x * mp["d"])
875+
return y
876+
877+
opt_fn = torch.compile(fn, backend="eager")
878+
x = torch.randn(4)
879+
ref = torch.sin(x * 4)
880+
res = opt_fn(x)
881+
self.assertEqual(ref, res)
882+
self.assertEqual(d.keys(), mp.keys())
883+
884+
def test_mapping_proxy_existing_local_mutation(self):
885+
d = {"a": 2, "b": 3, "c": 5}
886+
887+
mp = types.MappingProxyType(d)
888+
889+
def fn(x):
890+
# Dynamo should not cause a graph break here because it knows that
891+
# the existing proxy cant point to this new dict
892+
other_dict = {}
893+
other_dict["d"] = 4
894+
y = torch.sin(x * mp["c"])
895+
return y
896+
897+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
898+
x = torch.randn(4)
899+
ref = torch.sin(x * mp["c"])
900+
res = opt_fn(x)
901+
self.assertEqual(ref, res)
902+
self.assertEqual(d.keys(), mp.keys())
903+
841904
def test_move_to_end(self):
842905
def fn(x):
843906
d = OrderedDict({"a": torch.cos(x), "b": 3, "c": 5})

test/export/test_export.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11953,6 +11953,34 @@ def forward(self, x):
1195311953
]
1195411954
self.assertEqual(len(shift_op), 1)
1195511955

11956+
@unittest.skipIf(IS_MACOS, "Distributed not packaged in macos")
11957+
def test_distributed_all_reduce(self):
11958+
class Foo(torch.nn.Module):
11959+
def __init__(self):
11960+
super().__init__()
11961+
self.linear = torch.nn.Linear(4, 3)
11962+
11963+
def forward(self, x):
11964+
y = self.linear(x).abs().clamp(max=1.0) * 2
11965+
torch.distributed.all_reduce(y)
11966+
return y
11967+
11968+
try:
11969+
torch.distributed.init_process_group(
11970+
backend="fake",
11971+
world_size=2,
11972+
rank=0,
11973+
store=FakeStore(),
11974+
)
11975+
11976+
m = Foo()
11977+
ep = export(m, (torch.randn(4, 4),))
11978+
inp = (torch.randn(4, 4),)
11979+
self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp)))
11980+
11981+
finally:
11982+
torch.distributed.destroy_process_group()
11983+
1195611984

1195711985
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
1195811986
class TestOneOffModelExportResult(TestCase):
@@ -12524,39 +12552,6 @@ def forward(self, x):
1252412552
ep.graph_module.code
1252512553
)
1252612554

12527-
@unittest.skipIf(IS_MACOS, "Distributed not packaged in macos")
12528-
@testing.expectedFailureSerDerNonStrict # nonstrict doesn't support allreduce
12529-
@testing.expectedFailureNonStrict
12530-
@testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure
12531-
@testing.expectedFailureRetraceabilityNonStrict
12532-
@testing.expectedFailureLegacyExportNonStrict
12533-
def test_distributed_all_reduce(self):
12534-
class Foo(torch.nn.Module):
12535-
def __init__(self):
12536-
super().__init__()
12537-
self.linear = torch.nn.Linear(4, 3)
12538-
12539-
def forward(self, x):
12540-
y = self.linear(x).abs().clamp(max=1.0) * 2
12541-
torch.distributed.all_reduce(y)
12542-
return y
12543-
12544-
try:
12545-
torch.distributed.init_process_group(
12546-
backend="fake",
12547-
world_size=2,
12548-
rank=0,
12549-
store=FakeStore(),
12550-
)
12551-
12552-
m = Foo()
12553-
ep = export(m, (torch.randn(4, 4),))
12554-
inp = (torch.randn(4, 4),)
12555-
self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp)))
12556-
12557-
finally:
12558-
torch.distributed.destroy_process_group()
12559-
1256012555
def test_preserve_cia_op(self):
1256112556
class StaticResizeBilinear2dModule(torch.nn.Module):
1256212557
def forward(self, x):

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def cal_conv_generated_kernel_number(mod, input, dtype, dim=4):
117117
):
118118
input_kernel = 1
119119
if output.is_contiguous(memory_format=torch.contiguous_format) or (
120-
TEST_ACL and (dtype == torch.bfloat16 or dtype == torch.half)
120+
TEST_ACL and dtype == torch.bfloat16
121121
):
122122
output_kernel = 1
123123
return input_kernel + output_kernel

test/inductor/test_torchinductor_strided_blocks.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.testing._internal.inductor_utils import (
2121
GPU_TYPE,
2222
HAS_GPU,
23+
requires_gpu,
2324
skip_windows_ci,
2425
TRITON_HAS_CPU,
2526
)
@@ -895,6 +896,34 @@ def func(x, y):
895896
)
896897
self.assertTrue("Min" not in code[0])
897898

899+
@requires_gpu() # FIXME this test failed on Triton-CPU
900+
def test_3d_permute_tiling(self):
901+
"""
902+
Test 3D tiling with permute.
903+
"""
904+
905+
def foo(x, y, z):
906+
dims = [0, 2, 1]
907+
a = x.permute(dims=dims) + y
908+
b = (z + y).permute(dims=dims)
909+
return a + b
910+
911+
inps = (torch.rand((51, 51, 51), device=self.device, dtype=torch. 10000 float32),) * 3
912+
result, (code,) = run_and_compare(
913+
self,
914+
foo,
915+
*inps,
916+
expected_num_triton_kernels=1,
917+
expected_num_block_pointers=3,
918+
config_patches={
919+
"triton.max_tiles": 3,
920+
"triton.prefer_nd_tiling": True,
921+
},
922+
)
923+
924+
# Check for 3D tiling
925+
self.assertIn("ZBLOCK", code)
926+
898927

899928
@unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend")
900929
@config.patch(cpu_backend="triton")

torch/_C/_dynamo/guards.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class GuardManager:
104104
def add_torch_function_mode_stack_guard(
105105
self, initial_stack, verbose_code_parts: list[str]
106106
) -> None: ...
107+
def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ...
107108

108109
class RootGuardManager(GuardManager):
109110
def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ...

torch/_dynamo/guards.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,6 +1808,16 @@ def WEAKREF_ALIVE(self, guard):
18081808
get_verbose_code_parts(code, guard)
18091809
)
18101810

1811+
def MAPPING_KEYS_CHECK(self, guard):
1812+
"""Guard on the key order of types.MappingProxyType object"""
1813+
ref = self.arg_ref(guard)
1814+
value = self.get(guard.name)
1815+
1816+
code = []
1817+
code.append(f"list({ref}.keys()) == {list(value.keys())}")
1818+
self._set_guard_export_info(guard, code)
1819+
self.get_guard_manager(guard).add_mapping_keys_guard(value, code)
1820+
18111821
def DICT_KEYS_MATCH(self, guard):
18121822
"""Insert guard to check that the keys of a dict are same"""
18131823
ref = self.arg_ref(guard)

torch/_dynamo/side_effects.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def __init__(
9494
self.keepalive = keepalive or []
9595
self.save_for_backward = save_for_backward or []
9696
self.tensor_hooks = tensor_hooks or {}
97+
# Used by MappingProxyVariable to graph break in case of any mutated
98+
# dict
99+
self._has_existing_dict_mutation = False
97100
# Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph.
98101
# Only applicable if this graph is created from Dynamo tracing in Compiled Autograd.
99102
self.ca_final_callbacks_var = None
@@ -536,6 +539,15 @@ def mutation(self, var):
536539
self.check_allowed_side_effect(var)
537540
if isinstance(var.mutation_type, ValueMutationExisting):
538541
var.mutation_type.is_modified = True
542+
if (
543+
var.source
544+
and isinstance(var, variables.ConstDictVariable)
545+
and not isinstance(var, variables.SetVariable)
546+
):
547+
self._has_existing_dict_mutation = True
548+
549+
def has_existing_dict_mutation(self):
550+
return self._has_existing_dict_mutation
539551

540552
def _get_modified_vars(self):
541553
return [var for var in self.id_to_variable.values() if self.is_modified(var)]

torch/_dynamo/variables/builder.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@
158158
DefaultDictVariable,
159159
DictKeySetVariable,
160160
FrozensetVariable,
161+
MappingProxyVariable,
161162
SetVariable,
162163
)
163164
from .distributed import (
@@ -472,6 +473,7 @@ def _type_dispatch_impl(cls, trace_numpy):
472473
(weakref.ReferenceType, cls.wrap_weakref),
473474
(torch.utils.hooks.RemovableHandle, cls.wrap_removable_handle),
474475
(torch.jit.ScriptFunction, cls.wrap_jit_function),
476+
(types.MappingProxyType, cls.wrap_mapping_proxy),
475477
]
476478

477479
if trace_numpy and np:
@@ -507,6 +509,32 @@ def wrap_jit_function(self, value):
507509
value, "_torchdynamo_inline", source=self.source
508510
)
509511

512+
def wrap_mapping_proxy(self, value):
513+
self.install_guards(GuardBuilder.TYPE_MATCH)
514+
# This might be suboptimal compared to dict guards. But mappingproxy is
515+
# not very common, so its ok to guard on all keys.
516+
self.install_guards(GuardBuilder.MAPPING_KEYS_CHECK)
517+
all_const = all(ConstantVariable.is_literal(k) for k in value.keys())
518+
519+
if not all_const:
520+
unimplemented("mapping proxy type supports only const keys")
521+
522+
def build_key_value(k, v):
523+
key = ConstantVariable.create(k)
524+
source_key = k
525+
526+
source_value = GetItemSource(self.get_source(), source_key)
527+
value = LazyVariableTracker.create(v, source_value)
528+
529+
return key, value
530+
531+
items = dict(build_key_value(k, v) for k, v in value.items())
532+
533+
# Create a dict_vt to be used in the mapping proxy variable
534+
dict_vt = ConstDictVariable(items, source=None)
535+
result = MappingProxyVariable(dict_vt, source=self.source)
536+
return self.tx.output.side_effects.track_mutable(value, result)
537+
510538
@classmethod
511539
@functools.lru_cache(None)
512540
def _id_dispatch(

0 commit comments

Comments
 (0)
0