15
15
16
16
using namespace at ;
17
17
using namespace torch ;
18
- using namespace torch ::cuda::nccl;
19
- using namespace torch ::cuda::nccl::detail;
18
+ using namespace torch ::cuda;
19
+
20
+ namespace pynccl = torch::cuda::nccl;
20
21
21
22
static const char * COMM_CAPSULE_NAME = " torch.cuda.nccl.Communicator" ;
22
23
23
24
PyObject* THCPModule_nccl_version (PyObject* self, PyObject* args) {
24
- return PyLong_FromUnsignedLongLong (version ());
25
+ return PyLong_FromUnsignedLongLong (pynccl:: version ());
25
26
}
26
27
27
28
10000
PyObject* THCPModule_nccl_version_suffix (PyObject* self, PyObject* args) {
28
29
HANDLE_TH_ERRORS
29
- return PyBytes_FromString (version_suffix ());
30
+ return PyBytes_FromString (pynccl:: version_suffix ());
30
31
END_HANDLE_TH_ERRORS
31
32
}
32
33
33
34
PyObject* THCPModule_nccl_unique_id (PyObject* self, PyObject* args) {
34
35
HANDLE_TH_ERRORS
35
- ncclUniqueId id;
36
+ pynccl:: ncclUniqueId id;
36
37
get_unique_id (id);
37
38
return PyBytes_FromStringAndSize ((char *)&id, NCCL_UNIQUE_ID_BYTES);
38
39
END_HANDLE_TH_ERRORS
39
40
}
40
41
41
- static ncclComm_t unpack_nccl_comm (PyObject* capsule) {
42
- ncclComm_t comm =
43
- (ncclComm_t)PyCapsule_GetPointer (capsule, COMM_CAPSULE_NAME);
42
+ static pynccl:: ncclComm_t unpack_nccl_comm (PyObject* capsule) {
43
+ pynccl:: ncclComm_t comm =
44
+ (pynccl:: ncclComm_t)PyCapsule_GetPointer (capsule, COMM_CAPSULE_NAME);
44
45
if (!comm)
45
46
throw python_error ();
46
47
return comm;
47
48
}
48
49
49
50
static void destroy_nccl_comm (PyObject* capsule) {
50
51
HANDLE_TH_ERRORS
51
- ncclComm_t comm = unpack_nccl_comm (capsule);
52
+ pynccl:: ncclComm_t comm = unpack_nccl_comm (capsule);
52
53
{
53
54
pybind11::gil_scoped_release no_gil;
54
- comm_destroy (comm);
55
+ pynccl:: comm_destroy (comm);
55
56
}
56
57
END_HANDLE_TH_ERRORS_RET ()
57
58
}
@@ -73,19 +74,19 @@ static std::vector<std::optional<at::cuda::CUDAStream>> unpack_streams(
73
74
static at::Tensor extract_tensor (PyObject* obj);
74
E864
75
static std::vector<at::Tensor> extract_tensors (PyObject* obj);
75
76
76
- static std::vector<ncclComm_t> unpack_comms (PyObject* obj, size_t size) {
77
+ static std::vector<pynccl:: ncclComm_t> unpack_comms (PyObject* obj, size_t size) {
77
78
if (obj == Py_None) {
78
- return std::vector<ncclComm_t>();
79
+ return std::vector<pynccl:: ncclComm_t>();
79
80
}
80
- std::vector<ncclComm_t> comms;
81
+ std::vector<pynccl:: ncclComm_t> comms;
81
82
if (PyCapsule_CheckExact (obj)) {
82
83
comms = {unpack_nccl_comm (obj)};
83
84
} else {
84
85
auto seq = THPObjectPtr (PySequence_Fast (obj, " comm is not a sequence" ));
85
86
if (!seq)
86
87
throw python_error ();
87
88
auto size = PySequence_Fast_GET_SIZE (seq.get ());
88
- comms = std::vector<ncclComm_t>(size);
89
+ comms = std::vector<pynccl:: ncclComm_t>(size);
89
90
for (const auto i : c10::irange (size)) {
90
91
comms[i] = unpack_nccl_comm (PySequence_Fast_GET_ITEM (seq.get (), i));
91
92
}
@@ -116,12 +117,12 @@ PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
116
117
id_len,
117
118
" )" );
118
119
119
- ncclUniqueId commId;
120
+ pynccl:: ncclUniqueId commId;
120
121
memcpy (&commId, id, NCCL_UNIQUE_ID_BYTES);
121
- ncclComm_t comm = nullptr ;
122
+ pynccl:: ncclComm_t comm = nullptr ;
122
123
{
123
124
pybind11::gil_scoped_release no_gil;
124
- comm = comm_init_rank (nranks, commId, rank);
125
+ comm = pynccl:: comm_init_rank (nranks, commId, rank);
125
126
}
126
127
return PyCapsule_New (comm, COMM_CAPSULE_NAME, &destroy_nccl_comm);
127
128
END_HANDLE_TH_ERRORS
@@ -153,7 +154,7 @@ PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
153
154
154
155
{
155
156
pybind11::gil_scoped_release no_gil;
156
- torch::cuda::nccl ::reduce (inputs, output, root, op, streams, user_comms);
157
+ pynccl ::reduce (inputs, output, root, op, streams, user_comms);
157
158
}
158
159
159
160
Py_RETURN_NONE;
@@ -186,7 +187,7 @@ PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
186
187
187
188
{
188
189
pybind11::gil_scoped_release no_gil;
189
- all_reduce (inputs, outputs, op, streams, user_comms);
190
+ pynccl:: all_reduce (inputs, outputs, op, streams, user_comms);
190
191
}
191
192
192
193
Py_RETURN_NONE;
@@ -217,7 +218,7 @@ PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
217
218
218
219
{
219
220
pybind11::gil_scoped_release no_gil;
220
- torch::cuda::nccl ::broadcast (inputs, streams, user_comms);
221
+ pynccl ::broadcast (inputs, streams, user_comms);
221
222
}
222
223
223
224
Py_RETURN_NONE;
@@ -249,7 +250,7 @@ PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
249
250
250
251
{
251
252
pybind11::gil_scoped_release no_gil;
252
- all_gather (inputs, outputs, streams, user_comms);
253
+ pynccl:: all_gather (inputs, outputs, streams, user_comms);
253
254
}
254
255
255
256
Py_RETURN_NONE;
@@ -282,7 +283,7 @@ PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
282
283
283
284
{
284
285
pybind11::gil_scoped_release no_gil;
285
- reduce_scatter (inputs, outputs, op, streams, user_comms);
286
+ pynccl:: reduce_scatter (inputs, outputs, op, streams, user_comms);
286
287
}
287
288
288
289
Py_RETURN_NONE;
0 commit comments