8000 Use inductor TestCase for test_replicate_with_compiler.py by masnesral · Pull Request #131193 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Use inductor TestCase for test_replicate_with_compiler.py #131193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions test/distributed/_composable/test_replicate_with_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch._C import FileCheck
from torch._dynamo import compiled_autograd
from torch._dynamo.utils import counters
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_triton_code
from torch.distributed._composable.replicate import replicate
from torch.distributed.algorithms.ddp_comm_hooks import (
Expand Down Expand Up @@ -69,7 +70,15 @@ def inner_compiler(gm_, example_inputs_):
return _compiler_fn


class ReplicateTest(MultiProcessTestCase):
class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase):
"""
A version of MultiProcessTestCase that derives from the Inductor TestCase
to handle isolation of the inductor cache dir.
"""
pass


class ReplicateTest(MultiProcessInductorTestCase):
@property
def world_size(self) -> int:
return min(2, torch.cuda.device_count())
Expand Down Expand Up @@ -350,7 +359,7 @@ def test_bucketing_concat_op(self):
fc.run(code)


class DDP_TP_Test(MultiProcessTestCase):
class DDP_TP_Test(MultiProcessInductorTestCase):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
Expand Down
Loading
0