8000 [c10d] Remove Option for ProcessGroup and Expose backend Options to r… · pytorch/pytorch@65864d0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 65864d0

Browse files
fduwjjpytorchmergebot
authored andcommitted
[c10d] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931)
We introduced the dispatchable backend for a ProcessGroup and collective in #86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG. Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options" We need to make changes to the test to make it aligned with the change. Pull Request resolved: #132931 Approved by: https://github.com/H-Huang
1 parent 8b4c487 commit 65864d0

File tree

8 files changed

+113
-117
lines changed

8 files changed

+113
-117
lines changed

test/distributed/test_c10d_common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1815,6 +1815,7 @@ def test_init_process_group_optional_backend(self):
18151815

18161816
def test_init_process_group_for_all_backends(self):
18171817
for backend in dist.Backend.backend_list:
1818+
excepted_backend = backend
18181819
# skip if the backend is not available on the system
18191820
if backend == dist.Backend.UNDEFINED:
18201821
continue
@@ -1830,6 +1831,11 @@ def test_init_process_group_for_all_backends(self):
18301831
elif backend == dist.Backend.UCC:
18311832
if not dist.is_ucc_available():
18321833
continue
1834+
# Multi-threaded PG is defined as a pure python class.
1835+
# Its pg.name() does not going through Pybind, so its backend name
1836+
# is still "threaded" instead of "custom".
1837+
elif backend != "threaded":
1838+
excepted_backend = "custom"
18331839

18341840
with tempfile.NamedTemporaryFile(delete=False) as f:
18351841
store = dist.FileStore(f.name, self.world_size)
@@ -1842,7 +1848,7 @@ def test_init_process_group_for_all_backends(self):
18421848
pg = c10d._get_default_group()
18431849
self.assertEqual(pg.rank(), self.rank)
18441850
self.assertEqual(pg.size(), self.world_size)
1845-
self.assertEqual(pg.name(), str(backend))
1851+
self.assertEqual(pg.name(), str(excepted_backend))
18461852

18471853
dist.destroy_process_group()
18481854

test/distributed/test_device_mesh.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def test_set_mesh_dim_group_options(self):
232232

233233
mesh_tensor = torch.arange(4).reshape(2, 2)
234234
mesh = DeviceMesh(device_type, mesh_tensor)
235-
self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake")
235+
# Fake pg only have BackendType as BackendType::CUSTOM.
236+
self.assertEqual(mesh.get_group(1)._get_backend_name(), "custom")
236237

237238

238239
class DeviceMeshTestNDim(DTensorTestBase):

torch/_C/_distributed_c10d.pyi

Lines changed: 3 additions & 12 deletions
F438
Original file line numberDiff line numberDiff line change
@@ -296,15 +296,6 @@ class Backend:
296296
def _set_default_timeout(self, timeout: timedelta) -> None: ...
297297

