-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[Needs someone to complete] Reduce sum on many axes #2116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I wonder if it would be better to remove the implementation from cwrap, to avoid conflicts? I think that it's better than implementing these ops in cwrap, because we automatically have support for autograd.
Also, what is the behavior of numpy for operations like median when multiple axis are passed? Does it perform multiple kernel calls, or does it transpose+view+kernel call? For sum
it moght not matter, but for other ops that might make a difference.
@@ -2,6 +2,7 @@ | |||
from ._utils import _range | |||
from operator import mul | |||
from functools import reduce | |||
import collections |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/functional.py
Outdated
input = input.sum(ax, keepdims=True) | ||
else: | ||
for ax in sorted(axes, reverse=True): | ||
input = input.sum(ax) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/functional.py
Outdated
def sum(input, axes, keepdims=False, out=None): | ||
if isinstance(axes, collections.Iterable) | ||
if a.dim() > 3: | ||
if keepdims: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
# permute | ||
# reduce single dim | ||
else: | ||
return torch._C.sum(input, axes, keepdims, out) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
I am trying to perform std on many axes, pretty similar to what you are doing with sum. Is this problem solved? |
I also want to do var over many axes. is this solved? (same question as @bernardohenz except with var). I should note that numpy supports this, and the only way to do this in pytorch currently is to compute the mean, subtract (using expand), square, and then take the mean. Basically manually. |
@tstandley we are working on mean, variance and stdv on multiple axes. @colesbury should put up a PR soon for it. |
For signposting: |
Resolves #2006
keepdim
tokeepdims