8000 Revert "Fix skipIfXpu and skipIfHpu disables tests when used on class… · pytorch/pytorch@2344eca · GitHub
[go: up one dir, main page]

Skip to content

Commit 2344eca

Browse files
Revert "Fix skipIfXpu and skipIfHpu disables tests when used on class (#151315)"
This reverts commit ee096b8. Reverted #151315 on behalf of https://github.com/jeanschmidt due to Seems to have introduced internal regressions, see [D74668899](https://www.internalfb.com/diff/D74668899). @malfet may you help the author get this PR merged? ([comment](#151315 (comment)))
1 parent 2c19124 commit 2344eca

File tree

4 files changed

+19
-40
lines changed

4 files changed

+19
-40
lines changed

test/distributed/test_functional_api.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33
import sys
44
import unittest
55
from functools import partial, wraps
6-
from unittest.mock import patch
76

87
import torch
98
import torch.distributed as dist
109
import torch.distributed._functional_collectives as ft_c
1110
import torch.distributed.distributed_c10d as c10d
1211
import torch.distributed.tensor as dt
1312
from functorch import make_fx
14-
from torch._dynamo.metrics_context import MetricsContext
1513
from torch._inductor.utils import run_and_get_code
1614
from torch.testing import FileCheck
1715
from torch.testing._internal.common_device_type import instantiate_device_type_tests
@@ -33,6 +31,7 @@
3331
instantiate_parametrized_tests,
3432
parametrize,
3533
run_tests,
34+
skipIfHpu,
3635
TEST_CUDA,
3736
TEST_HPU,
3837
TestCase,
@@ -91,7 +90,7 @@ def new_subgroups(group_size: int, pg_tag=None):
9190
return cur_subgroup, subgroups
9291

9392

94-
@unittest.skipIf(TEST_HPU, "Unsupported on HPU")
93+
@skipIfHpu
9594
class TestExpand(MultiThreadedTestCase):
9695
@property
9796
def world_size(self):
@@ -181,7 +180,7 @@ def test_expand_device_mesh_tuple(self):
181180
self.assertEqual(2, group_size)
182181

183182

184-
@unittest.skipIf(TEST_HPU, "Unsupported on HPU")
183+
@skipIfHpu
185184
class TestPgTag(MultiThreadedTestCase):
186185
@property
187186
def world_size(self):
@@ -258,7 +257,7 @@ def test_find_root_pg(self):
258257

259258

260259
@instantiate_parametrized_tests
261-
@unittest.skipIf(TEST_HPU, "Unsupported on HPU")
260+
@skipIfHpu
262261
class TestTraceableCollectives(MultiThreadedTestCase):
263262
@property
264263
def world_size(self):
@@ -404,7 +403,7 @@ def test_all_reduce(self):
404403
self.assertEqual(x.size(), out.size())
405404

406405

407-
@unittest.skipIf(TEST_HPU, "Unsupported on HPU")
406+
@skipIfHpu
408407
class TestGradCollectives(MultiThreadedTestCase):
409408
@property
410409
def world_size(self):
@@ -657,7 +656,7 @@ def test_permute_tensor_with_sub_group(self, device):
657656

658657

659658
@instantiate_parametrized_tests
660-
@unittest.skipIf(TEST_HPU, "Unsupported on HPU")
659+
@skipIfHpu
661660
class TestFunctionalAutograd(MultiThreadedTestCase):
662661
def setUp(self):
663662
super().setUp()
@@ -667,13 +666,6 @@ def setUp(self):
667666
def world_size(self):
668667
return 2
669668

670-
# `compilation_metric` attempts to update the `is_forward` field of `metrics_context`. Since
671-
# `metrics_context` is a singleton, a runtime error will occur if multiple threads try to update it
672-
# because `MetricsContext` does not allow updating existing fields when `overwrite` is False.
673-
# So, we need to patch the `update` function of MetricsContext
674-
def _metrics_context_update(self, *args, **kwargs) -> None:
675-
pass
676-
677669
@parametrize("compile", [True, False])
678670
def test_all_to_all_single(self, compile: bool = True) -> None:
679671
group = dist.group.WORLD.group_name
@@ -699,8 +691,7 @@ def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
699691
self.assertIsNotNone(out.grad_fn)
700692
self.assertTrue(out.requires_grad)
701693
loss = out.sum()
702-
with patch.object(MetricsContext, "update", self._metrics_context_update):
703-
loss.backward()
694+
loss.backward()
704695
self.assertEqual(t.grad, torch.full_like(t, 2.0))
705696

706697
def test_all_to_all_single_inductor(self) -> None:
@@ -720,8 +711,7 @@ def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
720711

