8000 Add architecture to XPU device property (#138186) · fmo-mt/pytorch@d53fef0 · GitHub
[go: up one dir, main page]

Skip to content
< 8000 script crossorigin="anonymous" defer="defer" type="application/javascript" src="https://github.githubassets.com/assets/sessions-1e75b15ae60a.js">

Commit d53fef0

Browse files
guangyeyfmo-mtauthored andcommitted
Add architecture to XPU device property (pytorch#138186)
# Motivation Add `architecture` to XPU device property. In some cases, low-level application code can use special features or do specific optimizations depending on the device architecture, and this PR enables such applications. Modified from https://github.com/pytorch/pytorch/pull/129675/files Pull Request resolved: pytorch#138186 Approved by: https://github.com/ezyang
1 parent 6c555c3 commit d53fef0

File tree

5 files changed

+47
-4
lines changed

5 files changed

+47
-4
lines changed

c10/xpu/XPUDeviceProp.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ namespace c10::xpu {
152152
* device. */ \
153153
_(subgroup_2d_block_io)
154154

155+
#define AT_FORALL_XPU_EXP_DEVICE_PROPERTIES(_) \
156+
/* the device architecture of this SYCL device. */ \
157+
_(architecture)
158+
155159
#define _DEFINE_SYCL_PROP(ns, property, member) \
156160
ns::property::return_type member;
157161

@@ -166,6 +170,10 @@ namespace c10::xpu {
166170

167171
#define DEFINE_DEVICE_ASPECT(member) bool has_##member;
168172

173+
#define DEFINE_EXP_DEVICE_PROP(property) \
174+
_DEFINE_SYCL_PROP( \
175+
sycl::ext::oneapi::experimental::info::device, property, property)
176+
169177
struct C10_XPU_API DeviceProp {
170178
AT_FORALL_XPU_DEVICE_PROPERTIES(DEFINE_DEVICE_PROP);
171179

@@ -177,12 +185,17 @@ struct C10_XPU_API DeviceProp {
177185
AT_FORALL_XPU_DEVICE_ASPECT(DEFINE_DEVICE_ASPECT);
178186

179187
AT_FORALL_XPU_EXP_CL_ASPECT(DEFINE_DEVICE_ASPECT);
188+
189+
#if SYCL_COMPILER_VERSION >= 20250000
190+
AT_FORALL_XPU_EXP_DEVICE_PROPERTIES(DEFINE_EXP_DEVICE_PROP);
191+
#endif
180192
};
181193

182194
#undef _DEFINE_SYCL_PROP
183195
#undef DEFINE_DEVICE_PROP
184196
#undef DEFINE_PLATFORM_PROP
185197
#undef DEFINE_EXT_DEVICE_PROP
186198
#undef DEFINE_DEVICE_ASPECT
199+
#undef DEFINE_EXP_DEVICE_PROP
187200

188201
} // namespace c10::xpu

c10/xpu/XPUFunctions.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ void initDeviceProperties(DeviceProp* device_prop, DeviceIndex device) {
9898
device_prop->has_##member = raw_device.ext_oneapi_supports_cl_extension( \
9999
"cl_intel_" #member, &cl_version);
100100

101+
#define ASSIGN_EXP_DEVICE_PROP(property) \
102+
device_prop->property = \
103+
raw_device.get_info<oneapi::experimental::info::device::property>();
104+
101105
AT_FORALL_XPU_DEVICE_PROPERTIES(ASSIGN_DEVICE_PROP);
102106

103107
device_prop->platform_name =
@@ -110,6 +114,11 @@ void initDeviceProperties(DeviceProp* device_prop, DeviceIndex device) {
110114
// TODO: Remove cl_version since it is unnecessary.
111115
sycl::ext::oneapi::experimental::cl_version cl_version;
112116
AT_FORALL_XPU_EXP_CL_ASPECT(ASSIGN_EXP_CL_ASPECT);
117+
118+
#if SYCL_COMPILER_VERSION >= 20250000
119+
AT_FORALL_XPU_EXP_DEVICE_PROPERTIES(ASSIGN_EXP_DEVICE_PROP);
120+
#endif
121+
113122
return;
114123
}
115124

test/test_xpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def test_get_device_properties(self):
127127
device_properties.has_subgroup_2d_block_io,
128128
device_capability["has_subgroup_2d_block_io"],
129129
)
130+
if int(torch.version.xpu) >= 20250000:
131+
self.assertEqual(
132+
device_properties.architecture,
133+
device_capability["architecture"],
134+
)
130135

131136
def test_wrong_xpu_fork(self):
132137
stderr = TestCase.runWithPytorchAPIUsageStderr(

torch/_C/__init__.pyi.in

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2127,16 +2127,21 @@ class _XpuDeviceProperties:
21272127
vendor: str
21282128
driver_version: str
21292129
version: str
2130-
total_memory: _int
21312130
max_compute_units: _int
21322131
gpu_eu_count: _int
2133-
gpu_subslice_count: _int
21342132
max_work_group_size: _int
21352133
max_num_sub_groups: _int
21362134
sub_group_sizes: List[_int]
21372135
has_fp16: _bool
21382136
has_fp64: _bool
21392137
has_atomic64: _bool
2138+
has_bfloat16_conversions: _bool
2139+
has_subgroup_matrix_multiply_accumulate: _bool
2140+
has_subgroup_matrix_multiply_accumulate_tensor_float32: _bool
2141+
has_subgroup_2d_block_io: _bool
2142+
total_memory: _int
2143+
gpu_subslice_count: _int
2144+
architecture: _int
21402145
type: str
21412146

21422147
# Defined in torch/csrc/xpu/Stream.cpp

torch/csrc/xpu/Module.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,11 @@ static void registerXpuDeviceProperties(PyObject* module) {
306306
auto gpu_subslice_count = [](const DeviceProp& prop) {
307307
return (prop.gpu_eu_count / prop.gpu_eu_count_per_subslice);
308308
};
309+
#if SYCL_COMPILER_VERSION >= 20250000
310+
auto get_device_architecture = [](const DeviceProp& prop) {
311+
return static_cast<int64_t>(prop.architecture);
312+
};
313+
#endif
309314
auto m = py::handle(module).cast<py::module>();
310315

311316
#define DEFINE_READONLY_MEMBER(member) \
@@ -334,6 +339,9 @@ static void registerXpuDeviceProperties(PyObject* module) {
334339
THXP_FORALL_DEVICE_PROPERTIES(DEFINE_READONLY_MEMBER)
335340
.def_readonly("total_memory", &DeviceProp::global_mem_size)
336341
.def_property_readonly("gpu_subslice_count", gpu_subslice_count)
342+
#if SYCL_COMPILER_VERSION >= 20250000
343+
.def_property_readonly("architecture", get_device_architecture)
344+
#endif
337345
.def_property_readonly("type", get_device_type)
338346
.def(
339347
"__repr__",
@@ -343,8 +351,11 @@ static void registerXpuDeviceProperties(PyObject* module) {
343351
<< "', platform_name='" << prop.platform_name << "', type='"
344352
<< get_device_type(prop) << "', driver_version='"
345353
<< prop.driver_version << "', total_memory="
346-
<< prop.global_mem_size / (1024ull * 1024)
347-
<< "MB, max_compute_units=" << prop.max_compute_units
354+
<< prop.global_mem_size / (1024ull * 1024) << "MB"
355+
#if SYCL_COMPILER_VERSION >= 20250000
356+
<< ", architecture=" << get_device_architecture(prop)
357+
#endif
358+
<< ", max_compute_units=" << prop.max_compute_units
348359
<< ", gpu_eu_count=" << prop.gpu_eu_count
349360
<< ", gpu_subslice_count=" << gpu_subslice_count(prop)
350361
<< ", max_work_group_size=" << prop.max_work_group_size

0 commit comments

Comments
 (0)
0