@@ -774,41 +774,41 @@ def is_cpp_backend(device):
774
774
775
775
def skip_if_cpu (fn ):
776
776
@functools .wraps (fn )
777
- def wrapper (self ):
777
+ def wrapper (self , * args , ** kwargs ):
778
778
if self .device == "cpu" :
779
779
raise unittest .SkipTest ("cpu not supported" )
780
- return fn (self )
780
+ return fn (self , * args , ** kwargs )
781
781
782
782
return wrapper
783
783
784
784
785
785
def skip_if_halide (fn ):
786
786
@functools .wraps (fn )
787
- def wrapper (self ):
787
+ def wrapper (self , * args , ** kwargs ):
788
788
if is_halide_backend (self .device ):
789
789
raise unittest .SkipTest ("halide not supported" )
790
- return fn (self )
790
+ return fn (self , * args , ** kwargs )
791
791
792
792
return wrapper
793
793
794
794
795
795
def xfail_if_mps (fn ):
796
796
@functools .wraps (fn )
797
- def wrapper (self ):
797
+ def wrapper (self , * args , ** kwargs ):
798
798
if not is_mps_backend (self .device ):
799
- return fn (self )
799
+ return fn (self , * args , ** kwargs )
800
800
with self .assertRaises (Exception ):
801
- return fn (self )
801
+ return fn (self , * args , ** kwargs )
802
802
803
803
return wrapper
804
804
805
805
806
806
def skip_if_triton (fn ):
807
807
@functools .wraps (fn )
808
- def wrapper (self ):
808
+ def wrapper (self , * args , ** kwargs ):
809
809
if is_triton_backend (self .device ):
810
810
raise unittest .SkipTest ("triton not supported" )
811
- return fn (self )
811
+ return fn (self , * args , ** kwargs )
812
812
813
813
return wrapper
814
814
@@ -825,10 +825,10 @@ def wrapper(self, *args, **kwargs):
825
825
826
826
def skip_if_dynamic (fn ):
827
827
@functools .wraps (fn )
828
- def wrapper (self ):
828
+ def wrapper (self , * args , ** kwargs ):
829
829
if ifdynstaticdefault (True , False ) or torch ._dynamo .config .dynamic_shapes :
830
830
raise unittest .SkipTest ("associtaive_scan doesn's support lifted SymInts." )
831
- return fn (self )
831
+ return fn (self , * args , ** kwargs )
832
832
833
833
return wrapper
834
834
@@ -884,13 +884,13 @@ def xfail_if_triton_cpu(fn):
884
884
885
885
def skip_if_gpu_halide (fn ):
886
886
@functools .wraps (fn )
887
- def wrapper (self ):
887
+ def wrapper (self , * args , ** kwargs ):
888
888
if (
889
889
is_halide_backend (self .device )
890
890
and getattr (self .device , "type" , self .device ) == "cuda"
891
891
):
892
892
raise unittest .SkipTest ("halide not supported" )
893
- return fn (self )
893
+ return fn (self , * args , ** kwargs )
894
894
895
895
return wrapper
896
896
@@ -899,12 +899,12 @@ class skip_if_cpp_wrapper:
899
899
def __init__ (self , reason : str = "" ) -> None :
900
900
self .reason = reason
901
901
902
- def __call__ (self , fn ):
902
+ def __call__ (self , fn , * args , ** kwargs ):
903
903
@functools .wraps (fn )
904
904
def wrapper (test_self ):
905
905
if config .cpp_wrapper :
906
906
raise unittest .SkipTest (f"cpp wrapper bug to be fixed: { self .reason } " )
907
- return fn (test_self )
907
+ return fn (test_self , * args , ** kwargs )
908
908
909
909
return wrapper
910
910
0 commit comments