8000 Add optional generator to distribution sampler/rsample methods. by vladoovtcharov · Pull Request #146333 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add optional generator to distribution sampler/rsample methods. #146333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter 8000

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions torch/distributions/bernoulli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
from typing import Optional

import torch
from torch import nan, Tensor
from torch import nan, Tensor, Generator
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import (
Expand All @@ -12,7 +14,6 @@
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.types import _Number


__all__ = ["Bernoulli"]


Expand Down Expand Up @@ -100,10 +101,10 @@ def probs(self) -> Tensor:
def param_shape(self) -> torch.Size:
return self._param.size()

def sample(self, sample_shape=torch.Size()):
def sample(self, sample_shape=torch.Size(), generator: Optional[Generator] = None):
shape = self._extended_shape(sample_shape)
with torch.no_grad():
return torch.bernoulli(self.probs.expand(shape))
return torch.bernoulli(self.probs.expand(shape), generator)

def log_prob(self, value):
if self._validate_args:
Expand Down
9 changes: 5 additions & 4 deletions torch/distributions/beta.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# mypy: allow-untyped-defs
from typing import Optional

import torch
from torch import Tensor
from torch import Tensor, Generator
from torch.distributions import constraints
from torch.distributions.dirichlet import Dirichlet
from torch.distri 8000 butions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size


__all__ = ["Beta"]


Expand Down Expand Up @@ -73,8 +74,8 @@ def variance(self) -> Tensor:
total = self.concentration1 + self.concentration0
return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))

def rsample(self, sample_shape: _size = ()) -> Tensor:
return self._dirichlet.rsample(sample_shape).select(-1, 0)
def rsample(self, sample_shape: _size = (), generator: Optional[Generator] = None) -> Tensor:
return self._dirichlet.rsample(sample_shape, generator).select(-1, 0)

def log_prob(self, value):
if self._validate_args:
Expand Down
8 changes: 5 additions & 3 deletions torch/distributions/binomial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
from typing import Optional

import torch
from torch import Tensor
from torch import Tensor, Generator
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import (
Expand Down Expand Up @@ -116,11 +118,11 @@ def probs(self) -> Tensor:
def param_shape(self) -> torch.Size:
return self._param.size()

def sample(self, sample_shape=torch.Size()):
def sample(self, sample_shape=torch.Size(), generator: Optional[Generator] = None):
shape = self._extended_shape(sample_shape)
with torch.no_grad():
return torch.binomial(
self.total_count.expand(shape), self.probs.expand(shape)
self.total_count.expand(shape), self.probs.expand(shape), generator
)

def log_prob(self, value):
Expand Down
10 changes: 6 additions & 4 deletions torch/distributions/categorical.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# mypy: allow-untyped-defs
from typing import Optional

import torch
from torch import nan, Tensor
from torch import nan, Tensor, Generator
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits


__all__ = ["Categorical"]


Expand Down Expand Up @@ -127,11 +128,12 @@ def variance(self) -> Tensor:
device=self.probs.device,
)

def sample(self, sample_shape=torch.Size()):
def sample(self, sample_shape=torch.Size(), generator: Optional[Generator] = None):
if not isinstance(sample_shape, torch.Size):
sample_shape = torch.Size(sample_shape)
probs_2d = self.probs.reshape(-1, self._num_events)
samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
samples_2d = torch.multinomial(
probs_2d, sample_shape.numel(), True, generator=generator).T
return samples_2d.reshape(self._extended_shape(sample_shape))

def log_prob(self, value):
Expand Down
8 changes: 4 additions & 4 deletions torch/distributions/cauchy.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# mypy: allow-untyped-defs
import math
from typing import Optional

import torch
from torch import inf, nan, Tensor
from torch import inf, nan, Tensor, Generator
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size


__all__ = ["Cauchy"]


Expand Down Expand Up @@ -66,9 +66,9 @@ def variance(self) -> Tensor:
self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device
)

def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
def rsample(self, sample_shape: _size = torch.Size(), generator: Optional[Generator] = None) -> Tensor:
shape = self._extended_shape(sample_shape)
eps = self.loc.new(shape).cauchy_()
eps = self.loc.new(shape).cauchy_(generator=generator)
return self.loc + eps * self.scale

def log_prob(self, value):
Expand Down
9 changes: 5 additions & 4 deletions torch/distributions/continuous_bernoulli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import math
from typing import Optional

