8000 [typing] Add type hints to `__init__` methods in `torch.distributions`. by randolf-scholz · Pull Request #144197 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[typing] Add type hints to __init__ methods in torch.distributions. #144197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
ef2329f
added type hints to lazy_property
randolf-scholz Jan 2, 2025
0937a2a
added type hints to lazy_properties & simplified 'torch.Tensor' hints…
randolf-scholz Jan 2, 2025
61366fe
fixed typing errors
randolf-scholz Jan 2, 2025
1fc1bd9
replace 'Dict' with 'dict'
randolf-scholz Jan 2, 2025
aa3c769
replace 'Tuple' with 'tuple'
randolf-scholz Jan 2, 2025
dac5f4d
replace 'Type' with 'type'
randolf-scholz Jan 2, 2025
a9f72e1
fixed typing of transform._inv
randolf-scholz Jan 2, 2025
9e9a1bb
fixed circular import
randolf-scholz Jan 2, 2025
092be0b
added unit test
randolf-scholz Jan 2, 2025
def92d5
lintrunner fixes
randolf-scholz Jan 3, 2025
6e54911
added type hints for __init__ signatures
randolf-scholz Jan 4, 2025
d568418
replaced Any with more concrete type in Distribution.support
randolf-scholz Jan 4, 2025
6c5c878
Merge branch 'annotate_dist_properties' into distributions_type_hint_…
randolf-scholz Jan 4, 2025
d250c91
added second typevar to lazy_property
randolf-scholz Jan 4, 2025
14c0db1
Merge branch 'annotate_dist_properties' into distributions_type_hint_…
randolf-scholz Jan 4, 2025
c5a4d02
use typing.Union for 3.9 compat
randolf-scholz Jan 5, 2025
64d9269
wishart.py: use Tensor instead of torch.Tensor annotation
randolf-scholz Jan 5, 2025
b2852f6
merged
randolf-scholz Jan 5, 2025
9ca3df5
use typing.Union for 3.9 compat
randolf-scholz Jan 5, 2025
dc1010d
added missing type hints
randolf-scholz Jan 5, 2025
70b0634
added missing type hints and fixed type error
randolf-scholz Jan 5, 2025
a028526
use typing.Union for 3.9 compat
randolf-scholz Jan 5, 2025
98d12c3
wishart.py added missing type hint
randolf-scholz Jan 5, 2025
4da39a8
fixed torch.Size related failure
randolf-scholz Jan 5, 2025
e18d00a
fixed missing type hint for cache_size
randolf-scholz Jan 5, 2025
9c8c6ba
Update torch/distributions/geometric.py
randolf-scholz Jan 5, 2025
7656785
added missing type hint in CatTransform.__init__
randolf-scholz Jan 5, 2025
c9f1516
Merge branch 'main' into distributions_type_hint_init
randolf-scholz Jan 8, 2025
92f07f6
made distributiions.Independent a generic class of the base distribut…
randolf-scholz Jan 8, 2025
5dd4a6b
Merge branch 'main' into distributions_type_hint_init
randolf-scholz Jan 9, 2025
dad0e1d
removed non-typing changes
randolf-scholz Jan 9, 2025
ce07b51
removed unused type: ignore comments
randolf-scholz Jan 9, 2025
c344806
updated binomial
randolf-scholz Jan 15, 2025
a5e7396
merged main
randolf-scholz Jan 15, 2025
9caeb08
Merge branch 'main' into distributions_type_hint_init
randolf-scholz Jan 28, 2025
3924079
lintrunner fixes
randolf-scholz Jan 28, 2025
212325e
use assert logits is not None instead of elif branch
randolf-scholz Jan 28, 2025
0116bfc
removed spurious newlines
randolf-scholz Jan 28, 2025
aafae2d
Merge branch 'main' into distributions_type_hint_init
randolf-scholz Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added second typevar to lazy_property
  • Loading branch information
randolf-scholz committed Jan 4, 2025
commit d250c9129002ea40bc33f45f2ff109b27c43b1b5
24 changes: 13 additions & 11 deletions torch/distributions/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
from functools import update_wrapper
from numbers import Number
from typing import Any, Callable, Generic, overload, TypeVar
from typing import Any, Callable, Generic, overload
from typing_extensions import TypeVar

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -131,34 +132,35 @@ def probs_to_logits(probs, is_binary=False):
return torch.log(ps_clamped)


T = TypeVar("T", covariant=True)
T = TypeVar("T", contravariant=True)
R = TypeVar("R", covariant=True)


class lazy_property(Generic[T]):
class lazy_property(Generic[T, R]):
r"""
Used as a decorator for lazy loading of class attributes. This uses a
non-data descriptor that calls the wrapped method to compute the property on
first call; thereafter replacing the wrapped method into an instance
attribute.
"""

def __init__(self, wrapped: Callable[..., T]) -> None:
self.wrapped: Callable[..., T] = wrapped
def __init__(self, wrapped: Callable[[T], R]) -> None:
self.wrapped: Callable[[T], R] = wrapped
update_wrapper(self, wrapped) # type:ignore[arg-type]

@overload
def __get__(
self, instance: None, obj_type: Any = None
) -> "_lazy_property_and_property[T]":
) -> "_lazy_property_and_property[T, R]":
...

@overload
def __get__(self, instance: object, obj_type: Any = None) -> T:
def __get__(self, instance: T, obj_type: Any = None) -> R:
...

def __get__(
self, instance: object, obj_type: Any = None
) -> "T | _lazy_property_and_property[T]":
self, instance: T | None, obj_type: Any = None
) -> "R | _lazy_property_and_property[T, R]":
if instance is None:
return _lazy_property_and_property(self.wrapped)
with torch.enable_grad():
Expand All @@ -167,14 +169,14 @@ def __get__(
return value


class _lazy_property_and_property(lazy_property[T], property):
class _lazy_property_and_property(lazy_property[T, R], property):
"""We want lazy properties to look like multiple things.

* property when Sphinx autodoc looks
* lazy_property when Distribution validate_args looks
"""

def __init__(self, wrapped: Callable[..., T]) -> None:
def __init__(self, wrapped: Callable[[T], R]) -> None:
property.__init__(self, wrapped)


Expand Down
Loading
0