8000 add __torch_function__ API override mechanism by prasunanand · Pull Request #27064 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

add __torch_function__ API override mechanism #27064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 173 commits into from

Conversation

prasunanand
Copy link
Contributor
@prasunanand prasunanand commented Sep 30, 2019

Closes #24015 (see description of that issue for more details).

For a toy example, see the DiagonalTensor and SubDiagonalTensor class in test/test_overrides.py.

This PR currently contains:

  • tests for __torch_function__ behavior
  • modification to gen_python_functions and parse function signatures and dispatched to correct overloaded argument.

This feature is inspired by and analogous to NumPy's __array_function__ protocol (see NumPy Enhancement Proposal 18).

Benchmarks:

See Nathan's comment below: #27064 (comment)

prasunanand and others added 30 commits September 3, 2019 16:51
Remove utility code that we can simply import from NumPy for now.
Things import again and can be tested.
Note, it does not get called normally (there's just a check it exists),
the dispatcher calls the function implementation directly.
Manual check of current performance:
```
In [10]: %timeit mock_concatenate([Tensor(1), Tensor(2)])
2.58 µs ± 7.92 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [11]: %timeit mock_broadcast_tensors(Tensor(1))
1.65 µs ± 3.71 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```

Run with ASV:
```
$ asv run --python=same --dry-run
· Discovering benchmarks
· Running 6 total benchmarks (1 commits * 1 environments * 6 benchmarks)
[  0.00%] ·· Benchmarking existing-py_home_rgommers_anaconda3_envs_pytorch_bin_python
[  8.33%] ··· Running (bench_overrides.TorchFunction.time_mock_broadcast_tensors_duck--)......
[ 58.33%] ··· ...orchFunction.time_mock_broadcast_tensors_duck            793±5ns
[ 66.67%] ··· ...rchFunction.time_mock_broadcast_tensors_torch           867±70ns
[ 75.00%] ··· ...ides.TorchFunction.time_mock_concatenate_duck         1.44±0.1μs
[ 83.33%] ··· ...ides.TorchFunction.time_mock_concatenate_many           86.2±7μs
[ 91.67%] ··· ...des.TorchFunction.time_mock_concatenate_mixed        2.33±0.01μs
[100.00%] ··· ...des.TorchFunction.time_mock_concatenate_torch            902±9ns
```

So performance is as expected for a pure-Python implementation.
That behavior of forwarding the sum() function to a sum() method is
specific to NumPy.
The removed test checked handling of >32 input parameters.
NumPy limits this to 32, with NPY_MAXARGS. PyTorch doesn't have that
limitation.
This way, the ASV benchmarks can be run on master. Individual benchmarks
will fail, but not the ASV run itself.  This is an ASV feature; you
can go back in time and run a benchmark suite on older commits.
@cpuhrsch
Copy link
Contributor
cpuhrsch commented Dec 4, 2019

But, to be fair, torch.nn.functional.conv2d is covered by overwriting torch.conv2d - which again is also not documented (but that is unrelated to this PR).

@cpuhrsch
Copy link
Contributor
cpuhrsch commented Dec 4, 2019

From what I can tell by the tests, by default, this PR enables support for the majority of torch functions and torch.nn.functionals? However there appears to be a long list of functions that are ignored. It'd be great to make this easily accessible. Then I can, on the NestedTensor side, step through the torch functions that are already supported via torch_function and make sure I have an implementation for them. That'd make testing for me really easy.

@ngoldbaum
Copy link
Contributor

Right now we haven't explicitly added overrides for everything in torch.nn.functional. There is some stuff that has overrides because it also appears in the torch namespace. My plan was to add overrides for everything in torch.nn, torch.nn.functional and torch.functional (and explicitly test that everything is overridable) after this initial PR landed. Would that be sufficient for your needs?

It's not totally clear to me what you want in terms of documentation, if just adding the overrides is insufficient.

However there appears to be a long list of functions that are ignored.

Yes, these are all functions that don't take tensors in any of their arguments. As a general rule anything that accepts tensors is overridable.

You're right that it would be helpful for authors of tensor-like implementations to have some way of accessing that data though.

@cpuhrsch
Copy link
Contributor
cpuhrsch commented Dec 4, 2019

The docs also seem to indicate that I have to effectively reimplement the pytorch python argparser to make sure I can fully cover all signatures for NestedTensor

@cpuhrsch
Copy link
Contributor
cpuhrsch commented Dec 4, 2019

Right now we haven't explicitly added overrides for everything in torch.nn.functional. There is some stuff that has overrides because it also appears in the torch namespace. My plan was to add overrides for everything in torch.nn, torch.nn.functional and torch.functional (and explicitly test that everything is overridable) after this initial PR landed. Would that be sufficient for your needs?

I will also want to overwrite forward functions e.g. torch.nn.LSTM.forward, however that doesn't seem related here?

It's not totally clear to me what you want in terms of documentation, if just adding the overrides is insufficient.

The torch_function_dispatch decorator does have docs, but it seems relevant enough to be mentioned in that tutotirals.

You're right that it would be helpful for authors of tensor-like implementations to have some way of accessing that data though.

Yes this would make my life much easier :)