721712
def run_with_backward():
722713
out = compiled(t, self.world_size)
723-
with patch.object(MetricsContext, "update", self._metrics_context_update):
724-
out.backward()
714+
out.backward()
725715

726716
_, codes = run_and_get_code(run_with_backward)
727717
for code in codes:
@@ -761,8 +751,7 @@ def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
761751
gathered_tensor = compiled(local_tensor, dim)
762752
self.assertEqual(gathered_tensor, torch.ones(output_size))
763753

764-
with patch.object(MetricsContext, "update", self._metrics_context_update):
765-
gathered_tensor.sum().backward()
754+
gathered_tensor.sum().backward()
766755
self.assertEqual(
767756
local_tensor.grad,
768757
torch.full((3, 3, 3), fill_value=float(self.world_size)),
@@ -797,8 +786,7 @@ def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
797786
rs_tensor = compiled(input_tensor, dim)
798787
res_num = 1 * group_size
799788
self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
800-
with patch.object(MetricsContext, "update", self._metrics_context_update):
801-
rs_tensor.sum().backward()
789+
rs_te F438 nsor.sum().backward()
802790
self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0))
803791

804792

test/inductor/test_autoheuristic.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,17 @@
44

55
import torch
66
import torch._inductor.config as inductor_config
7+
from torch._dynamo.device_interface import get_interface_for_device
78
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristic, LocalFeedback
89
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext
910
from torch._inductor.runtime.runtime_utils import cache_dir
1011
from torch._inductor.test_case import run_tests, TestCase
1112
from torch._inductor.utils import get_gpu_shared_memory
12-
from torch.testing._internal.common_utils import TEST_XPU
13-
from torch.testing._internal.inductor_utils import (
14-
GPU_TYPE,
15-
HAS_CUDA,
16-
HAS_GPU,
17-
IS_A100,
18-
IS_H100,
19-
)
13+
from torch.testing._internal.common_utils import skipIfXpu
14+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_A100, IS_H100
2015

2116

22-
@unittest.skipIf(TEST_XPU, "AutoHeuristic doesn't currently work on the XPU stack")
17+
@skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack")
2318
class AutoHeuristicTest(TestCase):
2419
def count_lines_in_file(self, file_path):
2520
with open(file_path) as file:
@@ -107,9 +102,7 @@ def feedback_fn(choice):
107102
self.assertEqual(num_lines, 5)
108103

109104
shared_memory = get_gpu_shared_memory()
110-
111-
self.assertTrue(HAS_CUDA)
112-
(fst, snd) = torch.cuda.get_device_capability()
105+
(fst, snd) = get_interface_for_device(GPU_TYPE).get_device_capability()
113106

114107
with open(path) as file:
115108
lines = file.readlines()
@@ -158,7 +151,6 @@ def fn(a, b):
158151
fx_graph_cache=False,
159152
fx_graph_remote_cache=False,
160153
)
161-
@unittest.skipIf(not IS_A100, "heuristic only run on A100")
162154
def test_global_feedback(self):
163155
self.run_mixed_mm()
164156
path = self.get_path_to_autoheuristic_log("mixed_mm")

test/inductor/test_b2b_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from torch._inductor.runtime.benchmarking import benchmarker
77
from torch._inductor.test_case import run_tests, TestCase
88
from torch._inductor.utils import run_and_get_code
9-
from torch.testing._internal.common_utils import TEST_XPU
9+
from torch.testing._internal.common_utils import skipIfXpu
1010
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
1111

1212

13-
@unittest.skipIf(TEST_XPU, "Segmentation fault on CI machine")
13+
@skipIfXpu(msg="Segmentation fault on CI machine")
1414
class B2BGEMMTest(TestCase):
1515
device = GPU_TYPE
1616

test/inductor/test_layout_optim.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
import copy
33
import os
44
import random
5-
import unittest
65

76
import torch
87
from torch import nn
98
from torch._dynamo.utils import same
109
from torch._inductor import config
1110
from torch._inductor.test_case import run_tests, TestCase
1211
from torch.testing._internal.common_cuda import tf32_off
13-
from torch.testing._internal.common_utils import TEST_XPU
12+
from torch.testing._internal.common_utils import skipIfXpu
1413
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
1514

1615

@@ -35,7 +34,7 @@ def get_example_inputs(self):
3534
return (torch.rand(2, 3, 16, 16),)
3635

3736

38-
@unittest.skipIf(TEST_XPU, "ccl doesn't currently work on the XPU stack")
37+
@skipIfXpu(msg="ccl doesn't currently work on the XPU stack")
3938
class TestLayoutOptim(TestCase):
4039
@classmethod
4140
def setUpClass(cls):

0 commit comments

Comments
 (0)
0