8000 Merge pull request #5822 from anntzer/0-scale-distributions · masasin/numpy@c62b20d · GitHub
[go: up one dir, main page]

Skip to content

Commit c62b20d

Browse files
authored
Merge pull request numpy#5822 from anntzer/0-scale-distributions
Allow many distributions to have a scale of 0.
2 parents de0fcbd + 0319d0c commit c62b20d

File tree

2 files changed

+88
-60
lines changed

2 files changed

+88
-60
lines changed

numpy/random/mtrand/mtrand.pyx

Lines changed: 48 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,15 +1903,13 @@ cdef class RandomState:
19031903
if oloc.shape == oscale.shape == ():
19041904
floc = PyFloat_AsDouble(loc)
19051905
fscale = PyFloat_AsDouble(scale)
1906-
1907-
if fscale <= 0:
1908-
raise ValueError("scale <= 0")
1909-
1906+
if np.signbit(fscale):
1907+
raise ValueError("scale < 0")
19101908
return cont2_array_sc(self.internal_state, rk_normal, size, floc,
19111909
fscale, self.lock)
19121910

1913-
if np.any(np.less_equal(oscale, 0)):
1914-
raise ValueError("scale <= 0")
1911+
if np.any(np.signbit(oscale)):
1912+
raise ValueError("scale < 0")
19151913
return cont2_array(self.internal_state, rk_normal, size, oloc, oscale,
19161914
self.lock)
19171915

@@ -2029,14 +2027,13 @@ cdef class RandomState:
20292027

20302028
if oscale.shape == ():
20312029
fscale = PyFloat_AsDouble(scale)
2032-
2033-
if fscale <= 0:
2034-
raise ValueError("scale <= 0")
2030+
if np.signbit(fscale):
2031+
raise ValueError("scale < 0")
20352032
return cont1_array_sc(self.internal_state, rk_exponential, size,
20362033
fscale, self.lock)
20372034

2038-
if np.any(np.less_equal(oscale, 0.0)):
2039-
raise ValueError("scale <= 0")
2035+
if np.any(np.signbit(oscale)):
2036+
raise ValueError("scale < 0")
20402037
return cont1_array(self.internal_state, rk_exponential, size, oscale,
20412038
self.lock)
20422039

@@ -2147,14 +2144,13 @@ cdef class RandomState:
21472144

21482145
if oshape.shape == ():
21492146
fshape = PyFloat_AsDouble(shape)
2150-
2151-
if fshape <= 0:
2152-
raise ValueError("shape <= 0")
2147+
if np.signbit(fshape):
2148+
raise ValueError("shape < 0")
21532149
return cont1_array_sc(self.internal_state, rk_standard_gamma,
21542150
size, fshape, self.lock)
21552151

2156-
if np.any(np.less_equal(oshape, 0.0)):
2157-
raise ValueError("shape <= 0")
2152+
if np.any(np.signbit(oshape)):
2153+
raise ValueError("shape < 0")
21582154
return cont1_array(self.internal_state, rk_standard_gamma, size,
21592155
oshape, self.lock)
21602156

@@ -2240,18 +2236,17 @@ cdef class RandomState:
22402236
if oshape.shape == oscale.shape == ():
22412237
fshape = PyFloat_AsDouble(shape)
22422238
fscale = PyFloat_AsDouble(scale)
2243-
2244-
if fshape <= 0:
2245-
raise ValueError("shape <= 0")
2246-
if fscale <= 0:
2247-
raise ValueError("scale <= 0")
2239+
if np.signbit(fshape):
2240+
raise ValueError("shape < 0")
2241+
if np.signbit(fscale):
2242+
raise ValueError("scale < 0")
22482243
return cont2_array_sc(self.internal_state, rk_gamma, size, fshape,
22492244
fscale, self.lock)
22502245

