-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Open
Labels
enhancementNot as big of a feature, but technically not a bug. Should be easy to fixNot as big of a feature, but technically not a bug. Should be easy to fixmodule: bootcampWe plan to do a full writeup on the issue, and then get someone to do it for onboardingWe plan to do a full writeup on the issue, and then get someone to do it for onboardingtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🚀 The feature, motivation and pitch
Currently torch.split
can accept a tuple of ints, implementing split_with_sizes
(I propose torch.split_with_sizes
to be then made private / deprecated instead of documenting: #58181)
Sometimes it's convenient to pass in sizes as a tensor instead of tuple. Tensor could also be a CUDA tensor then without forcing GPU->CPU copy. This can be useful when synchronizing a "nested" tensor list:
>>> torch.rand(10, device = 'cuda').split(torch.tensor([5, 5], device = 'cuda'))
Traceback (most recent call last):
File ".../vadim/prefix/miniconda/lib/python3.9/site-packages/torch/_tensor.py", line 510, in split
split_size = int(split_size)
ValueError: only one element tensors can be converted to Python scalars
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../vadim/prefix/miniconda/lib/python3.9/site-packages/torch/_tensor.py", line 513, in split
return super(Tensor, self).split_with_sizes(split_size, dim)
TypeError: split_with_sizes(): argument 'split_sizes' (position 1) must be tuple of ints, not Tensor
>>> torch.rand(10, device = 'cuda').split(torch.tensor([5, 5], device = 'cuda').tolist())
(tensor([0.0725, 0.3508, 0.9368, 0.2485, 0.2627], device='cuda:0'), tensor([0.3593, 0.4366, 0.6633, 0.6635, 0.5102], device='cuda:0'))
def all_gather_tensors(tensor):
if get_world_size() == 1:
return [tensor]
shape_tensor = torch.tensor(tensor.shape, dtype = torch.int64, device = tensor.device)
shapes = list(torch.zeros([world_size, len(tensor.shape)], dtype = torch.int64, device = tensor.device).unbind())
torch.distributed.all_gather(shapes, shape_tensor)
max_shape = torch.stack(shapes).amax(dim=0).tolist()
pad = sum([[0, max_dim - dim] for max_dim, dim in reversed(list(zip(max_shape, tensor.shape)))], [])
padded_tensor = F.pad(tensor, pad = pad)
tensors = list(torch.zeros([world_size, *padded_tensor.shape], dtype=padded_tensor.dtype, device=padded_tensor.device).unbind())
torch.distributed.all_gather(tensors, padded_tensor)
for i, shape in enumerate(shapes):
tensors[i] = tensors[i][[slice(x) for x in shape.tolist()]]
return tensors
def all_gather_tensors_cat(tensors):
return torch.cat(all_gather_tensors(torch.cat(tensors)))
def all_gather_tensors_nested(tensors):
counts = torch.tensor(list(map(len, tensors)), dtype = torch.int64)
counts = torch.cat(all_gather_tensors(counts))
tensor = all_gather_tensors_cat(tensors)
return tensor.split(counts.tolist())
Alternatives
No response
Additional context
No response
Metadata
Metadata
Assignees
Labels
enhancementNot as big of a feature, but technically not a bug. Should be easy to fixNot as big of a feature, but technically not a bug. Should be easy to fixmodule: bootcampWe plan to do a full writeup on the issue, and then get someone to do it for onboardingWe plan to do a full writeup on the issue, and then get someone to do it for onboardingtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module