10000 softmax: add device check for xpu with half_to_float by weishi-deng · Pull Request #150278 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

softmax: add device check for xpu with half_to_float #150278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
58f8b21
softmax: add device check for xpu with half_to_float
weishi-deng Mar 31, 2025
54bd692
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
a7435e5
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
b3b8a32
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
9f9854c
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
88c89c4
Merge branch 'main' into xpu-softmax
weishi-deng Apr 21, 2025
745f7e6
softmax: add device check for xpu with half_to_float
weishi-deng Mar 31, 2025
269cac2
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
8afc462
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
1011f94
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
c78ce25
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
508df2c
Merge branch 'xpu-softmax' of https://github.com/weishi-deng/pytorch …
weishi-deng May 7, 2025
16479ba
add ut
weishi-deng May 7, 2025
8f74e86
Merge branch 'pytorch:main' into xpu-softmax
weishi-deng May 9, 2025
df16602
Update test/xpu/test_softmax.py
weishi-deng May 9, 2025
76f7464
update unit test
weishi-deng May 9, 2025
0e0b150
update ut
weishi-deng May 9, 2025
e001a8b
Update test/test_xpu.py
guangyey May 9, 2025
321fdba
Update test/test_xpu.py
guangyey May 9, 2025
5e844f2
Update test/test_xpu.py
guangyey May 9, 2025
2c1136d
Update test/test_xpu.py
guangyey May 9, 2025
698e641
Update test/test_xpu.py
guangyey May 9, 2025
0f19bb4
Update test/test_xpu.py
guangyey May 9, 2025
a8f156a
Update test_xpu.py
guangyey May 9, 2025
71b8d22
Update test/test_xpu.py
guangyey May 9, 2025
6ac21b2
Update test_xpu.py
guangyey May 11, 2025
653b161
Merge branch 'pytorch:main' into xpu-softmax
weishi-deng May 12, 2025
883fc0d
softmax: add device check for xpu with half_to_float
weishi-deng Mar 31, 2025
4d28339
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
3895903
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
ff7ea33
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
ded3cee
Update aten/src/ATen/native/SoftMax.cpp
weishi-deng Apr 8, 2025
af88dd9
add ut
weishi-deng May 7, 2025
53c8b95
Update test/xpu/test_softmax.py
weishi-deng May 9, 2025
a9b352b
update unit test
weishi-deng May 9, 2025
0cf9205
update ut
weishi-deng May 9, 2025
820b210
Update test/test_xpu.py
guangyey May 9, 2025
e208d9d
Update test/test_xpu.py
guangyey May 9, 2025
164a3b0
Update test/test_xpu.py
guangyey May 9, 2025
0f17fbd
Update test/test_xpu.py
guangyey May 9, 2025
0b40c23
Update test/test_xpu.py
guangyey May 9, 2025
dc6f6ca
Update test/test_xpu.py
guangyey May 9, 2025
7c322c6
Update test_xpu.py
guangyey May 9, 2025
5ba0f46
Update test/test_xpu.py
guangyey May 9, 2025
3bae7b7
Update test_xpu.py
guangyey May 11, 2025
400db32
Merge branch 'xpu-softmax' of https://github.com/weishi-deng/pytorch …
weishi-deng May 12, 2025
7d1b23d
update ut
weishi-deng May 14, 2025
e83021d
Merge branch 'pytorch:main' into xpu-softmax
weishi-deng May 14, 2025
6bc08b1
Merge branch 'pytorch:main' into xpu-softmax
weishi-deng May 15, 2025
0e5f1a3
Merge branch 'pytorch:main' into xpu-softmax
weishi-deng May 19, 2025
0de4614
Merge branch 'pytorch:main' into xpu-softmax
weishi-deng Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aten/src/ATen/native/SoftMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ TORCH_IMPL_FUNC(log_softmax_backward_cpu_out) (
Tensor softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
auto result = [&]() {
NoNamesGuard guard;
if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){
if ((input_.is_cuda() || input_.is_xpu()) && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weishi-deng , what's the impact? Could you elaborate on the motivation? Does this change lead to any performance improvement? Or does it intend to fix failed cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weishi-deng , what's the impact? Could you elaborate on the motivation? Does this change lead to any performance improvement? Or does it intend to fix failed cases?

This pr makes the softmax op support half input and float output without casting. This functionality has been added in torch-xpu-ops, so we need to add the device option here to enable it. Otherwise, it needs to cast the input before we do softmax, as lines 417-418 say.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weishi-deng , as we discussed, it should not be a feature improvement. The change is a performance improvement as the torch-xpu-ops has provided an optimal implementation to fuse the data type cast and softmax. So, pls. share the performance improvement data in the PR description.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weishi-deng , as we discussed, it should not be a feature improvement. The change is a performance improvement as the torch-xpu-ops has provided an optimal implementation to fuse the data type cast and softmax. So, pls. share the performance improvement data in the PR description.

Hi @EikanWang, The performance data is updated in the PR description. Pls review and comment.

return at::_softmax(input_, dim_, true);
} else {
Tensor converted = dtype.has_value() ? input_.toType(dtype.value()) : input_;
Expand All @@ -428,7 +428,7 @@ Tensor& softmax_out(
std::optional<ScalarType> dtype,
Tensor& output_) {
Tensor output_temp;
if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half &&
if ((input_.is_cuda() || input_.is_xpu()) && input_.scalar_type() == ScalarType::Half &&
dtype == ScalarType::Float) {
if (!output_.is_contiguous()) {
auto options =
Expand Down Expand Up @@ -467,7 +467,7 @@ Tensor special_softmax(const Tensor& input_, const int64_t dim_, std::optional<S
Tensor log_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
auto result = [&]() {
NoNamesGuard guard;
if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){
if ((input_.is_cuda() || input_.is_xpu()) && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float) {
return at::_log_softmax(input_, dim_, true);
} else {
Tensor converted = dtype.has_value()? input_.toType(dtype.value()) : input_;
Expand All @@ -484,7 +484,7 @@ Tensor& log_softmax_out(
std::optional<ScalarType> dtype,
Tensor& output_) {
Tensor output_temp;
if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half &&
if ((input_.is_cuda() || input_.is_xpu()) && input_.scalar_type() == ScalarType::Half &&
dtype == ScalarType::Float) {
if (!output_.is_contiguous()) {
auto options =
Expand Down
136 changes: 85 additions & 51 deletions test/test_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torch.testing import make_tensor
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
onlyXPU,
OpDTypes,
ops,
skipXPUIf,
Expand Down Expand Up @@ -378,56 +378,6 @@ def test_generator(self):
torch.xpu.set_rng_state(g_state0)
self.assertEqual(2024, torch.xpu.initial_seed())

@onlyXPU
@suppress_warnings
@ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
def test_compare_cpu(self, device, dtype, op):
def to_cpu(arg):
if isinstance(arg, torch.Tensor):
return arg.to(device="cpu")
return arg

samples = op.reference_inputs(device, dtype)

for sample in samples:
cpu_sample = sample.transform(to_cpu)
xpu_results = op(sample.input, *sample.args, **sample.kwargs)
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)

xpu_results = sample.output_process_fn_grad(xpu_results)
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)

# Lower tolerance because we are running this as a `@slowTest`
# Don't want the periodic tests to fail frequently
self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)

@onlyXPU
@ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
def test_non_standard_bool_values(self, device, dtype, op):
# Test boolean values other than 0x00 and 0x01 (gh-54789)
def convert_boolean_tensors(x):
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
return x

# Map False -> 0 and True -> Random value in [2, 255]
true_vals = torch.randint(
2, 255, x.shape, dtype=torch.uint8, device=x.device
)
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
x_int = torch.where(x, true_vals, false_vals)

ret = x_int.view(torch.bool)
self.assertEqual(ret, x)
return ret

for sample in op.sample_inputs(device, dtype):
expect = op(sample.input, *sample.args, **sample.kwargs)

transformed = sample.transform(convert_boolean_tensors)
actual = op(transformed.input, *transformed.args, **transformed.kwargs)

self.assertEqual(expect, actual)

def test_serialization_array_with_storage(self):
x = torch.randn(5, 5).xpu()
y = torch.zeros(2, 5, dtype=torch.int, device="xpu")
Expand Down Expand Up @@ -747,5 +697,89 @@ def test_torch_config_for_xpu(self):
self.assertTrue(value.group(1) in ["OFF", "0"])


@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
class TestXpuOps(TestCase):
@dtypes(torch.float16)
def test_softmax_half_to_float(self, device, dtype):
shape = [
[8],
[7, 8],
[8192, 64],
[8192, 8192],
[7, 8, 512],
[7, 8, 11],
[16, 7, 8, 512],
[16, 7, 8, 512, 35],
[117, 7, 9, 513, 35],
]
output_type = torch.float
for i in range(len(shape)):
for j in range(len(shape[i])):
dim = j - 1
x = torch.randn(shape[i], dtype=dtype)
grad = torch.randn(shape[i]).to(output_type)
x_cpu = x.clone().requires_grad_()
y_cpu = torch.nn.functional.softmax(x_cpu, dim, dtype=output_type)
y_cpu.backward(grad.clone())

x_xpu = x.clone().to(device).requires_grad_()
y_xpu = torch.nn.functional.softmax(x_xpu, dim, dtype=output_type)
self.assertEqual(y_xpu.dtype, torch.float32)
y_xpu.backward(grad.clone().to(device))
self.assertEqual(y_cpu, y_xpu.cpu())
self.assertEqual(x_cpu.grad, x_xpu.grad.cpu())

@suppress_warnings
@ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
def test_compare_cpu(self, device, dtype, op):
def to_cpu(arg):
if isinstance(arg, torch.Tensor):
return arg.to(device="cpu")
return arg

samples = op.reference_inputs(device, dtype)

for sample in samples:
cpu_sample = sample.transform(to_cpu)
xpu_results = op(sample.input, *sample.args, **sample.kwargs)
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)

xpu_results = sample.output_process_fn_grad(xpu_results)
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)

# Lower tolerance because we are running this as a `@slowTest`
# Don't want the periodic tests to fail frequently
self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)

@ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
def test_non_standard_bool_values(self, device, dtype, op):
# Test boolean values other than 0x00 and 0x01 (gh-54789)
def convert_boolean_tensors(x):
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
return x

# Map False -> 0 and True -> Random value in [2, 255]
true_vals = torch.randint(
2, 255, x.shape, dtype=torch.uint8, device=x.device
)
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
x_int = torch.where(x, true_vals, false_vals)

ret = x_int.view(torch.bool)
self.assertEqual(ret, x)
return ret

for sample in op.sample_inputs(device, dtype):
expect = op(sample.input, *sample.args, **sample.kwargs)

transformed = sample.transform(convert_boolean_tensors)
actual = op(transformed.input, *transformed.args, **transformed.kwargs)

self.assertEqual(expect, actual)


instantiate_device_type_tests(TestXpuOps, globals(), only_for="xpu", allow_xpu=True)


if __name__ == "__main__":
run_tests()
Loading
0