8000 Merge branch 'pytorch:main' into minjean/xpu_nested_layernorm · pytorch/pytorch@9c429c4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9c429c4

Browse files
authored
Merge branch 'pytorch:main' into minjean/xpu_nested_layernorm
2 parents 1a42c02 + d6d670a commit 9c429c4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+1091
-1351
lines changed

.ci/docker/build.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
#!/bin/bash
2+
# The purpose of this script is to:
3+
# 1. Extract the set of parameters to be used for a docker build based on the provided image name.
4+
# 2. Run docker build with the parameters found in step 1.
5+
# 3. Run the built image and print out the expected and actual versions of packages installed.
26

37
set -ex
48

.ci/docker/libtorch/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ case ${GPU_ARCH_TYPE} in
3939
BASE_TARGET=rocm
4040
DOCKER_TAG=rocm${GPU_ARCH_VERSION}
4141
GPU_IMAGE=rocm/dev-ubuntu-20.04:${GPU_ARCH_VERSION}-complete
42-
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx1102;gfx942"
42+
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
4343
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
4444
;;
4545
*)

.ci/docker/manywheel/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ case ${GPU_ARCH_TYPE} in
9797
DEVTOOLSET_VERSION="11"
9898
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
9999
fi
100-
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
100+
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
101101
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
102102
;;
103103
xpu)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ psutil
1616
pyyaml
1717
requests
1818
setuptools
19-
sympy==1.13.3
19+
sympy>=1.13.3
2020
types-dataclasses
2121
typing-extensions>=4.10.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,7 @@ def main():
11201120
"filelock",
11211121
"typing-extensions>=4.10.0",
11221122
'setuptools ; python_version >= "3.12"',
1123-
"sympy==1.13.3",
1123+
"sympy>=1.13.3",
11241124
"networkx",
11251125
"jinja2",
11261126
"fsspec",

