8000 xpu: support sycl with torch.utils.cpp_extension APIs (#132945) · pytorch/pytorch@d27ecf8 · GitHub
[go: up one dir, main page]

Skip to content

Commit d27ecf8

Browse files
dvrogozhmalfet
authored andcommitted
xpu: support sycl with torch.utils.cpp_extension APIs (#132945)
This patch adds support for sycl kernels build via `torch.utils.cpp_extension.load`, `torch.utils.cpp_extension.load_inline` and (new) `class SyclExtension` APIs. Files having `.sycl` extension are considered to have sycl kernels and are compiled with `icpx` (dpc++ sycl compiler from Intel). Files with other extensions, `.cpp`, `.cu`, are handled as before. API supports building sycl along with other file types into single extension. Note that `.sycl` file extension is a PyTorch convention for files containing sycl code which I propose to adopt. We did follow up with compiler team to introduce such file extension in the compiler, but they are opposed to this. At the same time discussion around sycl file extension and adding sycl language support into such tools as cmake is ongoing. Eventually cmake also considers to introduce some file extension convention for sycl. I hope we can further influence cmake and compiler communities to broader adopt `.sycl` file extension. By default SYCL kernels are compiled for all Intel GPU devices for which pytorch native aten SYCL kernels are compiled. At the moment `pvc,xe-lpg`. This behavior can be overridden by setting `TORCH_XPU_ARCH_LIST` environment variables to the comma separated list of desired devices to compile for. Fixes: #132944 CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5 Pull Request resolved: #132945 Approved by: https://github.com/albanD, https://github.com/guangyey, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
1 parent dd5d0ea commit d27ecf8

File tree

7 files changed

+521
-29
lines changed

7 files changed

+521
-29
lines changed

docs/source/cpp_extension.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ torch.utils.cpp_extension
44
.. currentmodule:: torch.utils.cpp_extension
55
.. autofunction:: CppExtension
66
.. autofunction:: CUDAExtension
7+
.. autofunction:: SyclExtension
78
.. autofunction:: BuildExtension
89
.. autofunction:: load
910
.. autofunction:: load_inline

test/cpp_extensions/setup.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
CUDA_HOME,
1212
CUDAExtension,
1313
ROCM_HOME,
14+
SyclExtension,
1415
)
1516

1617

@@ -69,6 +70,15 @@
6970
)
7071
ext_modules.append(extension)
7172

73+
if torch.xpu.is_available() and USE_NINJA:
74+
extension = SyclExtension(
75+
"torch_test_cpp_extension.sycl",
76+
["xpu_extension.sycl"],
77+
extra_compile_args={"cxx": CXX_FLAGS, "sycl": ["-O2"]},
78+
)
79+
ext_modules.append(extension)
80+
81+
7282
# todo(mkozuki): Figure out the root cause
7383
if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
7484
# malfet: One should not assume that PyTorch re-exports CUDA dependencies
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include <c10/xpu/XPUStream.h>
2+
#include <torch/extension.h>
3+
#include <sycl/sycl.hpp>
4+
5+
void sigmoid_add_kernel(const float* x,
6+
const float* y,
7+
float* output,
8+
const int size,
9+
const sycl::nd_item<3> &item_ct1) {
10+
const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
11+
item_ct1.get_local_id(2);
12+
if (index < size) {
13+
const float sigmoid_x = 1.0f / (1.0f + sycl::native::exp(-x[index]));
14+
const float sigmoid_y = 1.0f / (1.0f + sycl::native::exp(-y[index]));
15+
output[index] = sigmoid_x + sigmoid_y;
16+
}
17+
}
18+
19+
class SigmoidAddKernel {
20+
public:
21+
void operator()(const sycl::nd_item<3> &item_ct1) const {
22+
sigmoid_add_kernel(x, y, output, size, item_ct1);
23+
}
24+
SigmoidAddKernel(const float* _x, const float* _y, float* _output, int _size):
25+
x(_x),
26+
y(_y),
27+
output(_output),
28+
size(_size)
29+
{}
30+
private:
31+
const float* x;
32+
const float* y;
33+
float* output;
34+
int size;
35+
};
36+
37+
void sigmoid_add_xpu(const float* x, const float* y, float* output, int size) {
38+
SigmoidAddKernel krn(x, y, output, size);
39+
const int threads = 1024;
40+
const int blocks = (size + threads - 1) / threads;
41+
42+
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
43+
queue.submit([&](sycl::handler &cgh) {
44+
cgh.parallel_for<SigmoidAddKernel>(
45+
sycl::nd_range<3>(
46+
sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
47+
sycl::range<3>(1, 1, threads)),
48+
krn);
49+
});
50+
}
51+
52+
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
53+
TORCH_CHECK(x.device().is_xpu(), "x must be a XPU tensor");
54+
TORCH_CHECK(y.device().is_xpu(), "y must be a XPU tensor");
55+
auto output = torch::zeros_like(x);
56+
sigmoid_add_xpu(
57+
x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
58+
return output;
59+
}
60+
61+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
62+
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
63+
}

test/test_cpp_extensions_aot.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
IS_WINDOWS,
1919
shell,
2020
skipIfTorchDynamo,
21+
TEST_XPU,
2122
xfailIfTorchDynamo,
2223
)
2324

@@ -113,6 +114,22 @@ def test_mps_extension(self):
113114

114115
self.assertEqual(cpu_output, mps_output.to("cpu"))
115116

