8000 add generalized pareto distribution (GPD) (#135968) · pytorch/pytorch@2ed2cb5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2ed2cb5

Browse files
kashifStatMixedML
authored andcommitted
add generalized pareto distribution (GPD) (#135968)
Add the GPD as a distribution class Pull Request resolved: #135968 Approved by: https://github.com/albanD Co-authored-by: Alexander März <statmixedmlgit@gmail.com>
1 parent 7e2081f commit 2ed2cb5

File tree

4 files changed

+243
-1
lines changed

4 files changed

+243
-1
lines changed

docs/source/distributions.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ Probability distributions - torch.distributions
122122
:undoc-members:
123123
:show-inheritance:
124124

125+
:hidden:`GeneralizedPareto`
126+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
127+
128+
.. currentmodule:: torch.distributions.generalized_pareto
129+
.. autoclass:: GeneralizedPareto
130+
:members:
131+
:undoc-members:
132+
:show-inheritance:
133+
125134
:hidden:`Geometric`
126135
~~~~~~~~~~~~~~~~~~~~~~~
127136

@@ -419,6 +428,7 @@ Probability distributions - torch.distributions
419428
.. py:module:: torch.distributions.exponential
420429
.. py:module:: torch.distributions.fishersnedecor
421430
.. py:module:: torch.distributions.gamma
431+
.. py:module:: torch.distributions.generalized_pareto
422432
.. py:module:: torch.distributions.geometric
423433
.. py:module:: torch.distributions.gumbel
424434
.. py:module:: torch.distributions.half_cauchy

test/distributions/test_distributions.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
ExponentialFamily,
5858
FisherSnedecor,
5959
Gamma,
60+
GeneralizedPareto,
6061
Geometric,
6162
Gumbel,
6263
HalfCauchy,
@@ -151,7 +152,7 @@ def is_all_nan(tensor):
151152
Example = namedtuple("Example", ["Dist", "params"])
152153

153154

154-
# Register all distributions for generic tests.
155+
# Register all distributions for generic tests by appending to this list.
155156
def _get_examples():
156157
return [
157158
Example(
@@ -800,9 +801,20 @@ def _get_examples():
800801
},
801802
],
802803
),
804+
Example(
805+
GeneralizedPareto,
806+
[
807+
{
808+
"loc": torch.randn(5, 5, requires_grad=True).mul(10),
809+
"scale": torch.randn(5, 5).abs().requires_grad_(),
810+
"concentration": torch.randn(5, 5).div(10).requires_grad_(),
811+
},
812+
],
813+
),
803814
]
804815

805816

817+
# Register all distributions for bad examples by appending to this list.
806818
def _get_bad_examples():
807819
return [
808820
Example(
@@ -1199,6 +1211,21 @@ def _get_bad_examples():
11991211
},
12001212
],
12011213
),
1214+
Example(
1215+
GeneralizedPareto,
1216+
[
1217+
{
1218+
"loc": torch.tensor([0.0, 0.0], requires_grad=True),
1219+
"scale": torch.tensor([-1.0, -100.0], requires_grad=True),
1220+
"concentration": torch.tensor([0.0, 0.0], requires_grad=True),
1221+
},
1222+
{
1223+
"loc": torch.tensor([1.0, 1.0], requires_grad=True),
1224+
"scale": torch.tensor([0.0, 0.0], requires_grad=True),
1225+
"concentration": torch.tensor([-1.0, -100.0], requires_grad=True),
1226+
},
1227+
],
1228+
),
12021229
]
12031230

12041231

@@ -3498,6 +3525,51 @@ def test_pareto_sample(self):
34983525
)
34993526

