8000 Add device guard for xpu conv on multi device (#153345) · pytorch/pytorch@3bfe071 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3bfe071

Browse files
pytorchbotguangyey
andauthored
Add device guard for xpu conv on multi device (#153345)
Add device guard for xpu conv on multi device (#153067) # Motivation fixes #153022 The root cause is that the XPU backend registers the convolution op using `m.impl`, which bypasses the device guard logic typically added by the code generation system. This can lead to unexpected behavior if the current device isn't explicitly set. # Additional Context run the following script ```python import torch import torchvision.models as models torch.manual_seed(0) model = models.resnet50(weights="ResNet50_Weights.DEFAULT") model.eval() data = torch.rand(1, 3, 224, 224) device = torch.device('xpu:1') # 'xpu:0' model = model.to(device=device, dtype=torch.float16) data = data.to(device, dtype=torch.float16) with torch.no_grad(): ret = model(data) print(ret) print("Execution finished") ``` The output is ```bash -9.2102e-02, -7.7588e-01, -1.4111e+00, -9.2383e-01, 6.4551e-01, -6.0730e-03, -7.8271e-01, -1.1904e+00, -4.1602e-01, 3.2715e-02, -4.9854e-01, -6.3623e-01, -8.5107e-01, -6.8555e-01, -9.4434e-01, -8.8672e-01, -6.7969e-01, -6.9824e-01, -2.8882e-01, 2.0312e+00]], device='xpu:1', dtype=torch.float16) Execution finished ``` Pull Request resolved: #153067 Approved by: https://github.com/albanD, https://github.com/EikanWang (cherry picked from commit e06a080) Co-authored-by: Yu, Guangye <guangye.yu@intel.com>
1 parent fa98236 commit 3bfe071

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

aten/src/ATen/native/mkldnn/xpu/Conv.cpp

Lines changed: 15 additions & 0 deletions
< 10000 /div>
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,11 @@ Tensor _convolution_out(
401401
int64_t groups_,
402402
Attr attr,
403403
IntArrayRef pad_nd = IntArrayRef({})) {
404+
CheckedFrom c = "xpu_convolution";
405+
TensorArg input_t{input_r, "input", 1}, weight_t{weight_r, "weight", 2};
406+
checkAllSameType(c, {input_t, weight_t});
407+
checkAllSameGPU(c, {input_t, weight_t});
408+
c10::DeviceGuard device_guard(input_r.device());
404409
auto ndim = input_r.ndimension();
405410
TORCH_CHECK(
406411
3 == ndim || 4 == ndim || 5 == ndim,
@@ -611,6 +616,8 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
611616
IntArrayRef output_padding,
612617
int64_t groups,
613618
std::array<bool, 3> output_mask) {
619+
CheckedFrom c = "xpu_convolution_backward";
620+
c10::DeviceGuard device_guard(grad_output.device());
614621
auto ndim = input.ndimension();
615622
TORCH_CHECK(
616623
3 == ndim || 4 == ndim || 5 == ndim,
@@ -675,6 +682,10 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
675682
grad_bias = at::empty({grad_output_.size(1)}, opt);
676683

677684
if (output_mask[0]) {
685+
TensorArg grad_output_t{grad_output, "grad_output", 1},
686+
input_t{input, "input", 2};
687+
checkAllSameType(c, {grad_output_t, input_t});
688+
checkAllSameGPU(c, {grad_output_t, input_t});
678689
if (input.numel() > 0) {
679690
if (transposed_) {
680691
onednn::deconvolution_backward_data(
@@ -701,6 +712,10 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
701712
}
702713
}
703714
if (output_mask[1] || output_mask[2]) {
715+
TensorArg grad_output_t{grad_output, "grad_output", 1},
716+
weight_t{weight, "weight", 2};
717+
checkAllSameType(c, {grad_output_t, weight_t});
718+
checkAllSameGPU(c, {grad_output_t, weight_t});
704719
if (input.numel() > 0) {
705720
if (transposed_) {
706721
onednn::deconvolution_backward_weights(

test/xpu/test_conv.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: intel"]
22

3+
import copy
34
import itertools
45
import math
56
import unittest
@@ -1191,6 +1192,25 @@ def test_conv2d_no_grad(self, device, dtype):
11911192
output = m(input)
11921193
self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5)
11931194

1195+
@unittest.skipIf(torch.xpu.device_count() < 2, "only one GPU detected")
1196+
@dtypes(torch.double, torch.float, torch.half)
1197+
def test_conv2d_on_multi_device(self, dtype):
1198+
input = torch.randn(3, 256, 224, 224, dtype=dtype, requires_grad=True)
1199+
conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, dtype=dtype)
1200+
output_grad = torch.randn(3, 256, 224, 224, dtype=dtype)
1201+
input_0 = input.to(device="xpu:0")
1202+
conv_0 = copy.deepcopy(conv).to(device="xpu:0")
1203+
output_0 = conv_0(input_0)
1204+
input_1 = input.to(device="xpu:1")
1205+
conv_1 = copy.deepcopy(conv).to(device="xpu:1")
1206+
output_1 = conv_1(input_1)
1207+
self.assertEqual(output_0.cpu(), output_1.cpu())
1208+
output_grad_0 = output_grad.to(device="xpu:0")
1209+
output_0.backward(output_grad_0)
1210+
output_grad_1 = output_grad.to(device="xpu:1")
1211+
output_1.backward(output_grad_1)
1212+
self.assertEqual(output_grad_0.cpu(), output_grad_1.cpu())
1213+
11941214
def test_conv_double_backward_strided_with_3D_input_and_weight(self, device):
11951215
input = torch.randn(2, 3, 6, device=device)
11961216
weight = torch.randn(3, 3, 3, device=device)

0 commit comments

Comments
 (0)
0