8000 Fix #150779 · pytorch/pytorch@178ff0e · GitHub
[go: up one dir, main page]

Skip to content

Commit 178ff0e

Browse files
committed
Fix #150779
ghstack-source-id: a7a0ede Pull Request resolved: #151315
1 parent 032ef48 commit 178ff0e

File tree

4 files changed

+40
-19
lines changed

4 files changed

+40
-19
lines changed

test/distributed/test_functional_api.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import sys
44
import unittest
55
from functools import partial, wraps
6+
from unittest.mock import patch
67

78
import torch
89
import torch.distributed as dist
910
import torch.distributed._functional_collectives as ft_c
1011
import torch.distributed.distributed_c10d as c10d
1112
import torch.distributed.tensor as dt
1213
from functorch import make_fx
14+
from torch._dynamo.metrics_context import MetricsContext
1315
from torch._inductor.utils import run_and_get_code
1416
from torch.testing import FileCheck
1517
from torch.testing._internal.common_device_type import instantiate_device_type_tests
@@ -31,7 +33,6 @@
3133
instantiate_parametrized_tests,
3234
parametrize,
3335
run_tests,
34-
skipIfHpu,
3536
TEST_CUDA,
3637
TEST_HPU,
3738
TestCase,
@@ -90,7 +91,7 @@ def new_subgroups(group_size: int, pg_tag=None):
9091
return cur_subgroup, subgroups
9192

9293

93-
@skipIfHpu
94+
@unittest.skipIf(TEST_HPU, "Unsupported on HPU")
9495
class TestExpand(MultiThreadedTestCase):
9596
@property
9697
def world_size(self):
@@ -180,7 +181,7 @@ def test_expand_device_mesh_tuple(self):
180181
self.assertEqual(2, group_size)
181182

182183

183-
@skipIfHpu
184+
@unittest.skipIf(TEST_HPU, "Unsupported on HPU")
184185
class TestPgTag(MultiThreadedTestCase):
185186
@property
186187
def world_size(self):
@@ -257,7 +258,7 @@ def test_find_root_pg(self):
257258

258259

259260
@instantiate_parametrized_tests
260-
@skipIfHpu
261+
@unittest.skipIf(TEST_HPU, "Unsupported on HPU")
261262
class TestTraceableCollectives(MultiThreadedTestCase):
262263
@property
263264
def world_size(self):
@@ -403,7 +404,7 @@ def test_all_reduce(self):
403404
self.assertEqual(x.size(), out.size())
404405

405406

406-
@skipIfHpu
407+
@unittest.skipIf(TEST_HPU, "Unsupported on HPU")
407408
class TestGradCollectives(MultiThreadedTestCase):
408409
@property
409410
def world_size(self):
@@ -656,7 +657,7 @@ def test_permute_tensor_with_sub_group(self, device):
656657

657658

658659
@instantiate_parametrized_tests
659-
@skipIfHpu
660+
@unittest.skipIf(TEST_HPU, "Unsupported on HPU")
660661
class TestFunctionalAutograd(MultiThreadedTestCase):
661662
def setUp(self):
662663
super().setUp()
@@ -666,6 +667,13 @@ def setUp(self):
666667
def world_size(self):
667668
return 2
668669

670+
# `compilation_metric` attempts to update the `is_forward` field of `metrics_context`. Since
671+
# `metrics_context` is a singleton, a runtime error will occur if multiple threads try to update it
672+
# because `MetricsContext` does not allow updating existing fields when `overwrite` is False.
673+
# So, we need to patch the `update` function of MetricsContext
674+
def _metrics_context_update(self, *args, **kwargs) -> None:
675+
pass
676+
669677
@parametrize("compile", [True, False])
670678
def test_all_to_all_single(self, compile: bool = True) -> None:
671679
group = dist.group.WORLD.group_name
@@ -691,7 +699,8 @@ def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
691699
self.assertIsNotNone(out.grad_fn)
692700
self.assertTrue(out.requires_grad)
693701
loss = out.sum()
694-
loss.backward()
702+
with patch.object(MetricsContext, "update", self._metrics_context_update):
703+
loss.backward()
695704
self.assertEqual(t.grad, torch.full_like(t, 2.0))
696705

697706
def test_all_to_all_single_inductor(self) -> None:
@@ -711,7 +720,8 @@ def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
711720

712721
def run_with_backward():
713722
out = compiled(t, self.world_size)
714-
out.backward()
723+
with patch.object(MetricsContext, "update", self._metrics_context_update):
724+
out.backward()
715725

