8000 [Break XPU][Inductor] Generalize device-bias code and fix test_graph_… · pytorch/pytorch@762724f · GitHub
[go: up one dir, main page]

Skip to content

Commit 762724f

Browse files
etafpytorchmergebot
authored andcommitted
[Break XPU][Inductor] Generalize device-bias code and fix test_graph_partition for XPU (#148178)
This PR generalized the device-bias code introduced by #147038 . And align the behavior between XPU and CUDA on add + mm + pointwise pattern (for XPU, from addmm + pointwise to mm + fused_add_pointwise) , which fix the failed test case `test_graph_partiton` on XPU. Pull Request resolved: #148178 Approved by: https://github.com/benjaminglass1, https://github.com/jansel, https://github.com/EikanWang ghstack dependencies: #148155
1 parent ab78bf5 commit 762724f

File tree

4 files changed

+10
-17
lines changed

4 files changed

+10
-17
lines changed

test/inductor/test_kernel_benchmark.py

-6
Original file line numberDiff line numberDiff line change
@@ -362,12 +362,6 @@ def f(a, b, c):
362362
# num_gb = (1000 * 1000 + 2 * 1000 * 1000 + 1000 * 1000) * 2/ 1e9
363363
# = 0.008
364364
num_gb = "0.008"
365-
if GPU_TYPE == "xpu":
366-
# In XPU backend, mm + add + add will be fused as admm + add
367-
# And CUDA prefer not fuse add + mm, please check in function
368-
# `should_prefer_unfused_addmm` in torch/_inductor/fx_passes/post_grad.py
369-
num_gb = "0.006"
370-
371365
self.check_bandwidth(compiled_module, num_gb)
372366

373367
def test_mm_slice_add_bandwidth_computation_2(self):

test/inductor/test_pattern_matcher.py

-1
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,6 @@ def fn(a, b):
11271127
self.assertIn("return (buf0, )", code[0])
11281128
self.assertNotIn("async_compile.cpp", code[0])
11291129

1130-
@expectedFailureXPU
11311130
def test_unfuse_bias_addmm(self):
11321131
args = [
11331132
torch.randn(20, device=GPU_TYPE),

test/inductor/test_torchinductor.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -14024,7 +14024,7 @@ def fn(x):
1402414024
return x.sin()
1402514025

1402614026
fn_c = torch.compile(fn)
14027-
x = torch.rand(16, device="cuda")
14027+
x = torch.rand(16, device=GPU_TYPE)
1402814028

1402914029
_, code = run_and_get_code(fn_c, x)
1403014030

@@ -14039,7 +14039,7 @@ def f(x, y):
1403914039
y1 = y + 1
1404014040
y_cpu = y1.cpu() + 1
1404114041
z = x @ y
14042-
return x1 + y1 + z + y_cpu.cuda()
14042+
return x1 + y1 + z + y_cpu.to(GPU_TYPE)
1404314043

1404414044
x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)]
1404514045
x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]]
@@ -14065,7 +14065,7 @@ def f(x, y):
1406514065
y1 = y + 1
1406614066
y_cpu = y1.cpu() + 1
1406714067
z = x @ y
14068-
return x1 + y1 + z + y_cpu.cuda()
14068+
return x1 + y1 + z + y_cpu.to(GPU_TYPE)
1406914069

1407014070
def g(x):
1407114071
return x + 1
@@ -14106,7 +14106,7 @@ def f(x, y):
1410614106
y1 = y + 1
1410714107
y_cpu = y1.cpu() + 1
1410814108
z = x @ y
14109-
return x1 + y1 + z + y_cpu.cuda()
14109+
return x1 + y1 + z + y_cpu.to(GPU_TYPE)
1411014110

1411114111
f_compiled = torch.compile(f)
1411214112
x, y = torch.ones(3, 3, device=self.device), torch.randn(
@@ -14128,7 +14128,7 @@ def f(x, y):
1412814128
y1 = y + 1
1412914129
y_cpu = y1.cpu() + 1
1413014130
z = x @ y
14131-
return x1 + y1 + z + y_cpu.cuda()
14131+
return x1 + y1 + z + y_cpu.to(GPU_TYPE)
1413214132

1413314133
f_compiled = torch.compile(f)
1413414134
x, y = torch.ones(3, 3, device=self.device), torch.randn(
@@ -14149,11 +14149,11 @@ def f(x, y):
1414914149
y1 = y + 1
1415014150
y_cpu = y1.cpu() + 1
1415114151
z = x1 + y1 + x @ y
14152-
u = (y_cpu.cuda() + 2) @ y + 3
14152+
u = (y_cpu.to(GPU_TYPE) + 2) @ y + 3
1415314153
u_cpu = u.cpu() + 2
14154-
return z + u_cpu.cuda()
14154+
return z + u_cpu.to(GPU_TYPE)
1415514155

14156-
x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)]
14156+
x, y = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(2)]
1415714157
x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]]
1415814158
eager_out = f(x, y)
1415914159

torch/_inductor/fx_passes/post_grad.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
register_graph_pattern,
4545
stable_topological_sort,
4646
)
47-
from ..utils import decode_device, get_gpu_type, is_pointwise_use
47+
from ..utils import decode_device, get_gpu_type, is_gpu, is_pointwise_use
4848
from ..virtualized import V
4949
from .b2b_gemm import B2B_GEMM_PASS
5050
from .ddp_fusion import fuse_ddp_communication
@@ -888,7 +888,7 @@ def view_to_reshape(gm):
888888

889889
def should_prefer_unfused_addmm(match):
890890
inp = match.kwargs["inp"]
891-
if not inp.meta["val"].is_cuda:
891+
if not is_gpu(inp.meta["val"].device.type):
892892
return False
893893

894894
output = match.output_node()

0 commit comments

Comments
 (0)
0