8000 Full static typing for `torch.distributions` by randolf-scholz · Pull Request #144219 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Full static typing for torch.distributions #144219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

randolf-scholz
Copy link
Contributor
@randolf-scholz randolf-scholz commented Jan 5, 2025

Fixes #144196
Extends #144197 #144106 #144110

Open Problems /// LSP violations

  • mixture_same_family.py: cdf and log_prob violate LSP (argument named x instead of value).
    • suggestion: Imo these kinds of methods should make use of positional-only parameters, at least in base classes.
  • exp_family.py: LSP problem with _log_normalizer (parent class requires (*natural_params: Tensor) -> Tensor, subclasses implement (a: Tensor, b: Tensor) -> Tensor).
    • suggestion: change parent class signature to (natural_params: Tuple[Tensor, ...]) -> Tensor. While this is BC breaking, (a) this is a private method, i.e. implementation detail, and (b) no one other than torch seems to overwrite it
  • constraints.py: dependent_property: mypy does not apply the same special casing to subclasses of property as it does to property itself, hence the need for type: ignore[assignment] statements.
    • affects: relaxed_bernoulli.py, relaxed_categorical.py, logistic_normal.py, log_normal.py, kumaraswamy.py, half_cauchy.py, half_normal.py, inverse_gamma.py, gumbel.py, weibull.py.
    • suggestion: consider a construction similar to lazy_property in distributions/utils.
  • constraints.py public interface not usable as type hints.
    • More crisp design would likely have one class per constraints, instead of using a mix of classes and instances.
    • suggestion: Add 1 class per constraint in the public interface; this can be subclasses of the existing ones.
    • As a workaround, I currently added a bunch of TypeAlias-variants, but that is likely not the best solution.
  • transforms.py: _InverseTransform.with_cache violates LSP.
    • suggestion: change with_cache to return _InverseTransform.
  • test_distributions.py: One test uses Dist.arg_constraints.get, hence assumes arg_constraints is a class-attribute, but the base class Distribution defines it as a @property.
  • test_distributions.py: One test uses Dist.support.event_dim, hence assumes support is a class-attribute, but the base class Distribution defines it as a @property.
  • test_distributions.py: Multiple tests use dist.cdf(float), but the base class annotates cdf(Tensor) -> Tensor.
    • suggestion: replace float values with tensors in test, unless floats should be officially supported. Note that floats are nonsensical for multivariate distributions, so supporting it would probably require introducing a subclass for univariate distributions.
  • test_distributions.py: Multiple tests use dist.log_prob(float), but the base class annotates log_prob(Tensor) -> Tensor.

Notes

  • __init__.py: use += instead of extends (ruff PYI056)
  • binomial.py: Allow float arguments in probs and logits (gets used in tests)
  • constraints.py: made _DependentProperty a generic class, and _DependentProperty.__call__ polymorphic.
  • constraint_registry.py: Made ConstraintRegistry.register a polymorphic method, checking that the factory is compatible with the constraint.
  • constraint_registry.py: Needed to add type: ignore comments to functions that try to register multiple different constraints at once.
    • maybe split them up?
  • dirichlet.py: @once_differentiable is untyped, requires type: ignore[misc] comment.
  • dirichlet.py: ctx: Any could be replaced with ctx: FunctionContext, however, the type lacks the saved_tensors attribute.
  • distribution.py: Distribution._get_checked_instance Accessing "__init__" on an instance is unsound, requires type: ignore comment.
  • distribution.py: Changed support from Optional[Constraint] to Constraint (consistent with the existing docstring, and several functions in tests rely on this assumption)
  • exp_family.py: small update to ExponentialFamily.entropy to fix type error.
  • independent.py: fixed type bug in Independent.support.
  • multivariate_normal.py: Added type: ignore comments to _batch_mahalanobis caused by1.
  • relaxed_bernoulli.py: Allow float temperature argument (used in tests)
  • relaxed_categorical.py: Allow float temperature argument (used in tests)
  • transforms.py: Needed to change ComposeTransform.__init__ signature to accept Sequence[Transform] rather than just list[Transform] (covariance!)
  • transformed_distribution.py: Needed to change TransformedDistribution.__init__ signature to accept Sequence[Transform] rather than just list[Transform] (covariance!)
  • transformed_distribution.py: TransformedDistribution.support is problematic, because the parent class defines it as @property but several subclasses define it as an attribute, violating LSP.
  • von_mises.py: fixed result type being initialized as float instead of Tensor.
  • von_mises.py: @torch.jit.script_if_tracing is untyped, requires type: ignore[misc] comment.
  • von_mises.py: Allow float loc and scale (used in tests)

cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @malfet @xuzhao9 @gramster

Footnotes

  1. torch.Size is not correctly typed, causing mypy to think Size + Size is tuple[int, ...] instead of Size, see https://github.com/pytorch/pytorch/issues/144218.

Copy link
pytorch-bot bot commented Jan 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144219

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2a8e2ec with merge base aec3ef1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@randolf-scholz
Copy link
Contributor Author

@pytorchbot label "module: typing"

@pytorch-bot pytorch-bot bot added the module: typing Related to mypy type annotations label Jan 5, 2025
@randolf-scholz
Copy link
Contributor Author

@pytorchbot label "module: distributions"

@pytorch-bot pytorch-bot bot added the module: distributions Related to torch.distributions label Jan 5, 2025
@randolf-scholz
Copy link
Contributor Author

@pytorchbot label "release notes: python_frontend"

@pytorch-bot pytorch-bot bot added the release notes: python_frontend python frontend release notes category label Jan 5, 2025
@randolf-scholz randolf-scholz marked this pull request as ready for review January 5, 2025 22:05
@cpuhrsch cpuhrsch requested a review from Skylion007 January 7, 2025 06:33
@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 7, 2025
Copy link
Contributor Author
@randolf-scholz randolf-scholz left a comment
8000

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More accurate would be using some sort of polymorphic mapping like

class _KL_REGISTRY_TYPE(Protocol):
    def __iter__(self) -> tuple[type[Distribution], type[Distribution]]: ...
    def __getitem__(self, key: tuple[type[P], type[Q]], / ) -> _KL[P, Q]: ...
    def __setitem__(self, key: tuple[type[P], type[Q]], value: _KL[P, Q], /) -> None: ...
    def __delitem__(self, key: tuple[type[Distribution], type[Distribution]], /) -> None: ...
    def clear(self) -> None: ...

but likely this overcomplicates things unnecessarily

@randolf-scholz

This comment was marked as outdated.

@Skylion007
Copy link
Collaborator

@randolf-scholz Still interested in merging this?

@Skylion007 Skylion007 requested a review from malfet May 19, 2025 17:36
@randolf-scholz
Copy link
Contributor Author

@Skylion007 Yes, this was quite a bit of work, and it would be a shame if it goes to waste...

As I wrote in my last comment and in the OP, there are a few remaining open problems that mostly stem from LSP violations. It would be good to get some feedback on these.

Also, @malfet suggested in the other PR (the one adding __init__ signatures), which was also rather large, to split it up into multiple PRs to lessen the review burden. However, I am not sure how to do that reasonably. Is there a tool to automate this? Maybe git-explode?

@randolf-scholz randolf-scholz force-pushed the distributions_full_typing branch 2 times, most recently from fbd5731 to 1ffc7c1 Compare May 22, 2025 14:17
@randolf-scholz
Copy link
Contributor Author

@Skylion007 I rebased onto main and squashed some of my commits

@@ -206,16 +214,19 @@ def support(self):