35003527
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3528+
def test_generalized_pareto(self):
3529+
loc = torch.randn(2, 3).requires_grad_()
3530+
scale = torch.randn(2, 3).abs().requires_grad_()
3531+
concentration = torch.randn(2, 3).requires_grad_()
3532+
loc_1d = torch.randn(1).requires_grad_()
3533+
scale_1d = torch.randn(1).abs().requires_grad_()
3534+
concentration_1d = torch.randn(1).requires_grad_()
3535+
self.assertEqual(
3536+
GeneralizedPareto(loc, scale, concentration).sample().size(), (2, 3)
3537+
)
3538+
self.assertEqual(
3539+
GeneralizedPareto(loc, scale, concentration).sample((5,)).size(), (5, 2, 3)
3540+
)
3541+
self.assertEqual(
3542+
GeneralizedPareto(loc_1d, scale_1d, concentration_1d).sample((1,)).size(),
3543+
(1, 1),
3544+
)
3545+
self.assertEqual(
3546+
GeneralizedPareto(loc_1d, scale_1d, concentration_1d).sample().size(), (1,)
3547+
)
3548+
self.assertEqual(GeneralizedPareto(1.0, 1.0, 1.0).sample().size(), ())
3549+
self.assertEqual(GeneralizedPareto(1.0, 1.0, 1.0).sample((1,)).size(), (1,))
3550+
3551+
def ref_log_prob(idx, x, log_prob):
3552+
l = loc.view(-1)[idx].detach()
3553+
s = scale.view(-1)[idx].detach()
3554+
c = concentration.view(-1)[idx].detach()
3555+
expected = scipy.stats.genpareto.logpdf(x, c, loc=l, scale=s)
3556+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
3557+
3558+
self._check_log_prob(GeneralizedPareto(loc, scale, concentration), ref_log_prob)
3559+
3560+
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3561+
def test_generalized_pareto_sample(self):
3562+
set_rng_seed(1) # see note [Randomized statistical tests]
3563+
for loc, scale, concentration in product(
3564+
[-1.0, 0.0, 1.0], [0.1, 1.0, 10.0], [-0.5, 0.0, 0.5]
3565+
):
3566+
self._check_sampler_sampler(
3567+
GeneralizedPareto(loc, scale, concentration),
3568+
scipy.stats.genpareto(c=concentration, loc=loc, scale=scale),
3569+
f"GeneralizedPareto(loc={loc}, scale={scale}, concentration={concentration})",
3570+
failure_rate=7e-4,
3571+
)
3572+
35013573
def test_gumbel(self):
35023574
loc = torch.randn(2, 3, requires_grad=True)
35033575
scale = torch.randn(2, 3).abs().requires_grad_()
@@ -6321,6 +6393,14 @@ def setUp(self):
63216393
Gumbel(random_var, positive_var2),
63226394
scipy.stats.gumbel_r(random_var, positive_var2),
63236395
),
6396+
(
6397+
GeneralizedPareto(
6398+
loc=random_var, scale=positive_var, concentration=random_var / 10
6399+
),
6400+
scipy.stats.genpareto(
6401+
c=random_var / 10, loc=random_var, scale=positive_var
6402+
),
6403+
),
63246404
(HalfCauchy(positive_var), scipy.stats.halfcauchy(scale=positive_var)),
63256405
(HalfNormal(positive_var2), scipy.stats.halfnorm(scale=positive_var2)),
63266406
(

torch/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
from .exponential import Exponential
8787
from .fishersnedecor import FisherSnedecor
8888
from .gamma import Gamma
89+
from .generalized_pareto import GeneralizedPareto
8990
from .geometric import Geometric
9091
from .gumbel import Gumbel
9192
from .half_cauchy import HalfCauchy
@@ -135,6 +136,7 @@
135136
"ExponentialFamily",
136137
"FisherSnedecor",
137138
"Gamma",
139+
"GeneralizedPareto",
138140
"Geometric",
139141
"Gumbel",
140142
"HalfCauchy",
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# mypy: allow-untyped-defs
2+
import math
3+
from numbers import Number, Real
4+
5+
import torch
6+
from torch import inf, nan
7+
from torch.distributions import constraints, Distribution
8+
from torch.distributions.utils import broadcast_all
9+
10+
11+
__all__ = ["GeneralizedPareto"]
12+
13+
14+
class GeneralizedPareto(Distribution):
15+
r"""
16+
Creates a Generalized Pareto distribution parameterized by :attr:`loc`, :attr:`scale`, and :attr:`concentration`.
17+
18+
The Generalized Pareto distribution is a family of continuous probability distributions on the real line.
19+
Special cases include Exponential (when :attr:`loc` = 0, :attr:`concentration` = 0), Pareto (when :attr:`concentration` > 0,
20+
:attr:`loc` = :attr:`scale` / :attr:`concentration`), and Uniform (when :attr:`concentration` = -1).
21+
22+
This distribution is often used to model the tails of other distributions. This implementation is based on the
23+
implementation in TensorFlow Probability.
24+
25+
Example::
26+
27+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
28+
>>> m = GeneralizedPareto(torch.tensor([0.1]), torch.tensor([2.0]), torch.tensor([0.4]))
29+
>>> m.sample() # sample from a Generalized Pareto distribution with loc=0.1, scale=2.0, and concentration=0.4
30+
tensor([ 1.5623])
31+
32+
Args:
33+
loc (float or Tensor): Location parameter of the distribution
34+
scale (float or Tensor): Scale parameter of the distribution
35+
concentration (float or Tensor): Concentration parameter of the distribution
36+
"""
37+
38+
arg_constraints = {
39+
"loc": constraints.real,
40+
"scale": constraints.positive,
41+
"concentration": constraints.real,
42+
}
43+
has_rsample = True
44+
45+
def __init__(self, loc, scale, concentration, validate_args=None):
46+
self.loc, self.scale, self.concentration = broadcast_all(
47+
loc, scale, concentration
48+
)
49+
if (
50+
isinstance(loc, Number)
51+
and isinstance(scale, Number)
52+
and isinstance(concentration, Number)
53+
):
54+
batch_shape = torch.Size()
55+
else:
56+
batch_shape = self.loc.size()
57+
super().__init__(batch_shape, validate_args=validate_args)
58+
59+
def expand(self, batch_shape, _instance=None):
60+
new = self._get_checked_instance(GeneralizedPareto, _instance)
61+
batch_shape = torch.Size(batch_shape)
62+
new.loc = self.loc.expand(batch_shape)
63+
new.scale = self.scale.expand(batch_shape)
64+
new.concentration = self.concentration.expand(batch_shape)
65+
super(GeneralizedPareto, new).__init__(batch_shape, validate_args=False)
66+
new._validate_args = self._validate_args
67+
return new
68+
69+
def rsample(self, sample_shape=torch.Size()):
70+
shape = self._extended_shape(sample_shape)
71+
u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device)
72+
return self.icdf(u)
73+
74+
def log_prob(self, value):
75+
if self._validate_args:
76+
self._validate_sample(value)
77+
z = self._z(value)
78+
eq_zero = torch.isclose(self.concentration, torch.tensor(0.0))
79+
safe_conc = torch.where(
80+
eq_zero, torch.ones_like(self.concentration), self.concentration
81+
)
82+
y = 1 / safe_conc + torch.ones_like(z)
83+
where_nonzero = torch.where(y == 0, y, y * torch.log1p(safe_conc * z))
84+
log_scale = (
85+
math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
86+
)
87+
return -log_scale - torch.where(eq_zero, z, where_nonzero)
88+
89+
def log_survival_function(self, value):
90+
if self._validate_args:
91+
self._validate_sample(value)
92+
z = self._z(value)
93+
eq_zero = torch.isclose(self.concentration, torch.tensor(0.0))
94+
safe_conc = torch.where(
95+
eq_zero, torch.ones_like(self.concentration), self.concentration
96+
)
97+
where_nonzero = -torch.log1p(safe_conc * z) / safe_conc
98+
return torch.where(eq_zero, -z, where_nonzero)
99+
100+
def log_cdf(self, value):
101+
return torch.log1p(-torch.exp(self.log_survival_function(value)))
102+
103+
def cdf(self, value):
104+
return torch.exp(self.log_cdf(value))
105+
106+
def icdf(self, value):
107+
loc = self.loc
108+
scale = self.scale
109+
concentration = self.concentration
110+
eq_zero = torch.isclose(concentration, torch.zeros_like(concentration))
111+
safe_conc = torch.where(eq_zero, torch.ones_like(concentration), concentration)
112+
logu = torch.log1p(-value)
113+
where_nonzero = loc + scale / safe_conc * torch.expm1(-safe_conc * logu)
114+
where_zero = loc - scale * logu
115+
return torch.where(eq_zero, where_zero, where_nonzero)
116+
117+
def _z(self, x):
118+
return (x - self.loc) / self.scale
119+
120+
@property
121+
def mean(self):
122+
concentration = self.concentration
123+
valid = concentration < 1
124+
safe_conc = torch.where(valid, concentration, 0.5)
125+
result = self.loc + self.scale / (1 - safe_conc)
126+
return torch.where(valid, result, nan)
127+
128+
@property
129+
def variance(self):
130+
concentration = self.concentration
131+
valid = concentration < 0.5
132+
safe_conc = torch.where(valid, concentration, 0.25)
133+
result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc))
134+
return torch.where(valid, result, nan)
135+
136+
def entropy(self):
137+
ans = torch.log(self.scale) + self.concentration + 1
138+
return torch.broadcast_to(ans, self._batch_shape)
139+
140+
@property
141+
def mode(self):
142+
return self.loc
143+
144+
@constraints.dependent_property(is_discrete=False, event_dim=0)
145+
def support(self):
146+
lower = self.loc
147+
upper = torch.where(
148+
self.concentration < 0, lower - self.scale / self.concentration, inf
149+
)
150+
return constraints.interval(lower, upper)

0 commit comments

Comments
 (0)
0