8000 use assert logits is not None instead of elif branch · pytorch/pytorch@212325e · GitHub
[go: up one dir, main page]

Skip to content

Commit 212325e

Browse files
use assert logits is not None instead of elif branch
1 parent 3924079 commit 212325e

11 files changed

+24
-13
lines changed

torch/distributions/bernoulli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ def __init__(
5656
is_scalar = isinstance(probs, _Number)
5757
(self.probs,) = broadcast_all(probs)
5858
else:
59+
assert logits is not None # helps mypy
5960
is_scalar = isinstance(logits, _Number)
6061
(self.logits,) = broadcast_all(logits)
62+
6163
self._param = self.probs if probs is not None else self.logits
6264
if is_scalar:
6365
batch_shape = torch.Size()

torch/distributions/beta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
concentration1, concentration0
5353
)
5454
concentration1_concentration0 = torch.stack(
55-
[concentration1, concentration0], -1 # type: ignore[list-item]
55+
[concentration1, concentration0], -1
5656
)
5757
self._dirichlet = Dirichlet(
5858
concentration1_concentration0, validate_args=validate_args

torch/distributions/binomial.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
) = broadcast_all(total_count, probs)
7070
self.total_count = self.total_count.type_as(self.probs)
7171
else:
72+
assert logits is not None # helps mypy
7273
(
7374
self.total_count,
7475
self.logits,

torch/distributions/categorical.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@ def __init__(
6666
if probs.dim() < 1:
6767
raise ValueError("`probs` parameter must be at least one-dimensional.")
6868
self.probs = probs / probs.sum(-1, keepdim=True)
69-
if logits is not None: # Note: 'if is None' instead of 'else' to help mypy
69+
else:
70+
assert logits is not None # helps mypy
7071
if logits.dim() < 1:
7172
raise ValueError("`logits` parameter must be at least one-dimensional.")
7273
# Normalize
7374
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
75+
7476
self._param = self.probs if probs is not None else self.logits
7577
self._num_events = self._param.size()[-1]
7678
batch_shape = (

torch/distributions/continuous_bernoulli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def __init__(
7272
raise ValueError("The parameter probs has invalid values")
7373
self.probs = clamp_probs(self.probs)
7474
else:
75+
assert logits is not None # helps mypy
7576
is_scalar = isinstance(logits, _Number)
7677
(self.logits,) = broadcast_all(logits)
78+
7779
self._param = self.probs if probs is not None else self.logits
7880
if is_scalar:
7981
batch_shape = torch.Size()

torch/distributions/geometric.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,17 @@ def __init__(
5858
)
5959
if probs is not None:
6060
(self.probs,) = broadcast_all(probs)
61-
if logits is not None: # Note: 'if is None' instead of 'else' to help mypy
61+
else:
62+
assert logits is not None # helps mypy
6263
(self.logits,) = broadcast_all(logits)
64+
6365
probs_or_logits = probs if probs is not None else logits
6466
if isinstance(probs_or_logits, _Number):
6567
batch_shape = torch.Size()
6668
else:
67-
assert probs_or_logits is not None
69+
assert probs_or_logits is not None # helps mypy
6870
batch_shape = probs_or_logits.size()
71+
6972
super().__init__(batch_shape, validate_args=validate_args)
7073
if self._validate_args and probs is not None:
7174
# Add an extra check beyond unit_interval

torch/distributions/multivariate_normal.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ def __init__(
166166
covariance_matrix.shape[:-2], loc.shape[:-1]
167167
)
168168
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
169-
elif precision_matrix is not None:
169+
else:
170+
assert precision_matrix is not None # helps mypy
170171
if precision_matrix.dim() < 2:
171172
raise ValueError(
172173
"precision_matrix must be at least two-dimensional, "
@@ -176,11 +177,6 @@ def __init__(
176177
precision_matrix.shape[:-2], loc.shape[:-1]
177178
)
178179
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
179-
else: # Note: redundant check, only here to make linters happy
180-
raise ValueError(
181-
"At least one of covariance_matrix, precision_matrix or scale_tril must be specified."
182-
)
183-
184180
self.loc = loc.expand(batch_shape + (-1,))
185181

186182
event_shape = self.loc.shape[-1:]

torch/distributions/negative_binomial.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
) = broadcast_all(total_count, probs)
5858
self.total_count = self.total_count.type_as(self.probs)
5959
else:
60+
assert logits is not None # helps mypy
6061
(
6162
self.total_count,
6263
self.logits,

torch/distributions/relaxed_bernoulli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ def __init__(
5858
is_scalar = isinstance(probs, _Number)
5959
(self.probs,) = broadcast_all(probs)
6060
else:
61+
assert logits is not None # helps mypy
6162
is_scalar = isinstance(logits, _Number)
6263
(self.logits,) = broadcast_all(logits)
64+
6365
self._param = self.probs if probs is not None else self.logits
6466
if is_scalar:
6567
batch_shape = torch.Size()

torch/distributions/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def with_cache(self, cache_size=1):
593593

594594
@lazy_property
595595
def sign(self) -> int: # type: ignore[override]
596-
return self.exponent.sign()
596+
return self.exponent.sign() # type: ignore[return-value]
597597

598598
def __eq__(self, other):
599599
if not isinstance(other, PowerTransform):

torch/distributions/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88
from torch import Tensor
99
from torch.overrides import is_tensor_like
10-
from torch.types import _Number
10+
from torch.types import _Number, Number
1111

1212

1313
euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
@@ -23,7 +23,9 @@
2323
]
2424

2525

26-
def broadcast_all(*values):
26+
# FIXME: Use (*values: *Ts) -> tuple[Tensor for T in Ts] if Mapping-Type is ever added.
27+
# See https://github.com/python/typing/issues/1216#issuecomment-2126153831
28+
def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]:
2729
r"""
2830
Given a list of values (possibly containing numbers), returns a list where each
2931
value is broadcasted based on the following rules:

0 commit comments

Comments
 (0)
0