10000 [c10d] Expose backend Options to reflect the correct code structure · pytorch/pytorch@3e447cf · GitHub
[go: up one dir, main page]

Skip to content

Commit 3e447cf

Browse files
committed
[c10d] Expose backend Options to reflect the correct code structure
ghstack-source-id: 5e9e7a6 Pull Request resolved: #132931
1 parent 8705313 commit 3e447cf

File tree

2 files changed

+58
-41
lines changed

2 files changed

+58
-41
lines changed

torch/csrc/distributed/c10d/init.cpp

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2221,26 +2221,20 @@ The hook must have the following signature:
22212221
.value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM)
22222222
.export_values();
22232223

2224-
// base ProcessGroup::Options binding
2225-
auto processGroupOptions =
2226-
intrusive_ptr_class_<::c10d::ProcessGroup::Options>(
2227-
processGroup,
2228-
"Options",
2229-
R"(
2230-
Base class for all processes group options implementations, such as the nccl
2231-
options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
2232-
)")
2233-
.def(
2234-
py::init([](const std::string& backend,
2235-
const std::chrono::milliseconds& timeout) {
2236-
return c10::make_intrusive<::c10d::ProcessGroup::Options>(
2237-
backend, timeout);
2238-
}),
2239-
py::arg("backend"),
2240-
py::arg("timeout") = kProcessGroupDefaultTimeout,
2241-
py::call_guard<py::gil_scoped_release>())
2242-
.def_readonly("backend", &::c10d::ProcessGroup::Options::backend)
2243-
.def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout);
2224+
// ProcessGroup::Options binding
2225+
intrusive_ptr_class_<::c10d::ProcessGroup::Options>(
2226+
processGroup, "Options", R"(Class for processes group options.)")
2227+
.def(
2228+
py::init([](const std::string& backend,
2229+
const std::chrono::milliseconds& timeout) {
2230+
return c10::make_intrusive<::c10d::ProcessGroup::Options>(
2231+
backend, timeout);
2232+
}),
2233+
py::arg("backend"),
2234+
py::arg("timeout") = kProcessGroupDefaultTimeout,
2235+
py::call_guard<py::gil_scoped_release>())
2236+
.def_readonly("backend", &::c10d::ProcessGroup::Options::backend)
2237+
.def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout);
22442238

22452239
#ifndef _WIN32
22462240
module.def(
@@ -2556,6 +2550,27 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
25562550
&::c10d::Backend::endCoalescing,
25572551
py::call_guard<py::gil_scoped_release>());
25582552

2553+
// base Backend::Options binding
2554+
auto backendOptions =
2555+
intrusive_ptr_class_<::c10d::Backend::Options>(
2556+
backend,
2557+
"Options",
2558+
R"(
2559+
Base class for all backend options implementations, such as the nccl
2560+
options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
2561+
)")
2562+
.def(
2563+
py::init([](const std::string& backend,
2564+
const std::chrono::milliseconds& timeout) {
2565+
return c10::make_intrusive<::c10d::Backend::Options>(
2566+
backend, timeout);
2567+
}),
2568+
py::arg("backend"),
2569+
py::arg("timeout") = kProcessGroupDefaultTimeout,
2570+
py::call_guard<py::gil_scoped_release>())
2571+
.def_readonly("backend", &::c10d::Backend::Options::backend)
2572+
.def_readwrite("_timeout", &::c10d::Backend::Options::timeout);
2573+
25592574
#ifdef USE_C10D_GLOO
25602575
static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
25612576

@@ -2567,7 +2582,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
25672582
shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device");
25682583

