8000 Update · pytorch/pytorch@0a437b7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0a437b7

Browse files
committed
Update
[ghstack-poisoned]
2 parents 71e0a33 + f342255 commit 0a437b7

File tree

6 files changed

+135
-18
lines changed

6 files changed

+135
-18
lines changed

.ci/manywheel/build_common.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,8 @@ for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.w
321321
# ROCm workaround for roctracer dlopens
322322
if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then
323323
patchedpath=$(fname_without_so_number $destpath)
324-
# Keep the so number for XPU dependencies
325-
elif [[ "$DESIRED_CUDA" == *"xpu"* ]]; then
324+
# Keep the so number for XPU dependencies and libgomp.so.1 to avoid twice load
325+
elif [[ "$DESIRED_CUDA" == *"xpu"* || "$filename" == "libgomp.so.1" ]]; then
326326
patchedpath=$destpath
327327
else
328328
patchedpath=$(fname_with_sha256 $destpath)

.ci/pytorch/smoke_test/check_gomp.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import ctypes
2+
import os
3+
import sys
4+
from pathlib import Path
5+
6+
7+
def get_gomp_thread():
8+
"""
9+
Retrieves the maximum number of OpenMP threads after loading the `libgomp.so.1` library
10+
and the `libtorch_cpu.so` library. It then queries the
11+
maximum number of threads available for OpenMP parallel regions using the
12+
`omp_get_max_threads` function.
13+
14+
Returns:
15+
int: The maximum number of OpenMP threads available.
16+
17+
Notes:
18+
- The function assumes the default path for `libgomp.so.1` on AlmaLinux OS.
19+
- The path to `libtorch_cpu.so` is constructed based on the Python executable's
20+
installation directory.
21+
- This function is specific to environments where PyTorch and OpenMP are used
22+
together and may require adjustments for other setups.
23+
"""
24+
python_path = Path(sys.executable).resolve()
25+
python_prefix = (
26+
python_path.parent.parent
27+
) # Typically goes to the Python installation root
28+
29+
# Get the additional ABI flags (if any); it may be an empty string.
30+
abiflags = getattr(sys, "abiflags", "")
31+
32+
# Construct the Python directory name correctly (e.g., "python3.13t").
33+
python_version = (
34+
f"python{sys.version_info.major}.{sys.version_info.minor}{abiflags}"
35+
)
36+
37+
libtorch_cpu_path = (
38+
python_prefix
39+
/ "lib"
40+
/ python_version
41+
/ "site-packages"
42+
/ "torch"
43+
/ "lib"
44+
/ "libtorch_cpu.so"
45+
)
46+
47+
# use the default gomp path of AlmaLinux OS
48+
libgomp_path = "/usr/lib64/libgomp.so.1"
49+
50+
os.environ["GOMP_CPU_AFFINITY"] = "0-3"
51+
52+
libgomp = ctypes.CDLL(libgomp_path)
53+
libgomp = ctypes.CDLL(libtorch_cpu_path)
54+
55+
libgomp.omp_get_max_threads.restype = ctypes.c_int
56+
libgomp.omp_get_max_threads.argtypes = []
57+
58+
omp_max_threads = libgomp.omp_get_max_threads()
59+
return omp_max_threads
60+
61+
62+
def main():
63+
omp_max_threads = get_gomp_thread()
64+
print(
65+
f"omp_max_threads after loading libgomp.so and libtorch_cpu.so: {omp_max_threads}"
66+
)
67+
if omp_max_threads == 1:
68+
raise RuntimeError(
69+
"omp_max_threads is 1. Check whether libgomp.so is loaded twice."
70+
)
71+
72+
73+
if __name__ == "__main__":
74+
main()

.circleci/scripts/binary_linux_test.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ if [[ "\$GPU_ARCH_TYPE" != *s390x* && "\$GPU_ARCH_TYPE" != *xpu* && "\$GPU_ARCH_
101101
else
102102
python /pytorch/.ci/pytorch/smoke_test/smoke_test.py --package=torchonly --torch-compile-check disabled $extra_parameters
103103
fi
104+
105+
if [[ "\$GPU_ARCH_TYPE" != *cpu-aarch64* ]]; then
106+
# https://github.com/pytorch/pytorch/issues/149422
107+
python /pytorch/.ci/pytorch/smoke_test/check_gomp.py
108+
fi
104109
fi
105110
106111
# Clean temp files

torch/csrc/distributed/c10d/nvshmem_extension.cu

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,22 +202,30 @@ __global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_spli
202202
auto source_offsets = in_out_splits + npes * 2;
203203
int bid = blockIdx.x;
204204
int tid = threadIdx.x;
205+
int blocks_per_peer = max(gridDim.x / npes, 1);
205206

206207
// Calculate the output offsets
207208
__shared__ int64_t peer_offsets[THREADS_PER_BLOCK];
208209
prefixSum(peer_offsets, output_splits, npes);
209210
__syncthreads();
210211

211-
// Each block targets a different peer
212-
for (int i = bid; i < npes; i += gridDim.x) {
212+
// Target a different peer based on bid
213+
for (int i = bid / blocks_per_peer; i < npes; i += gridDim.x / blocks_per_peer) {
213214
int peer = (mype + i) % npes;
214-
auto size = output_splits[peer] * stride;
215-
auto source_offset = source_offsets[peer] * stride;
216-
auto write_offset = peer_offsets[peer] * stride;
215+
// Total amount from `peer`
216+
auto peer_size = output_splits[peer] * stride;
217+
// Amount to get from `peer` in this block
218+
auto block_size = peer_size / blocks_per_peer;
219+
// Being lazy here, we should handle the residual if the division is not exact
220+
CUDA_KERNEL_ASSERT(block_size * blocks_per_peer == peer_size);
221+
// This block's offset in the data from `peer`
222+
auto block_offset = block_size * (bid % blocks_per_peer);
223+
auto source_offset = source_offsets[peer] * stride + block_offset;
224+
auto write_offset = peer_offsets[peer] * stride + block_offset;
217225
nvshmemx_getmem_block(
218226
(char*)recv_data + write_offset,
219227
(char*)send_data + source_offset,
220-
size,
228+
block_size,
221229
peer);
222230
}
223231
// Write out the output offsets (to the scratchpad line)
@@ -266,11 +274,26 @@ at::Tensor nvshmem_all_to_all_vdev(
266274
0,
267275
stream);
268276

269-
// All to all data exchange
270-
// Limit the number of blocks to 16
271-
int num_blocks = std::min(world_size, 16);
277+
// CTA Tuning
278+
// Intra-node: use multiple blocks per peer to increase data parallelism, up to 8.
279+
// Up to 1 MB -> 1 block
280+
// Up to 2 MB -> 2 blocks
281+
// Up to 4 MB -> 4 blocks
282+
// More -> 8 blocks
283+
auto input_size = input.numel() * input.element_size();
284+
const int max_blocks_per_peer = input_size < 1024 * 1024 ? 1 :
285+
(input_size < 2 * 1024 * 1024 ? 2 :
286+
(input_size < 4 * 1024 * 1024 ? 4 : 8));
287+
288+
// Inter-node: limit the total the number of blocks to 8 which is able to
289+
// drive 57 GB/s bandwidth in test, enough to drive a 400 Gb/s NIC.
290+
// TODO: better intra vs inter detection, currently it is based on world_size
291+
int num_blocks = world_size > 8 ? 8 : max_blocks_per_peer * world_size;
292+
272293
// Stride at dim 0 (assuming input is contiguous, TODO)
273294
size_t stride_bytes = input.stride(0) * input.element_size();
295+
296+
// All to all data exchange
274297
void* args1[] = {
275298
&input_ptr,
276299
&output_ptr,

torch/utils/viz/MemoryViz.js

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,6 +1228,7 @@ function create_trace_view(
12281228
dst.selectAll('svg').remove();
12291229
dst.selectAll('div').remove();
12301230

1231+
max_entries = Math.min(max_entries, data.elements_length);
12311232
const d = dst.append('div');
12321233
d.append('input')
12331234
.attr('type', 'range')
@@ -1237,7 +1238,9 @@ function create_trace_view(
12371238
.on('change', function () {
12381239
create_trace_view(dst, snapshot, device, plot_segments, this.value);
12391240
});
1240-
d.append('label').text('Detail');
1241+
d.append('label').text(
1242+
`Detail: ${max_entries} of ${data.elements_length} entries`,
1243+
);
12411244

12421245
const grid_container = dst
12431246
.append('div')

torchgen/utils.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from dataclasses import fields, is_dataclass
1111
from enum import auto, Enum
1212
from pathlib import Path
13-
from typing import Any, Callable, Generic, Literal, TYPE_CHECKING, TypeVar
14-
from typing_extensions import assert_never, Self
13+
from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
14+
from typing_extensions import assert_never, deprecated, Self
1515

1616
from torchgen.code_template import CodeTemplate
1717

@@ -97,6 +97,15 @@ def context(msg_fn: Callable[[], str]) -> Iterator[None]:
9797
raise
9898

9999

100+
if TYPE_CHECKING:
101+
# A little trick from https://github.com/python/mypy/issues/6366
102+
# for getting mypy to do exhaustiveness checking
103+
# TODO: put this somewhere else, maybe
104+
@deprecated("Use typing_extensions.assert_never instead")
105+
def assert_never(x: NoReturn) -> NoReturn: # type: ignore[misc] # noqa: F811
106+
raise AssertionError(f"Unhandled type: {type(x).__name__}")
107+
108+
100109
@functools.cache
101110
def _read_template(template_fn: str) -> CodeTemplate:
102111
return CodeTemplate.from_file(template_fn)
@@ -173,14 +182,17 @@ def substitute_with_template(
173182
}
174183
template = _read_template(template_path)
175184
substitute_out = template.substitute(env)
176-
# Ensure an extra blank line before the class/function definition
177-
# if it is followed by a docstring
185+
# Ensure an extra blank line between the class/function definition
186+
# and the docstring of the previous class/function definition.
187+
# NB: It is generally not recommended to have docstrings in pyi stub
188+
# files. But if there are any, we need to ensure that the file
189+
# is properly formatted.
178190
return re.sub(
179191
r'''
180192
(""")\n+ # match triple quotes
181193
(
182-
([ ]*@.+\n)* # match decorators if any
183-
[ ]*(class|def) # match class/function definition
194+
(\s*@.+\n)* # match decorators if any
195+
\s*(class|def) # match class/function definition
184196
)
185197
''',
186198
r"\g<1>\n\n\g<2>",

0 commit comments

Comments
 (0)
0