8000 Add windows-specific test skips · pytorch/pytorch@44497f9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 44497f9

Browse files
committed
Add windows-specific test skips
1 parent 71cff26 commit 44497f9

11 files changed

+27
-9
lines changed

test/export/test_db.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
parametrize,
1616
run_tests,
1717
TestCase,
18+
IS_WINDOWS
1819
)
1920

20-
21+
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
2122
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
2223
class ExampleTests(TestCase):
2324
# TODO Maybe we should make this tests actually show up in a file?

test/export/test_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def branch_on_shape(x: torch.Tensor):
131131
# Being able to export means shape is preserved as static
132132
export(WrapperModule(branch_on_shape), inp)
133133

134-
134+
@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
135135
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
136136
class TestExport(TestCase):
137137
def _test_export_same_as_eager(self, f, args, kwargs=None):

test/export/test_pass_infra.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
99
from torch.export import export
1010
from torch.fx.passes.infra.pass_base import PassResult
11-
from torch.testing._internal.common_utils import run_tests, TestCase
11+
from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS
1212

1313

1414
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
@@ -41,6 +41,7 @@ class NullPass(_ExportPassBaseDeprecatedDoNotUse):
4141
self.assertEqual(new_node.op, old_node.op)
4242
self.assertEqual(new_node.target, old_node.target)
4343

44+
@unittest.skipIf(IS_WINDOWS, "Windows not supported")
4445
def test_cond(self) -> None:
4546
class M(torch.nn.Module):
4647
def __init__(self):

test/export/test_passes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch.fx.passes.infra.partitioner import Partition
2727
from torch.fx.passes.operator_support import OperatorSupport
2828
from torch.testing import FileCheck
29-
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo
29+
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo, IS_WINDOWS
3030
from torch.utils import _pytree as pytree
3131

3232

@@ -276,6 +276,7 @@ def forward(self, x):
276276
new_inp = torch.tensor([1, 1, 1, 1])
277277
self.assertEqual(mod(new_inp), ep.module()(new_inp))
278278

279+
@unittest.skipIf(IS_WINDOWS, "Windows not supported")
279280
def test_runtime_assert_inline_constraints_for_cond(self) -> None:
280281
class M(torch.nn.Module):
281282
def __init__(self):

test/export/test_serialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
210210
g.nodes[1].inputs[0].arg.as_tensor.name
211211
)
212212

213-
213+
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
214214
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
215215
class TestDeserialize(TestCase):
216216
def check_graph(self, fn, inputs, dynamic_shapes=None, _check_meta=True) -> None:

test/export/test_unflatten.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def forward(self, x):
210210
id(getattr(unflattened_module.sub_net, "2")),
211211
)
212212

213+
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
213214
@skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
214215
def test_unflatten_preserve_signature(self):
215216
class NestedChild(torch.nn.Module):

test/export/test_upgrade.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.testing._internal.common_utils import (
1111
run_tests,
1212
TestCase,
13+
IS_WINDOWS,
1314
)
1415

1516
TEST_UPGRADERS = {
@@ -115,6 +116,7 @@ def forward(self, a: torch.Tensor, b):
115116
custom_op_count = count_op(upgraded.graph, "aten::div__Scalar_mode_0_3")
116117
self.assertEqual(custom_op_count, 1)
117118

119+
@unittest.skipIf(IS_WINDOWS, "Test case not supported on Windows")
118120
def test_div_upgrader_pass_return_new_op_after_retrace(self):
119121
class Foo(torch.nn.Module):
120122
def forward(self, a: torch.Tensor, b):

test/export/test_verifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from torch._export.verifier import SpecViolationError, Verifier
1111
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
12-
from torch.testing._internal.common_utils import run_tests, TestCase
12+
from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS
1313

1414
@unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported")
1515
class TestVerifier(TestCase):
@@ -56,6 +56,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
5656
with self.assertRaises(SpecViolationError):
5757
verifier.check(ep)
5858

59+
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
5960
def test_verifier_higher_order(self) -> None:
6061
class Foo(torch.nn.Module):
6162
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@@ -76,6 +77,7 @@ def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
7677
verifier = Verifier()
7778
verifier.check(ep)
7879

80+
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
7981
def test_verifier_nested_invalid_module(self) -> None:
8082
class Foo(torch.nn.Module):
8183
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

test/functorch/test_aotdispatch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
run_tests,
1414
IS_ARM64,
1515
IS_MACOS,
16+
IS_WINDOWS,
1617
IS_X86,
1718
compare_equal_outs_and_grads,
1819
outs_and_grads,
@@ -3261,6 +3262,7 @@ def fn(p, x):
32613262
):
32623263
aot_export_module(mod, [inp], trace_joint=True, output_loss_index=1)
32633264

3265+
@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case")
32643266
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "Cond needs dynamo to run")
32653267
def test_aot_export_with_torch_cond(self):
32663268
class M(torch.nn.Module):

test/functorch/test_control_flow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from functorch.experimental.control_flow import UnsupportedAliasMutationException, cond
1111
from torch._higher_order_ops.while_loop import while_loop
1212
from torch.fx.experimental.proxy_tensor import make_fx
13-
from torch.testing._internal.common_utils import run_tests, TestCase
13+
from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS
1414
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
1515
from torch._subclasses.functional_tensor import FunctionalTensor, CppFunctionalizeAPI, PythonFunctionalizeAPI, FunctionalTensorMode
1616

@@ -130,7 +130,7 @@ def forward(self, *operands):
130130
return self._reduce(*operands)
131131

132132

133-
133+
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
134134
@skipIfNoDynamoSupport
135135
class TestControlFlow(TestCase):
136136
def setUp(self):
@@ -316,6 +316,7 @@ def fwbw(map_op, f, x, y):
316316
self.assertEqual(true_outs, fake_outs)
317317

318318

319+
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
319320
@skipIfNoDynamoSupport
320321
class TestControlFlowTraced(TestCase):
321322
def setUp(self):

0 commit comments

Comments
 (0)
0