-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
ENH: scipy.stats: add multivariate hypergeometric distribution #12839
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I also tested the time taken by rvs method implemented vs numpy's generator: speed test
NumPy's generator is more than twice as fast! Some more tests
|
I wouldn't worry about the speed too much. IMO speeding up the RVS does not need to be in scope. You're already using the faster NumPy alternatives when the provided |
Yes. When the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tirthasheshpatel one (minor) comment in two places, how the version check for numpy Generator
is performed should be pep-440 compliant.
assert_allclose(rvs.mean(0), rv.mean(), rtol=1e-2) | ||
|
||
def test_rvs_numpy(self): | ||
if np.__version__ < '1.18': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment I made before also applies here about the pep-440 comparison and there is a skipif annotation if you want to make it a little cleaner and consistent with how we've handled this in other PRs IIRC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if I'm seeing a stale change but the test here checks numpy's version for generator support?
@tirthasheshpatel, there is a problem with the
A basic example, which uses
Now try with
I'm working on a fix for this particular problem, and for handling We currently don't have a written specification for exactly how the multivariate distribution are supposed to handle broadcasting, nor how broadcasting of the parameters interacts with the |
Done!
I have eliminated all except one which is necessary to handle a case in the previous comment. When the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tirthasheshpatel, I made a bunch more suggestions, mostly cosmetic, in-line. Sorry I missed these on my first pass through the code.
There are few more significant issues to be discussed.
- This example breaks because
M-1
is 0; the result ofcov
in this case should be an array of zeros with shape (3, 3)):
In [189]: multivariate_hypergeom.cov([1, 0, 0], 1)
/Users/warren/mc37scipymaster/lib/python3.7/site-packages/scipy-1.6.0.dev0+21d6fce-py3.7-macosx-10.9-x86_64.egg/scipy/stats/_multivariate.py:4620: RuntimeWarning: invalid value encountered in true_divide
output = (-n * (M-n)/(M-1) / (M**2) *
/Users/warren/mc37scipymaster/lib/python3.7/site-packages/scipy-1.6.0.dev0+21d6fce-py3.7-macosx-10.9-x86_64.egg/scipy/stats/_multivariate.py:4627: RuntimeWarning: invalid value encountered in long_scalars
m[..., i]*(M-m[..., i]) / (M**2))
Out[189]:
array([[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]])
- It would be nice if these edges cases were handled:
multivariate_hypergeom.mean([0, 0, 0], 0)
should returnnp.array([0.0, 0.0, 0.0])
, andmultivariate_hypergeometric.cov([0, 0, 0], 0)
should return an array with shape (3, 3) containing all zeros. (The PMF is 1 at [0, 0, 0] and 0 everywhere else, so the expected value of the distribution is [0, 0, 0].)multivariate_hypergeom.mean([], 0)
should returnnp.array([], dtype=np.float64)
, andmultivariate_hypergeom.cov([], 0)
should returnnp.array([], shape=(0, 0), dtype=np.float64)
-
I mentioned this in an in-line comment: I don't think we should be casting the inputs from whatever is given to integers. An input such as
multivariate_hypergeom.mean([1.2, 3.4, 5.6], 7.8)
should be rejected with an error (perhaps something likeTypeError: input must be integers
). I include the quantiles in this, somultivariate_hypergeom.pmf([2.5, 3.1], [3, 5], 5)
should also raise aTypeError
. Currently the code checks for non-integer quantiles and returns 0 for the PMF; instead, I think we can interpret the distribution as one whose domain is the integers (it is a discrete distribution, after all), and simply not accept floating point values. -
Broadcasting in
logpmf
,pmf
,mean
,var
andcov
works nicely! But there are some issues in thervs
method.
- The NumPy sampler for the multivariate hypergeometric distribution does not handle broadcasting. So there is an inconsistency in whether or not
multivariate_hypergeom.rvs
handles broadcasting, depending on the NumPy version and the parameterrandom_state
. - Reviewing this issue led to looking into the state of broadcasting for the multivariate distributions in NumPy. There is currently a pull request to add it to NumPy's
multinomial
sampler (numpy/numpy#16740). In numpy/numpy#17669, I started a discussion on how broadcasting and thesize
parameter interact for the multivariate distributions. To be consistent with the rules that I outlined there,size=None
is not the same assize=1
. Instead,size=None
means the size is determined by the broadcast of the distribution parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WarrenWeckesser A lot of hefty changes below! I have added code to handle all the edge cases and I also added checks for integer x
, m
, and n
.
First of all, thank you so much for such a thorough look at this PR. I really missed all these edge cases and I am happy they are covered now.
I have tried my best to handle all the edge cases as concisely as possible. Below are some changes and their justifications. One big comment I had was whether or not it is OK to use masked_array
s to handle edge cases. I am not very experienced with masked_array
s and don't know the consequences of using them, but they do seem to resolve all the edge cases. Sorry, if I have introduced some breaking code below while trying to resolve the edge cases 😅
x = np.asarray(x) | ||
if not np.issubdtype(x.dtype, np.integer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment about only allowing integers has been addressed here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's great progress!
I made several minor suggestions inline, but there are a few bigger items still be be addressed:
-
These (very) edge cases don't work:
multivariate_hypergeom.mean([], 0) multivariate_hypergeom.cov([], 0)
(See my comment back in #12839 (review)) The problem is that when
np.asarray
is given an empty list, it usefloat64
as the default dtype. You could handle this in_process_parameters
by doing something like this:m = np.asarray(m) if m.size == 0: m = m.astype(int)
And do the same for
n
. Note, however, that I haven't checked if the rest of the code will "just work" with these emtpy arrays and return the expected results. If it turns out that it will require a lot of special code to handle these cases, we could defer handling it for now, and just raise an error stating that empty inputs are not allowed. -
Also noted in #12839 (review), the NumPy distribution sampler that is available in NumPy 1.18 or later does not handle broadcasting. So if the user passes parameters to the
rvs()
method that require broadcasting, we can't use the NumPy sampler. It might be simpler to just not attempt to use the NumPy sampler at all. Always use the version that you implemented here, and add a comment that we can start using the NumPy version when it is updated to handle broadcasting. -
The comment I made in item 4 of #12839 (review) still applies: if
size
isNone
, the size is not to be treated assize=1
. Instead, the size is determined by the broadcast shape of the input parameters. So an input such asmultivariate_hypergeom.rvs([2, 3, 5], [1, 2, 3, 4])
should act like
multivariate_hypergeom.rvs([2, 3, 5], [1, 2, 3, 4], size=(4,))
and generate a result with shape (4, 3), andmultivariate_hypergeom.rvs([[5, 10, 20], [10, 10, 15]], [[2], [3], [4], [5]])
(where the first argument has shape (2, 3) and the second has shape (4, 1)) should act like
multivariate_hypergeom.rvs([[5, 10, 20], [10, 10, 15]], [[2], [3], [4], [5]], size=(4, 2))
and generate a result with shape (4, 2, 3).
Thanks, @WarrenWeckesser for the explanation of the behavior when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few more tiny changes. I'll commit them as a batch, and let the tests run once more.
Whitespace tweaks and remove imports of `NumpyVersion`
I'm not sure what happened with appveyor and travisci when I pushed a batch of small changes. Closing and reopening to try again. |
@rlucas7, I think the change that you requested was related to the use of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good w/the skipif test on rvs, thanks, marking approved.
@WarrenWeckesser or @mdhaber is this ready to squash-and-merge then?
EDIT: It seems ready for merge, if no more comments (or someone else merges beforehand), I'll merge on Sunday (Nov 15th, in the morning EST).
@rlucas7 I haven't reviewed recently, but what i saw before was good, so with +2 LGTM! When you merge, please modify the line about
|
Thanks for the ping @mdhaber squash and merged-glad to have this in 1.6.0. |
Glad I could help. This was a lot of fun to do. Thanks, @mdhaber, @rlucas7, and @WarrenWeckesser for all the reviews and help! |
Reference issue
Closes #12585
What does this implement/fix?
This PR implements the multivariate hypergeometric distribution in scipy.stats module!
@mdhaber