8000 Avoid the builtin `numbers` module. · Issue #144788 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Avoid the builtin numbers module. #144788

8000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
randolf-scholz opened this issue Jan 14, 2025 · 5 comments
Closed

Avoid the builtin numbers module. #144788

randolf-scholz opened this issue Jan 14, 2025 · 5 comments
Labels
actionable module: distributions Related to torch.distributions module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@randolf-scholz
Copy link
Contributor
randolf-scholz commented Jan 14, 2025

🚀 The feature, motivation and pitch

Currently, torch uses the builtin numbers module in a few places (only ~40 hits). However, the numbers module is problematic for multiple reasons:

  1. The numbers module is incompatible with type annotations (see int is not a Number? python/mypy#3186, example: mypy-playground).

  2. Since it's just an abstract base class, it requires users to do Number.register(my_number_type) to ensure isinstance succeeds.

  3. Internally, torch.tensor doesn't seem to care if something is a numbers.Number, in fact, the supported types appear to be

    • symbolic torch scalars torch.SymBool, torch.SymInt and torch.SymFloat
    • numpy scalars numpy.int32, numpy.int64, numpy.float32, etc.
    • python built-in scalars bool, int, float, complex
    • things that can be converted to built-in scalars via __bool__, __int__, __index__, __float__ or __complex__ (requires specifying dtype)

    (see /torch/csrc/utils/tensor_new.cpp and torch/_refs/__init__.py)

    demo
    import torch
    from numbers import Real
    
    
    class MyReal(Real):
        """Simple wrapper class for float."""
        
        __slots__ = ("val")
    
        def __float__(self): return self.val.__float__()
        def __complex__(self): return  self.val.__complex__()
        
        def __init__(self, x) -> None:
            self.val = float(x)
        
        @property
        def real(self): return MyReal(self.val.real)
        @property
        def imag(self): return MyReal(self.val.imag)
        def conjugate(self): return MyReal(self.val.conjugate())
    
        def __abs__(self): return MyReal(self.val.__abs__())
        def __neg__(self): return MyReal(self.val.__neg__())
        def __pos__(self): return MyReal(self.val.__pos__())
        def __trunc__(self): return MyReal(self.val.__trunc__())
        def __floor__(self): return MyReal(self.val.__floor__())
        def __ceil__(self): return MyReal(self.val.__ceil__())
        def __round__(self, ndigits=None): return MyReal(self.val.__round__(ndigits=ndigits))
    
        def __eq__(self, other): return MyReal(self.val.__eq__(other))
        def __lt__(self, other): return MyReal(self.val.__lt__(other))
        def __le__(self, other): return MyReal(self.val.__le__(other))
        
        def __add__(self, other):  return MyReal(self.val.__add__(other))
        def __radd__(self, other):  return MyReal(self.val.__radd__(other))
        def __mul__(self, other):  return MyReal(self.val.__mul__(other))
        def __rmul__(self, other):  return MyReal(self.val.__rmul__(other))
        def __truediv__(self, other):  return MyReal(self.val.__truediv__(other))
        def __rtruediv__(s
    8000
    elf, other):  return MyReal(self.val.__rtruediv__(other))
        def __floordiv__(self, other): return MyReal(self.val.__floordiv__(other))
        def __rfloordiv__(self, other): return MyReal(self.val.__rfloordiv__(other))
        def __mod__(self, other): return MyReal(self.val.__mod__(other))
        def __rmod__(self, other): return MyReal(self.val.__rmod__(other))
    
        def __pow__(self, exponent): return MyReal(self.val.__pow__(exponent))
        def __rpow__(self, base): return MyReal(self.val.__rmod__(base))
    
    class Pi:
        def __float__(self) -> float: return 3.14
    
    torch.tensor(MyReal(3.14), dtype=float)  # ✅
    torch.tensor(Pi(), dtype=float)  # ✅
    
    torch.tensor(MyReal(3.14))  # ❌ Runtimerror: Could not infer dtype of MyReal
    torch.tensor(Pi())  # ❌ Runtimerror: Could not infer dtype of Pi    

Alternatives

There are 3 main alternatives:

  1. Use Union type of the supported types (tuple for python 3.9). torch already provides for example like torch.types.Number and torch._prims_common.Number
  2. Use builtin Protocol types like typing.SupportsFloat
  • The main disadvantage here is that Tensor, since it implements __float__, is a SupportsFloat itself, which could require changing some exisiting if-else tests.
  1. Provide a custom Protocol type.

Additional context

One concern could be speed of `isinstance(x, Number)`, below is a comparison between the approaches.
import torch
from numbers import Real
import numpy as np
from typing import SupportsFloat

T1 = Real
T2 = SupportsFloat
T3 = (bool, int, float, complex, torch.SymBool, torch.SymInt, torch.SymFloat, np.number)

print("Testing float")
x = 3.14
%timeit isinstance(x, T1)  # 237 ns ± 0.374 ns 
%timeit isinstance(x, T2)  # 214 ns ± 0.325 ns 
%timeit isinstance(x, T3)  #  35 ns ± 0.844 ns
print("Testing np.float32")
y = np.float32(3.14)
%timeit isinstance(y, T1)  # 106 ns ± 2.3 ns 
%timeit isinstance(y, T2)  # 223 ns ± 2.33 ns
%timeit isinstance(y, T3)  # 104 ns ± 0.52 ns
print("Testing Tensor")
z = torch.tensor(3.14)
%timeit isinstance(z, T1)  # 117 ns ± 0.962 ns 
%timeit isinstance(z, T2)  # 226 ns ± 0.508 ns 
%timeit isinstance(z, T3)  # 99.1 ns ± 0.699 ns
print("Testing string (non-match)")
w = "3.14"
%timeit isinstance(w, T1)  # 114 ns ± 1.47 ns 
%timeit isinstance(w, T2)  # 2.21 μs ± 79.2 ns
%timeit isinstance(w, T3)  # 95 ns ± 0.887 ns 

One can see that isinstance(val, SupportsFloat) is roughly twice as slow as isinstance(val, Real) for a positive, but can be a lot slower for a negative. The Union can be a lot faster, but the speed depends on the order of the members (if we put float last, the first run takes ~90ns, since the argument is checked sequentially against the provided types).

cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @malfet @xuzhao9 @gramster

@vmoens
Copy link
Contributor
vmoens commented Jan 15, 2025

Thanks for this.

I agree with all your points. I actually ran into the same issues with SymInt in tensordict and found that using number.Number was eventually more painful than anything.

The first alternative seems to be the most appropriate to me (also probably the easiest to compile). As you pointed out Protocols can be slow for instance check (which people will ultimately do).

@malfet If we remove all refs to number and that becomes a rule, maybe we should also incorporate that check in the linter?

@ezyang
Copy link
Contributor
ezyang commented Jan 15, 2025

yeah lets use our builtin number

@randolf-scholz
Copy link
Contributor Author
randolf-scholz commented Jan 16, 2025

Since torch still supports 3.9, one question is how to replace the isinstance(x, Number) checks. There seems to be 3 options:

  1. Use if isinstance(x, _Number), where _Number = (bool, int, float)
  2. Use if isinstance(x, typing.get_args(Number)) (however, this loses a lot of type info, since get_args returns tuple[Any, ...]
  3. Use if is_number(x), with helper function def is_number(x) -> TypeIs[Number].

① seems to be the one that's easiest to replace once 3.9 reaches EOL this year. I tried also the following, but mypy doesn't like it for 3.10+

if sys.version_info < (3, 10):
    _Number = (bool, int, float)
else:
    _Number: TypeAlias = Number

@ezyang
Copy link
Contributor
ezyang commented Jan 16, 2025

I think the tuple (1) is what most people expected

@drisspg drisspg added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module actionable and removed triage review labels Jan 27, 2025
nWEIdia pushed a commit to nWEIdia/pytorch that referenced this issue Jan 27, 2025
@randolf-scholz
Copy link
Contributor Author
randolf-scholz commented Jan 28, 2025

@ezyang This should probably be reopened, since #145086 was only a partial fix that only covers the module torch.distributions.

After this, there are 22 files left that use the numbers module in some form.

One issue that needs to be addressed: complex numbers (which torch.distributions doesn't really use)
According to https://pytorch.org/docs/2.6/complex_numbers.html these are still experimental?

But likely this means that we need both a real-valued type, which is the current torch.types.Number, and a potentially complex-valued type. _C/__init__.pyi uses Union[Tensor, torch.types.Number, complex] for the most part, so I guess that's what's supposed to be used in this case (instead of a separate aliases or Real / potentially complex numbers)?

I used the name "Number" in the PR, because that is what the docstrings in torch.distributions used, but really, these probably should have been numbers.Real all along fwiw.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable module: distributions Related to torch.distributions module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants
0