8000 Type hints for distributions/utils by randolf-scholz · Pull Request #154712 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Type hints for distributions/utils #154712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions torch/distributions/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
# mypy: allow-untyped-defs
from collections.abc import Sequence
from functools import update_wrapper
from typing import Any, Callable, Generic, overload, Union
from typing_extensions import TypeVar
from typing import Any, Callable, Final, Generic, Optional, overload, TypeVar, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch import SymInt, Tensor
from torch.overrides import is_tensor_like
from torch.types import _Number, Number
from torch.types import _dtype, _Number, Device, Number


euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
euler_constant: Final[float] = 0.57721566490153286060 # Euler Mascheroni Constant

__all__ = [
"broadcast_all",
Expand Down Expand Up @@ -59,7 +58,11 @@ def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]:
return torch.broadcast_tensors(*values)


def _standard_normal(shape, dtype, device):
def _standard_normal(
shape: Sequence[Union[int, SymInt]],
dtype: Optional[_dtype],
device: Optional[Device],
) -> Tensor:
if torch._C._get_tracing_state():
# [JIT WORKAROUND] lack of support for .normal_()
return torch.normal(
Expand All @@ -69,7 +72,7 @@ def _standard_normal(shape, dtype, device):
return torch.empty(shape, dtype=dtype, device=device).normal_()


def _sum_rightmost(value, dim):
def _sum_rightmost(value: Tensor, dim: int) -> Tensor:
r"""
Sum out ``dim`` many rightmost dimensions of a given tensor.

Expand All @@ -83,7 +86, 97C7 7 @@ def _sum_rightmost(value, dim):
return value.reshape(required_shape).sum(-1)


def logits_to_probs(logits, is_binary=False):
def logits_to_probs(logits: Tensor, is_binary: bool = False) -> Tensor:
r"""
Converts a tensor of logits into probabilities. Note that for the
binary case, each value denotes log odds, whereas for the
Expand All @@ -95,7 +98,7 @@ def logits_to_probs(logits, is_binary=False):
return F.softmax(logits, dim=-1)


def clamp_probs(probs):
def clamp_probs(probs: Tensor) -> Tensor:
"""Clamps the probabilities to be in the open interval `(0, 1)`.

The probabilities would be clamped between `eps` and `1 - eps`,
Expand All @@ -121,7 +124,7 @@ def clamp_probs(probs):
return probs.clamp(min=eps, max=1 - eps)


def probs_to_logits(probs, is_binary=False):
def probs_to_logits(probs: Tensor, is_binary: bool = False) -> Tensor:
r"""
Converts a tensor of probabilities into logits. For the binary case,
this denotes the probability of occurrence of the event indexed by `1`.
Expand Down
Loading
0