@ngoldbaum
Copy link
Contributor

The torch_function_dispatch decorator does have docs, but it seems relevant enough to be mentioned in that tutotirals.

My thinking when writing those docs that that only pytorch developers would need to touch torch_function_dispatch and that those docs were more focused on pytorch users who want to write tensor-like implementations. Can you share a little bit more why you want to use torch_function_dispatch?

I think ultimately the goal is for every function in the pytorch API that accepts tensor-likes will work with the __torch_function__ dispatch so that you wouldn't ever need to manually hook into the dispatch mechanism.

@cpuhrsch
Copy link
Contributor
cpuhrsch commented Dec 4, 2019

I'd want to use torch_function_dispatch to add support for uncovered functions within the torch modules - but if you say that these will eventually all be covered I'll link my support to those based on the given coverage.

To keep track of development for NestedTensor it'd be good to have full insight into which op I should implement and which op I don't need to have support for yet because it's not been hooked into __torch_function__. Then I can run tests to make sure I fully cover them (akin to the overwrite test here).

facebook-github-bot pushed a commit that referenced this pull request Dec 4, 2019
Summary:
This is a re-do of #27064, which was reverted (b8792c0). This was landed at the same time as other work that added new operators to the `torch` namespace so the check for whether the `torch` namespace is exhaustively checked for overridability was triggering test failures.

I've temporarily disabled that check and added an explanatory comment that the check will be re-enabled in a future PR that will be merged during a time when the commit velocity on PyTorch is lower.
Pull Request resolved: #30730

Differential Revision: D18813270

Pulled By: ezyang

fbshipit-source-id: 70477c4656dca8fea6e7bc59259555041fcfbf68
@cpuhrsch
Copy link
Contributor
cpuhrsch commented Dec 4, 2019

For visibility, I'm moving over NestedTensor to use __torch_function__ instead of the custom monkey patching code in this branch. I'll finish it as soon as we have support for the ops it needs within core pytorch.

Thank you so much again for working on this @ngoldbaum!

@rgommers
Copy link
Collaborator
rgommers commented Dec 5, 2019

To keep track of development for NestedTensor it'd be good to have full insight into which op I should implement and which op I don't need to have support for yet because it's not been hooked into

This is probably a matter of a week or so I think, right @ngoldbaum? Should be fai 8000 rly mechanical.

I will also want to overwrite forward functions e.g. torch.nn.LSTM.forward, however that doesn't seem related here?

Indeed, not related, and it hasn't come before in our discussions as far as I can remember. Looks like LSTM.forward is simple and pure Python, so easy to monkeypatch if needed. Making all methods of all classes in PyTorch overridable seems to me like it'd be a few bridges too far, at least right now.

@dylanbespalko
Copy link
Contributor
dylanbespalko commented Dec 5, 2019

@rgommers, @ngoldbaum, @prasunanand, and rest.

Great job guys. Very much appreciated!

@cpuhrsch
Copy link
Contributor

FYI I just switched NestedTensor over to only use torchfunction for its dispatch.

@rgommers
Copy link
Collaborator

Awesome, looks like that switch went well - thanks for letting us know @cpuhrsch.

facebook-github-bot pushed a commit that referenced this pull request Jan 29, 2020
Summary:
In-tree changes to pytorch to support complex numbers are being submitted here.
Out-of-tree support for CUDA complex numbers is here: [pytorch-cuda-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cuda-strided-complex)

Changes:
[x] Fixed performance issue raise in #30704 so that non-complex numbers do not call `conj()` and `real()`.
[x] Fixed tensor_to_numpy() conversion likely broken by a `checkBackend()` in #27064.
[x] Fixed some ReduceOps and TensorCompare Ops that recently added a `checkBackend()`.
    - `checkBackend()` is replaced with a device type check and a layout check.
    - This ensures the ComplexCPU Type ID is supported.
[x] Added AVX support for complex `exp()`, as requested in #755
Pull Request resolved: #30871

Differential Revision: D19200726

Pulled By: ezyang

fbshipit-source-id: d7e1be0b0a89c5d6e5f4a68ce5fcd2adc5b88277
wuhuikx pushed a commit to wuhuikx/pytorch that referenced this pull request Jan 30, 2020
Summary:
Closes pytorch#24015 (see description of that issue for more details).