def __init__(
self,
fn: Optional[Callable[..., Any]] = None,
fn: Optional[Callable[[T], R]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn mor 8000 e.

Why not ParamSpec here? T isn't used elsewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParamSpec does not make sense here, since this is a property-decorator, and properties are usually not supposed to take arguments. (note that T is actually the type the property gets bound to)

Copy link
Collaborator
@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some nits, but this is definitely an improvement for what was there before.

__all__ = ["register_kl", "kl_divergence"]

P = TypeVar("P", bound=Distribution)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want these types to be public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The codebase seems to be inconsistent with respect to that, but more often than not uses private variables. Personally, I strongly prefer non-private, because it makes hints that show up for instance with pylance more readable. Moreover, when support for 3.11 is dropped and PEP 695 is used there is really no reason anymore to use an underscore prefix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But for the purposes of this PR, I am fine with changing it if necessary.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're the ones setting a lot of the standards for type naming!

I personally like the _.

But "The names, they convey nothing!" :-D

DistributionT1 = TypeVar("DistributionT1", bound=Distribution)
DistributionT2 = TypeVar("DistributionT2", bound=Distribution)
DistributionT3 = TypeVar("DistributionT3", bound=Distribution)
DistributionT4= TypeVar("DistributionT3", bound=Distribution)

and

DistributionBinaryFunc: Callable[[DistributionT1, DistributionT2], Tensor]


def register_kl(type_p, type_q):
def register_kl(type_p: type[P], type_q: type[Q]) -> _KL_Decorator:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold on, when does P, and Q are different from P2, Q2, this feels like the PERFECT place for a static type check tha the function you are decorating matches here?

Copy link
Contributor Author
@randolf-scholz randolf-scholz May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be nice, but I do not see how to do it currently, I think it requires HKTs. That's because below you have some cases where the same function gets decorated multiple times like

@register_kl(Normal, Beta)
@register_kl(Normal, ContinuousBernoulli)
@register_kl(Normal, Exponential)
@register_kl(Normal, Gamma)
@register_kl(Normal, Pareto)
@register_kl(Normal, Uniform)
def _kl_normal_infinity(
    p: Normal, q: Union[Beta, ContinuousBernoulli, Exponential, Gamma, Pareto, Uniform]
) -> Tensor:
    return _infinite_like(p.loc)

So, the fist decorator must return Callable[[Normal, Union[Beta, ContinuousBernoulli, Exponential, Gamma, Pareto, Uniform]], Tensor], and not just Callable[[Normal, Uniform], Tensor], otherwise the next decorator will cause a type error.

What would be ideal would be something like

class _KL_Decorator[P, Q](Protocol):
    def __call__[P2 :> P, Q2 :> Q](self, arg: _KL[P2, Q2], /) -> _KL[P2, Q2]: ...

def register_kl(type_p: type[P], type_q: type[Q]) -> _KL_Decorator[P, Q]:

So that for instance the first @register_kl(Normal, Uniform) would produce a

class _KL_Decorator[Normal, Uniform](Protocol):
    def __call__[P2: Normal, Q2: Uniform](self, arg: _KL[P2, Q2], /) -> _KL[P2, Q2]: ...

Because then when this gets applied to Callable[[Normal, Uniform | Pareto |...], Tensor], it gives back Callable[[Normal, Uniform | Pareto |...], Tensor].

But this requires HKTs, which are currently not available in the python typing system.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually with mypy==1.15 similar issues crop up in the constraint_registry.py file, it seems one also needs to loosen the Factory type hint in a similar manner.

@@ -83,7 +80,7 @@ def expand(self, batch_shape, _instance=None):
new._validate_args = self._validate_args
return new

def _new(self, *args, **kwargs):
def _new(self, *args: Any, **kwargs: Any) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might be able to do a typevaruple for *Args here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly, but all this function does is call TensorBase.new which currently has an *args: Any overload. Really what would be needed here is the ability to reference the signature of another function, which is currently not a feature supported by the type system.

@randolf-scholz randolf-scholz force-pushed the distributions_full_typing branch from 1b65b97 to e625cb4 Compare May 25, 2025 18:35
@randolf-scholz

This comment was marked as outdated.

+ removed '## mypy: allow-untyped-defs' comments
@randolf-scholz randolf-scholz force-pushed the distributions_full_typing branch from e625cb4 to 73f160e Compare May 26, 2025 19:19
@@ -735,3 +759,34 @@ def check(self, value):
positive_definite = _PositiveDefinite()
cat = _Cat
stack = _Stack

# Type aliases.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are now all reexported and demand doc strings I think?

@Skylion007 Skylion007 requested review from rec and benjaminglass1 May 28, 2025 19:45
@rec
Copy link
Collaborator
rec commented May 29, 2025

@skylion: Sigh, the delta is greater than 2k lines, and this makes the "sanity check" test fail.

2025-05-27T18:27:27.4803708Z Your PR is 3049 LOC which is more than the 2000 maximum
2025-05-27T18:27:27.4804568Z allowed within PyTorch infra. PLease make sure to split up
2025-05-27T18:27:27.4805366Z your PR into smaller pieces that can be reviewed.
2025-05-27T18:27:27.4806075Z If you think that this rule should not apply to your PR,
2025-05-27T18:27:27.4806753Z please contact @albanD or @seemethere.

It's generally pretty easy to split typing pull requests.

Starting a full review now, don't split before that! 🙂

Copy link
Collaborator
@rec rec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whew, that's a big one!

I read every line though.

Thanks for doing this all!


def build_constraint(
constraint_fn: Union[C, type[C]],
args: tuple,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What, we can just do that as a synonym for tuple[Any, ...]? I had gotten the impression that this wouldn't work with mypy?

This is a test, is it even being type checked at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the lintrunner is ignoring these, but I do check them locally because the runtime code helps debugging the annotations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$ mypy torch/distributions/ test/distributions/ --warn-unused-ignores 
test/distributions/test_distributions.py:3677:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor"  [arg-type]
test/distributions/test_distributions.py:3687:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor"  [arg-type]
test/distributions/test_distributions.py:3697:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor"  [arg-type]
test/distributions/test_distributions.py:3707:19: error: Argument 1 to "cdf" of "TransformedDistribution" has incompatible type "float"; expected "Tensor"  [arg-type]
test/distributions/test_distributions.py:5225:41: error: Argument 1 to "log_prob" of "Gamma" has incompatible type "int"; expected "Tensor"  [arg-type]
test/distributions/test_distributions.py:5251:40: error: Argument 1 to "log_prob" of "Gamma" has incompatible type "int"; expected "Tensor"  [arg-type]
Found 6 errors in 1 file (checked 52 source files)

"""
Create 179B s a pair of distributions `Dist` initialized to test each element of
param with each other.
"""
params1 = [torch.tensor([p] * len(p)) for p in params]
params2 = [p.transpose(0, 1) for p in params1]
return Dist(*params1), Dist(*params2)
return Dist(*params1), Dist(*params2) # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, why does this fail? It should understand that both params1 and params2 are Sequence[Tensor] and not have an issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy infers D as Distribution and Distribution.__init__ only expects batch_shape, event_shape and validate_args.

What we could do is make these arguments keyword-only in Distribution.__init__, then the error goes away. Probably a good idea from a design POV, but backward incompatible!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think maybe this is better for a follow-up PR.

@@ -1266,7 +1287,9 @@ def _check_forward_ad(self, fn):
torch.count_nonzero(fwAD.unpack_dual(dual_out).tangent).item(), 0
)

def _check_log_prob(self, dist, asset_fn):
def _check_log_prob(
self, dist: Distribution, asset_fn: Callable[Concatenate[int, ...], None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so... cool... I had no idea you could do that, it's obvious only in hindsight.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining partial function signatures can be really handy, I wish this was even better supported (for instance when writing Callback Protocols)

__all__ = ["register_kl", "kl_divergence"]

P = TypeVar("P", bound=Distribution)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're the ones setting a lot of the standards for type naming!

I personally like the _.

But "The names, they convey nothing!" :-D

DistributionT1 = TypeVar("DistributionT1", bound=Distribution)
DistributionT2 = TypeVar("DistributionT2", bound=Distribution)
DistributionT3 = TypeVar("DistributionT3", bound=Distribution)
DistributionT4= TypeVar("DistributionT3", bound=Distribution)

and

DistributionBinaryFunc: Callable[[DistributionT1, DistributionT2], Tensor]

@@ -28,7 +30,14 @@ class Pareto(TransformedDistribution):
alpha (float or Tensor): Shape parameter of the distribution
"""

arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive}
arg_constraints: ClassVar[dict[str, Constraint]] = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could conceivably a TypedDict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that would be a nice enhancement. Thinking about it, it should probably even be Final[ClassVar[SomeTypedDict]], but that would require running mypy with --python-version 3.13. (currently 3.11 in mypy.ini)

if self._cache_size == cache_size:
return self
if type(self).__init__ is Transform.__init__:
return type(self)(cache_size=cache_size)
raise NotImplementedError(f"{type(self)}.with_cache is not implemented")

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why object over Any?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any is for gradual types, really no good reason to use it here. The signature mirrors that of object.__eq__

@randolf-scholz
Copy link
Contributor Author

@rec I implemented most of your suggestions.

So, splitting it up I think it would make most sense to first make a PR for constraints and constraints_registry, and
then distributions and transforms, since the annotations for other modules depend on these.

Regarding constraints, I added all these type-aliases at the bottom. The CI complained that these are not listed in __all__, so should I just add them to __all__?

Alternatively, one could just make the classes they are pointing to public, but it's not my decision to make.

@rec
Copy link
Collaborator
rec commented May 29, 2025

Well, this was extremely educational, with at least one head-slapper.

You really covered everything in your response and I also agree with your plan to split.

Regarding the type-aliases in constraints, IIRC linting requires them to be in __all__ exactly if they do not start with _, so it's up to you.

I think it's good to go!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributions Related to torch.distributions module: typing Related to mypy type annotations open source release notes: python_frontend python frontend release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[typing] Add static type hints to torch.distributions.
6 participants
0