8000 [feature request] Support tensor count vector argument in torch.split · Issue #73175 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[feature request] Support tensor count vector argument in torch.split  #73175
@vadimkantorov

Description

@vadimkantorov

🚀 The feature, motivation and pitch

Currently torch.split can accept a tuple of ints, implementing split_with_sizes (I propose torch.split_with_sizes to be then made private / deprecated instead of documenting: #58181)

Sometimes it's convenient to pass in sizes as a tensor instead of tuple. Tensor could also be a CUDA tensor then without forcing GPU->CPU copy. This can be useful when synchronizing a "nested" tensor list:

>>> torch.rand(10, device = 'cuda').split(torch.tensor([5, 5], device = 'cuda'))
Traceback (most recent call last):
  File ".../vadim/prefix/miniconda/lib/python3.9/site-packages/torch/_tensor.py", line 510, in split
    split_size = int(split_size)
ValueError: only one element tensors can be converted to Python scalars

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../vadim/prefix/miniconda/lib/python3.9/site-packages/torch/_tensor.py", line 513, in split
    return super(Tensor, self).split_with_sizes(split_size, dim)
TypeError: split_with_sizes(): argument 'split_sizes' (position 1) must be tuple of ints, not Tensor

>>> torch.rand(10, device = 'cuda').split(torch.tensor([5, 5], device = 'cuda').tolist())
(tensor([0.0725, 0.3508, 0.9368, 0.2485, 0.2627], device='cuda:0'), tensor([0.3593, 0.4366, 0.6633, 0.6635, 0.5102], device='cuda:0'))
def all_gather_tensors(tensor):
    if get_world_size() == 1:
        return [tensor]

    shape_tensor = torch.tensor(tensor.shape, dtype = torch.int64, device = tensor.device)
    shapes = list(torch.zeros([world_size, len(tensor.shape)], dtype = torch.int64, device = tensor.device).unbind())
    torch.distributed.all_gather(shapes, shape_tensor)

    max_shape = torch.stack(shapes).amax(dim=0).tolist()
    pad = sum([[0, max_dim - dim] for max_dim, dim in reversed(list(zip(max_shape, tensor.shape)))], [])
    padded_tensor = F.pad(tensor, pad = pad)

    tensors = list(torch.zeros([world_size, *padded_tensor.shape], dtype=padded_tensor.dtype, device=padded_tensor.device).unbind())
    torch.distributed.all_gather(tensors, padded_tensor)

    for i, shape in enumerate(shapes):
        tensors[i] = tensors[i][[slice(x) for x in shape.tolist()]]

    return tensors

def all_gather_tensors_cat(tensors):
    return torch.cat(all_gather_tensors(torch.cat(tensors)))

def all_gather_tensors_nested(tensors):
    counts = torch.tensor(list(map(len, tensors)), dtype = torch.int64)
    counts = torch.cat(all_gather_tensors(counts))
    tensor = all_gather_tensors_cat(tensors)
    return tensor.split(counts.tolist())

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: bootcampWe plan to do a full writeup on the issue, and then get someone to do it for onboardingtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0