For a toy example, see the `DiagonalTensor` and `SubDiagonalTensor` class in test/test_overrides.py.

This PR currently contains:

* tests for `__torch_function__` behavior
* modification to `gen_python_functions` and `parse` function signatures and dispatched to correct overloaded argument.

This feature is inspired by and analogous to NumPy's `__array_function__` protocol ([see NumPy Enhancement Proposal 18](https://numpy.org/neps/nep-0018-array-function-protocol.html#trying-array-function-methods-until-the-right-one-works)).

### Benchmarks:
See Nathan's comment below: pytorch#27064 (comment)
Pull Request resolved: pytorch#27064

Differential Revision: D18645954

Pulled By: ezyang

fbshipit-source-id: 54b5e4344d7afdbcf996bb57191b0bdadc7b1767
wuhuikx pushed a commit to wuhuikx/pytorch that referenced this pull request Jan 30, 2020
Summary:
This is a re-do of pytorch#27064, which was reverted (pytorch@b8792c0). This was landed at the same time as other work that added new operators to the `torch` namespace so the check for whether the `torch` namespace is exhaustively checked for overridability was triggering test failures.

I've temporarily disabled that check and added an explanatory comment that the check will be re-enabled in a future PR that will be merged during a time when the commit velocity on PyTorch is lower.
Pull Request resolved: pytorch#30730

Differential Revision: D18813270

Pulled By: ezyang

fbshipit-source-id: 70477c4656dca8fea6e7bc59259555041fcfbf68
wuhuikx pushed a commit to wuhuikx/pytorch that referenced this pull request Jan 30, 2020
Summary:
In-tree changes to pytorch to support complex numbers are being submitted here.
Out-of-tree support for CUDA complex numbers is here: [pytorch-cuda-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cuda-strided-complex)

Changes:
[x] Fixed performance issue raise in pytorch#30704 so that non-complex numbers do not call `conj()` and `real()`.
[x] Fixed tensor_to_numpy() conversion likely broken by a `checkBackend()` in pytorch#27064.
[x] Fixed some ReduceOps and TensorCompare Ops that recently added a `checkBackend()`.
    - `checkBackend()` is replaced with a device type check and a layout check.
    - This ensures the ComplexCPU Type ID is supported.
[x] Added AVX support for complex `exp()`, as requested in pytorch#755
Pull Request resolved: pytorch#30871

Differential Revision: D19200726

Pulled By: ezyang

fbshipit-source-id: d7e1be0b0a89c5d6e5f4a68ce5fcd2adc5b88277
@cpuhrsch
Copy link
Contributor
cpuhrsch commented Feb 11, 2020

@rgommers - are there any updates on programmatically getting a list of functions that currently have __torch_function___ support so I can use those as a way of guaranteeing a high level of operator coverage for NestedTensors?

@ngoldbaum
Copy link
Contributor

@cpuhrsch right now we test that all of the functions in the torch namespace that satisfy these predicates are overridable:

https://github.com/pytorch/pytorch/blob/master/test/test_overrides.py#L724-L740https://github.com/pytorch/pytorch/blob/master/test/test_overrides.py#L724-L740

Note, however, that there are a few exceptions because the test to make sure that the API is exhaustively covered are disabled because this test needs to be enabled during a merge window when no new operators are added at the same time as the test.

I'm currently working on #32799, which expands testing to all functions in torch.functional and torch.nn.functional. I got distracted last week with our company retreat but I'm hoping to finish that up this week.

Presumably the code I'm using in those tests to generate the list of overrided functions for the tests could be moved to a place where users could access it? I'm not sure what the best API would be.

@rgommers
Copy link
Collaborator

Presumably the code I'm using in those tests to generate the list of overrided functions for the tests could be moved to a place where users could access it? I'm not sure what the best API would be.

It probably shouldn't be public API since it's a meta-thing mostly useful for testing correctness/completeness in other libraries. Either a standalone script or a private API that library authors can use with the understanding it may change without deprecation would be my preference.
Could also make sense as a list in the docs.

@cpuhrsch
Copy link
Contributor

Yes, a private API that exposes the list of functions in that test could be very useful. In the same way that the JIT has functions such as torch.jit._builtins._get_builtin_table(). The entire TENSOR_LIKE_TORCH_IMPLEMENTATIONS is actually quite useful because it gets you the name plus the Python signature of that function.

I'd then use this to write a test quite similar to TensorLike that compares the result of a call to that of a reference value.

For now I'll copy-paste TENSOR_LIKE_TORCH_IMPLEMENTATIONS and continue to track the development, but I think this would be useful to have.