2251-
if np.any(np.less_equal(oshape, 0.0)):
2252-
raise ValueError("shape <= 0")
2253-
if np.any(np.less_equal(oscale, 0.0)):
2254-
raise ValueError("scale <= 0")
2246+
if np.any(np.signbit(oshape)):
2247+
raise ValueError("shape < 0")
2248+
if np.any(np.signbit(oscale)):
2249+
raise ValueError("scale < 0")
22552250
return cont2_array(self.internal_state, rk_gamma, size, oshape, oscale,
22562251
self.lock)
22572252

@@ -3122,14 +3117,13 @@ cdef class RandomState:
31223117

31233118
if oa.shape == ():
31243119
fa = PyFloat_AsDouble(a)
3125-
3126-
if fa <= 0:
3127-
raise ValueError("a <= 0")
3120+
if np.signbit(fa):
3121+
raise ValueError("a < 0")
31283122
return cont1_array_sc(self.internal_state, rk_weibull, size, fa,
31293123
self.lock)
31303124

3131-
if np.any(np.less_equal(oa, 0.0)):
3132-
raise ValueError("a <= 0")
3125+
if np.any(np.signbit(oa)):
3126+
raise ValueError("a < 0")
31333127
return cont1_array(self.internal_state, rk_weibull, size, oa,
31343128
self.lock)
31353129

@@ -3235,14 +3229,13 @@ cdef class RandomState:
32353229

32363230
if oa.shape == ():
32373231
fa = PyFloat_AsDouble(a)
3238-
3239-
if fa <= 0:
3240-
raise ValueError("a <= 0")
3232+
if np.signbit(fa):
3233+
raise ValueError("a < 0")
32413234
return cont1_array_sc(self.internal_state, rk_power, size, fa,
32423235
self.lock)
32433236

3244-
if np.any(np.less_equal(oa, 0.0)):
3245-
raise ValueError("a <= 0")
3237+
if np.any(np.signbit(oa)):
3238+
raise ValueError("a < 0")
32463239
return cont1_array(self.internal_state, rk_power, size, oa, self.lock)
32473240

32483241
def laplace(self, loc=0.0, scale=1.0, size=None):
@@ -3333,14 +3326,13 @@ cdef class RandomState:
33333326
if oloc.shape == oscale.shape == ():
33343327
floc = PyFloat_AsDouble(loc)
33353328
fscale = PyFloat_AsDouble(scale)
3336-
3337-
if fscale <= 0:
3338-
raise ValueError("scale <= 0")
3329+
if np.signbit(fscale):
3330+
raise ValueError("scale < 0")
33393331
return cont2_array_sc(self.internal_state, rk_laplace, size, floc,
33403332
fscale, self.lock)
33413333

3342-
if np.any(np.less_equal(oscale, 0.0)):
3343-
raise ValueError("scale <= 0")
3334+
if np.any(np.signbit(oscale)):
3335+
raise ValueError("scale < 0")
33443336
return cont2_array(self.internal_state, rk_laplace, size, oloc, oscale,
33453337
self.lock)
33463338

@@ -3465,14 +3457,13 @@ cdef class RandomState:
34653457
if oloc.shape == oscale.shape == ():
34663458
floc = PyFloat_AsDouble(loc)
34673459
fscale = PyFloat_AsDouble(scale)
3468-
3469-
if fscale <= 0:
3470-
raise ValueError("scale <= 0")
3460+
if np.signbit(fscale):
3461+
raise ValueError("scale < 0")
34713462
return cont2_array_sc(self.internal_state, rk_gumbel, size, floc,
34723463
fscale, self.lock)
34733464

3474-
if np.any(np.less_equal(oscale, 0.0)):
3475-
raise ValueError("scale <= 0")
3465+
if np.any(np.signbit(oscale)):
3466+
raise ValueError("scale < 0")
34763467
return cont2_array(self.internal_state, rk_gumbel, size, oloc, oscale,
34773468
self.lock)
34783469

