-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Full static typing for torch.distributions
#144219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Full static typing for torch.distributions
#144219
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144219
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2a8e2ec with merge base aec3ef1 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "module: typing" |
@pytorchbot label "module: distributions" |
@pytorchbot label "release notes: python_frontend" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More accurate would be using some sort of polymorphic mapping like
class _KL_REGISTRY_TYPE(Protocol):
def __iter__(self) -> tuple[type[Distribution], type[Distribution]]: ...
def __getitem__(self, key: tuple[type[P], type[Q]], / ) -> _KL[P, Q]: ...
def __setitem__(self, key: tuple[type[P], type[Q]], value: _KL[P, Q], /) -> None: ...
def __delitem__(self, key: tuple[type[Distribution], type[Distribution]], /) -> None: ...
def clear(self) -> None: ...
but likely this overcomplicates things unnecessarily
This comment was marked as outdated.
This comment was marked as outdated.
@randolf-scholz Still interested in merging this? |
@Skylion007 Yes, this was quite a bit of work, and it would be a shame if it goes to waste... As I wrote in my last comment and in the OP, there are a few remaining open problems that mostly stem from LSP violations. It would be good to get some feedback on these. Also, @malfet suggested in the other PR (the one adding |
fbd5731
to
1ffc7c1
Compare
@Skylion007 I rebased onto main and squashed some of my commits |
@@ -206,16 +214,19 @@ def support(self): | |||
|
|||
def __init__( | |||
self, | |||
fn: Optional[Callable[..., Any]] = None, | |||
fn: Optional[Callable[[T], R]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn mor 8000 e.
Why not ParamSpec here? T isn't used elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ParamSpec
does not make sense here, since this is a property-decorator, and properties are usually not supposed to take arguments. (note that T
is actually the type the property gets bound to)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some nits, but this is definitely an improvement for what was there before.
__all__ = ["register_kl", "kl_divergence"] | ||
|
||
P = TypeVar("P", bound=Distribution) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want these types to be public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The codebase seems to be inconsistent with respect to that, but more often than not uses private variables. Personally, I strongly prefer non-private, because it makes hints that show up for instance with pylance more readable. Moreover, when support for 3.11 is dropped and PEP 695 is used there is really no reason anymore to use an underscore prefix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But for the purposes of this PR, I am fine with changing it if necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're the ones setting a lot of the standards for type naming!
I personally like the _
.
But "The names, they convey nothing!" :-D
DistributionT1 = TypeVar("DistributionT1", bound=Distribution)
DistributionT2 = TypeVar("DistributionT2", bound=Distribution)
DistributionT3 = TypeVar("DistributionT3", bound=Distribution)
DistributionT4= TypeVar("DistributionT3", bound=Distribution)
and
DistributionBinaryFunc: Callable[[DistributionT1, DistributionT2], Tensor]
|
||
def register_kl(type_p, type_q): | ||
def register_kl(type_p: type[P], type_q: type[Q]) -> _KL_Decorator: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hold on, when does P, and Q are different from P2, Q2, this feels like the PERFECT place for a static type check tha the function you are decorating matches here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be nice, but I do not see how to do it currently, I think it requires HKTs. That's because below you have some cases where the same function gets decorated multiple times like
@register_kl(Normal, Beta)
@register_kl(Normal, ContinuousBernoulli)
@register_kl(Normal, Exponential)
@register_kl(Normal, Gamma)
@register_kl(Normal, Pareto)
@register_kl(Normal, Uniform)
def _kl_normal_infinity(
p: Normal, q: Union[Beta, ContinuousBernoulli, Exponential, Gamma, Pareto, Uniform]
) -> Tensor:
return _infinite_like(p.loc)
So, the fist decorator must return Callable[[Normal, Union[Beta, ContinuousBernoulli, Exponential, Gamma, Pareto, Uniform]], Tensor]
, and not just Callable[[Normal, Uniform], Tensor]
, otherwise the next decorator will cause a type error.
What would be ideal would be something like
class _KL_Decorator[P, Q](Protocol):
def __call__[P2 :> P, Q2 :> Q](self, arg: _KL[P2, Q2], /) -> _KL[P2, Q2]: ...
def register_kl(type_p: type[P], type_q: type[Q]) -> _KL_Decorator[P, Q]:
So that for instance the first @register_kl(Normal, Uniform)
would produce a
class _KL_Decorator[Normal, Uniform](Protocol):
def __call__[P2: Normal, Q2: Uniform](self, arg: _KL[P2, Q2], /) -> _KL[P2, Q2]: ...
Because then when this gets applied to Callable[[Normal, Uniform | Pareto |...], Tensor]
, it gives back Callable[[Normal, Uniform | Pareto |...], Tensor]
.
But this requires HKTs, which are currently not available in the python typing system.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually with mypy==1.15
similar issues crop up in the constraint_registry.py
file, it seems one also needs to loosen the Factory
type hint in a similar manner.
@@ -83,7 +80,7 @@ def expand(self, batch_shape, _instance=None): | |||
new._validate_args = self._validate_args | |||
return new | |||
|
|||
def _new(self, *args, **kwargs): | |||
def _new(self, *args: Any, **kwargs: Any) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might be able to do a typevaruple for *Args here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly, but all this function does is call TensorBase.new
which currently has an *args: Any
overload. Really what would be needed here is the ability to reference the signature of another function, which is currently not a feature supported by the type system.
1b65b97
to
e625cb4
Compare
This comment was marked as outdated.
This comment was marked as outdated.
+ removed '## mypy: allow-untyped-defs' comments
+ removed unused 'type: ignore' comments
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
e625cb4
to
73f160e
Compare
@@ -735,3 +759,34 @@ def check(self, value): | |||
positive_definite = _PositiveDefinite() | |||
cat = _Cat | |||
stack = _Stack | |||
|
|||
# Type aliases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are now all reexported and demand doc strings I think?
@skylion: Sigh, the delta is greater than 2k lines, and this makes the "sanity check" test fail.
It's generally pretty easy to split typing pull requests. Starting a full review now, don't split before that! 🙂 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whew, that's a big one!
I read every line though.
Thanks for doing this all!
|
||
def build_constraint( | ||
constraint_fn: Union[C, type[C]], | ||
args: tuple, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What, we can just do that as a synonym for tuple[Any, ...]
? I had gotten the impression that this wouldn't work with mypy?
This is a test, is it even being type checked at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, the lintrunner
is ignoring these, but I do check them locally because the runtime code helps debugging the annotations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
$ mypy torch/distributions/ test/distributions/ --warn-unused-ignores
test/distributions/test_distributions.py:3677:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor" [arg-type]
test/distributions/test_distributions.py:3687:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor" [arg-type]
test/distributions/test_distributions.py:3697:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor" [arg-type]
test/distributions/test_distributions.py:3707:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor" [arg-type]
test/distributions/test_distributions.py:5225:41: error: Argument 1 to "log_prob" of "Gamma" has incompatible type "int"; expected "Tensor" [arg-type]
test/distributions/test_distributions.py:5251:40: error: Argument 1 to "log_prob" of "Gamma" has incompatible type "int"; expected "Tensor" [arg-type]
Found 6 errors in 1 file (checked 52 source files)
""" | ||
Create 179B s a pair of distributions `Dist` initialized to test each element of | ||
param with each other. | ||
""" | ||
params1 = [torch.tensor([p] * len(p)) for p in params] | ||
params2 = [p.transpose(0, 1) for p in params1] | ||
return Dist(*params1), Dist(*params2) | ||
return Dist(*params1), Dist(*params2) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, why does this fail? It should understand that both params1
and params2
are Sequence[Tensor]
and not have an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mypy
infers D
as Distribution
and Distribution.__init__
only expects batch_shape
, event_shape
and validate_args
.
What we could do is make these arguments keyword-only in Distribution.__init__
, then the error goes away. Probably a good idea from a design POV, but backward incompatible!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think maybe this is better for a follow-up PR.
@@ -1266,7 +1287,9 @@ def _check_forward_ad(self, fn): | |||
torch.count_nonzero(fwAD.unpack_dual(dual_out).tangent).item(), 0 | |||
) | |||
|
|||
def _check_log_prob(self, dist, asset_fn): | |||
def _check_log_prob( | |||
self, dist: Distribution, asset_fn: Callable[Concatenate[int, ...], None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so... cool... I had no idea you could do that, it's obvious only in hindsight.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defining partial function signatures can be really handy, I wish this was even better supported (for instance when writing Callback Protocols)
__all__ = ["register_kl", "kl_divergence"] | ||
|
||
P = TypeVar("P", bound=Distribution) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're the ones setting a lot of the standards for type naming!
I personally like the _
.
But "The names, they convey nothing!" :-D
DistributionT1 = TypeVar("DistributionT1", bound=Distribution)
DistributionT2 = TypeVar("DistributionT2", bound=Distribution)
DistributionT3 = TypeVar("DistributionT3", bound=Distribution)
DistributionT4= TypeVar("DistributionT3", bound=Distribution)
and
DistributionBinaryFunc: Callable[[DistributionT1, DistributionT2], Tensor]
@@ -28,7 +30,14 @@ class Pareto(TransformedDistribution): | |||
alpha (float or Tensor): Shape parameter of the distribution | |||
""" | |||
|
|||
arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive} | |||
arg_constraints: ClassVar[dict[str, Constraint]] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could conceivably a TypedDict
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that would be a nice enhancement. Thinking about it, it should probably even be Final[ClassVar[SomeTypedDict]]
, but that would require running mypy
with --python-version 3.13
. (currently 3.11 in mypy.ini
)
if self._cache_size == cache_size: | ||
return self | ||
if type(self).__init__ is Transform.__init__: | ||
return type(self)(cache_size=cache_size) | ||
raise NotImplementedError(f"{type(self)}.with_cache is not implemented") | ||
|
||
def __eq__(self, other): | ||
def __eq__(self, other: object) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why object
over Any
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any
is for gradual types, really no good reason to use it here. The signature mirrors that of object.__eq__
@rec I implemented most of your suggestions. So, splitting it up I think it would make most sense to first make a PR for Regarding Alternatively, one could just make the classes they are pointing to public, but it's not my decision to make. |
Well, this was extremely educational, with at least one head-slapper. You really covered everything in your response and I also agree with your plan to split. Regarding the type-aliases in I think it's good to go! |
Fixes #144196 Part of #144219 Pull Request resolved: #154712 Approved by: https://github.com/Skylion007
Fixes pytorch#144196 Part of pytorch#144219 Pull Request resolved: pytorch#154712 Approved by: https://github.com/Skylion007
Fixes #144196
Extends #144197 #144106 #144110
Open Problems /// LSP violations
mixture_same_family.py
:cdf
andlog_prob
violate LSP (argument namedx
instead ofvalue
).exp_family.py
: LSP problem with_log_normalizer
(parent class requires(*natural_params: Tensor) -> Tensor
, subclasses implement(a: Tensor, b: Tensor) -> Tensor
).(natural_params: Tuple[Tensor, ...]) -> Tensor
. While this is BC breaking, (a) this is a private method, i.e. implementation detail, and (b) no one other than torch seems to overwrite itconstraints.py
:dependent_property
: mypy does not apply the same special casing to subclasses ofproperty
as it does toproperty
itself, hence the need fortype: ignore[assignment]
statements.relaxed_bernoulli.py
,relaxed_categorical.py
,logistic_normal.py
,log_normal.py
,kumaraswamy.py
,half_cauchy.py
,half_normal.py
,inverse_gamma.py
,gumbel.py
,weibull.py
.lazy_property
indistributions/utils
.constraints.py
public interface not usable as type hints.TypeAlias
-variants, but that is likely not the best solution.transforms.py
:_InverseTransform.with_cache
violates LSP.with_cache
to return_InverseTransform
.test_distributions.py
: One test usesDist.arg_constraints.get
, hence assumesarg_constraints
is a class-attribute, but the base classDistribution
defines it as a@property
.test_distributions.py
: One test usesDist.support.event_dim
, hence assumessupport
is a class-attribute, but the base classDistribution
defines it as a@property
.test_distributions.py
: Multiple tests usedist.cdf(float)
, but the base class annotatescdf(Tensor) -> Tensor
.test_distributions.py
: Multiple tests usedist.log_prob(float)
, but the base class annotateslog_prob(Tensor) -> Tensor
.Notes
__init__.py
: use+=
instead ofextends
(ruff PYI056)binomial.py
: Allowfloat
arguments inprobs
andlogits
(gets used in tests)constraints.py
: made_DependentProperty
a generic class, and_DependentProperty.__call__
polymorphic.constraint_registry.py
: MadeConstraintRegistry.register
a polymorphic method, checking that the factory is compatible with the constraint.constraint_registry.py
: Needed to addtype: ignore
comments to functions that try to register multiple different constraints at once.dirichlet.py
:@once_differentiable
is untyped, requirestype: ignore[misc]
comment.dirichlet.py
:ctx: Any
could be replaced withctx: FunctionContext
, however, the type lacks thesaved_tensors
attribute.distribution.py
:Distribution._get_checked_instance
Accessing"__init__"
on an instance is unsound, requirestype: ignore
comment.distribution.py
: Changedsupport
fromOptional[Constraint]
toConstraint
(consistent with the existing docstring, and several functions in tests rely on this assumption)exp_family.py
: small update toExponentialFamily.entropy
to fix type error.independent.py
: fixed type bug inIndependent.support
.multivariate_normal.py
: Addedtype: ignore
comments to_batch_mahalanobis
caused by1.relaxed_bernoulli.py
: Allow float temperature argument (used in tests)relaxed_categorical.py
: Allow float temperature argument (used in tests)transforms.py
: Needed to changeComposeTransform.__init__
signature to acceptSequence[Transform]
rather than justlist[Transform]
(covariance!)transformed_distribution.py
: Needed to changeTransformedDistribution.__init__
signature to acceptSequence[Transform]
rather than justlist[Transform]
(covariance!)transformed_distribution.py
:TransformedDistribution.support
is problematic, because the parent class defines it as@property
but several subclasses define it as an attribute, violating LSP.von_mises.py
: fixedresult
type being initialized asfloat
instead ofTensor
.von_mises.py
:@torch.jit.script_if_tracing
is untyped, requirestype: ignore[misc]
comment.von_mises.py
: Allow floatloc
andscale
(used in tests)cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @malfet @xuzhao9 @gramster
Footnotes
torch.Size
is not correctly typed, causingmypy
to thinkSize + Size
istuple[int, ...]
instead ofSize
, see https://github.com/pytorch/pytorch/issues/144218. ↩