@ngoldbaum
Copy link
Contributor

I've opened #33182 to track this feature request.

facebook-github-bot pushed a commit that referenced this pull request Feb 21, 2020
…onal (#32799)

Summary:
This adds `__torch_function__` support for all functions in `torch.functional` and `torch.nn.functional`.

The changes to C++ code and codegen scripts are to facilitate adding `__torch_function__` support for the native functions in `torch._C._nn`. Note that I moved the `handle_torch_function` C++ function to a header that both `python_torch_functions.cpp` and `python_nn_functions.cpp` include. The changes to `python_nn_functions.cpp` mirror the changes I made to `python_torch_functions.cpp` when `__torch_function__` support was first added in #27064. Due to the somewhat different way the `torch._C` and `torch._C._nn` namespaces are initialized I needed to create a new static reference to the `torch._C._nn` namespace (`THPNNVariableFunctions`). I'm not sure if that is the best way to do this. In principle I could import these namespaces in each kernel and avoid the global variable but that would have a runtime cost.

I added `__torch_function__` support to the Python functions in `torch.nn.functional` following the approach in #32194.

I re-enabled the test that checks if all functions in the `torch` namespace are explicitly tested for `__torch_function__` support. I also generalized the check to work for `torch.functional` and `torch.nn.functional` as well. This test was explicitly disabled in #30730 and I'm happy to disable it again if you think that's appropriate. I figured now was as good a time as any to try to re-enable it.

Finally I adjusted the existing torch API tests to suppress deprecation warnings and add keyword arguments used by some of the code in `torch.nn.functional` that were missed when I originally added the tests in #27064.
Pull Request resolved: #32799

Differential Revision: D19956809

Pulled By: ezyang

fbshipit-source-id: 40d34e0109cc4b9f3ef62f409d2d35a1d84e3d22
ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
Summary:
In-tree changes to pytorch to support complex numbers are being submitted here.
Out-of-tree support for CUDA complex numbers is here: [pytorch-cuda-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cuda-strided-complex)

Changes:
[x] Fixed performance issue raise in pytorch#30704 so that non-complex numbers do not call `conj()` and `real()`.
[x] Fixed tensor_to_numpy() conversion likely broken by a `checkBackend()` in pytorch#27064.
[x] Fixed some ReduceOps and TensorCompare Ops that recently added a `checkBackend()`.
    - `checkBackend()` is replaced with a device type check and a layout check.
    - This ensures the ComplexCPU Type ID is supported.
[x] Added AVX support for complex `exp()`, as requested in pytorch#755
Pull Request resolved: pytorch#30871

Differential Revision: D19200726

Pulled By: ezyang

fbshipit-source-id: d7e1be0b0a89c5d6e5f4a68ce5fcd2adc5b88277
ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
…onal (pytorch#32799)

Summary:
This adds `__torch_function__` support for all functions in `torch.functional` and `torch.nn.functional`.

The changes to C++ code and codegen scripts are to facilitate adding `__torch_function__` support for the native functions in `torch._C._nn`. Note that I moved the `handle_torch_function` C++ function to a header that both `python_torch_functions.cpp` and `python_nn_functions.cpp` include. The changes to `python_nn_functions.cpp` mirror the changes I made to `python_torch_functions.cpp` when `__torch_function__` support was first added in pytorch#27064. Due to the somewhat different way the `torch._C` and `torch._C._nn` namespaces are initialized I needed to create a new static reference to the `torch._C._nn` namespace (`THPNNVariableFunctions`). I'm not sure if that is the best way to do this. In principle I could import these namespaces in each kernel and avoid the global variable but that would have a runtime cost.

I added `__torch_function__` support to the Python functions in `torch.nn.functional` following the approach in pytorch#32194.

I re-enabled the test that checks if all functions in the `torch` namespace are explicitly tested for `__torch_function__` support. I also generalized the check to work for `torch.functional` and `torch.nn.functional` as well. This test was explicitly disabled in pytorch#30730 and I'm happy to disable it again if you think that's appropriate. I figured now was as good a time as any to try to re-enable it.

Finally I adjusted the existing torch API tests to suppress deprecation warnings and add keyword arguments used by some of the code in `torch.nn.functional` that were missed when I originally added the tests in pytorch#27064.
Pull Request resolved: pytorch#32799

Differential Revision: D19956809

Pulled By: ezyang

fbshipit-source-id: 40d34e0109cc4b9f3ef62f409d2d35a1d84e3d22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
caffe2 module: autograd Related to torch.autograd, and the autograd engine in general module: docs Related to our documentation, both in docs/ and docblocks module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement __torch_function__ to let Tensor-like objects override torch functions
0