3
3
import sys
4
4
import unittest
5
5
from functools import partial , wraps
6
+ from unittest .mock import patch
6
7
7
8
import torch
8
9
import torch .distributed as dist
9
10
import torch .distributed ._functional_collectives as ft_c
10
11
import torch .distributed .distributed_c10d as c10d
11
12
import torch .distributed .tensor as dt
12
13
from functorch import make_fx
14
+ from torch ._dynamo .metrics_context import MetricsContext
13
15
from torch ._inductor .utils import run_and_get_code
14
16
from torch .testing import FileCheck
15
17
from torch .testing ._internal .common_device_type import instantiate_device_type_tests
31
33
instantiate_parametrized_tests ,
32
34
parametrize ,
33
35
run_tests ,
34
- skipIfHpu ,
35
36
TEST_CUDA ,
36
37
TEST_HPU ,
37
38
TestCase ,
@@ -90,7 +91,7 @@ def new_subgroups(group_size: int, pg_tag=None):
90
91
return cur_subgroup , subgroups
91
92
92
93
93
- @skipIfHpu
94
+ @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
94
95
class TestExpand (MultiThreadedTestCase ):
95
96
@property
96
97
def world_size (self ):
@@ -180,7 +181,7 @@ def test_expand_device_mesh_tuple(self):
180
181
self .assertEqual (2 , group_size )
181
182
182
183
183
- @skipIfHpu
184
+ @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
184
185
class TestPgTag (MultiThreadedTestCase ):
185
186
@property
186
187
def world_size (self ):
@@ -257,7 +258,7 @@ def test_find_root_pg(self):
257
258
258
259
259
260
@instantiate_parametrized_tests
260
- @skipIfHpu
261
+ @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
261
262
class TestTraceableCollectives (MultiThreadedTestCase ):
262
263
@property
263
264
def world_size (self ):
@@ -403,7 +404,7 @@ def test_all_reduce(self):
403
404
self .assertEqual (x .size (), out .size ())
404
405
405
406
406
- @skipIfHpu
407
+ @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
407
408
class TestGradCollectives (MultiThreadedTestCase ):
408
409
@property
409
410
def world_size (self ):
@@ -656,7 +657,7 @@ def test_permute_tensor_with_sub_group(self, device):
656
657
657
658
658
659
@instantiate_parametrized_tests
659
- @skipIfHpu
660
+ @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
660
661
class TestFunctionalAutograd (MultiThreadedTestCase ):
661
662
def setUp (self ):
662
663
super ().setUp ()
@@ -666,6 +667,13 @@ def setUp(self):
666
667
def world_size (self ):
667
668
return 2
668
669
670
+ # `compilation_metric` attempts to update the `is_forward` field of `metrics_context`. Since
671
+ # `metrics_context` is a singleton, a runtime error will occur if multiple threads try to update it
672
+ # because `MetricsContext` does not allow updating existing fields when `overwrite` is False.
673
+ # So, we need to patch the `update` function of MetricsContext
674
+ def _metrics_context_update (self , * args , ** kwargs ) -> None :
675
+ pass
676
+
669
677
@parametrize ("compile" , [True , False ])
670
678
def test_all_to_all_single (self , compile : bool = True ) -> None :
671
679
group = dist .group .WORLD .group_name
@@ -691,7 +699,8 @@ def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
691
699
self .assertIsNotNone (out .grad_fn )
692
700
self .assertTrue (out .requires_grad )
693
701
loss = out .sum ()
694
- loss .backward ()
702
+ with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
703
+ loss .backward ()
695
704
self .assertEqual (t .grad , torch .full_like (t , 2.0 ))
696
705
697
706
def test_all_to_all_single_inductor (self ) -> None :
@@ -711,7 +720,8 @@ def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
711
720
712
721
def run_with_backward ():
713
722
out = compiled (t , self .world_size )
714
- out .backward ()
723
+ with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
724
+ out .backward ()
715
725
716
726
_ , codes = run_and_get_code (run_with_backward )
717
727
for code in codes :
@@ -751,7 +761,8 @@ def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
751
761
gathered_tensor = compiled (local_tensor , dim )
752
762
self .assertEqual (gathered_tensor , torch .ones (output_size ))
753
763
754
- gathered_tensor .sum ().backward ()
764
+ with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
765
+ gathered_tensor .sum ().backward ()
755
766
self .assertEqual (
756
767
local_tensor .grad ,
757
768
torch .full ((3 , 3 , 3 ), fill_value = float (self .world_size )),
@@ -786,7 +797,8 @@ def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
786
797
rs_tensor = compiled (input_tensor , dim )
787
798
res_num = 1 * group_size
788
799
self .assertEqual (rs_tensor , torch .ones (input_size ) * res_num )
789
- rs_tensor .sum ().backward ()
800
+ with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
801
+ rs_tensor .sum ().backward ()
790
802
self .assertEqual (input_tensor .grad , torch .full (output_size , fill_value = 1.0 ))
791
803
792
804
0 commit comments