-
Notifications
You must be signed in to change notification settings - Fork 24.3k
add __torch_function__ API override mechanism #27064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
173 commits
Select commit
Hold shift + click to select a range
e966598
first try
prasunanand cda2057
modify dispatcher
prasunanand 9f710a4
signatures matched
prasunanand fee5ef1
Fix mixed tabs/spaces
rgommers 1bc13fd
Add implement_torch_function (in Python) implementation.
rgommers f9ea2ac
Move code from torch/__init__.py to torch/_overrides.py
rgommers 5965063
Remove TORCH_FUNCTION_ENABLED, this is an env var we don't need
rgommers 6ea9718
Fix flake8 warnings
rgommers 8982470
Add TODO for temporary addition to torch/__init__.py
rgommers 04af379
Add some imports, comments, and dummy __torch_function__
rgommers 4c05cfc
Implement __torch_function__ in Python.
rgommers 8e80113
Add an example of using the override for gemm in test/test_overrides.py
rgommers 4b4ca02
Add ASV benchmarks. Also fix an issue with Tensor.__torch_function__
rgommers c3e8731
Add some documentation for writing and running ASV benchmarks
rgommers 3753da7
Add a few overloads, and adds docs on what dispatcher functions should
rgommers da6ed8e
Another documentation tweak.
rgommers 47f4703
adopt tests from numpy
prasunanand fb8986e
modify the unittests and mark a few of them to be skipped
prasunanand fce2ef8
correct type order for subclass tests
prasunanand cb22ee1
Remove `assert_` again in favor of plain `assert`
rgommers 3a35770
Remove `TORCH_FUNCTION_ENABLED`, it wasn't doing anything.
rgommers 70e9e29
Remove torch.gemm, and change test to use torch.unique
rgommers c361848
Fix __torch_function__ subclass ordering test
rgommers a70d13b
Remove irrelevant test for `__torch_function__`
rgommers e0fedb9
Fix a couple more `__torch_function__` tests.
rgommers 20e535e
Fix an issue with subclasses for `__torch_function__`.
rgommers b757e11
Fix one more test, and remove an unnecessary test
rgommers 26885eb
Fix last failing test for `__torch_function__`.
rgommers fb71ad0
Treat imports properly in override benchmarks.
rgommers 58de6bb
Fix some typos in comments
rgommers 29bdccc
Make flake8 happy
rgommers f012e5b
Remove a spurious tab in asv.conf.json
rgommers d9dbf16
Fix two more pylint issues
rgommers a344329
Torch function overrides in cpp
prasunanand 8c7b2d4
Added comments to the code
prasunanand 7537bb2
Parse modified
prasunanand 6cae1b9
Test for NN
prasunanand 36fbd71
Skip overrides test
prasunanand 08ecb78
Added comments to parse and removed some unused code
prasunanand 8335d41
Find attribute only if present
prasunanand e05da42
Get rid of Python code and modify overrides tests
prasunanand 2556ded
Parse works! Remove duplicate code
prasunanand fe954b2
Python 2 support
prasunanand faa517a
Lint fix: Add new line
prasunanand 145309b
Fix Clang tidy error: {nullptr}
prasunanand 56498fc
Flake errors fix
prasunanand 53c0f21
Flake errors fix
prasunanand a69abe4
Check for overheads on add and multiply
prasunanand 9cd4ef7
Add __torch_function__ to other Torch APIs
prasunanand 8783bb2
Rebase with master
prasunanand c98161e
Modify the benchmark code
prasunanand 8242cfd
Subclass of torch.Tensor should check for __torch_function__
prasunanand 4b0468c
Benchmark SubTensors with __torch_function__ defined
prasunanand 175a118
Minor tweaks for lesser overhead
prasunanand 298bad4
Handle subclasses of Tensor, fix test
prasunanand de139c6
Fix Python3.7 Lint errors
prasunanand 2931eb0
Fix Python Lint errors
prasunanand af89388
Fix clang tidy
prasunanand b63f58b
Test overrides of Torch public APIs
prasunanand 5f87cc2
More Test overrides of Torch public APIs
prasunanand 8dfabca
make args default to () in __torch_function__ implementations used fo…
ngoldbaum 6496500
Fix duplicate function name
ngoldbaum 1ab345b
reduce boilerplate in override tests by defining ImplementationMeta m…
ngoldbaum c3465f8
Merge pull request #1 from ngoldbaum/torch_function
prasunanand 5a8a2c7
Merge branch 'master' into torch_function
ngoldbaum df60370
add support for torch functions defined in python
ngoldbaum b949180
autogenerate tests for the full torch API
ngoldbaum b67c39e
add override tests for some more functions
ngoldbaum 89695d0
Merge pull request #2 from ngoldbaum/torch_function
prasunanand f547e73
include test_overrides in main test runner
ngoldbaum 1b612a6
Move helpers from python_variable.h to python_arg_parser.h
ngoldbaum 509a05d
fix python2.7 SyntaxError
ngoldbaum 81a9aed
appease clang-tidy
ngoldbaum f687f15
Merge remote-tracking branch 'prasun/torch_function' into torch_function
ngoldbaum 271d94d
use functools.wraps for the dispatch decorators
ngoldbaum 6f0da73
rename HANDLED_FUNCTIONS to HANDLED_FUNCTIONS_DIAGONAL
ngoldbaum 630a155
remove ImplementationMeta to make tests follow suggested implementation
ngoldbaum c391bc4
reorganize so that dispatch tables and dispatch decorators are groupe…
ngoldbaum d0de54c
expand comments explaining dispatch tables
ngoldbaum a2b046e
expand comments
ngoldbaum 604c87c
use only one TestCase subclass
ngoldbaum 9687422
make override tests runnable in pytest
ngoldbaum d6c852b
Add comments and small corrections
prasunanand db6ddf2
Reference to Numpy and minor edits related to review
prasunanand d5c9eb8
Remove check_exact, instead parse a boolean
prasunanand 8b3741e
remove usage of getargspec from numpy
ngoldbaum a82499b
make it clearer that everything besides torch_function_dispatch is pr…
ngoldbaum 5e98d3b
remove unused docs_from_dispatcher keyword for torch_function_dispatch
ngoldbaum f96f94a
expand docs for torch_function_dispatch decorator
ngoldbaum baae730
add py2/py3 compat code in test_overrides
ngoldbaum 9102244
simplify testing somewhat
ngoldbaum eb25705
fix spelling and rst syntax
ngoldbaum 7e150ac
Merge branch 'master' into torch_function
ngoldbaum c472d48
bring back keyword arguments for TensorLike API override tests
ngoldbaum e710eb2
add new keyword argument to cdist
ngoldbaum 2b9064b
pass the function object instead of the name to __torch_function__
ngoldbaum bc67494
remove unused _torch_function function in torch.tensor module
ngoldbaum db06350
remove unused python bindings for _promote_types
ngoldbaum 16c4122
remove unnecessary __torch_function__ checking for functions that can…
ngoldbaum 30936bc
make torch.numel overridable
ngoldbaum b8b783e
simplify argument parsing logic and add explanatory comments
ngoldbaum 92ce07a
remove breakpoint
ngoldbaum f1cf69d
make overloaded_args a std::vector to remove signature size limit
ngoldbaum 2166bb1
don't initialize overloaded_args with 32 nullptr
ngoldbaum fd4f3d7
remove unnecessary returns
ngoldbaum a0faf8d
remove PythonArgs::get_overload_arg
ngoldbaum efe5d2e
update comment wording
ngoldbaum 9f0a146
check if __torch_function__ returns NotImplemented and call next-high…
ngoldbaum 52f1b5b
expand tests for __torch_function__ return semantics
ngoldbaum f39297c
Merge branch 'master' into torch_function
ngoldbaum 6893a11
fix clang_tidy nit
ngoldbaum e096ef3
update asv benchmarks and benchmark docs
ngoldbaum b82cb5c
expand benchmarks
ngoldbaum a0aacf0
reduce code duplication in the code generation
ngoldbaum 78b6f87
fix indentation
ngoldbaum bcdcf67
remove unused get_tensor_torch_function
ngoldbaum dc781f7
ignore torch.sparse_coo_tensor for overriding purposes
ngoldbaum 00f3035
fix indentation
ngoldbaum 377b6d1
explicitly check __torch_function__ doesn't return nullptr
ngoldbaum f10ba9f
use a range-based for loop to simplify handle_torch_function
ngoldbaum a4cd209
remove unused testing code
ngoldbaum 08d6d9f
fix reference counting in handle_torch_function
ngoldbaum 11525b1
remove unused get_torch_function
ngoldbaum fa37c21
move check_has_torch_function to the arg parser header
ngoldbaum bcb498f
fix reference leak in check_has_torch_function
ngoldbaum 994b83e
simplify logic in PythonArgParse::parse
ngoldbaum 5f44700
do reference counting for objects in overloaded_args vector
ngoldbaum c00fa4d
expand docs of new C-level helper functions
ngoldbaum a4651b5
revert added whitespace
ngoldbaum 6de2423
combine THPVariable_Check and THPVariable_CheckExact to simplify logi…
ngoldbaum 8ad8070
update tests to check for TypeError
ngoldbaum 0892635
add doc comment for handle_torch_function
ngoldbaum a622a0f
refactor to use pybind11 wrappers, make handle_torch_function raise e…
ngoldbaum fc45687
ensure exceptions raised in user implementations are propagated
ngoldbaum 0cf2227
cross-reference python and C++ implementations of __torch_function__ …
ngoldbaum 7a2f055
fix deprecation warning
ngoldbaum e2ec5a2
refactor overloaded_args handling in parser into a helper function
ngoldbaum d596d56
use range-based for loops
ngoldbaum 724d248
expand documentation for torch_override helper functions
ngoldbaum 14b479b
fix review nits in python code
ngoldbaum 1265134
fix compiler error
ngoldbaum 64b1962
Merge branch 'master' into torch_function
ngoldbaum 14b1634
use the default py::object initializer for the return value of handle…
ngoldbaum e28d602
use reinterpret_steal instead of reinterpret_borrow to make reference…
ngoldbaum 00cb4e7
explicitly throw errors using python_error()
ngoldbaum c70c7d1
refactor tests to explicitly compare python and C++ dispatch
ngoldbaum 8eced71
remove unnecessary comments
ngoldbaum 98da931
reword docstring for precedence tests
ngoldbaum a3a98e7
add explanatory comment to overrides tests
ngoldbaum c55c2f1
attempt to reduce branching using templates
ngoldbaum 9717cf7
Merge branch 'master' into torch_function
ngoldbaum c36da88
add bitwise_xor test
ngoldbaum 082bd6d
use a template parameter instead of partial specialization to reduce …
ngoldbaum b9bfb17
use release() instead of incref() to avoid changing the reference cou…
ngoldbaum e2286f3
add a check for tensor types in PyTorch_LookupSpecial
ngoldbaum 274c59d
only do exact checking on tensor operands
ngoldbaum ec804c4
refactor to move overloaded_args to FunctionSignature
ngoldbaum 1a1fab6
remove unnecessary parens
ngoldbaum d38f2c3
add documentation for __torch_function__
ngoldbaum 0be6167
Merge branch 'master' into torch_function
ngoldbaum bc6d00b
Merge remote-tracking branch 'origin/master' into torch_function
ezyang 0aaf456
add PYBIND11_EXPORT to FunctionSignature
ngoldbaum ff7d3c0
update expected answers for ONNX tests
ngoldbaum 99b0412
add hyperlinks for numpy's __array_function__ to docs
ngoldbaum 903df1d
move __torch_function__ documentation to the end of notes/extending.rst
ngoldbaum 7571dc8
make class definitions more copy/pasteable
ngoldbaum d5ebcdc
reword intro section
ngoldbaum ae791ca
add note that HANDLED_FUNCTIONS pattern isn't required
ngoldbaum 488d55a
rewording
ngoldbaum 07793ce
respond to doc comments
ngoldbaum 403d45c
Merge remote-tracking branch 'upstream/master' into torch_function
ngoldbaum 3e72316
Merge remote-tracking branch 'origin/master' into torch_function
ezyang d2d9c12
remove asv benchmarks
ngoldbaum File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.