@@ -3559,14 +3550,13 @@ cdef class RandomState:
35593550
if oloc.shape == oscale.shape == ():
35603551
floc = PyFloat_AsDouble(loc)
35613552
fscale = PyFloat_AsDouble(scale)
3562-
3563-
if fscale <= 0:
3564-
raise ValueError("scale <= 0")
3553+
if np.signbit(fscale):
3554+
raise ValueError("scale < 0")
35653555
return cont2_array_sc(self.internal_state, rk_logistic, size, floc,
35663556
fscale, self.lock)
35673557

3568-
if np.any(np.less_equal(oscale, 0.0)):
3569-
raise ValueError("scale <= 0")
3558+
if np.any(np.signbit(oscale)):
3559+
raise ValueError("scale < 0")
35703560
return cont2_array(self.internal_state, rk_logistic, size, oloc,
35713561
oscale, self.lock)
35723562

@@ -3684,14 +3674,13 @@ cdef class RandomState:
36843674
if omean.shape == osigma.shape == ():
36853675
fmean = PyFloat_AsDouble(mean)
36863676
fsigma = PyFloat_AsDouble(sigma)
3687-
3688-
if fsigma <= 0:
3689-
raise ValueError("sigma <= 0")
3677+
if np.signbit(fsigma):
3678+
raise ValueError("sigma < 0")
36903679
return cont2_array_sc(self.internal_state, rk_lognormal, size,
36913680
fmean, fsigma, self.lock)
36923681

3693-
if np.any(np.less_equal(osigma, 0.0)):
3694-
raise ValueError("sigma <= 0.0")
3682+
if np.any(np.signbit(osigma)):
3683+
raise ValueError("sigma < 0.0")
36953684
return cont2_array(self.internal_state, rk_lognormal, size, omean,
36963685
osigma, self.lock)
36973686

@@ -3764,14 +3753,13 @@ cdef class RandomState:
37643753

37653754
if oscale.shape == ():
37663755
fscale = PyFloat_AsDouble(scale)
3767-
3768-
if fscale <= 0:
3769-
raise ValueError("scale <= 0")
3756+
if np.signbit(fscale):
3757+
raise ValueError("scale < 0")
37703758
return cont1_array_sc(self.internal_state, rk_rayleigh, size,
37713759
fscale, self.lock)
37723760

3773-
if np.any(np.less_equal(oscale, 0.0)):
3774-
raise ValueError("scale <= 0.0")
3761+
if np.any(np.signbit(oscale)):
3762+
raise ValueError("scale < 0.0")
37753763
return cont1_array(self.internal_state, rk_rayleigh, size, oscale,
37763764
self.lock)
37773765

numpy/random/tests/test_random.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,10 @@ def test_exponential(self):
485485
[0.68717433461363442, 1.69175666993575979]])
486486
assert_array_almost_equal(actual, desired, decimal=15)
487487

488+
def test_exponential_0(self):
489+
assert_equal(np.random.exponential(scale=0), 0)
490+
assert_raises(ValueError, np.random.exponential, scale=-0.)
491+
488492
def test_f(self):
489493
np.random.seed(self.seed)
490494
actual = np.random.f(12, 77, size=(3, 2))
@@ -501,6 +505,10 @@ def test_gamma(self):
501505
[31.71863275789960568, 33.30143302795922011]])
502506
assert_array_almost_equal(actual, desired, decimal=14)
503507

508+
def test_gamma_0(self):
509+
assert_equal(np.random.gamma(shape=0, scale=0), 0)
510+
assert_raises(ValueError, np.random.gamma, shape=-0., scale=-0.)
511+
504512
def test_geometric(self):
505513
np.random.seed(self.seed)
506514
actual = np.random.geometric(.123456789, size=(3, 2))
@@ -517,6 +525,10 @@ def test_gumbel(self):
517525
[1.10651090478803416, -0.69535848626236174]])
518526
assert_array_almost_equal(actual, desired, decimal=15)
519527