117+
@unittest.skipIf(not TEST_XPU, "XPU not found")
118+
@unittest.skipIf(
119+
os.getenv("USE_NINJA", "0") == "0",
120+
"sycl extension requires ninja to build",
121+
)
122+
def test_sycl_extension(self):
123+
import torch_test_cpp_extension.sycl as sycl_extension
124+
125+
x = torch.zeros(100, device="xpu", dtype=torch.float32)
126+
y = torch.zeros(100, device="xpu", dtype=torch.float32)
127+
128+
z = sycl_extension.sigmoid_add(x, y).cpu()
129+
130+
# 2 * sigmoid(0) = 2 * 0.5 = 1
131+
self.assertEqual(z, torch.ones_like(z))
132+
116133
@common.skipIfRocm
117134
@unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
118135
@unittest.skipIf(not TEST_CUDA, "CUDA not found")

test/test_cpp_extensions_jit.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.testing._internal.common_utils as common
1818
import torch.utils.cpp_extension
1919
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
20-
from torch.testing._internal.common_utils import gradcheck
20+
from torch.testing._internal.common_utils import gradcheck, TEST_XPU
2121
from torch.utils.cpp_extension import (
2222
_TORCH_PATH,
2323
check_compiler_is_gcc,
@@ -116,6 +116,26 @@ def test_jit_cuda_extension(self):
116116
# 2 * sigmoid(0) = 2 * 0.5 = 1
117117
self.assertEqual(z, torch.ones_like(z))
118118

119+
@unittest.skipIf(not (TEST_XPU), "XPU not found")
120+
def test_jit_xpu_extension(self):
121+
# NOTE: The name of the extension must equal the name of the module.
122+
module = torch.utils.cpp_extension.load(
123+
name="torch_test_xpu_extension",
124+
sources=[
125+
"cpp_extensions/xpu_extension.sycl",
126+
],
127+
verbose=True,
128+
keep_intermediates=False,
129+
)
130+
131+
x = torch.zeros(100, device="xpu", dtype=torch.float32)
132+
y = torch.zeros(100, device="xpu", dtype=torch.float32)
133+
134+
z = module.sigmoid_add(x, y).cpu()
135+
136+
# 2 * sigmoid(0) = 2 * 0.5 = 1
137+
self.assertEqual(z, torch.ones_like(z))
138+
119139
@unittest.skipIf(not TEST_MPS, "MPS not found")
120140
def test_mps_extension(self):
121141
module = torch.utils.cpp_extension.load(
@@ -442,6 +462,80 @@ def test_inline_jit_compile_custom_op_cuda(self):
442462
z = torch.ops.inline_jit_extension_custom_op_cuda.cos_add(x, y)
443463
self.assertEqual(z, x.cos() + y.cos())
444464

465+
@unittest.skipIf(not TEST_XPU, "XPU not found")
466+
def test_inline_jit_compile_extension_xpu(self):
467+
sycl_source = """
468+
#include <c10/xpu/XPUStream.h>
469+
470+
class CosAddKernel {
471+
public:
472+
void operator()(const sycl::nd_item<3> &item_ct1) const {
473+
const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
474+
item_ct1.get_local_id(2);
475+
if (index < size) {
476+
output[index] = cosf(x[index]) + cosf(y[index]);
477+
}
478+
}
479+
CosAddKernel(const float* _x, const float* _y, float* _output, int _size):
480+
x(_x),
481+
y(_y),
482+
output(_output),
483+
size(_size)
484+
{}
485+
private:
486+
const float* x;
487+
const float* y;
488+
float* output;
489+
int size;
490+
};
491+
492+
void cos_add_kernel(
493+
const float* x,
494+
const float* y,
495+
float* output,
496+
int size) {
497+
CosAddKernel krn(x, y, output, size);
498+
const int threads = 1024;
499+
const int blocks = (size + threads - 1) / threads;
500+
501+
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
502+
queue.submit([&](sycl::handler &cgh) {
503+
cgh.parallel_for<CosAddKernel>(
504+
sycl::nd_range<3>(
505+
sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
506+
sycl::range<3>(1, 1, threads)),
507+
krn);
508+
});
509+
}
510+
511+
torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
512+
auto output = torch::zeros_like(x);
513+
const int threads = 1024;
514+
const int blocks = (output.numel() + threads - 1) / threads;
515+
cos_add_kernel(x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
516+
return output;
517+
}
518+
"""
519+
520+
# Here, the C++ source need only declare the function signature.
521+
cpp_source = "torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);"
522+
523+
module = torch.utils.cpp_extension.load_inline(
524+
name="inline_jit_extension_xpu",
525+
cpp_sources=cpp_source,
526+
sycl_sources=sycl_source,
527+
functions=["cos_add"],
528+
verbose=True,
529+
)
530+
531+
self.assertEqual(module.cos_add.__doc__.split("\n")[2], "cos_add")
532+
533+
x = torch.randn(4, 4, device="xpu", dtype=torch.float32)
534+
y = torch.randn(4, 4, device="xpu", dtype=torch.float32)
535+
536+
z = module.cos_add(x, y)
537+
self.assertEqual(z, x.cos() + y.cos())
538+
445539
def test_inline_jit_compile_extension_throws_when_functions_is_bad(self):
446540
with self.assertRaises(ValueError):
447541
torch.utils.cpp_extension.load_inline(

torch/utils/_cpp_extension_versioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@ def bump_version_if_changed(self,
4040
build_arguments,
4141
build_directory,
4242
with_cuda,
43+
with_sycl,
4344
is_python_module,
4445
is_standalone):
4546
hash_value = 0
4647
hash_value = hash_source_files(hash_value, source_files)
4748
hash_value = hash_build_arguments(hash_value, build_arguments)
4849
hash_value = update_hash(hash_value, build_directory)
4950
hash_value = update_hash(hash_value, with_cuda)
51+
hash_value = update_hash(hash_value, with_sycl)
5052
hash_value = update_hash(hash_value, is_python_module)
5153
hash_value = update_hash(hash_value, is_standalone)
5254

0 commit comments

Comments
 (0)
0