8000 Update · pytorch/pytorch@c3632c5 · GitHub
[go: up one dir, main page]

Skip to content

Commit c3632c5

Browse files
committed
Update
[ghstack-poisoned]
1 parent 4ad332e commit c3632c5

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

test/inductor/test_torchinductor.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -774,41 +774,41 @@ def is_cpp_backend(device):
774774

775775
def skip_if_cpu(fn):
776776
@functools.wraps(fn)
777-
def wrapper(self):
777+
def wrapper(self, *args, **kwargs):
778778
if self.device == "cpu":
779779
raise unittest.SkipTest("cpu not supported")
780-
return fn(self)
780+
return fn(self, *args, **kwargs)
781781

782782
return wrapper
783783

784784

785785
def skip_if_halide(fn):
786786
@functools.wraps(fn)
787-
def wrapper(self):
787+
def wrapper(self, *args, **kwargs):
788788
if is_halide_backend(self.device):
789789
raise unittest.SkipTest("halide not supported")
790-
return fn(self)
790+
return fn(self, *args, **kwargs)
791791

792792
return wrapper
793793

794794

795795
def xfail_if_mps(fn):
796796
@functools.wraps(fn)
797-
def wrapper(self):
797+
def wrapper(self, *args, **kwargs):
798798
if not is_mps_backend(self.device):
799-
return fn(self)
799+
return fn(self, *args, **kwargs)
800800
with self.assertRaises(Exception):
801-
return fn(self)
801+
return fn(self, *args, **kwargs)
802802

803803
return wrapper
804804

805805

806806
def skip_if_triton(fn):
807807
@functools.wraps(fn)
808-
def wrapper(self):
808+
def wrapper(self, *args, **kwargs):
809809
if is_triton_backend(self.device):
810810
raise unittest.SkipTest("triton not supported")
811-
return fn(self)
811+
return fn(self, *args, **kwargs)
812812

813813
return wrapper
814814

@@ -825,10 +825,10 @@ def wrapper(self, *args, **kwargs):
825825

826826
def skip_if_dynamic(fn):
827827
@functools.wraps(fn)
828-
def wrapper(self):
828+
def wrapper(self, *args, **kwargs):
829829
if ifdynstaticdefault(True, False) or torch._dynamo.config.dynamic_shapes:
830830
raise unittest.SkipTest("associtaive_scan doesn's support lifted SymInts.")
831-
return fn(self)
831+
return fn(self, *args, **kwargs)
832832

833833
return wrapper
834834

@@ -884,13 +884,13 @@ def xfail_if_triton_cpu(fn):
884884

885885
def skip_if_gpu_halide(fn):
886886
@functools.wraps(fn)
887-
def wrapper(self):
887+
def wrapper(self, *args, **kwargs):
888888
if (
889889
is_halide_backend(self.device)
890890
and getattr(self.device, "type", self.device) == "cuda"
891891
):
892892
raise unittest.SkipTest("halide not supported")
893-
return fn(self)
893+
return fn(self, *args, **kwargs)
894894

895895
return wrapper
896896

@@ -899,12 +899,12 @@ class skip_if_cpp_wrapper:
899899
def __init__(self, reason: str = "") -> None:
900900
self.reason = reason
901901

902-
def __call__(self, fn):
902+
def __call__(self, fn, *args, **kwargs):
903903
@functools.wraps(fn)
904904
def wrapper(test_self):
905905
if config.cpp_wrapper:
906906
raise unittest.SkipTest(f"cpp wrapper bug to be fixed: {self.reason}")
907-
return fn(test_self)
907+
return fn(test_self, *args, **kwargs)
908908

909909
return wrapper
910910

0 commit comments

Comments
 (0)
0