528+
def test_gumbel_0(self):
529+
assert_equal(np.random.gumbel(scale=0), 0)
530+
assert_raises(ValueError, np.random.gumbel, scale=-0.)
531+
520532
def test_hypergeometric(self):
521533
np.random.seed(self.seed)
522534
actual = np.random.hypergeometric(10.1, 5.5, 14, size=(3, 2))
@@ -551,6 +563,10 @@ def test_laplace(self):
551563
[-0.05391065675859356, 1.74901336242837324]])
552564
assert_array_almost_equal(actual, desired, decimal=15)
553565

566+
def test_laplace_0(self):
567+
assert_equal(np.random.laplace(scale=0), 0)
568+
assert_raises(ValueError, np.random.laplace, scale=-0.)
569+
554570
def test_logistic(self):
555571
np.random.seed(self.seed)
556572
actual = np.random.logistic(loc=.123456789, scale=2.0, size=(3, 2))
@@ -559,6 +575,10 @@ def test_logistic(self):
559575
[-0.21682183359214885, 2.63373365386060332]])
560576
assert_array_almost_equal(actual, desired, decimal=15)
561577

578+
def test_laplace_0(self):
579+
assert_(np.random.laplace(scale=0) in [0, 1])
580+
assert_raises(ValueError, np.random.laplace, scale=-0.)
581+
562582
def test_lognormal(self):
563583
np.random.seed(self.seed)
564584
actual = np.random.lognormal(mean=.123456789, sigma=2.0, size=(3, 2))
@@ -567,6 +587,10 @@ def test_lognormal(self):
567587
[65.72798501792723869, 86.84341601437161273]])
568588
assert_array_almost_equal(actual, desired, decimal=13)
569589

590+
def test_lognormal_0(self):
591+
assert_equal(np.random.lognormal(sigma=0), 1)
592+
assert_raises(ValueError, np.random.lognormal, sigma=-0.)
593+
570594
def test_logseries(self):
571595
np.random.seed(self.seed)
572596
actual = np.random.logseries(p=.923456789, size=(3, 2))
@@ -657,6 +681,10 @@ def test_normal(self):
657681
[4.18552478636557357, 4.46410668111310471]])
658682
assert_array_almost_equal(actual, desired, decimal=15)
659683

684+
def test_normal_0(self):
685+
assert_equal(np.random.normal(scale=0), 0)
686+
assert_raises(ValueError, np.random.normal, scale=-0.)
687+
660688
def test_pareto(self):
661689
np.random.seed(self.seed)
662690
actual = np.random.pareto(a=.123456789, size=(3, 2))
@@ -704,6 +732,10 @@ def test_rayleigh(self):
704732
[11.06066537006854311, 17.35468505778271009]])
705733
assert_array_almost_equal(actual, desired, decimal=14)
706734

735+
def test_rayleigh_0(self):
736+
assert_equal(np.random.rayleigh(scale=0), 0)
737+
assert_raises(ValueError, np.random.rayleigh, scale=-0.)
738+
707739
def test_standard_cauchy(self):
708740
np.random.seed(self.seed)
709741
actual = np.random.standard_cauchy(size=(3, 2))
@@ -728,6 +760,10 @@ def test_standard_gamma(self):
728760
[7.54838614231317084, 8.012756093271868]])
729761
assert_array_almost_equal(actual, desired, decimal=14)
730762

763+
def test_standard_gamma_0(self):
764+
assert_equal(np.random.standard_gamma(shape=0), 0)
765+
assert_raises(ValueError, np.random.standard_gamma, shape=-0.)
766+
731767
def test_standard_normal(self):
732768
np.random.seed(self.seed)
733769
actual = np.random.standard_normal(size=(3, 2))
@@ -803,6 +839,10 @@ def test_weibull(self):
803839
[0.67057783752390987, 1.39494046635066793]])
804840
assert_array_almost_equal(actual, desired, decimal=15)
805841

842+
def test_weibull_0(self):
843+
assert_equal(np.random.weibull(a=0), 0)
844+
assert_raises(ValueError, np.random.weibull, a=-0.)
845+
806846
def test_zipf(self):
807847
np.random.seed(self.seed)
808848
actual = np.random.zipf(a=1.23, size=(3, 2))

0 commit comments

Comments
 (0)
0