10000 Update on "[AOTI] Fix #140546 and support AOTI package load for Intel… · pytorch/pytorch@2abadda · GitHub
[go: up one dir, main page]

Skip to content

Commit 2abadda

Browse files
committed
Update on "[AOTI] Fix #140546 and support AOTI package load for Intel GPU."
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #140686 * __->__ #140664 * #140269 * #140268 * #135320 * #135318 * #139026 Fix #140546 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
2 parents cd6ca16 + 9779a9b commit 2abadda

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,8 +1268,16 @@ def false_fn(x):
12681268

12691269
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
12701270

1271-
input1 = (torch.ones(3, 3), torch.ones(5), torch.ones(3, 3))
1272-
input2 = (torch.ones(10, 3), torch.ones(6), torch.ones(10, 3))
1271+
input1 = (
1272+
torch.ones(3, 3, device=self.device),
1273+
torch.ones(5, device=self.device),
1274+
torch.ones(3, 3, device=self.device),
1275+
)
1276+
input2 = (
1277+
torch.ones(10, 3, device=self.device),
1278+
torch.ones(6, device=self.device),
1279+
torch.ones(10, 3, device=self.device),
1280+
)
12731281
inputs = (input1, input2)
12741282
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
12751283
self.check_model_with_multiple_inputs(
@@ -1395,6 +1403,9 @@ def forward(self, x):
13951403
self.check_model(M(self.device), (torch.randn(5, 5, device=self.device),))
13961404

13971405
def test_zero_grid_with_backed_symbols(self):
1406+
if self.device != GPU_TYPE:
1407+
raise unittest.SkipTest("requires GPU")
1408+
13981409
class Repro(torch.nn.Module):
13991410
def __init__(self) -> None:
14001411
super().__init__()
@@ -1417,7 +1428,7 @@ def forward(self, x, b):
14171428
example_inputs,
14181429
dynamic_shapes=dynamic_shapes,
14191430
)
1420-
aot_inductor_module = AOTIRunnerUtil.load(GPU_TYPE, so_path)
1431+
aot_inductor_module = AOTIRunnerUtil.load(self.device, so_path)
14211432
aot_inductor_module(*example_inputs)
14221433

14231434
# Re-run where dynamic dim size is 0.
@@ -1920,7 +1931,7 @@ def __init__(self) -> None:
19201931
def forward(self, x):
19211932
return torch.ops.aten.normal_functional.default(x)
19221933

1923-
self.check_model(Model(), (torch.empty(4, 1, 4, 4),))
1934+
self.check_model(Model(), (torch.empty(4, 1, 4, 4, device=self.device),))
19241935

19251936
def test_empty_graph(self):
19261937
class Model(torch.nn.Module):

torch/csrc/inductor/aoti_runtime/model.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,8 @@ inline void parse_device_str(
9999

100100
if (sm[1].str() == "cpu") {
101101
device_type = aoti_torch_device_type_cpu();
102-
#ifdef USE_CUDA
103102
} else if (sm[1].str() == "cuda") {
104103
device_type = aoti_torch_device_type_cuda();
105-
#endif
106104
#ifdef USE_XPU
107105
} else if (sm[1].str() == "xpu") {
108106
device_type = aoti_torch_device_type_xpu();

0 commit comments

Comments
 (0)
0