1
1
# mypy: allow-untyped-defs
2
2
from functools import update_wrapper
3
3
from numbers import Number
4
- from typing import Any , Dict
4
+ from typing import Any , Callable , Dict , Generic , overload , TypeVar
5
5
6
6
import torch
7
7
import torch .nn .functional as F
@@ -130,19 +130,34 @@ def probs_to_logits(probs, is_binary=False):
130
130
return torch .log (ps_clamped )
131
131
132
132
133
- class lazy_property :
133
+ T = TypeVar ("T" , covariant = True )
134
+
135
+
136
+ class lazy_property (Generic [T ]):
134
137
r"""
135
138
Used as a decorator for lazy loading of class attributes. This uses a
136
139
non-data descriptor that calls the wrapped method to compute the property on
137
140
first call; thereafter replacing the wrapped method into an instance
138
141
attribute.
139
142
"""
140
143
141
- def __init__ (self , wrapped ) :
142
- self .wrapped = wrapped
144
+ def __init__ (self , wrapped : Callable [..., T ]) -> None :
145
+ self .wrapped : Callable [..., T ] = wrapped
143
146
update_wrapper (self , wrapped ) # type:ignore[arg-type]
144
147
145
- def __get__ (self , instance , obj_type = None ):
148
+ @overload
149
+ def __get__ (
150
+ self , instance : None , obj_type : Any = None
151
+ ) -> "_lazy_property_and_property[T]" :
152
+ ...
153
+
154
+ @overload
155
+ def __get__ (self , instance : object , obj_type : Any = None ) -> T :
156
+ ...
157
+
158
+ def __get__ (
159
+ self , instance : object , obj_type : Any = None
160
+ ) -> "T | _lazy_property_and_property[T]" :
146
161
if instance is None :
147
162
return _lazy_property_and_property (self .wrapped )
148
163
with torch .enable_grad ():
@@ -151,14 +166,14 @@ def __get__(self, instance, obj_type=None):
151
166
return value
152
167
153
168
154
- class _lazy_property_and_property (lazy_property , property ):
169
+ class _lazy_property_and_property (lazy_property [ T ] , property ):
155
170
"""We want lazy properties to look like multiple things.
156
171
157
172
* property when Sphinx autodoc looks
158
173
* lazy_property when Distribution validate_args looks
159
174
"""
160
175
161
- def __init__ (self , wrapped ) :
176
+ def __init__ (self , wrapped : Callable [..., T ]) -> None :
162
177
property .__init__ (self , wrapped )
163
178
164
179
0 commit comments