-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[feature request] Reduction (torch.add / torch.logaddexp / torch.max / torch.min / torch.mean) of several tensors without extra copies/allocations / memory accesses } TensorList inputs support #27522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I agree that would be nice. Would you be willing to submit a PR for this? @gchanan What would this take on the dispatch side? I imagine this would require changing the core op to take N inputs, so it might be more complicated than it seems. |
Also wondering how hard it would be to automatically parallelize this with streams (when suited) in backend. |
do you mean |
fwiw NumPy supports |
This is not equivalent to In fact, numpy convert its input into an array in But a = [np.random.rand(2, 2, 2) for _ in range(5)]
print(sum(a).shape)
# gives (2, 2, 2)
print(np.sum(a).shape)
# gives () TF has a dedicated function for that. The JIT can definitely optimize the particular pattern |
@fmasssa can we adapt |
@gchanan I meant equivalent of |
@vadimkantorov thanks for clarifying. I think all the dispatch building blocks are there for this, it's the same as |
we could, but |
@fmassa One way is to factor out the fused multipy-add into a separate method |
It would also be nice to have torch.logsumexp support multiple input tensors, so that prior torch.stack can be avoided. Actually, this kind of API makes sense for all reduction functions. Alternatively, maybe this could be supported without API change if a NestedTensor becomes reality (so that torch.stack is done only logically without materialization). e.g. consider a simplistic impl of CTC: @torch.jit.script
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank : int = 0, reduction : str = 'none'):
targets_ = torch.full((targets.shape[0], 2 * targets.shape[-1] + 1), blank, device = targets.device, dtype = targets.dtype)
temporal_mask = torch.arange(targets.shape[-1], device = input_lengths.device, dtype = input_lengths.dtype).unsqueeze(0) < target_lengths.unsqueeze(1)
targets_[:, 1::2] = temporal_mask * targets + (~temporal_mask) * targets_[:, 1::2]
max_target_length = int(target_lengths.max())
max_target_length_ = 2 * max_target_length + 1
targets_ = targets_[:, :max_target_length_]
neginf = torch.as_tensor([float('-inf')], device = log_probs.device, dtype = log_probs.dtype)
batch_size = targets.shape[0]
log_alpha = torch.empty(batch_size, log_probs.shape[0], 2 + max_target_length_, device = log_probs.device, dtype = log_probs.dtype)
log_alpha[:, :3].fill_(neginf.sum())
log_alpha[:, 0, 2 + 0] = log_probs[0, :, blank]
log_alpha[:, 0, 2 + 1] = log_probs[0, torch.arange(batch_size), targets_[:, 1]]
log_probs_ = log_probs.gather(-1, targets_.expand(len(log_probs), -1, -1))
la3_ = torch.cat([torch.as_tensor([[True, True]], device = targets.device).expand(batch_size, -1), targets_[:, 2:] != targets_[:, :-2]], dim = 1)
for t in range(1, len(log_probs)):
la3 = log_alpha[:, t - 1, 0:-2]
la2 = log_alpha[:, t - 1, 1:-1]
la1 = log_alpha[:, t - 1, 2:]
log_alpha[:, t, 2:] = log_probs_[t] + torch.logsumexp(torch.stack([la1, la2, torch.where(la3_, la3, neginf)]), dim = 0)
l1 = log_alpha[:, input_lengths - 1, 2 + target_lengths * 2].diag()
l2 = log_alpha[:, input_lengths - 1, 2 + target_lengths * 2 - 1].diag()
return -torch.logsumexp(torch.stack([l1, l2]), dim = 0) In order to do logsumexp of several tensors I have to first stack them. |
In #32100 it seems that NumPy actually supports lists of tensors as reduction functions input instead of a stacked tensor |
Related: #38377 |
Given that logaddexp is merged, I thought of bringing this up again: maybe via nested tensor. Another instance that could be useful: torch.argmax across two tensors (without stacking them first). E.g. if we pass two tensors in that dimension, it would return for every element 0 or 1. NumPy seems to support it for arbitrary functions via ufuncs: |
Yay! NestedTensor seems to support reductions now: https://github.com/pytorch/nestedtensor/blob/master/examples/naryops_and_reduce.ipynb some API musings: # if default nested dim is 0?
torch.add(NestedTensor([a, b]), dim = 0)
# maybe we can specify it semantically?
torch.add(NestedTensor([a, b]), dim = torch.nested_dimension)
# or have torch.stack return a nested tensor, since it's already an established way to bind tensors together, and new dim = 0 is expected
torch.add(torch.stack([a, b], layout = torch.nested), dim = 0)
# or maybe we can ask any op do torch.as_tensor and do torch.as_nested_tensor if it's a list of tensors
torch.add([a, b], dim = torch.nested_dimension) |
@izdeby Maybe for each related? |
Maybe if your test doesn't require grads they will be deallocated quickly. With gradients I think it should do N copies |
Yes, not O(N), but just O(2) Feel free to experiment:
and then 4 jupyter notebook cells:
|
From what I understand, in your test elements in |
That's a great point, Vadim. Let's try with grads. If I use:
then everything changes depending on the reduction function from which backward is called, e.g. and with note that in both cases, your version doesn't fair very differently from normal If you want to break it down: |
Is it at all realistic to use Triton here? Surely their tutorial on summing vectors can be extended to multiple vectors. You'd then be left with implementing autograd for this operation, which can be achieved via autograd.Function. |
I think it doens't make much sense to take a dependency on Triton just for this summation optimization. If it's used for all other multi-tensor reductions, it's more reasonable... Ideally, these multi-tensor reductions would be supported for all reductions via NestedTensor bindings (or just accepting tensorlists in place of tensor) ... |
@vadimkantorov - You could guard against the presence of Trition and fall back to a less efficient implementation if it's not present. It should also allow you to unblock your perf explorations such as overall influence on the model and potential gains. |
I think for summation specifically manual |
But as #70533 mentions, probably the sum / functools.reduce workarounds aren't very TorchScript-able... |
If you need TorchScript support you're unlikely to get around using the C++ extensibility features and ship your own binary code at this point. Neither Triton nor functools are JIT-able. |
@cpuhrsch a real-world use-case for torch.sum(tensor_list, dim = 0): https://github.com/pytorch/pytorch/blob/master/torch/distributed/nn/functional.py#L298 also, could there be an improvement of using a single allocation instead of [torch.empty_like(tensor) for tensor in grad_outputs]? using nestedtensor (e.g. a _foreach_empty_like or a new general multi/foreach-allocation utility?) elements in the list do not have an equal shapes, right? this should be quite simple if the tensors are all dense / not meta-tensors for torch.sum to benefit from more efficiency, nestedtensor/tensorlist should not force a copy (especially if number of tensor to be summed is small). ideally should always be an inplace version of sum/add torch.add_(res, *[list of extra summants]), this would be the most idiomatic to replace the manual loop |
another usecase/argument to study about ops on multiple tensors and using or avoiding stack/copy: 5654e63#diff-8333640cff01dcc6a2ab20085003e8c998db905ad85c09603a20a6e363aff2e2R135 |
@cpuhrsch Also, one could imagine reduction across multiple tensors of different sizes without prior padding, e.g. using neutral elements as imputation: 0 for sum. Also might be possible for mean. Motivating example: #27522 (comment) where one computes logsumexp of tensors of slightly different sizes: N, N-1, N-2. Also related to #77876 |
@cpuhrsch Recent apple transformers optimization stresses reducing memory copies and also uses softmax on a tensorlist: So supporting copy-avoiding TensorList (in addition to always-copying NestedTensor) at least for some ops may be a good idea |
@vadimkantorov - while I agree, I'm wondering whether we couldn't achieve the same by mostly sticking to NestedTensors across the operations? As in, let's look at where the TensorList comes from to begin with and use a NestedTensor right away. |
copy-semantics may be unwanted overhead for some things (e.g. optimizer existing multi-tensor-apply that may try to do things inplace as much as possible?) Regarding frontend, allowing regular python tensor lists may be terser syntax, e.g. torch.softmax([...my tensor list], dim = -1) -> tensor list (but computed with parallelization if needed, especially on CPU, or with CUDA streams, related: #78507) |
Another useful method to support for TensorList inputs is torch.max / torch.min |
I guess there would be a perf difference for re-allocation / copies that NestedTensor currently forces |
Uh oh!
There was an error while loading. Please reload this page.
If I understand correctly
sum(tensor_list)
will allocate and keep O(N) intermediate tensors (same with a for loop) where N is number of tensors, which can be quite large in the case of big DenseNet. I propose to maybe generalizetorch.add
to support more than two tensors as input.Currently one can do:
functools.reduce(lambda acc, x: acc.add_(x), tensor_list, torch.zeros_like(tensor_list[0]))
, so it's not super-urging, but a more idiomatic, TorchScript-able way may be niceThe text was updated successfully, but these errors were encountered: