8000 Enable XPU path for FlexAttention · pytorch/pytorch@0f4578c · GitHub
[go: up one dir, main page]

Skip to content

Commit 0f4578c

Browse files
committed
Enable XPU path for FlexAttention
1 parent 19d8bba commit 0f4578c

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

torch/_inductor/kernel/flex_attention.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,8 @@ class Mode(Enum):
687687
fwd = auto()
688688
bwd = auto()
689689

690+
def _get_xpu_config(query, mode: Mode) -> Tuple[int, int, int, int]:
691+
return (64, 64, 4, 3)
690692

691693
def _get_rocm_config(query, mode: Mode) -> Tuple[int, int, int, int]:
692694
dtype = query.get_dtype()
@@ -770,18 +772,29 @@ def _get_nv_config(query, mode: Mode) -> Tuple[int, int, int, int]:
770772

771773

772774
def _get_default_config_fwd(query) -> Tuple[int, int, int, int]:
773-
if torch.version.hip is None:
774-
return _get_nv_config(query, mode=Mode.fwd)
775+
device_type = query.device.type
776+
if device_type == "cuda":
777+
if torch.version.hip is None:
778+
return _get_nv_config(query, mode=Mode.fwd)
779+
else:
780+
return _get_rocm_config(query, mode=Mode.fwd)
781+
elif device_type == "xpu":
782+
return _get_xpu_config(query, mode=Mode.fwd)
775783
else:
776-
return _get_rocm_config(query, mode=Mode.fwd)
784+
raise NotImplementedError(f"Unsupported device type: {device_type}")
777785

778786

779787
def _get_default_config_bwd(query) -> Tuple[int, int, int, int]:
780-
if torch.version.hip is None:
781-
return _get_nv_config(query, mode=Mode.bwd)
788+
device_type = query.device.type
789+
if device_type == "cuda":
790+
if torch.version.hip is None:
791+
return _get_nv_config(query, mode=Mode.bwd)
792+
else:
793+
return _get_rocm_config(query, mode=Mode.bwd)
794+
elif device_type == "xpu":
795+
return _get_xpu_config(query, mode=Mode.bwd)
782796
else:
783-
return _get_rocm_config(query, mode=Mode.bwd)
784-
797+
raise NotImplementedError(f"Unsupported device type: {device_type}")
785798

786799
def create_num_blocks_fake_generator(sparse_indices):
787800
# The idea here is that we need to create a real tensor with real data

torch/_inductor/kernel/flex_decoding.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
298298

299299

300300
def get_split_k(B: int, H: int, Mk: int) -> int:
301-
num_SM = torch.cuda.get_device_properties("cuda").multi_processor_count
301+
if torch.xpu.is_available():
302+
num_SM = torch.xpu.get_device_properties("xpu").gpu_subslice_count
303+
else:
304+
num_SM = torch.cuda.get_device_properties("cuda").multi_processor_count
302305
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
303306
assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers"
304307
split_k = num_SM // bh * 2 # Each SM should at least get one block.
@@ -312,8 +315,10 @@ def get_split_k(B: int, H: int, Mk: int) -> int:
312315
def _get_decoding_default_config(key) -> Tuple[int, int, int]:
313316
dtype = key.get_dtype()
314317
head_dim = key.get_size()[-1]
315-
sm_version = torch.cuda.get_device_capability()
316318
default_config = (64, 2, 1)
319+
if key.get_device().type == "xpu":
320+
return default_config
321+
sm_version = torch.cuda.get_device_capability()
317322
if sm_version >= (9, 0):
318323
if head_dim > 128 and dtype == torch.float32:
319324
return default_config

torch/_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
239239
DispatchKey.BackendSelect,
240240
DispatchKey.AutocastCPU, # type: ignore[attr-defined]
241241
DispatchKey.AutocastCUDA, # type: ignore[attr-defined]
242+
DispatchKey.AutocastXPU,
242243
]
243244

244245

0 commit comments

Comments
 (0)
0