716726
_, codes = run_and_get_code(run_with_backward)
717727
for code in codes:
@@ -751,7 +761,8 @@ def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
751761
gathered_tensor = compiled(local_tensor, dim)
752762
self.assertEqual(gathered_tensor, torch.ones(output_size))
753763

754-
gathered_tensor.sum().backward()
764+
with patch.object(MetricsContext, "update", self._metrics_context_update):
765+
gathered_tensor.sum().backward()
755766
self.assertEqual(
756767
local_tensor.grad,
757768
torch.full((3, 3, 3), fill_value=float(self.world_size)),
@@ -786,7 +797,8 @@ def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
786797
rs_tensor = compiled(input_tensor, dim)
787798
res_num = 1 * group_size
788799
self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
789-
rs_tensor.sum().backward()
800+
with patch.object(MetricsContext, "update", self._metrics_context_update):
801+
rs_tensor.sum().backward()
790802
self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0))
791803

792804

test/inductor/test_autoheuristic.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,22 @@
44

55
import torch
66
import torch._inductor.config as inductor_config
7-
from torch._dynamo.device_interface import get_interface_for_device
87
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristic, LocalFeedback
98
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext
109
from torch._inductor.runtime.runtime_utils import cache_dir
1110
from torch._inductor.test_case import run_tests, TestCase
1211
from torch._inductor.utils import get_gpu_shared_memory
13-
from torch.testing._internal.common_utils import skipIfXpu
14-
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_A100, IS_H100
12+
from torch.testing._internal.common_utils import TEST_XPU
13+
from torch.testing._internal.inductor_utils import (
14+
GPU_TYPE,
15+
HAS_CUDA,
16+
HAS_GPU,
17+
IS_A100,
18+
IS_H100,
19+
)
1520

1621

17-
@skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack")
22+
@unittest.skipIf(TEST_XPU, "AutoHeuristic doesn't currently work on the XPU stack")
1823
class AutoHeuristicTest(TestCase):
1924
def count_lines_in_file(self, file_path):
2025
with open(file_path) as file:
@@ -102,7 +107,9 @@ def feedback_fn(choice):
102107
self.assertEqual(num_lines, 5)
103108

104109
shared_memory = get_gpu_shared_memory()
105-
(fst, snd) = get_interface_for_device(GPU_TYPE).get_device_capability()
110+
111+
self.assertTrue(HAS_CUDA)
112+
(fst, snd) = torch.cuda.get_device_capability()
106113

107114
with open(path) as file:
108115
lines = file.readlines()
@@ -151,6 +158,7 @@ def fn(a, b):
151158
fx_graph_cache=False,
152159
fx_graph_remote_cache=False,
153160
)
161+
@unittest.skipIf(not IS_A100, "heuristic only run on A100")
154162
def test_global_feedback(self):
155163
self.run_mixed_mm()
156164
path = self.get_path_to_autoheuristic_log("mixed_mm")

test/inductor/test_b2b_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from torch._inductor.runtime.benchmarking import benchmarker
77
from torch._inductor.test_case import run_tests, TestCase
88
from torch._inductor.utils import run_and_get_code
9-
from torch.testing._internal.common_utils import skipIfXpu
9+
from torch.testing._internal.common_utils import TEST_XPU
1010
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
1111

1212

13-
@skipIfXpu(msg="Segmentation fault on CI machine")
13+
@unittest.skipIf(TEST_XPU, "Segmentation fault on CI machine")
1414
class B2BGEMMTest(TestCase):
1515
device = GPU_TYPE
1616

test/inductor/test_layout_optim.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import copy
33
import os
44
import random
5+
import unittest
56

67
import torch
78
from torch import nn
89
from torch._dynamo.utils import same
910
from torch._inductor import config
1011
from torch._inductor.test_case import run_tests, TestCase
1112
from torch.testing._internal.common_cuda import tf32_off
12-
from torch.testing._internal.common_utils import skipIfXpu
13+
from torch.testing._internal.common_utils import TEST_XPU
1314
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
1415

1516

@@ -34,7 +35,7 @@ def get_example_inputs(self):
3435
return (torch.rand(2, 3, 16, 16),)
3536

3637

37-
@skipIfXpu(msg="ccl doesn't currently work on the XPU stack")
38+
@unittest.skipIf(TEST_XPU, "ccl doesn't currently work on the XPU stack")
3839
class TestLayoutOptim(TestCase):
3940
@classmethod
4041
def setUpClass(cls):

0 commit comments

Comments
 (0)
0