-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[typing] Add type hints to __init__
methods in torch.distributions
.
#144197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[typing] Add type hints to __init__
methods in torch.distributions
.
#144197
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144197
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit aafae2d with merge base 87a63a9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "module: typing" |
@pytorchbot label "module: distributions" |
@pytorchbot label "release notes: python_frontend" |
@randolf-scholz but I feel like 50+% of this change could be landed separately even with mypy suppression for the whole file, say something like torch/distributions/weibull.py which will make incremental reviews faster. |
One can do that, but I think it would make most sense to first fix the |
I added annotations to |
def broadcast_all(*values): | ||
# FIXME: Use (*values: *Ts) -> tuple[Tensor for T in Ts] if Mapping-Type is ever added. | ||
# See https://github.com/python/typing/issues/1216#issuecomment-2126153831 | ||
def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, the annotation does not cover the case of __torch_function__
.
This is somewhat outside the scope of this PR, I think possibly a Protocol
type should be added to torch.types
that covers the concrete semantics.
fd46589
to
212325e
Compare
One nice thing to also add would be overloads for pytorch/torch/distributions/bernoulli.py Line 43 in 87a63a9
However, the current API does not make use of keyword-only arguments for some reason. The following would be much nicer, imo: @overload
def __init__(self, probs: Tensor, *, validate_args: Optional[bool]=None) -> None: ...
@overload
def __init__(self, *, logits: Tensor, validate_args: Optional[bool]=None) -> None: ...
def __init__(self, probs: Optional[Tensor]=None, *, logits: Optional[Tensor]=None, validate_args: Optional[bool]=None) -> None: However, this would require some backward compatibility breaking changes, for instance the argument order in |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
@randolf-scholz Mind doing a rebase? @malfet Think this is blocked by your review |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/14295053243 |
@pytorchbot merge -f "Rebase fails, let's see if merge will be fine" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…`. (pytorch#144197) Fixes pytorch#144196 Extends pytorch#144106 and pytorch#144110 ## Open Problems: - [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. pytorch#144197 (comment) # Notes - `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped. - `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ pytorch#144402 - `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. pytorch#144218 - `independent.py`: made `Independent` a generic class of its base distribution. - `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - `relaxed_bernoulli.py`: added class-level type hint for `base_dist`. - `relaxed_categorical.py`: added class-level type hint for `base_dist`. - ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ pytorch#144401 - ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ pytorch#144400 - `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`. - `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1]. - `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`. - skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`. ## Remark `TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`. ```python import torch from torch.distributions import * b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0])) b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1)) t = StickBreakingTransform() d1 = TransformedDistribution(b1, t) d2 = TransformedDistribution(b2, t) print(d1.base_dist) # Independent with 1 dimension print(d2.base_dist) # MultivariateNormal ``` One could consider changing this to `if reinterpreted_batch_ndims > 1:`. [^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <python/mypy#3186>). This results in us having to add type-ignore comments in several places [^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped. Pull Request resolved: pytorch#144197 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
…`. (pytorch#144197) Fixes pytorch#144196 Extends pytorch#144106 and pytorch#144110 ## Open Problems: - [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. pytorch#144197 (comment) # Notes - `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped. - `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~` 9A82 dirichlet.py`: replaced `axis` with `dim` arguments.~~ pytorch#144402 - `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. pytorch#144218 - `independent.py`: made `Independent` a generic class of its base distribution. - `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - `relaxed_bernoulli.py`: added class-level type hint for `base_dist`. - `relaxed_categorical.py`: added class-level type hint for `base_dist`. - ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ pytorch#144401 - ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ pytorch#144400 - `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`. - `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1]. - `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`. - skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`. ## Remark `TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`. ```python import torch from torch.distributions import * b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0])) b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1)) t = StickBreakingTransform() d1 = TransformedDistribution(b1, t) d2 = TransformedDistribution(b2, t) print(d1.base_dist) # Independent with 1 dimension print(d2.base_dist) # MultivariateNormal ``` One could consider changing this to `if reinterpreted_batch_ndims > 1:`. [^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see <python/mypy#3186>). This results in us having to add type-ignore comments in several places [^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped. Pull Request resolved: pytorch#144197 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Fixes #144196
Extends #144106 and #144110
Open Problems:
numbers.Number
is a bad idea, should consider usingfloat
,SupportsFloat
or someProcotol
. [typing] Add type hints to__init__
methods intorch.distributions
. #144197 (comment)Notes
beta.py
: needed to addtype: ignore
sincebroadcast_all
is untyped.categorical.py
: convertedelse
branches of mutually exclusive arguments toif
branch1.dirichlet.py
: replacedaxis
withdim
arguments.Dirichlet.mode
: usedim=
instead ofaxis=
#144402gemoetric.py
: convertedelse
branches of mutually exclusive arguments toif
branch1.EDIT: turns out the bug is related to typing ofindependent.py
: fixed bug inIndependent.__init__
wheretuple[int, ...]
could be passed toDistribution.__init__
instead oftorch.Size
.torch.Size
. Improve static typing fortorch.Size
#144218independent.py
: madeIndependent
a generic class of its base distribution.multivariate_normal.py
: convertedelse
branches of mutually exclusive arguments toif
branch1.relaxed_bernoulli.py
: added class-level type hint forbase_dist
.relaxed_categorical.py
: added class-level type hint forbase_dist
.ReshapeTransform: added missing argument in docstring #144401transforms.py
: Added missing argument to docstring ofReshapeTransform
Fixtransforms.py
: Fixed bug inAffineTransform.sign
(could returnTensor
instead ofint
).AffineTransform.sign
#144400transforms.py
: Addedtype: ignore
comments toAffineTransform.log_abs_det_jacobian
2; replacedtorch.abs(scale)
withscale.abs()
.transforms.py
: Addedtype: ignore
comments toAffineTransform.__eq__
2.transforms.py
: Fixed type hint onCumulativeDistributionTransform.domain
. Note that this is still an LSP violation, becauseTransform.domain
is defined asConstraint
, butDistribution.domain
is defined asOptional[Constraint]
.constraints.py
,constraints_registry.py
,kl.py
,utils.py
,exp_family.py
,__init__.py
.Remark
TransformedDistribution
:__init__
uses the checkif reinterpreted_batch_ndims > 0:
, which can lead to the creation ofIndependent
distributions with only 1 component. This results in awkward code likebase_dist.base_dist
inLogisticNormal
.One could consider changing this to
if reinterpreted_batch_ndims > 1:
.cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @malfet @xuzhao9 @gramster
Footnotes
Otherwise, we would have to add a bunch of
type: ignore
comments to makemypy
happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped. ↩ ↩2 ↩3Usage of
isinstance(value, numbers.Real)
leads to problems with static typing, as thenumbers
module is not supported bymypy
(see https://github.com/python/mypy/issues/3186). This results in us having to add type-ignore comments in several places ↩ ↩2