8000 [c10d] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure by fduwjj · Pull Request #132931 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[c10d] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure #132931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update on "[c10d] Remove Option for ProcessGroup and Expose backend O…
…ptions to reflect the correct code structure"


We introduced the dispatchable backend for a ProcessGroup and collective in #86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG.

Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options"


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin wz337 wconstab d4l3k c-p-i-o ezyang gchanan

[ghstack-poisoned]
  • Loading branch information
fduwjj committed Aug 28, 2024
commit 144800c73f23b31e956b7a36a258693099dfee6e
1 change: 1 addition & 0 deletions test/allowlist_for_publicAPI.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"BroadcastOptions",
"BuiltinCommHookType",
"Callable",
"C10dBackend",
"DebugLevel",
"Dict",
"Enum",
Expand Down
9 changes: 8 additions & 1 deletion test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,7 @@ def test_init_process_group_optional_backend(self):

def test_init_process_group_for_all_backends(self):
for backend in dist.Backend.backend_list:
excepted_backend = backend
# skip if the backend is not available on the system
if backend == dist.Backend.UNDEFINED:
continue
Expand All @@ -1830,6 +1831,11 @@ def test_init_process_group_for_all_backends(self):
elif backend == dist.Backend.UCC:
if not dist.is_ucc_available():
continue
# Multi-threaded PG is defined as a pure python class.
# Its pg.name() does not going through Pybind, so its backend name
# is still "threaded" instead of "custom".
elif backend != "threaded":
excepted_backend = "custom"

with tempfile.NamedTemporaryFile(delete=False) as f:
store = dist.FileStore(f.name, self.world_size)
Expand All @@ -1842,7 +1848,8 @@ def test_init_process_group_for_all_backends(self):
pg = c10d._get_default_group()
self.assertEqual(pg.rank(), self.rank)
self.assertEqual(pg.size(), self.world_size)
self.assertEqual(pg.name(), str(backend))
print(backend, excepted_backend, pg.name())
self.assertEqual(pg.name(), str(excepted_backend))

dist.destroy_process_group()

Expand Down
3 changes: 2 additions & 1 deletion test/distributed/test_device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def test_set_mesh_dim_group_options(self):

mesh_tensor = torch.arange(4).reshape(2, 2)
mesh = DeviceMesh(device_type, mesh_tensor)
self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake")
# Fake pg only have BackendType as BackendType::CUSTOM.
self.assertEqual(mesh.get_group(1)._get_backend_name(), "custom")


class DeviceMeshTestNDim(DTensorTestBase):
Expand Down
1 change: 1 addition & 0 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ class ProcessGroup:
@property
def _device_types(self) -> list[torch.device]: ...
def _get_backend(self, device: torch.device) -> Backend: ...
def _set_default_backend(self, backend_type: BackendType) -> None: ...
def _register_backend(
self,
device: torch.device,
Expand Down
9 changes: 7 additions & 2 deletions torch/csrc/distributed/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
case BackendType::MPI:
return "mpi";
case BackendType::UNDEFINED:
default:
return "undefined";
default:
return "custom";
}
};

Expand Down Expand Up @@ -646,6 +647,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
return backendTypeToBackend_.at(backendType_);
}

void setDefaultBackend(const BackendType& backendType) {
backendType_ = backendType;
}

c10::intrusive_ptr<Backend> getBackend(c10::DeviceType deviceType);

c10::intrusive_ptr<Backend> getBackend(BackendType backendType) const {
Expand Down Expand Up @@ -718,7 +723,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const int size_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const BackendType backendType_;
BackendType backendType_;
std::string pg_desc_;

// Debug level setting. It is parsed once when ProcessGroup is constructed and
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2130,6 +2130,14 @@ communication mechanism.
},
py::arg("device"),
py::call_guard<py::gil_scoped_release>())
.def(
"_set_default_backend",
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
const ::c10d::ProcessGroup::BackendType& backendType) {
return self->setDefaultBackend(backendType);
},
py::arg("backend_type"),
py::call_guard<py::gil_scoped_release>())
.def(
"_register_on_completion_hook",
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
Expand Down
8 changes: 6 additions & 2 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def _export_c_types() -> None:
AllToAllOptions,
BarrierOptions,
BroadcastOptions,
C10dBackend,
GatherOptions,
PrefixStore,
ProcessGroup,
Expand Down Expand Up @@ -269,6 +268,7 @@ class Backend(str):
GLOO: ProcessGroup.BackendType.GLOO,
NCCL: ProcessGroup.BackendType.NCCL,
UCC: ProcessGroup.BackendType.UCC,
MPI: ProcessGroup.BackendType.MPI,
}

def __new__(cls, name: str):
Expand Down Expand Up @@ -1714,6 +1714,8 @@ def _new_process_group_helper(
group_rank,
group_size,
)
assert backend in Backend.backend_type_map, f"Unknown backend type {backend}"
pg._set_default_backend(Backend.backend_type_map[backend])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if backend isn't provided (e.g. backend=None), then the backend type will be undefined right? Will this break the APIs that depend on the looking at the backend type, like getSequenceNumber or the hooks that we talked about.

I'm not sure if we have tests that cover these cases, but could be worthwhile to check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend is not optional for this function. Do we see the use case when backend=None? This logic will break when the backend is something like "cpu:gloo,gpu:nccl".

if device_id:
pg.bound_device_id = device_id
backend_config = BackendConfig(backend)
Expand Down Expand Up @@ -4451,7 +4453,9 @@ def split_group(
group_rank,
len(my_group),
)
backend_type = ProcessGroup.BackendType.NCCL
pg.bound_device_id = device_id
pg._set_default_backend(backend_type)

pg_options._timeout = timeout
pg_options.split_from = parent_backend
Expand All @@ -4461,7 +4465,6 @@ def split_group(
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, len(my_group), pg_options
)
backend_type = ProcessGroup.BackendType.NCCL
backend_class._set_sequence_number_for_group()

pg._register_backend(torch.device("cuda"), backend_type, backend_class)
Expand Down Expand Up @@ -4608,6 +4611,7 @@ def _new_group_with_tag(
if not backend:
backend = default_backend
backend = Backend(backend)
print(backend)

# this timeout defaulting/validation is used for all the new_groups/new_subgroups variants,
# which may just pass their timeout value (or None)
Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.
0