@@ -56,6 +56,13 @@ def get_src_dest_devices(self, case, device):
56
56
@skipCUDAIf (True , "Does not work for CUDA" )
57
57
@skipIfTorchDynamo ("Not a suitable test for TorchDynamo" )
58
58
@skipXLA
59
+ @parametrize (
60
+ "op" ,
61
+ [
62
+ "_lazy_clone" ,
63
+ "to" ,
64
+ ],
65
+ )
59
66
@parametrize ("materialize_first" , ("src" , "dest" ))
60
67
@parametrize (
61
68
"case" ,
@@ -67,7 +74,7 @@ def get_src_dest_devices(self, case, device):
67
74
"from_1_to_0" ,
68
75
],
69
76
)
70
- def test_interdevice_materialize (self , device , materialize_first , case ):
77
+ def test_interdevice_materialize (self , device , op , materialize_first , case ):
71
78
src_device , dest_device = self .get_src_dest_devices (case , device )
72
79
73
80
src_device_check = torch .empty (0 , device = src_device ).device
<
6D40
td data-grid-cell-id="diff-0c51b9bc6c147021067aadf7e429a512239d65dc279875559ac5112ee7cc6e0f-75-82-0" data-selected="false" role="gridcell" style="background-color:var(--bgColor-accent-muted, var(--color-accent-subtle));flex-grow:1" tabindex="-1" valign="top" class="focusable-grid-cell diff-hunk-cell left-side" colSpan="4">@@ -76,7 +83,15 @@ def test_interdevice_materialize(self, device, materialize_first, case):
76
83
77
84
a = torch .randn (10 , device = src_device , pin_memory = pin_memory )
78
85
orig_data_ptr = torch ._C ._data_address_resolve_unified (a )
79
- b = a ._lazy_clone (device = dest_device )
86
+
87
+ if op == "_lazy_clone" :
88
+ b = a ._lazy_clone (device = dest_device )
89
+ elif op == "to" :
90
+ if torch .device (device ).type != "mps" :
91
+ self .skipTest ("op='to' only runs if device='mps'" )
92
+ b = a .to (device = dest_device )
93
+ else :
94
+ raise AssertionError (f"op='{ op } ' not recognized" )
80
95
81
96
self .assertEqual (a .device , src_device_check )
82
97
self .assertEqual (b .device , dest_device_check )
@@ -146,6 +161,13 @@ def test_interdevice_materialize(self, device, materialize_first, case):
146
161
@skipCUDAIf (True , "Does not work for CUDA" )
147
162
@skipIfTorchDynamo ("Not a suitable test for TorchDynamo" )
148
163
@skipXLA
164
+ @parametrize (
165
+ "op" ,
166
+ [
167
+ "_lazy_clone" ,
168
+ "to" ,
169
+ ],
170
+ )
149
171
@parametrize (
150
172
"case" ,
151
173
[
@@ -156,7 +178,7 @@ def test_interdevice_materialize(self, device, materialize_first, case):
156
178
"from_1_to_0" ,
157
179
],
158
180
)
159
- def test_interdevice_read (self , device , case ):
181
+ def test_interdevice_read (self , device , op , case ):
160
182
src_device , dest_device = self .get_src_dest_devices (case , device )
161
183
162
184
src_device_check = torch .empty (0 , device = src_device ).device
@@ -168,7 +190,14 @@ def test_interdevice_read(self, device, case):
168
190
a .copy_ (orig_tensor )
169
191
170
192
orig_data_ptr = torch ._C ._data_address_resolve_unified (a )
171
- b = a ._lazy_clone (device = dest_device )
193
+ if op == "_lazy_clone" :
194
+ b = a ._lazy_clone (device = dest_device )
195
+ elif op == "to" :
196
+ if torch .device (device ).type != "mps" :
197
+ self .skipTest ("op='to' only runs if device='mps'" )
198
+ b = a .to (device = dest_device )
199
+ else :
200
+ raise AssertionError (f"op='{ op } ' not recognized" )
172
201
173
202
self .assertEqual (a .device , src_device_check )
174
203
self .assertEqual (b .device , dest_device_check )
0 commit comments