25692584
intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>(
2570-
processGroupGloo, "_Options", processGroupOptions)
2585+
processGroupGloo, "_Options", backendOptions)
25712586
.def(py::init<>())
25722587
.def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
25732588
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
@@ -2794,7 +2809,7 @@ for details.
27942809
intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>(
27952810
processGroupNCCL,
27962811
"Options",
2797-
processGroupOptions,
2812+
backendOptions,
27982813
R"(
27992814
ProcessGroup options for the NCCL backend
28002815

torch/distributed/distributed_c10d.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,7 +1533,7 @@ def init_process_group(
15331533
backend,
15341534
store,
15351535
group_name,
1536-
pg_options=pg_options,
1536+
backend_options=pg_options,
15371537
timeout=timeout,
15381538
device_id=device_id,
15391539
group_desc="default_pg",
@@ -1630,7 +1630,7 @@ def _new_process_group_helper(
16301630
backend,
16311631
store,
16321632
group_name,
1633-
pg_options=None,
1633+
backend_options=None,
16341634
timeout=None,
16351635
pg_tag=None,
16361636
device_id=None,
@@ -1750,28 +1750,30 @@ def _new_process_group_helper(
17501750
elif backend_str == Backend.NCCL:
17511751
if not is_nccl_available():
17521752
raise RuntimeError("Distributed package doesn't have NCCL built in")
1753-
if pg_options is not None:
1753+
if backend_options is not None:
17541754
assert isinstance(
1755-
pg_options, ProcessGroupNCCL.Options
1756-
), "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
1757-
if pg_options._timeout != timeout:
1755+
backend_options, ProcessGroupNCCL.Options
1756+
), "Expected backend_options argument to be of type ProcessGroupNCCL.Options"
1757+
if backend_options._timeout != timeout:
17581758
warnings.warn(
1759-
"pg_options._timeout was specified, "
1759+
"backend_options._timeout was specified, "
17601760
"but timeout kwarg has a default value that will always override it. "
17611761
)
17621762
else:
1763-
# default pg_options for NCCL
1764-
pg_options = ProcessGroupNCCL.Options()
1765-
pg_options.is_high_priority_stream = False
1766-
pg_options._timeout = timeout
1763+
# default backend_options for NCCL
1764+
backend_options = ProcessGroupNCCL.Options()
1765+
backend_options.is_high_priority_stream = False
1766+
backend_options._timeout = timeout
17671767

17681768
if split_from:
1769-
pg_options.split_from = split_from
1770-
pg_options.split_color = _process_group_color(global_ranks_in_group)
1771-
pg_options.global_ranks_in_group = global_ranks_in_group
1772-
pg_options.group_name = group_name
1769+
backend_options.split_from = split_from
1770+
backend_options.split_color = _process_group_color(
1771+
global_ranks_in_group
1772+
)
1773+
backend_options.global_ranks_in_group = global_ranks_in_group
1774+
backend_options.group_name = group_name
17731775
backend_class = ProcessGroupNCCL(
1774-
backend_prefix_store, group_rank, group_size, pg_options
1776+
backend_prefix_store, group_rank, group_size, backend_options
17751777
)
17761778
backend_type = ProcessGroup.BackendType.NCCL
17771779
elif backend_str == Backend.UCC and is_ucc_available():
@@ -1806,7 +1808,7 @@ def _new_process_group_helper(
18061808
dist_backend_opts.group_id = group_name
18071809
dist_backend_opts.global_ranks_in_group = global_ranks_in_group
18081810

1809-
backend_class = creator_fn(dist_backend_opts, pg_options)
1811+
backend_class = creator_fn(dist_backend_opts, backend_options)
18101812

18111813
# Set sequence numbers for gloo and nccl backends.
18121814
if backend_str == Backend.GLOO:
@@ -4570,7 +4572,7 @@ def _new_group_with_tag(
45704572
ranks=None,
45714573
timeout=None,
45724574
backend=None,
4573-
pg_options=None,
4575+
backend_options=None,
45744576
pg_tag=None,
45754577
use_local_synchronization=False,
45764578
group_desc=None,
@@ -4645,7 +4647,7 @@ def _new_group_with_tag(
46454647
backend,
46464648
default_store,
46474649
group_name,
4648-
pg_options=pg_options,
4650+
backend_options=backend_options,
46494651
timeout=timeout,
46504652
pg_tag=pg_tag,
46514653
device_id=device_id,

0 commit comments

Comments
 (0)
0