8000 xpu: support sycl with torch.utils.cpp_extension APIs by dvrogozh · Pull Request #132945 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

xpu: support sycl with torch.utils.cpp_extension APIs #132945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/cpp_extension.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ torch.utils.cpp_extension
.. currentmodule:: torch.utils.cpp_extension
.. autofunction:: CppExtension
.. autofunction:: CUDAExtension
.. autofunction:: SyclExtension
.. autofunction:: BuildExtension
.. autofunction:: load
.. autofunction:: load_inline
Expand Down
10 changes: 10 additions & 0 deletions test/cpp_extensions/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
CUDA_HOME,
CUDAExtension,
ROCM_HOME,
SyclExtension,
)


Expand Down Expand Up @@ -69,6 +70,15 @@
)
ext_modules.append(extension)

if torch.xpu.is_available() and USE_NINJA:
extension = SyclExtension(
"torch_test_cpp_extension.sycl",
["xpu_extension.sycl"],
extra_compile_args={"cxx": CXX_FLAGS, "sycl": ["-O2"]},
)
ext_modules.append(extension)


# todo(mkozuki): Figure out the root cause
if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
# malfet: One should not assume that PyTorch re-exports CUDA dependencies
Expand Down
63 changes: 63 additions & 0 deletions test/cpp_extensions/xpu_extension.sycl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.sycl is not documented by SYCL spec and Intel SYCL compiler implementation. For now, I think it is not proper time to deliver the usage to community. We are following up the feature with compiler team. @EikanWang Please correct me. BTW, it is a good example to show compiler team.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.sycl is not documented by SYCL spec and Intel SYCL compiler implementation.

Actually I used a documented feature to support files named with .sycl extension. Which is while this extension is not automatically recognized by the compiler, you can use -x <lang> option to say what's the type of the file which is being compiled. I used -x c++ file.sycl.

$ icpx --help | grep "\-x "
  -x <language>           Treat subsequent input files as having type <language>

I agree that we should follow up with dpc++ compiler asking for automated support of .sycl extension. I fill file issue for that tomorrow. But I believe we can proceed in a meanwhile with approach I described above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed intel/llvm#15015 with request for .sycl extension.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd prefer to leave the flexibility to the SYCL compiler community to provide the solution. If SYCL compiler community decides to use file extension to support this case, it is the freedom of the SYCL compiler community to decide which the file extension for SYCL source files should be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is summary of discussions with our compiler team and compiler community. At the moment they oppose introducing .sycl file extension into the compiler. They also encourage to deal with SYCL/C++ compilation differences on build system level using build system agreed custom file extensions or other methods to logically separate sources. This discussion needs to happen here for PyTorch. Similar discussion is ongoing around SYCL support in cmake.

Overall for the PyTorch cpp_extension feature we have 2 options to proceed:

  • Option 1. We agree on and introduce custom file extension for SYCL source in PyTorch. That's what this PR is currently doing. So, proposal is to adopt .sycl as a file extension specific to PyTorch ecosystem and further influence other communities to align on that.
  1. Option 2. As an alternative, we can introduce other logical separation for SYCL sources. In particular we can:
    • Have sycl_sources = [ ... ] variable to take sycl source in torch.utils.cpp_extension.load (this will be new a new addition, CUDA does not have that)
    • torch.utils.cpp_extension.load already has cuda_sources and this PR introduces sycl_sources
    • Have both sources = [...] and sycl_sources = [ ... ] variable on class SyclExtension (that will be difference vs. how CUDAExtension class is defined)

Currently PR follows Option 1. Please, let me know your opinions on the better option.

Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include <c10/xpu/XPUStream.h>
#include <torch/extension.h>
#include <sycl/sycl.hpp>

void sigmoid_add_kernel(const float* x,
const float* y,
float* output,
const int size,
const sycl::nd_item<3> &item_ct1) {
const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
item_ct1.get_local_id(2);
if (index < size) {
const float sigmoid_x = 1.0f / (1.0f + sycl::native::exp(-x[index]));
const float sigmoid_y = 1.0f / (1.0f + sycl::native::exp(-y[index]));
output[index] = sigmoid_x + sigmoid_y;
}
}

class SigmoidAddKernel {
public:
void operator()(const sycl::nd_item<3> &item_ct1) const {
sigmoid_add_kernel(x, y, output, size, item_ct1);
}
SigmoidAddKernel(const float* _x, const float* _y, float* _output, int _size):
x(_x),
y(_y),
output(_output),
size(_size)
{}
private:
const float* x;
const float* y;
float* output;
int size;
};

void sigmoid_add_xpu(const float* x, const float* y, float* output, int size) {
SigmoidAddKernel krn(x, y, output, size);
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;

sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
queue.submit([&](sycl::handler &cgh) {
cgh.parallel_for<SigmoidAddKernel>(
sycl::nd_range<3>(
sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
sycl::range<3>(1, 1, threads)),
krn);
});
}

torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
TORCH_CHECK(x.device().is_xpu(), "x must be a XPU tensor");
TORCH_CHECK(y.device().is_xpu(), "y must be a XPU tensor");
auto output = torch::zeros_like(x);
sigmoid_add_xpu(
x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
}
17 changes: 17 additions & 0 deletions test/test_cpp_extensions_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
IS_WINDOWS,
shell,
skipIfTorchDynamo,
TEST_XPU,
xfailIfTorchDynamo,
)

Expand Down Expand Up @@ -113,6 +114,22 @@ def test_mps_extension(self):

self.assertEqual(cpu_output, mps_output.to("cpu"))

@unittest.skipIf(not TEST_XPU, "XPU not found")
@unittest.skipIf(
os.getenv("USE_NINJA", "0") == "0",
"sycl extension requires ninja to build",
)
def test_sycl_extension(self):
import torch_test_cpp_extension.sycl as sycl_extension

x = torch.zeros(100, device="xpu", dtype=torch.float32)
y = torch.zeros(100, device="xpu", dtype=torch.float32)

z = sycl_extension.sigmoid_add(x, y).cpu()

# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))

@common.skipIfRocm
@unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
Expand Down
96 changes: 95 additions & 1 deletion test/test_cpp_extensions_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.testing._internal.common_utils as common
import torch.utils.cpp_extension
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
from torch.testing._internal.common_utils import gradcheck
from torch.testing._internal.common_utils import gradcheck, TEST_XPU
from torch.utils.cpp_extension import (
_TORCH_PATH,
check_compiler_is_gcc,
Expand Down Expand Up @@ -116,6 +116,26 @@ def test_jit_cuda_extension(self):
# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))

@unittest.skipIf(not (TEST_XPU), "XPU not found")
def test_jit_xpu_extension(self):
# NOTE: The name of the extension must equal the name of the module.
module = torch.utils.cpp_extension.load(
name="torch_test_xpu_extension",
sources=[
"cpp_extensions/xpu_extension.sycl",
],
verbose=True,
keep_intermediates=False,
)

x = torch.zeros(100, device="xpu", dtype=torch.float32)
y = torch.zeros(100, device="xpu", dtype=torch.float32)

z = module.sigmoid_add(x, y).cpu()

# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))

@unittest.skipIf(not TEST_MPS, "MPS not found")
def test_mps_extension(self):
module = torch.utils.cpp_extension.load(
Expand Down Expand Up @@ -442,6 +462,80 @@ def test_inline_jit_compile_custom_op_cuda(self):
z = torch.ops.inline_jit_extension_custom_op_cuda.cos_add(x, y)
self.assertEqual(z, x.cos() + y.cos())

@unittest.skipIf(not TEST_XPU, "XPU not found")
def test_inline_jit_compile_extension_xpu(self):
sycl_source = """
#include <c10/xpu/XPUStream.h>

class CosAddKernel {
public:
void operator()(const sycl::nd_item<3> &item_ct1) const {
const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
item_ct1.get_local_id(2);
if (index < size) {
output[index] = cosf(x[index]) + cosf(y[index]);
}
}
CosAddKernel(const float* _x, const float* _y, float* _output, int _size):
x(_x),
y(_y),
output(_output),
size(_size)
{}
private:
const float* x;
const float* y;
float* output;
int size;
};

void cos_add_kernel(
const float* x,
const float* y,
float* output,
int size) {
CosAddKernel krn(x, y, output, size);
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;

sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
queue.submit([&](sycl::handler &cgh) {
cgh.parallel_for<CosAddKernel>(
sycl::nd_range<3>(
sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
sycl::range<3>(1, 1, threads)),
krn);
});
}

torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
auto output = torch::zeros_like(x);
const int threads = 1024;
const int blocks = (output.numel() + threads - 1) / threads;
cos_add_kernel(x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
return output;
}
"""

# Here, the C++ source need only declare the function signature.
cpp_source = "torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);"

module = torch.utils.cpp_extension.load_inline(
name="inline_jit_extension_xpu",
cpp_sources=cpp_source,
sycl_sources=sycl_source,
functions=["cos_add"],
verbose=True,
)

self.assertEqual(module.cos_add.__doc__.split("\n")[2], "cos_add")

x = torch.randn(4, 4, device="xpu", dtype=torch.float32)
y = torch.randn(4, 4, device="xpu", dtype=torch.float32)

z = module.cos_add(x, y)
self.assertEqual(z, x.cos() + y.cos())

def test_inline_jit_compile_extension_throws_when_functions_is_bad(self):
with self.assertRaises(ValueError):
torch.utils.cpp_extension.load_inline(
Expand Down
2 changes: 2 additions & 0 deletions torch/utils/_cpp_extension_versioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def bump_version_if_changed(self,
build_arguments,
build_directory,
with_cuda,
with_sycl,
is_python_module,
is_standalone):
hash_value = 0
hash_value = hash_source_files(hash_value, source_files)
hash_value = hash_build_arguments(hash_value, build_arguments)
hash_value = update_hash(hash_value, build_directory)
hash_value = update_hash(hash_value, with_cuda)
hash_value = update_hash(hash_value, with_sycl)
hash_value = update_hash(hash_value, is_python_module)
hash_value = update_hash(hash_value, is_standalone)

Expand Down
Loading
Loading
0