8000 torch.nonzero(t, as_tuple=...) does not work with the JIT because the as_tuple signatures are not exposed properly · Issue #45499 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torch.nonzero(t, as_tuple=...) does not work with the JIT because the as 8000 _tuple signatures are not exposed properly #45499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mruberry opened this issue Sep 29, 2020 · 20 comments
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators module: python array api Issues related to the Python Array API oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mruberry
Copy link
Collaborator
mruberry commented Sep 29, 2020
def foo(t):
  return torch.nonzero(t, as_tuple=False)
: RuntimeError: Arguments for call are not valid.

This is because of the Python arg parsing done to handle the "as_tuple" kwarg. Note that tracing works fine.

cc @mruberry @rgommers @heitorschueroff @pmeier @gmagogsfm

@mruberry mruberry added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: numpy Related to numpy support, and also numpy compatibility of our operators labels Sep 29, 2020
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Sep 29, 2020
@eellison
Copy link
Contributor

currently it's hard to support this op because the return type depends on the value of as_tuple. I think this would not be that hard to support by adding Literal[True] and Literal[False] and then making the argument rule that we only accept constants for literal arguments. Integration with native_functions.yaml would be a little trickier, or we could just register it as a prim op & separately in python as currently.

@mruberry
Copy link
Collaborator Author

currently it's hard to support this op because the return type depends on the value of as_tuple. I think this would not be that hard to support by adding Literal[True] and Literal[False] and then making the argument rule that we only accept constants for literal arguments. Integration with native_functions.yaml would be a little trickier, or we could just register it as a prim op & separately in python as currently.

Or maybe we should remove the kwarg and deprecate functions that return different types of objects based on their input values.

@eellison
Copy link
Contributor

That also sounds good to me! I don't know whether nonzero was on the path to deprecation, or what would be the eager concerns around deprecating it outside of the JIT. I know that numpy has quite a large number of functions that have an output type dependent on the input type.

The mechanism to return a different type dependent on input type is implemented for python, just not for builtins :

_return_inverse_false = boolean_dispatch(

@ngimel
Copy link
Collaborator
ngimel commented Oct 1, 2020

@eellison what did you do for unique? It also returns either tensor or tuple of tensors depending on the value of kwargs (and btw unique has the same behavior in numpy)

@ppwwyyxx
Copy link
Collaborator
ppwwyyxx commented Oct 1, 2020

Same issue as #38718

@eellison
Copy link
Contributor
eellison commented Oct 1, 2020

@ngimel it uses the dispatching method we have implemented in
python https://github.com/pytorch/pytorch/blob/master/torch/functional.py#L803. Since torch.nonzero already dispatches to two difference aten ops (nonzero, and nonzero_numpy), maybe we could move nonzero C++ pybind binding to pytorch and support it without a lot of trouble ?

@rafale77
Copy link

Just ran into this as well. Quite annoying. Looking forward to a fix.

@eellison
Copy link
Contributor
8000 eellison commented Jan 27, 2021

I can give guidance on how this might be fixed but I don't know if I have bandwidth to fix this. Adding back to triage.

@gmagogsfm
Copy link
Contributor

I can give guidance on how this might be fixed but I don't know if I have bandwidth to fix this. Adding back to triage.

Could you write up a little about how to fix it in details? Thanks!

@eellison
Copy link
Contributor

To do this "the right way" I would :

Add Literal Types - Literal[True], Literal[False]

@torch.jit.overload
def unique(tensor, as_tuple: Literal[True]): ...

@torch.jit.overload
def unique(tensor, as_tuple: Literal[False]): ...

def unique(tensor, as_tuple: bool):
    if as_tuple:
        return torch._C.unique_tuple(tensor)
   else:
        return torch._C.unique(tensor)

But this might be more trouble than it's worth. We might consider throwing an error, exposing the tuple/non-tuple signatures, and asking the user to either directly invoke the unique_tuple vs unique overloads (those may or may not exist rn, i'm not sure).

@ngimel
Copy link
Collaborator
ngimel commented Jan 28, 2021

Where by unique you must mean nonzero, @eellison ?

@nlgranger
Copy link

FYI, 649e683 (folllowing #45499) fixes the type annotations of nonzero.

@rgommers
Copy link
Collaborator
rgommers commented Feb 8, 2021

But this might be more trouble than it's worth.

It does seem worth doing - this has already come up for unique, nonzero, and various torch.linalg functions.

@ppwwyyxx
Copy link
Collaborator
ppwwyyxx commented May 23, 2021

Any updates? It has been a year since #38718 is opened, and it once was part of the 1.6.0 milestone (#38718 (comment))

@eellison eellison removed their assignment May 24, 2021
@gmagogsfm
Copy link
Contributor

Thanks @eellison for proposing a potential solution. Indeed it looks more trouble (introducing Literal type) than it is worth. Besides, since the output tuple type depends on rank of input tensor, it doesn't seem we can handle the as_tuple=True case easily in IR emission either.

What if we model the output as Any and asks user, who might have better knowledge about specific type of output tuple to refine the result?

@mruberry mruberry added module: python array api Issues related to the Python Array API and removed module: python array api Issues related to the Python Array API labels May 30, 2021
@mruberry
Copy link
Collaborator Author
mruberry commented May 30, 2021

Note that the Python Array API specifies nonzero return a tuple by default (see https://data-apis.org/array-api/latest/API_specification/searching_functions.html#nonzero-x), although we may be able to argue it should return a tensor, @ngimel, @rgommers

@rgommers
Copy link
Collaborator
rgommers commented Jun 1, 2021

From data-apis/array-api#23:

The reason apparently was that it's easier for advanced indexing (you can then simply do x[*tuple_idx], with the non-tuple form it'd be more annoying).
...
The tuple form for nonzero allows boolean array indices to be equivalent to nonzero (except for scalar booleans), so x[boolean_array] == x[nonzero(boolean_array)] (if there are multiple indices, the nonzeros are unpacked in the tuple).

@ngimel
Copy link
Collaborator
ngimel commented Jun 1, 2021

Without tuple, you can do x[mask.nonzero().unbind(-1)], I don't think it's more annoying.

@rgommers
Copy link
Collaborator
rgommers commented Jun 2, 2021

True, unbind comes in pretty handy here. That's PyTorch-specific though, and the comments above are about a portable solution (NumPy, TF, Dask, CuPy, MXNet).

@rgommers
Copy link
Collaborator
rgommers commented Jun 10, 2021

I investigated the state some more, a bit of history in NumPy on nonzero returning a tuple added in data-apis/array-api#23 (comment). I also searched the NumPy mailing list and issue trackers of NumPy, CuPy, JAX, TensorFlow and Dask, and asked a few maintainers of other libraries. There were no real complaints/issues about nonzero returning a tuple.

There were a number of issues related to nonzero output shape depending on values in the input and that being a problem for preallocating/JIT-ing. But that's unrelated to the tuple-return thing.

In practice, there's not much of a difference one way or the other - it's just that PyTorch made a different choice than other libraries for return type early on. It's not going to change in the array API or in NumPy. So for that aspect, I think the choice is between PyTorch remaining different, or doing a hard BC-breaking change. There's something to say for the former here.

EDIT: discussion on data-apis/array-api#23 continued after the above, there's also argwhere which is yet another flavor of this functionality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators module: python array api Issues related to the Python Array API oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants
0