8000 update _unsafe_set_version_counter to accept lists of tensors (#137921) · pytorch/pytorch@e68f508 · GitHub
[go: up one dir, main page]

Skip to content

Commit e68f508

Browse files
bdhirshpytorchmergebot
authored andcommitted
update _unsafe_set_version_counter to accept lists of tensors (#137921)
See the comment [here](#132014 (comment)) (cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @XilunWu @rec) - this PR updates `_unsafe_set_version_counter` to accept a list of tensors, for overhead-sensitive users (e.g. distributed) who need to hide VC bumps from autograd on a large list of tensors without wanting to suffer the overhead of going from python->C++ separately for every tensor in the list. I left the binding in pybind, and used a `std::vector`. if we **really** need to optimize overhead even further, we could write a manual cpython binding. I use this updated API in the next PR to fix FSDP2, so that it properly hides the VC of all `all_gather_buffer` tensors in its call to `split_with_sizes_copy.out(all_gather_buffers)`. Pull Request resolved: #137921 Approved by: https://github.com/awgu, https://github.com/albanD
1 parent 425aca4 commit e68f508

File tree

9 files changed

+94
-49
lines changed

9 files changed

+94
-49
lines changed

test/inductor/test_distributed_patterns.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def fn(w, x):
242242
x = x.sin()
243243
v = w._version
244244
w.copy_(x + 1)
245-
torch._C._autograd._unsafe_set_version_counter(w, v)
245+
torch._C._autograd._unsafe_set_version_counter((w,), (v,))
246246
return w, v
247247

248248
for v in (3, 0, 1):
@@ -266,7 +266,7 @@ def fn(w, x):
266266
with torch.no_grad():
267267
v = w._version
268268
w.copy_(x)
269-
torch._C._autograd._unsafe_set_version_counter(w, v)
269+
torch._C._autograd._unsafe_set_version_counter((w,), (v,))
270270
return r
271271

272272
w1 = torch.randn(1, requires_grad=True)

test/test_autograd.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -4799,10 +4799,18 @@ def test_unsafe_set_version_counter(self):
47994799
# version counter doesn't change inside of the context manager
48004800
self.assertEqual(2, x._version)
48014801

4802-
torch._C._autograd._unsafe_set_version_counter(x, 0)
4802+
torch._C._autograd._unsafe_set_version_counter((x,), (0,))
48034803
self.assertEqual(0, x._version)
48044804
with self.assertRaisesRegex(RuntimeError, "Cannot set"):
4805-
torch._C._autograd._unsafe_set_version_counter(x, -1)
4805+
torch._C._autograd._unsafe_set_version_counter((x,), (-1,))
4806+
4807+
y = torch.ones(2, requires_grad=True).clone()
4808+
with torch.autograd._unsafe_preserve_version_counter((x, y)):
4809+
x.mul_(2)
4810+
y.mul_(3)
4811+
# version counter doesn't change inside of the context manager
4812+
self.assertEqual(0, x._version)
4813+
self.assertEqual(0, y._version)
48064814

48074815
def test_current_node(self):
48084816
pr = []

torch/_C/_autograd.pyi

+3-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ def _push_saved_tensors_default_hooks(
115115
unpack_hook: Callable[[Any], torch.Tensor],
116116
) -> None: ...
117117
def _pop_saved_tensors_default_hooks() -> None: ...
118-
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
118+
def _unsafe_set_version_counter(
119+
t: tuple[torch.Tensor, ...], prev_version: tuple[int, ...]
120+
) -> None: ...
119121
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
120122
def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
121123
def _profiler_type() -> ActiveProfilerType: ...

torch/_dynamo/tensor_version_op.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _tensor_version_fake(fake_mode, self_tensor):
2525

2626

2727
_unsafe_set_version_counter = _make_prim(
28-
schema="_unsafe_set_version_counter(Tensor self, SymInt version) -> ()",
28+
schema="_unsafe_set_version_counter(Tensor[] tensors, SymInt[] versions) -> ()",
2929
return_type=RETURN_TYPE.NEW,
3030
meta=lambda self, version: None,
3131
impl_aten=torch._C._autograd._unsafe_set_version_counter,
@@ -55,5 +55,5 @@ def _tensor_version_functional(mode, self):
5555

5656

5757
@_unsafe_set_version_counter.py_impl(FunctionalTensorMode)
58-
def _unsafe_set_version_counter_functional(ctx, self, version):
59-
torch._C._autograd._unsafe_set_version_counter(self, version)
58+
def _unsafe_set_version_counter_functional(ctx, tensors, versions):
59+
torch._C._autograd._unsafe_set_version_counter(tensors, versions)

torch/_dynamo/variables/builtin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1872,7 +1872,7 @@ def _lower_version_count_by_1(x):
18721872
version = x._version
18731873
if version > 0:
18741874
version = version - 1
1875-
torch._C._autograd._unsafe_set_version_counter(x, version)
1875+
torch._C._autograd._unsafe_set_version_counter((x,), (version,))
18761876
return x
18771877

18781878
tx.output.create_proxy(

torch/_dynamo/variables/ctx_manager.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -975,20 +975,39 @@ class PreserveVersionContextVariable(ContextWrappingVariable):
975975
Wraps torch.autograd._unsafe_preserve_version_counter
976976
"""
977977

978+
@staticmethod
979+
def _create_lambda_from_tensors(tx, tensors):
980+
if isinstance(tensors, variables.TensorVariable):
981+
versions = variables.TupleVariable(
982+
[x.var_getattr(tx, "_version") for x in [tensors]]
983+
)
984+
tensors = variables.TupleVariable([tensors])
985+
else:
986+
versions = variables.TupleVariable(
987+
[x.var_getattr(tx, "_version") for x in tensors.items]
988+
)
989+
return PreserveVersionContextVariable(tensors, versions)
990+
978991
@staticmethod
979992
def constructor(tx):
980993
return variables.LambdaVariable(
981-
lambda tensor: PreserveVersionContextVariable(
982-
tensor,
983-
tensor.var_getattr(tx, "_version"),
994+
lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors(
995+
tx, tensors
984996
)
985997
)
986998

987-
def __init__(self, tensor, prev_version, **kwargs) -> None:
999+
def __init__(self, tensors, prev_versions, **kwargs) -> None:
9881000
kwargs.setdefault("target_values", None)
9891001
super().__init__(**kwargs)
990-
self.tensor = tensor
991-
self.prev_version = prev_version
1002+
self.tensors = tensors
1003+
self.prev_versions = prev_versions
1004+
# The context manager accepts Union[Tensor, Tuple[Tensor]]
1005+
if isinstance(self.tensors, variables.TensorVariable):
1006+
self.tensors = variables.TupleVariable([self.tensors])
1007+
if isinstance(
1008+
self.prev_versions, (variables.ConstantVariable, variables.SymNodeVariable)
1009+
):
1010+
self.prev_versions = variables.TupleVariable([self.prev_versions])
9921011

9931012
def enter(self, tx):
9941013
pass
@@ -998,7 +1017,7 @@ def exit(self, tx: "InstructionTranslator", *args):
9981017

9991018
return variables.TorchInGraphFunctionVariable(
10001019
_unsafe_set_version_counter
1001-
).call_function(tx, [self.tensor, self.prev_version], {})
1020+
).call_function(tx, [self.tensors, self.prev_versions], {})
10021021

10031022
def reconstruct(self, codegen):
10041023
unimplemented(

torch/autograd/grad_mode.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# mypy: allow-untyped-defs
2-
from typing import Any
2+
from typing import Any, Tuple, Union
33

44
import torch
55
from torch.utils._contextlib import (
@@ -386,12 +386,13 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
386386
387387
"""
388388

389-
def __init__(self, tensor: torch.Tensor) -> None:
390-
self.tensor = tensor
391-
self.prev_version = tensor._version
389+
def __init__(self, tensors: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) -> None:
390+
self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors
391+
assert isinstance(self.tensors, tuple)
392+
self.prev_versions = tuple(t._version for t in self.tensors)
392393

393394
def __enter__(self) -> None:
394395
pass
395396

396397
def __exit__(self, *args) -> None:
397-
torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)
398+
torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions)

torch/csrc/autograd/init.cpp

+17-4
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,23 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
388388
return activities;
389389
});
390390

391-
m.def("_unsafe_set_version_counter", [](const at::Tensor& t, int64_t i) {
392-
auto vc = torch::autograd::impl::version_counter(t);
393-
vc.set_version(i);
394-
});
391+
m.def(
392+
"_unsafe_set_version_counter",
393+
[](const std::vector<at::Tensor>& tensors,
394+
const std::vector<int64_t>& versions) {
395+
auto tensors_len = tensors.size();
396+
auto versions_len = versions.size();
397+
TORCH_CHECK(
398+
tensors_len == versions_len,
399+
"tensors_len=",
400+
tensors_len,
401+
", versions_len=",
402+
versions_len);
403+
for (const auto i : c10::irange(tensors_len)) {
404+
auto vc = torch::autograd::impl::version_counter(tensors[i]);
405+
vc.set_version(versions[i]);
406+
}
407+
});
395408

396409
m.def("_enable_profiler_legacy", enableProfilerLegacy);
397410
py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")

torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py

+26-24
Original file line numberDiff line numberDiff line change
@@ -290,35 +290,37 @@ def foreach_all_gather_copy_out(
290290
out = [t.view(world_size, -1).view(torch.uint8) for t in split_with_sizes_out]
291291
else:
292292
out = [t.view(world_size, -1) for t in split_with_sizes_out]
293-
torch.ops.fsdp.split_with_sizes_copy(
294-
all_gather_output, all_gather_input_split_sizes, dim=1, out=out
295-
)
293+
with torch.autograd._unsafe_preserve_version_counter(tuple(out)):
294+
torch.ops.fsdp.split_with_sizes_copy(
295+
all_gather_output, all_gather_input_split_sizes, dim=1, out=out
296+
)
296297

297298
for fsdp_param, param_all_gather_outputs in shard_i_copy_infos:
298299
# Chunk-cat from the temporary to the final all-gather output tensors
299300
shard_dim = fsdp_param.fsdp_placement.dim
300-
for param_all_gather_output, target_all_gather_output in zip(
301-
param_all_gather_outputs, fsdp_param.all_gather_outputs
301+
302+
with torch.autograd._unsafe_preserve_version_counter(
303+
tuple(fsdp_param.all_gather_outputs)
302304
):
303-
padded_sharded_size = (
304-
fsdp_param.padded_sharded_param_size
305-
if fsdp_param.sharded_state == ShardedState.SHARDED
306-
else cast(
307-
torch.Tensor, fsdp_param._sharded_post_forward_param_data
308-
).size()
309-
)
310-
pre_param_size = list(padded_sharded_size)
311-
pre_param_size[0] *= world_size
312-
chunks = torch.chunk(
313-
param_all_gather_output.view(pre_param_size), world_size, dim=0
314-
)
315-
post_param_size = list(padded_sharded_size)
316-
post_param_size[shard_dim] *= world_size
317-
cat_out = target_all_gather_output.view(post_param_size)
318-
torch.cat(chunks, dim=shard_dim, out=cat_out)
319-
torch._C._autograd._unsafe_set_version_counter(
320-
target_all_gather_output, target_all_gather_output._version - 1
321-
)
305+
for param_all_gather_output, target_all_gather_output in zip(
306+
param_all_gather_outputs, fsdp_param.all_gather_outputs
307+
):
308+
padded_sharded_size = (
309+
fsdp_param.padded_sharded_param_size
310+
if fsdp_param.sharded_state == ShardedState.SHARDED
311+
else cast(
312+
torch.Tensor, fsdp_param._sharded_post_forward_param_data
313+
).size()
314+
)
315+
pre_param_size = list(padded_sharded_size)
316+
pre_param_size[0] *= world_size
317+
chunks = torch.chunk(
318+
param_all_gather_output.view(pre_param_size), world_size, dim=0
319+
)
320+
post_param_size = list(padded_sharded_size)
321+
post_param_size[shard_dim] *= world_size
322+
cat_out = target_all_gather_output.view(post_param_size)
323+
torch.cat(chunks, dim=shard_dim, out=cat_out)
322324

323325

324326
@torch.no_grad()

0 commit comments

Comments
 (0)
0