8000 Merge pull request #44 from bashtage/doc-and-other-small-fixes · mattip/numpy@fe9cd13 · GitHub
[go: up one dir, main page]

Skip to content

Commit fe9cd13

Browse files
authored
Merge pull request #44 from bashtage/doc-and-other-small-fixes
BUG: Protect gamma generation from 0 input
2 parents 3233d69 + cf72d30 commit fe9cd13

File tree

5 files changed

+62
-2
lines changed

5 files changed

+62
-2
lines changed

_randomgen/randomgen/src/distributions/distributions.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ static NPY_INLINE double standard_gamma_zig(brng_t *brng_state, double shape) {
346346

347347
if (shape == 1.0) {
348348
return random_standard_exponential_zig(brng_state);
349+
} else if (shape == 0.0) {
350+
return 0.0;
349351
} else if (shape < 1.0) {
350352
for (;;) {
351353
U = next_double(brng_state);

_randomgen/randomgen/src/legacy/distributions-boxmuller.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ double legacy_standard_gamma(aug_brng_t *aug_state, double shape) {
3939

4040
if (shape == 1.0) {
4141
return legacy_standard_exponential(aug_state);
42+
}
43+
else if (shape == 0.0) {
44+
return 0.0;
4245
} else if (shape < 1.0) {
4346
for (;;) {
4447
U = legacy_double(aug_state);
@@ -84,6 +87,9 @@ double legacy_pareto(aug_brng_t *aug_state, double a) {
8487
}
8588

8689
double legacy_weibull(aug_brng_t *aug_state, double a) {
90+
if (a == 0.0) {
91+
return 0.0;
92+
}
8793
return pow(legacy_standard_exponential(aug_state), 1. / a);
8894
}
8995

_randomgen/randomgen/tests/test_legacy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,8 @@ def test_pickle():
1010
lg2 = pickle.loads(pickle.dumps(lg))
1111
assert lg.standard_normal() == lg2.standard_normal()
1212
assert lg.random_sample() == lg2.random_sample()
13+
14+
15+
def test_weibull():
16+
lg = LegacyGenerator()
17+
assert lg.weibull(0.0) == 0.0

_randomgen/randomgen/tests/test_numpy_mt19937.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,11 +1161,13 @@ def test_normal(self):
11611161
actual = normal(loc * 3, scale)
11621162
assert_array_almost_equal(actual, desired, decimal=14)
11631163
assert_raises(ValueError, normal, loc * 3, bad_scale)
1164+
assert_raises(ValueError, mt19937.normal, loc * 3, bad_scale)
11641165

11651166
self.set_seed()
11661167
actual = normal(loc, scale * 3)
11671168
assert_array_almost_equal(actual, desired, decimal=14)
11681169
assert_raises(ValueError, normal, loc, bad_scale * 3)
1170+
assert_raises(ValueError, mt19937.normal, loc, bad_scale * 3)
11691171

11701172
def test_beta(self):
11711173
a = [1]
@@ -1182,12 +1184,14 @@ def test_beta(self):
11821184
assert_array_almost_equal(actual, desired, decimal=14)
11831185
assert_raises(ValueError, beta, bad_a * 3, b)
11841186
assert_raises(ValueError, beta, a * 3, bad_b)
1187+
assert_raises(ValueError, mt19937.beta, bad_a * 3, b)
1188+
assert_raises(ValueError, mt19937.beta, a * 3, bad_b)
11851189

11861190
self.set_seed()
11871191
actual = beta(a, b * 3)
11881192
assert_array_almost_equal(actual, desired, decimal=14)
1189-
assert_raises(ValueError, beta, bad_a, b * 3)
1190-
assert_raises(ValueError, beta, a, bad_b * 3)
1193+
assert_raises(ValueError, mt19937.beta, bad_a, b * 3)
1194+
assert_raises(ValueError, mt19937.beta, a, bad_b * 3)
11911195

11921196
def test_exponential(self):
11931197
scale = [1]
@@ -1201,6 +1205,7 @@ def test_exponential(self):
12011205
actual = exponential(scale * 3)
12021206
assert_array_almost_equal(actual, desired, decimal=14)
12031207
assert_raises(ValueError, exponential, bad_scale * 3)
1208+
assert_raises(ValueError, mt19937.exponential, bad_scale * 3)
12041209

12051210
def test_standard_gamma(self):
12061211
shape = [1]
@@ -1214,6 +1219,7 @@ def test_standard_gamma(self):
12141219
actual = std_gamma(shape * 3)
12151220
assert_array_almost_equal(actual, desired, decimal=14)
12161221
assert_raises(ValueError, std_gamma, bad_shape * 3)
1222+
assert_raises(ValueError, mt19937.standard_gamma, bad_shape * 3)
12171223

12181224
def test_gamma(self):
12191225
shape = [1]
@@ -1230,12 +1236,16 @@ def test_gamma(self):
12301236
assert_array_almost_equal(actual, desired, decimal=14)
12311237
assert_raises(ValueError, gamma, bad_shape * 3, scale)
12321238
assert_raises(ValueError, gamma, shape * 3, bad_scale)
1239+
assert_raises(ValueError, mt19937.gamma, bad_shape * 3, scale)
1240+
assert_raises(ValueError, mt19937.gamma, shape * 3, bad_scale)
12331241

12341242
self.set_seed()
12351243
actual = gamma(shape, scale * 3)
12361244
assert_array_almost_equal(actual, desired, decimal=14)
12371245
assert_raises(ValueError, gamma, bad_shape, scale * 3)
12381246
assert_raises(ValueError, gamma, shape, bad_scale * 3)
1247+
assert_raises(ValueError, mt19937.gamma, bad_shape, scale * 3)
1248+
assert_raises(ValueError, mt19937.gamma, shape, bad_scale * 3)
12391249

12401250
def test_f(self):
12411251
dfnum = [1]
@@ -1252,12 +1262,16 @@ def test_f(self):
12521262
assert_array_almost_equal(actual, desired, decimal=14)
12531263
assert_raises(ValueError, f, bad_dfnum * 3, dfden)
12541264
assert_raises(ValueError, f, dfnum * 3, bad_dfden)
1265+
assert_raises(ValueError, mt19937.f, bad_dfnum * 3, dfden)
1266+
assert_raises(ValueError, mt19937.f, dfnum * 3, bad_dfden)
12551267

12561268
self.set_seed()
12571269
actual = f(dfnum, dfden * 3)
12581270
assert_array_almost_equal(actual, desired, decimal=14)
12591271
assert_raises(ValueError, f, bad_dfnum, dfden * 3)
12601272
assert_raises(ValueError, f, dfnum, bad_dfden * 3)
1273+
assert_raises(ValueError, mt19937.f, bad_dfnum, dfden * 3)
1274+
assert_raises(ValueError, mt19937.f, dfnum, bad_dfden * 3)
12611275

12621276
def test_noncentral_f(self):
12631277
dfnum = [2]
@@ -1277,20 +1291,29 @@ def test_noncentral_f(self):
12771291
assert_raises(ValueError, nonc_f, bad_dfnum * 3, dfden, nonc)
12781292
assert_raises(ValueError, nonc_f, dfnum * 3, bad_dfden, nonc)
12791293
assert_raises(ValueError, nonc_f, dfnum * 3, dfden, bad_nonc)
1294+
assert_raises(ValueError, mt19937.noncentral_f, bad_dfnum * 3, dfden, nonc)
1295+
assert_raises(ValueError, mt19937.noncentral_f, dfnum * 3, bad_dfden, nonc)
1296+
assert_raises(ValueError, mt19937.noncentral_f, dfnum * 3, dfden, bad_nonc)
12801297

12811298
self.set_seed()
12821299
actual = nonc_f(dfnum, dfden * 3, nonc)
12831300
assert_array_almost_equal(actual, desired, decimal=14)
12841301
assert_raises(ValueError, nonc_f, bad_dfnum, dfden * 3, nonc)
12851302
assert_raises(ValueError, nonc_f, dfnum, bad_dfden * 3, nonc)
12861303
assert_raises(ValueError, nonc_f, dfnum, dfden * 3, bad_nonc)
1304+
assert_raises(ValueError, mt19937.noncentral_f, bad_dfnum, dfden * 3, nonc)
1305+
assert_raises(ValueError, mt19937.noncentral_f, dfnum, bad_dfden * 3, nonc)
1306+
assert_raises(ValueError, mt19937.noncentral_f, dfnum, dfden * 3, bad_nonc)
12871307

12881308
self.set_seed()
12891309
actual = nonc_f(dfnum, dfden, nonc * 3)
12901310
assert_array_almost_equal(actual, desired, decimal=14)
12911311
assert_raises(ValueError, nonc_f, bad_dfnum, dfden, nonc * 3)
12921312
assert_raises(ValueError, nonc_f, dfnum, bad_dfden, nonc * 3)
12931313
assert_raises(ValueError, nonc_f, dfnum, dfden, bad_nonc * 3)
1314+
assert_raises(ValueError, mt19937.noncentral_f, bad_dfnum, dfden, nonc * 3)
1315+
assert_raises(ValueError, mt19937.noncentral_f, dfnum, bad_dfden, nonc * 3)
1316+
assert_raises(ValueError, mt19937.noncentral_f, dfnum, dfden, bad_nonc * 3)
12941317

12951318
def test_chisquare(self):
12961319
df = [1]
@@ -1320,12 +1343,16 @@ def test_noncentral_chisquare(self):
13201343
assert_array_almost_equal(actual, desired, decimal=14)
13211344
assert_raises(ValueError, nonc_chi, bad_df * 3, nonc)
13221345
assert_raises(ValueError, nonc_chi, df * 3, bad_nonc)
1346+
assert_raises(ValueError, mt19937.noncentral_chisquare, bad_df * 3, nonc)
1347+
assert_raises(ValueError, mt19937.noncentral_chisquare, df * 3, bad_nonc)
13231348

13241349
self.set_seed()
13251350
actual = nonc_chi(df, nonc * 3)
13261351
assert_array_almost_equal(actual, desired, decimal=14)
13271352
assert_raises(ValueError, nonc_chi, bad_df, nonc * 3)
13281353
assert_raises(ValueError, nonc_chi, df, bad_nonc * 3)
1354+
assert_raises(ValueError, mt19937.noncentral_chisquare, bad_df, nonc * 3)
1355+
assert_raises(ValueError, mt19937.noncentral_chisquare, df, bad_nonc * 3)
13291356

13301357
def test_standard_t(self):
13311358
df = [1]
@@ -1339,6 +1366,7 @@ def test_standard_t(self):
13391366
actual = t(df * 3)
13401367
assert_array_almost_equal(actual, desired, decimal=14)
13411368
assert_raises(ValueError, t, bad_df * 3)
1369+
assert_raises(ValueError, mt19937.standard_t, bad_df * 3)
13421370

13431371
def test_vonmises(self):
13441372
mu = [2]
@@ -1371,6 +1399,7 @@ def test_pareto(self):
13711399
actual = pareto(a * 3)
13721400
assert_array_almost_equal(actual, desired, decimal=14)
13731401
assert_raises(ValueError, pareto, bad_a * 3)
1402+
assert_raises(ValueError, mt19937.pareto, bad_a * 3)
13741403

13751404
def test_weibull(self):
13761405
a = [1]
@@ -1384,6 +1413,7 @@ def test_weibull(self):
13841413
actual = weibull(a * 3)
13851414
assert_array_almost_equal(actual, desired, decimal=14)
13861415
assert_raises(ValueError, weibull, bad_a * 3)
1416+
assert_raises(ValueError, mt19937.weibull, bad_a * 3)
13871417

13881418
def test_power(self):
13891419
a = [1]
@@ -1397,6 +1427,7 @@ def test_power(self):
13971427
actual = power(a * 3)
13981428
assert_array_almost_equal(actual, desired, decimal=14)
13991429
assert_raises(ValueError, power, bad_a * 3)
1430+
assert_raises(ValueError, mt19937.power, bad_a * 3)
14001431

14011432
def test_laplace(self):
14021433
loc = [0]
@@ -1468,11 +1499,13 @@ def test_lognormal(self):
14681499
actual = lognormal(mean * 3, sigma)
14691500
assert_array_almost_equal(actual, desired, decimal=14)
14701501
assert_raises(ValueError, lognormal, mean * 3, bad_sigma)
1502+
assert_raises(ValueError, mt19937.lognormal, mean * 3, bad_sigma)
14711503

14721504
self.set_seed()
14731505
actual = lognormal(mean, sigma * 3)
14741506
assert_array_almost_equal(actual, desired, decimal=14)
14751507
assert_raises(ValueError, lognormal, mean, bad_sigma * 3)
1508+
assert_raises(ValueError, mt19937.lognormal, mean, bad_sigma * 3)
14761509

14771510
def test_rayleigh(self):
14781511
scale = [1]
@@ -1502,12 +1535,16 @@ def test_wald(self):
15021535
assert_array_almost_equal(actual, desired, decimal=14)
15031536
assert_raises(ValueError, wald, bad_mean * 3, scale)
15041537
assert_raises(ValueError, wald, mean * 3, bad_scale)
1538+
assert_raises(ValueError, mt19937.wald, bad_mean * 3, scale)
1539+
assert_raises(ValueError, mt19937.wald, mean * 3, bad_scale)
15051540

15061541
self.set_seed()
15071542
actual = wald(mean, scale * 3)
15081543
assert_array_almost_equal(actual, desired, decimal=14)
15091544
assert_raises(ValueError, wald, bad_mean, scale * 3)
15101545
assert_raises(ValueError, wald, mean, bad_scale * 3)
1546+
assert_raises(ValueError, mt19937.wald, bad_mean, scale * 3)
1547+
assert_raises(ValueError, mt19937.wald, mean, bad_scale * 3)
15111548

15121549
def test_triangular(self):
15131550
left = [1]
@@ -1583,13 +1620,19 @@ def test_negative_binomial(self):
15831620
assert_raises(ValueError, neg_binom, bad_n * 3, p)
15841621
assert_raises(ValueError, neg_binom, n * 3, bad_p_one)
15851622
assert_raises(ValueError, neg_binom, n * 3, bad_p_two)
1623+
assert_raises(ValueError, mt19937.negative_binomial, bad_n * 3, p)
1624+
assert_raises(ValueError, mt19937.negative_binomial, n * 3, bad_p_one)
1625+
assert_raises(ValueError, mt19937.negative_binomial, n * 3, bad_p_two)
15861626

15871627
self.set_seed()
15881628
actual = neg_binom(n, p * 3)
15891629
assert_array_equal(actual, desired)
15901630
assert_raises(ValueError, neg_binom, bad_n, p * 3)
15911631
assert_raises(ValueError, neg_binom, n, bad_p_one * 3)
15921632
assert_raises(ValueError, neg_binom, n, bad_p_two * 3)
1633+
assert_raises(ValueError, mt19937.negative_binomial, bad_n, p * 3)
1634+
assert_raises(ValueError, mt19937.negative_binomial, n, bad_p_one * 3)
1635+
assert_raises(ValueError, mt19937.negative_binomial, n, bad_p_two * 3)
15931636

15941637
def test_poisson(self):
15951638
max_lam = random.poisson_lam_max

_randomgen/randomgen/tests/test_numpy_mt19937_regressions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,7 @@ def __array__(self):
158158
perm = mt19937.permutation(m)
159159
assert_array_equal(perm, np.array([2, 1, 4, 0, 3]))
160160
assert_array_equal(m.__array__(), np.arange(5))
161+
162+
def test_gamma_0(self):
163+
assert mt19937.standard_gamma(0.0) == 0.0
164+
assert_array_equal(mt19937.standard_gamma([0.0]), 0.0)

0 commit comments

Comments
 (0)
0