298298
class ProcessGroup:
299-
class Options:
300-
def __init__(self, backend: str, timeout: timedelta = ...) -> None: ...
301-
@property
302-
def backend(self) -> str: ...
303-
@property
304-
def _timeout(self) -> timedelta: ...
305-
@_timeout.setter
306-
def _timeout(self, val: timedelta) -> None: ...
307-
308299
class BackendType(Enum):
309300
UNDEFINED = ...
310301
GLOO = ...
@@ -318,7 +309,6 @@ class ProcessGroup:
318309
store: Store,
319310
rank: int,
320311
size: int,
321-
options: Options,
322312
) -> None: ...
323313
def rank(self) -> int: ...
324314
def size(self) -> int: ...
@@ -508,6 +498,7 @@ class ProcessGroup:
508498
@property
509499
def _device_types(self) -> list[torch.device]: ...
510500
def _get_backend(self, device: torch.device) -> Backend: ...
501+
def _set_default_backend(self, backend_type: BackendType) -> None: ...
511502
def _register_backend(
512503
self,
513504
device: torch.device,
@@ -532,7 +523,7 @@ class ProcessGroup:
532523
class ProcessGroupGloo(Backend):
533524
class Device: ...
534525

535-
class Options(ProcessGroup.Options):
526+
class Options(Backend.Options):
536527
devices: list[ProcessGroupGloo.Device]
537528
threads: int
538529

@@ -562,7 +553,7 @@ class ProcessGroupNCCL(Backend):
562553
min_ctas: int
563554
max_ctas: int
564555

565-
class Options(ProcessGroup.Options):
556+
class Options(Backend.Options):
566557
config: ProcessGroupNCCL.NCCLConfig
567558
is_high_priority_stream: bool
568559
split_from: ProcessGroupNCCL

torch/csrc/distributed/c10d/ProcessGroup.cpp

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,6 @@
1414

1515
namespace c10d {
1616

17-
static ProcessGroup::BackendType strToBackendType(std::string_view backend) {
18-
if (backend == "undefined") {
19-
return ProcessGroup::BackendType::UNDEFINED;
20-
} else if (backend == "gloo") {
21-
return ProcessGroup::BackendType::GLOO;
22-
} else if (backend == "nccl") {
23-
return ProcessGroup::BackendType::NCCL;
24-
} else if (backend == "ucc") {
25-
return ProcessGroup::BackendType::UCC;
26-
} else if (backend == "mpi") {
27-
return ProcessGroup::BackendType::MPI;
28-
} else {
29-
return ProcessGroup::BackendType::CUSTOM;
30-
}
31-
}
32-
3317
std::string opTypeToString(OpType opType) {
3418
switch (opType) {
3519
case OpType::BROADCAST:
@@ -119,13 +103,11 @@ c10::intrusive_ptr<Backend> ProcessGroup::getBackend(
119103
ProcessGroup::ProcessGroup(
120104
const c10::intrusive_ptr<::c10d::Store>& store,
121105
int rank,
122-
int size,
123-
c10::intrusive_ptr<Options> options)
106+
int size)
124107
: store_(store),
125108
rank_(rank),
126109
size_(size),
127-
options_(std::move(options)),
128-
backendType_(strToBackendType(options_->backend)),
110+
backendType_(BackendType::UNDEFINED),
129111
dist_debug_level_(debug_level()) {
130112
C10_LOG_API_USAGE_ONCE("c10d.process_group");
131113
}

torch/csrc/distributed/c10d/ProcessGroup.hpp

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,6 @@ namespace c10d {
4545
//
4646
class TORCH_API ProcessGroup : public torch::CustomClassHolder {
4747
public:
48-
// ProcessGroup Options is a base struct that defines the basic options
49-
// when constructing a ProcessGroup. Each ProcessGroup subclass should
50-
// extend this struct and define its options if it wants to provide more
51-
// config options (beyond basic ones defined here) to end user.
52-
struct TORCH_API Options : torch::CustomClassHolder {
53-
explicit Options(
54-
std::string backend,
55-
std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout)
56-
: timeout(timeout), backend(std::move(backend)) {}
57-
~Options() override = default;
58-
59-
std::chrono::milliseconds timeout;
60-
61-
// backend name
62-
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
63-
const std::string backend;
64-
};
65-
6648
enum BackendType : uint8_t {
6749
UNDEFINED = 0,
6850
GLOO = 1,
@@ -72,15 +54,31 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
7254
CUSTOM = 5,
7355
};
7456

57+
static std::string backendTypeToString(BackendType type) {
58+
switch (type) {
59+
case BackendType::GLOO:
60+
return "gloo";
61+
case BackendType::NCCL:
62+
return "nccl";
63+
case BackendType::UCC:
64+
return "ucc";
65+
case BackendType::MPI:
66+
return "mpi";
67+
case BackendType::UNDEFINED:
68+
return "undefined";
69+
default:
70+
return "custom";
71+
}
72+
};
73+
7574
// Not used, set for backwards compatibility and only used for TypeDef in
7675
// Ops.cpp
7776
explicit ProcessGroup(int rank, int size);
7877

7978
explicit ProcessGroup(
8079
const c10::intrusive_ptr<::c10d::Store>& store,
8180
int rank,
82-
int size,
83-
c10::intrusive_ptr<Options> options);
81+
int size);
8482
~ProcessGroup() override;
8583

8684
int getRank() const {
@@ -103,7 +101,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
103101
}
104102

105103
virtual const std::string getBackendName() const {
106-
return options_->backend;
104+
return backendTypeToString(backendType_);
107105
};
108106

109107
BackendType getBackendType() const {
@@ -609,10 +607,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
609607
opts.timeout.count());
610608
}
611609

612-
c10::intrusive_ptr<Options> getOptions() {
613-
return options_;
614-
}
615-
616610
bool hasBackends() {
617611
return !deviceTypeToBackendType_.empty();
618612
}
@@ -653,6 +647,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
653647
return backendTypeToBackend_.at(backendType_);
654648
}
655649

650+
void setDefaultBackend(const BackendType& backendType) {
651+
backendType_ = backendType;
652+
}
653+
656654
c10::intrusive_ptr<Backend> getBackend(c10::DeviceType deviceType);
657655

658656
c10::intrusive_ptr<Backend> getBackend(BackendType backendType) const {
@@ -725,9 +723,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
725723
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
726724
const int size_;
727725
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
728-
const c10::intrusive_ptr<Options> options_;
729-
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
730-
const BackendType backendType_;
726+
BackendType backendType_;
731727
std::string pg_desc_;
732728

733729
// Debug level setting. It is parsed once when ProcessGroup is constructed and

torch/csrc/distributed/c10d/init.cpp

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,8 +1814,7 @@ communication mechanism.
18141814
py::init<
18151815
const c10::intrusive_ptr<::c10d::Store>&,
18161816
int,
1817-
int,
1818-
c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(),
1817+
int>(),
18191818
py::call_guard<py::gil_scoped_release>())
18201819
.def("rank", &::c10d::ProcessGroup::getRank)
18211820
.def("size", &::c10d::ProcessGroup::getSize)
@@ -1825,7 +1824,6 @@ communication mechanism.
18251824
"_backend_id",
18261825
&::c10d::ProcessGroup::getBackendID,
18271826
py::arg("backend_type"))
1828-
.def_property_readonly("options", &::c10d::ProcessGroup::getOptions)
18291827
.def(
18301828
"broadcast",
18311829
&::c10d::ProcessGroup::broadcast,
@@ -2135,6 +2133,14 @@ communication mechanism.
21352133
},
21362134
py::arg("device"),
21372135
py::call_guard<py::gil_scoped_release>())
2136+
.def(
2137+
"_set_default_backend",
2138+
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2139+
const ::c10d::ProcessGroup::BackendType& backendType) {
2140+
return self->setDefaultBackend(backendType);
2141+
},
2142+
py::arg("backend_type"),
2143+
py::call_guard<py::gil_scoped_release>())
21382144
.def(
21392145
"_register_on_completion_hook",
21402146
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
@@ -2237,27 +2243,6 @@ The hook must have the following signature:
22372243
.value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM)
22382244
.export_values();
22392245

