8000 Add API for listing functions overridable by __torch_function__ (#33791) · ttumiel/pytorch@b62fce8 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit b62fce8

Browse files
ngoldbaumttumiel
authored andcommitted
Add API for listing functions overridable by __torch_function__ (pytorch#33791)
Summary: Fixes pytorch#33182 This adds private API functions that developers of types that implement `__torch_function__` can use to ensure full coverage of the subset of the PyTorch API that can be overrided. I've refactored some of the code in the tests into a new `torch._overrides.get_overridable_functions` function. I've also changed `TENSOR_LIKE_TORCH_OVERRIDES` into `torch._overrides.get_testing_overrides` and `IGNORED_TORCH_FUNCTIONS` into `torch._overrides.get_ignored_functions`. Making these two static global variables in the tests into functions should allow rewriting their implementation to construct their return values instead of just statically defining the return value as is done here. Currently that is blocked on not being able to inspect function signatures of compiled kernels in PyTorch (see pytorch#28233). See the docs I've added for usage examples of these new functions. I also refactored the existing override tests to make use of these new functions, which should be a good forcing function to make sure they're kept up-to-date. Finally, while working on this I discovered that `TestTorchFunctionOverrides.test_mean` and `TestTorchFunctionOverrides.test_mm` weren't ever being run because they were getting clobbered by the other dynamically generated override tests. I fixed that by renaming the tests and then fixing the actual test code. I've verified that all the subclassing semantics is correct and that the updated test answers are correct. I'm happy to put the fixes to the existing tests in as a separate pull request if that would be easier to review. ping cpuhrsch since the feature request originally came from them. Pull Request resolved: pytorch#33791 Differential Revision: D20195053 Pulled By: cpuhrsch fbshipit-source-id: 1585f4e405f5223932b410eae03a288dc8eb627e
1 parent 51364d0 commit b62fce8

File tree

3 files changed

+724
-628
lines changed

3 files changed

+724
-628
lines changed

docs/source/notes/extending.rst

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,54 @@ a case the rules are:
484484
* If all of the ``__torch_function__`` implementations return
485485
``NotImplemented``, PyTorch raises a ``TypeError``.
486486

487+
Testing Coverage of Overrides for the PyTorch API
488+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
489+
490+
One troublesome aspect of implementing ``__torch_function__`` is that if some
491+
operations do and others do not have overrides, users will at best see an
492+
inconsistent experience, or at worst will see errors raised at runtime when they
493+
use a function that does not have an override. To ease this process, PyTorch
494+
provides a developer-facing API for ensuring full support for
495+
``__torch_function__`` overrides. This API is private and may be subject to
496+
changes without warning in the future.
497+
498+
First, to get a listing of all overridable functions, use
499+
``torch._overrides.get_overridable_functions``. This returns a dictionary whose
500+
keys are namespaces in the ``PyTorch`` Python API and whose values are a list of
501+
functions in that namespace that can be overriden. For example, let's print the
502+
names of the first 5 functions in ``torch.nn.functional`` that can be
503+
overriden::
504+
505+
>>> from torch._overrides import get_overridable_functions
506+
>>> func_dict = get_overridable_functions()
507+
>>> nn_funcs = func_dict[torch.nn.functional]
508+
>>> print([f.__name__ for f in nn_funcs[:5])
509+
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
510+
'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices']
511+
512+
This listing of functions makes it possible to iterate over all overridable
513+
functions, however in practice this is not enough to write tests for all of
514+
these functions without laboriously and manually copying the signature of each
515+
function for each test. To ease this process, the
516+
``torch._overrides.get_testing_overrides`` function returns a dictionary mapping
517+
overridable functions in the ``PyTorch`` API to dummy lambda functions that have
518+
the same signature as the original function but unconditionally return -1. These
519+
functions are most useful to use with ``inspect`` to analyze the function
520+
signature of the original ``PyTorch`` function::
521+
522+
>>> import inspect
523+
>>> from torch._overrides import get_testing_overrides
524+
>>> override_dict = get_testing_overrides()
525+
>>> dummy_add = override_dict[torch.add]
526+
>>> inspect.signature(dummy_add)
527+
<Signature (input, other, out=None)>
528+
529+
Finally, ``torch._overrides.get_ignored_functions`` returns a tuple of functions
530+
that explicitly cannot be overrided by ``__torch_function__``. This list can be
531+
useful to confirm that a function that isn't present in the dictionary returned
532+
by ``get_overridable_functions`` cannot be overriden.
533+
534+
487535
Writing custom C++ extensions
488536
-----------------------------
489537

0 commit comments

Comments
 (0)
0