import torch
from torch import Tensor
from torch import Tensor, Generator
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import (
Expand All @@ -15,7 +16,6 @@
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.types import _Number, _size


__all__ = ["ContinuousBernoulli"]


Expand Down Expand Up @@ -168,9 +168,10 @@ def sample(self, sample_shape=torch.Size()):
with torch.no_grad():
return self.icdf(u)

def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
def rsample(self, sample_shape: _size = torch.Size(), generator: Optional[Generator] = None) -> Tensor:
shape = self._extended_shape(sample_shape)
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
u = torch.rand(
shape, dtype=self.probs.dtype, device=self.probs.device, generator=generator)
return self.icdf(u)

def log_prob(self, value):
Expand Down
13 changes: 7 additions & 6 deletions torch/distributions/dirichlet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# mypy: allow-untyped-defs
from typing import Optional

import torch
from torch import Tensor
from torch import Tensor, Generator
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.types import _size


__all__ = ["Dirichlet"]


Expand All @@ -20,8 +21,8 @@ def _Dirichlet_backward(x, concentration, grad_output):

class _Dirichlet(Function):
@staticmethod
def forward(ctx, concentration):
x = torch._sample_dirichlet(concentration)
def forward(ctx, concentration, generator: Optional[Generator] = None):
x = torch._sample_dirichlet(concentration, generator)
ctx.save_for_backward(x, concentration)
return x

Expand Down Expand Up @@ -72,10 +73,10 @@ def expand(self, batch_shape, _instance=None):
new._validate_args = self._validate_args
return new

def rsample(self, sample_shape: _size = ()) -> Tensor:
def rsample(self, sample_shape: _size = (), generator: Optional[Generator] = None) -> Tensor:
shape = self._extended_shape(sample_shape)
concentration = self.concentration.expand(shape)
return _Dirichlet.apply(concentration)
return _Dirichlet.apply(concentration, generator)

def log_prob(self, value):
if self._validate_args:
Expand Down
6 changes: 3 additions & 3 deletions torch/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing_extensions import deprecated

import torch
from torch import Tensor
from torch import Tensor, Generator
from torch.distributions import constraints
from torch.distributions.utils import lazy_property
from torch.types import _size
Expand Down Expand Up @@ -159,15 +159,15 @@ def stddev(self) -> Tensor:
"""
return self.variance.sqrt()

def sample(self, sample_shape: _size = torch.Size()) -> Tensor:
def sample(self, sample_shape: _size = torch.Size(), generator: Optional[Generator] = None) -> Tensor:
"""
Generates a sample_shape shaped sample or sample_shape shaped batch of
samples if the distribution parameters are batched.
"""
with torch.no_grad():
return self.rsample(sample_shape)

def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
def rsample(self, sample_shape: _size = torch.Size(), generator: Optional[Generator] = None) -> Tensor:
"""
Generates a sample_shape shaped reparameterized sample or sample_shape
shaped batch of reparameterized samples if the distribution parameters
Expand Down
9 changes: 5 additions & 4 deletions torch/distributions/exponential.py
F987
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# mypy: allow-untyped-defs
from typing import Optional

import torch
from torch import Tensor
from torch import Tensor, Generator
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size


__all__ = ["Exponential"]


Expand Down Expand Up @@ -58,9 +59,9 @@ def expand(self, batch_shape, _instance=None):
new._validate_args = self._validate_args
return new

def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
def rsample(self, sample_shape: _size = torch.Size(), generator: Optional[Generator] = None) -> Tensor:
shape = self._extended_shape(sample_shape)
return self.rate.new(shape).exponential_() / self.rate
return self.rate.new(shape).exponential_(generator=generator) / self.rate

def log_prob(self, value):
if self._validate_args:
Expand Down
11 changes: 6 additions & 5 deletions torch/distributions/fishersnedecor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# mypy: allow-untyped-defs
from typing import Optional

import torch
from torch import nan, Tensor
from torch import nan, Tensor, Generator
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.gamma import Gamma
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size


__all__ = ["FisherSnedecor"]


Expand Down Expand Up @@ -75,12 +76,12 @@ def variance(self) -> Tensor:
/ (self.df1 * (df2 - 2).pow(2) * (df2 - 4))
)

def rsample(self, sample_shape: _size = torch.Size(())) -> Tensor:
def rsample(self, sample_shape: _size = torch.Size(()), generator: Optional[Generator] = None) -> Tensor:
shape = self._extended_shape(sample_shape)
# X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2)
# Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2)
X1 = self._gamma1.rsample(sample_shape).view(shape)
X2 = self._gamma2.rsample(sample_shape).view(shape)
X1 = self._gamma1.rsample(sample_shape, generator).view(shape)
X2 = self._gamma2.rsample(sample_shape, generator).view(shape)
tiny = torch.finfo(X2.dtype).tiny
X2.clamp_(min=tiny)
Y = X1 / X2
Expand Down
13 changes: 7 additions & 6 deletions torch/distributions/gamma.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# mypy: allow-untyped-defs
from typing import Optional

import torch
from torch import Tensor
from torch import Tensor, Generator
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all
from torch.types import _Number, _size


__all__ = ["Gamma"]


def _standard_gamma(concentration):
return torch._standard_gamma(concentration)
def _standard_gamma(concentration, generator):
return torch._standard_gamma(concentration, generator)


class Gamma(ExponentialFamily):
Expand Down Expand Up @@ -68,9 +69,9 @@ def expand(self, batch_shape, _instance=None):
new._validate_args = self._validate_args
return new

def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
def rsample(self, sample_shape: _size = torch.Size(), generator: Optional[Generator] = None) -> Tensor:
shape = self._extended_shape(sample_shape)
value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand(
value = _standard_gamma(self.concentration.expand(shape), generator) / self.rate.expand(
shape
)
value.detach().clamp_(
Expand Down
11 changes: 6 additions & 5 deletions torch/distributions/geometric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
from typing import Optional

import torch
from torch import Tensor
from torch import Tensor, Generator
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import (
Expand All @@ -12,7 +14,6 @@
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.types import _Number


__all__ = ["Geometric"]


Expand Down Expand Up @@ -103,16 +104,16 @@ def logits(self) -> Tensor:
def probs(self) -> Tensor:
return logits_to_probs(self.logits, is_binary=True)

def sample(self, sample_shape=torch.Size()):
def sample(self, sample_shape=torch.Size(), generator: Optional[Generator] = None):
shape = self._extended_shape(sample_shape)
tiny = torch.finfo(self.probs.dtype).tiny
with torch.no_grad():
if torch._C._get_tracing_state():
# [JIT WORKAROUND] lack of support for .uniform_()
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device, generator=generator)
u = u.clamp(min=tiny)
else:
u = self.probs.new(shape).uniform_(tiny, 1)
u = self.probs.new(shape).uniform_(tiny, 1, generator=generator)
return (u.log() / (-self.probs).log1p()).floor()

def log_prob(self, value):
Expand Down
Loading
Loading
0