2240-
// base ProcessGroup::Options binding
2241-
auto processGroupOptions =
2242-
intrusive_ptr_class_<::c10d::ProcessGroup::Options>(
2243-
processGroup,
2244-
"Options",
2245-
R"(
2246-
Base class for all processes group options implementations, such as the nccl
2247-
options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
2248-
)")
2249-
.def(
2250-
py::init([](const std::string& backend,
2251-
const std::chrono::milliseconds& timeout) {
2252-
return c10::make_intrusive<::c10d::ProcessGroup::Options>(
2253-
backend, timeout);
2254-
}),
2255-
py::arg("backend"),
2256-
py::arg("timeout") = kProcessGroupDefaultTimeout,
2257-
py::call_guard<py::gil_scoped_release>())
2258-
.def_readonly("backend", &::c10d::ProcessGroup::Options::backend)
2259-
.def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout);
2260-
22612246
// TODO: The collection definitions handles direct instantiation of
22622247
// ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported
22632248
// and should be removed once all tests are transitioned
@@ -2556,6 +2541,29 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
25562541
&::c10d::Backend::endCoalescing,
25572542
py::call_guard<py::gil_scoped_release>());
25582543

2544+
// base Backend::Options binding
2545+
// TODO: Maybe we can consider how to merge this with
2546+
// `DistributedBackendOptions`.
2547+
auto backendOptions =
2548+
intrusive_ptr_class_<::c10d::Backend::Options>(
2549+
backend,
2550+
"Options",
2551+
R"(
2552+
Base class for all backend options implementations, such as the nccl
2553+
options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
2554+
)")
2555+
.def(
2556+
py::init([](const std::string& backend,
2557+
const std::chrono::milliseconds& timeout) {
2558+
return c10::make_intrusive<::c10d::Backend::Options>(
2559+
backend, timeout);
2560+
}),
2561+
py::arg("backend"),
2562+
py::arg("timeout") = kProcessGroupDefaultTimeout,
2563+
py::call_guard<py::gil_scoped_release>())
2564+
.def_readonly("backend", &::c10d::Backend::Options::backend)
2565+
.def_readwrite("_timeout", &::c10d::Backend::Options::timeout);
2566+
25592567
#ifdef USE_C10D_GLOO
25602568
static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
25612569

@@ -2567,7 +2575,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
25672575
shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device");
25682576

25692577
intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>(
2570-
processGroupGloo, "_Options", processGroupOptions)
2578+
processGroupGloo, "_Options", backendOptions)
25712579
.def(py::init<>())
25722580
.def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
25732581
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
@@ -2794,7 +2802,7 @@ for details.
27942802
intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>(
27952803
processGroupNCCL,
27962804
"Options",
2797-
processGroupOptions,
2805+
backendOptions,
27982806
R"(
27992807
ProcessGroup options for the NCCL backend
28002808

torch/distributed/device_mesh.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _init_device_mesh_stub():
3636

3737

3838
else:
39+
from torch._C._distributed_c10d import Backend as C10dBackend
3940
from torch.distributed.distributed_c10d import (
4041
_find_pg_by_ranks_and_tag,
4142
_get_default_group,
@@ -66,7 +67,7 @@ def __init__(self) -> None:
6667
self.mesh_stack: List[DeviceMesh] = []
6768
self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {}
6869
self.mesh_dim_group_options: Dict[
69-
int, Tuple[str, Optional[ProcessGroup.Options]]
70+
int, Tuple[str, Optional[C10dBackend.Options]]
7071
] = {}
7172
self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {}
7273
# Record flatten mesh name to its mesh dim index in root mesh.
@@ -279,7 +280,7 @@ def _set_mesh_dim_group_options(
279280
self,
280281
dim: int,
281282
backend: str,
282-
pg_options: Optional[ProcessGroup.Options] = None,
283+
pg_options: Optional[C10dBackend.Options] = None,
283284
) -> None:
284285
self.mesh_dim_group_options[dim] = (backend, pg_options)
285286

0 commit comments

Comments
 (0)
0