test/distributed/tensor/test_attention.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from torch.distributed.tensor.parallel import parallelize_module
2222
from torch.nn.attention import sdpa_kernel, SDPBackend
2323
from torch.testing._internal.common_cuda import (
24-
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
2524
PLATFORM_SUPPORTS_FLASH_ATTENTION,
2625
PLATFORM_SUPPORTS_FUSED_ATTENTION,
2726
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
@@ -42,8 +41,6 @@
4241
backends.append(SDPBackend.FLASH_ATTENTION)
4342
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
4443
backends.append(SDPBackend.EFFICIENT_ATTENTION)
45-
if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
46-
backends.append(SDPBackend.CUDNN_ATTENTION)
4744

4845
rotater_enum_to_str = {
4946
_RotateMethod.ALL_GATHER: "allgather",
@@ -89,10 +86,6 @@ def _test_ring_attention_sdpa(
8986
rotater: _RotateMethod,
9087
test_forward_only: bool,
9188
) -> None:
92-
# TODO: DTensor does not support backward on SDPBackend.CUDNN_ATTENTION so far
93-
if not test_forward_only and backend == SDPBackend.CUDNN_ATTENTION:
94-
return
95-
9689
def fn_eval(fn, *args, **kwargs):
9790
if test_forward_only:
9891
with torch.no_grad():
@@ -116,10 +109,7 @@ def fn_eval(fn, *args, **kwargs):
116109
nheads = 8
117110
torch.manual_seed(10)
118111
dtype = (
119-
torch.bfloat16
120-
if backend == SDPBackend.FLASH_ATTENTION
121-
or backend == SDPBackend.CUDNN_ATTENTION
122-
else torch.float32
112+
torch.bfloat16 if backend == SDPBackend.FLASH_ATTENTION else torch.float32
123113
)
124114

125115
_cp_options.enable_load_balance = load_balance

test/dynamo/test_decorators.py

Lines changed: 151 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,36 @@ def fn(a):
204204
self.assertEqual(cnts.frame_count, 1)
205205
self.assertEqual(cnts.op_count, 5)
206206

207+
def test_allow_in_graph_no_id_reuse(self):
208+
cnts = torch._dynamo.testing.CompileCounter()
209+
210+
def do_allow_in_graph(x):
211+
return x + 1
212+
213+
torch._dynamo.allow_in_graph(do_allow_in_graph)
214+
del do_allow_in_graph
215+
216+
# `id(dont_allow_in_graph)` would likely match `id(do_allow_in_graph)`
217+
# We want to make sure Dynamo always trace through
218+
# `dont_allow_in_graph`, by checking for the explicit graph break.
219+
def dont_allow_in_graph(x):
220+
torch._dynamo.graph_break()
221+
return x + 1
222+
223+
@torch.compile(backend=cnts)
224+
def fn(a):
225+
x = torch.add(a, 1)
226+
x = torch.add(x, 1)
227+
x = dont_allow_in_graph(x)
228+
x = torch.add(x, 1)
229+
x = torch.add(x, 1)
230+
return x
231+
232+
fn(torch.randn(10))
233+
234+
# Check for graph break
235+
self.assertEqual(cnts.frame_count, 3)
236+
207237
def test_incorrect_usage_disallow_in_graph(self):
208238
with self.assertRaises(IncorrectUsage):
209239

@@ -441,6 +471,49 @@ def fn(x, y):
441471
res = opt_fn(x, y)
442472
self.assertEqual(ref, res)
443473

474+
def test_nonstrict_trace_pre_existing_register_constant_type_guard(self):
475+
class State:
476+
def __init__(self, n):
477+
self.n = n
478+
479+
def get_num(self):
480+
torch._dynamo.graph_break()
481+
return self.n
482+
483+
def __eq__(self, other):
484+
return isinstance(other, State) and self.n == other.n
485+
486+
def __hash__(self):
487+
return hash(self.n)
488+
489+
# Assume `State` is implemented in C, and the author didn't bother to
490+
# provide a pytree decomposition for it, and its instances are safe to
491+
# treat as a constant by `torch.compile`.
492+
torch.utils._pytree.register_constant(State)
493+
494+
@torch._dynamo.nonstrict_trace
495+
def trace_me(x, s):
496+
return x * s.get_num()
497+
498+
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
499+
500+
@torch.compile(fullgraph=True, backend=cnts)
501+
def fn(x, s):
502+
res = trace_me(x, s)
503+
return res
504+
505+
x = torch.ones(10)
506+
# Make sure recompilation didn't happen.
507+
self.assertEqual(cnts.frame_count, 0)
508+
fn(x, State(42))
509+
self.assertEqual(cnts.frame_count, 1)
510+
fn(x, State(42))
511+
self.assertEqual(cnts.frame_count, 1)
512+
513+
# Make sure recompilation did happen.
514+
fn(x, State(41))
515+
self.assertEqual(cnts.frame_count, 2)
516+
444517
def test_nonstrict_trace_tuple_and_sym_int_output(self):
445518
@torch._dynamo.nonstrict_trace
446519
def trace_me(x):
@@ -602,6 +675,7 @@ def fn(p):
602675
except torch._dynamo.exc.Unsupported as e:
603676
msg = """
604677
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
678+
* `torch.utils._pytree.register_constant`
605679
* `torch.utils._pytree.register_dataclass`
606680
* `torch.utils._pytree.register_pytree_node`
607681
""" # NOQA: B950
@@ -653,39 +727,104 @@ def fn(x, y):
653727
except torch._dynamo.exc.Unsupported as e:
654728
msg = """
655729
For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <DecoratorTests.test_nonstrict_trace_nested_custom_class_error.<locals>.Point>, please use one of the following to register the type with pytree:
730+
* `torch.utils._pytree.register_constant`
656731
* `torch.utils._pytree.register_dataclass`
657732
* `torch.utils._pytree.register_pytree_node`
658733
""" # NOQA: B950
659734
self.assertIn(msg, str(e))
660735

661-
def test_nonstrict_trace_pytree_register_constant_error(self):
736+
def test_nonstrict_newly_constructed_trace_register_constant_type_error(self):
737+
class State:
738+
def __init__(self, n):
739+
self.n = n
740+
741+
def get_num(self):
742+
torch._dynamo.graph_break()
743+
return self.n
744+
745+
def __eq__(self, other):
746+
return isinstance(other, State) and self.n == other.n
747+
748+
def __hash__(self):
749+
return hash(self.n)
750+
751+
# Assume `State` is implemented in C, and the author didn't bother to
752+
# provide a pytree decomposition for it, and its instances are safe to
753+
# treat as a constant by `torch.compile`.
754+
torch.utils._pytree.register_constant(State)
755+
756+
@torch._dynamo.nonstrict_trace
757+
def trace_me(x, s):
758+
return x * s.get_num()
759+
760+
@torch.compile(fullgraph=True, backend="aot_eager")
761+
def fn(x):
762+
s = State(10)
763+
res = trace_me(x, s)
764+
return res
765+
766+
try:
767+
x = torch.ones(10)
768+
fn(x)
769+
self.assertFalse(True) # must raise error before this
770+
except torch._dynamo.exc.Unsupported as e:
771+
msg = """
772+
You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <DecoratorTests.test_nonstrict_newly_constructed_trace_register_constant_type_error.<locals>.State>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region.
773+
774+
Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.
775+
""" # NOQA: B950
776+
self.assertIn(msg, str(e))
777+
778+
def test_nonstrict_trace_object_in_context_error(self):
662779
class Point:
663-
x: int
664-
y: int
780+
x: torch.Tensor
781+
y: torch.Tensor
665782

666783
def __init__(self, x, y):
667784
self.x = x
668785
self.y = y
669786

670-
torch.utils._pytree.register_constant(Point)
787+
class PointTensor:
788+
p: Point
789+
t: torch.Tensor
790+
791+
def __init__(self, p, t):
792+
self.p = p
793+
self.t = t
794+
795+
torch.utils._pytree.register_pytree_node(
796+
PointTensor,
797+
lambda pt: ((pt.t,), pt.p),
798+
lambda ts, p: PointTensor(p, ts[0]),
799+
)
671800

672801
@torch._dynamo.nonstrict_trace
673-
def trace_me(x, p):
802+
def trace_me(pt):
674803
torch._dynamo.graph_break()
675-
return x * p.x + p.y
804+
return pt.t + pt.p.x * pt.p.y
676805

677806
@torch.compile(fullgraph=True, backend="aot_eager")
678-
def fn(x, p):
679-
res = trace_me(x, p)
680-
return res + 1
807+
def fn(x, y):
808 C94A +
p = Point(x, y)
809+
t = x + y
810+
pt = PointTensor(p, t)
811+
res = trace_me(pt)
812+
return res
681813

682814
try:
683-
p = Point(3, 4)
684-
fn(torch.ones(10), p)
815+
x, y = torch.ones(10), torch.ones(1)
816+
fn(x, y)
685817
self.assertFalse(True) # must raise error before this
686818
except torch._dynamo.exc.Unsupported as e:
687819
msg = """
688-
This error is most likely due to a call to `nonstrict_trace`-ed function, where one of the argument contains object of a type that has been (or needs to be) `torch.utils._pytree.register_constant`-ed. We currently don't support that.
820+
You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <DecoratorTests.test_nonstrict_trace_object_in_context_error.<locals>.Point> into the context.
821+
822+
Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <DecoratorTests.test_nonstrict_trace_object_in_context_error.<locals>.Point>
823+
* `torch.utils._pytree.register_constant`
824+
* `torch.utils._pytree.register_dataclass`
825+
* `torch.utils._pytree.register_pytree_node`
826+
827+
If the above doesn't work, please subtmit an issue to GitHub.
689828
""" # NOQA: B950
690829
self.assertIn(msg, str(e))
691830

test/dynamo/test_flat_apply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def distance(a, b, norm):
2424
return (a.x - b.x).abs() + (a.y - b.y).abs()
2525

2626

27-
@dataclass
27+
@dataclass(frozen=True)
2828
class Norm:
2929
typ: str
3030

test/dynamo/test_repros.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4893,6 +4893,19 @@ def fn(x_weak, weight, y):
48934893
self.assertEqual(ref, res)
48944894
self.assertEqual(cnt.frame_count, 2)
48954895

4896+
def test_return_weakref(self):
4897+
def f(t):
4898+
t = t * 2
4899+
wr = weakref.ref(t)
4900+
return wr, t
4901+
4902+
ref_t = torch.randn(2, 2, requires_grad=True)
4903+
ref_y = f(ref_t)
4904+
4905+
t = ref_t.detach().clone().requires_grad_()
4906+
y = torch.compile(f, backend="eager", fullgraph=True)(t)
4907+
self.assertEqual(ref_y[0](), y[0]())
4908+
48964909
def test_weakref_del(self):
48974910
def fn(x_weak, y):
48984911
x = x_weak()

0 commit comments

Comments
 (0)
0