3
3
import sys
4
4
import unittest
5
5
from functools import partial , wraps
6
- from unittest .mock import patch
7
6
8
7
import torch
9
8
import torch .distributed as dist
10
9
import torch .distributed ._functional_collectives as ft_c
11
10
import torch .distributed .distributed_c10d as c10d
12
11
import torch .distributed .tensor as dt
13
12
from functorch import make_fx
14
- from torch ._dynamo .metrics_context import MetricsContext
15
13
from torch ._inductor .utils import run_and_get_code
16
14
from torch .testing import FileCheck
17
15
from torch .testing ._internal .common_device_type import instantiate_device_type_tests
33
31
instantiate_parametrized_tests ,
34
32
parametrize ,
35
33
run_tests ,
34
+ skipIfHpu ,
36
35
TEST_CUDA ,
37
36
TEST_HPU ,
38
37
TestCase ,
@@ -91,7 +90,7 @@ def new_subgroups(group_size: int, pg_tag=None):
91
90
return cur_subgroup , subgroups
92
91
93
92
94
- @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
93
+ @skipIfHpu
95
94
class TestExpand (MultiThreadedTestCase ):
96
95
@property
97
96
def world_size (self ):
@@ -181,7 +180,7 @@ def test_expand_device_mesh_tuple(self):
181
180
self .assertEqual (2 , group_size )
182
181
183
182
184
- @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
183
+ @skipIfHpu
185
184
class TestPgTag (MultiThreadedTestCase ):
186
185
@property
187
186
def world_size (self ):
@@ -258,7 +257,7 @@ def test_find_root_pg(self):
258
257
259
258
260
259
@instantiate_parametrized_tests
261
- @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
260
+ @skipIfHpu
262
261
class TestTraceableCollectives (MultiThreadedTestCase ):
263
262
@property
264
263
def world_size (self ):
@@ -404,7 +403,7 @@ def test_all_reduce(self):
404
403
self .assertEqual (x .size (), out .size ())
405
404
406
405
407
- @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
406
+ @skipIfHpu
408
407
class TestGradCollectives (MultiThreadedTestCase ):
409
408
@property
410
409
def world_size (self ):
@@ -657,7 +656,7 @@ def test_permute_tensor_with_sub_group(self, device):
657
656
658
657
659
658
@instantiate_parametrized_tests
660
- @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
659
+ @skipIfHpu
661
660
class TestFunctionalAutograd (MultiThreadedTestCase ):
662
661
def setUp (self ):
663
662
super ().setUp ()
@@ -667,13 +666,6 @@ def setUp(self):
667
666
def world_size (self ):
668
667
return 2
669
668
670
- # `compilation_metric` attempts to update the `is_forward` field of `metrics_context`. Since
671
- # `metrics_context` is a singleton, a runtime error will occur if multiple threads try to update it
672
- # because `MetricsContext` does not allow updating existing fields when `overwrite` is False.
673
- # So, we need to patch the `update` function of MetricsContext
674
- def _metrics_context_update (self , * args , ** kwargs ) -> None :
675
- pass
676
-
677
669
@parametrize ("compile" , [True , False ])
678
670
def test_all_to_all_single (self , compile : bool = True ) -> None :
679
671
group = dist .group .WORLD .group_name
@@ -699,8 +691,7 @@ def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
699
691
self .assertIsNotNone (out .grad_fn )
700
692
self .assertTrue (out .requires_grad )
701
693
loss = out .sum ()
702
- with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
703
- loss .backward ()
694
+ loss .backward ()
704
695
self .assertEqual (t .grad , torch .full_like (t , 2.0 ))
705
696
706
697
def test_all_to_all_single_inductor (self ) -> None :
@@ -720,8 +711,7 @@ def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
720
711
721
712
def run_with_backward ():
722
713
out = compiled (t , self .world_size )
723
- with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
724
- out .backward ()
714
+ out .backward ()
725
715
726
716
_ , codes = run_and_get_code (run_with_backward )
727
717
for code in codes :
@@ -761,8 +751,7 @@ def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
761
751
gathered_tensor = compiled (local_tensor , dim )
762
752
self .assertEqual (gathered_tensor , torch .ones (output_size ))
763
753
764
- with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
765
- gathered_tensor .sum ().backward ()
754
+ gathered_tensor .sum ().backward ()
766
755
self .assertEqual (
767
756
local_tensor .grad ,
768
757
torch .full ((3 , 3 , 3 ), fill_value = float (self .world_size )),
@@ -797,8 +786,7 @@ def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
797
786
rs_tensor = compiled (input_tensor , dim )
798
787
res_num = 1 * group_size
799
788
self .assertEqual (rs_tensor , torch .ones (input_size ) * res_num )
800
- with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
801
- rs_tensor .sum ().backward ()
789
+ rs_te
F438
nsor .sum ().backward ()
802
790
self .assertEqual (input_tensor .grad , torch .full (output_size , fill_value = 1.0 ))
803
791
804
792
0 commit comments