8000 added type hints to lazy_property · pytorch/pytorch@ef2329f · GitHub
[go: up one dir, main page]

Skip to content

Commit ef2329f

Browse files
added type hints to lazy_property
1 parent 0431d47 commit ef2329f

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

torch/distributions/utils.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# mypy: allow-untyped-defs
22
from functools import update_wrapper
33
from numbers import Number
4-
from typing import Any, Dict
4+
from typing import Any, Callable, Dict, Generic, overload, TypeVar
55

66
import torch
77
import torch.nn.functional as F
@@ -130,19 +130,34 @@ def probs_to_logits(probs, is_binary=False):
130130
return torch.log(ps_clamped)
131131

132132

133-
class lazy_property:
133+
T = TypeVar("T", covariant=True)
134+
135+
136+
class lazy_property(Generic[T]):
134137
r"""
135138
Used as a decorator for lazy loading of class attributes. This uses a
136139
non-data descriptor that calls the wrapped method to compute the property on
137140
first call; thereafter replacing the wrapped method into an instance
138141
attribute.
139142
"""
140143

141-
def __init__(self, wrapped):
142-
self.wrapped = wrapped
144+
def __init__(self, wrapped: Callable[..., T]) -> None:
145+
self.wrapped: Callable[..., T] = wrapped
143146
update_wrapper(self, wrapped) # type:ignore[arg-type]
144147

145-
def __get__(self, instance, obj_type=None):
148+
@overload
149+
def __get__(
150+
self, instance: None, obj_type: Any = None
151+
) -> "_lazy_property_and_property[T]":
152+
...
153+
154+
@overload
155+
def __get__(self, instance: object, obj_type: Any = None) -> T:
156+
...
157+
158+
def __get__(
159+
self, instance: object, obj_type: Any = None
160+
) -> "T | _lazy_property_and_property[T]":
146161
if instance is None:
147162
return _lazy_property_and_property(self.wrapped)
148163
with torch.enable_grad():
@@ -151,14 +166,14 @@ def __get__(self, instance, obj_type=None):
151166
return value
152167

153168

154-
class _lazy_property_and_property(lazy_property, property):
169+
class _lazy_property_and_property(lazy_property[T], property):
155170
"""We want lazy properties to look like multiple things.
156171
157172
* property when Sphinx autodoc looks
158173
* lazy_property when Distribution validate_args looks
159174
"""
160175

161-
def __init__(self, wrapped):
176+
def __init__(self, wrapped: Callable[..., T]) -> None:
162177
property.__init__(self, wrapped)
163178

164179

0 commit comments

Comments
 (0)
0