-
-
Notifications
You must be signed in to change notification settings - Fork 11k
BUG, API: np.random.multivariate_normal behavior with bad covariance matrix #5726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
fc57915
BUG, API: Allow covariance matrix with small fp errors.
cowlicks f555826
TST: test multivariate_normal check_valid kw
cowlicks 4c93e28
Merge remote-tracking branch 'numpy-org/master' into mult-norm
fde2617
fixed merged test
6d7f14f
Documentation fix and proper handling of tolerance
c85d199
single too argument + mention in release docs.
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4355,9 +4355,10 @@ cdef class RandomState: | |
self.lock) | ||
|
||
# Multivariate distributions: | ||
def multivariate_normal(self, mean, cov, size=None): | ||
def multivariate_normal(self, mean, cov, size=None, check_valid='warn', | ||
tol=1e-8): | ||
""" | ||
multivariate_normal(mean, cov[, size]) | ||
multivariate_normal(mean, cov[, size, check_valid, tol]) | ||
|
||
Draw random samples from a multivariate normal distribution. | ||
|
||
|
@@ -4380,6 +4381,10 @@ cdef class RandomState: | |
generated, and packed in an `m`-by-`n`-by-`k` arrangement. Because | ||
each sample is `N`-dimensional, the output shape is ``(m,n,k,N)``. | ||
If no shape is specified, a single (`N`-D) sample is returned. | ||
check_valid : { 'warn', 'raise', 'ignore' }, optional | ||
Behavior when the covariance matrix is not positive semidefinite. | ||
tol : float, optional | ||
Tolerance when checking the singular values in covariance matrix. | ||
|
||
Returns | ||
------- | ||
|
@@ -4464,11 +4469,11 @@ cdef class RandomState: | |
shape = size | ||
|
||
if len(mean.shape) != 1: | ||
raise ValueError("mean must be 1 dimensional") | ||
raise ValueError("mean must be 1 dimensional") | ||
if (len(cov.shape) != 2) or (cov.shape[0] != cov.shape[1]): | ||
raise ValueError("cov must be 2 dimensional and square") | ||
raise ValueError("cov must be 2 dimensional and square") | ||
if mean.shape[0] != cov.shape[0]: | ||
raise ValueError("mean and cov must have same length") | ||
raise ValueError("mean and cov must have same length") | ||
|
||
# Compute shape of output and create a matrix of independent | ||
# standard normally distributed random numbers. The matrix has rows | ||
|
@@ -4491,12 +4496,20 @@ cdef class RandomState: | |
# not zero. We continue to use the SVD rather than Cholesky in | ||
# order to preserve current outputs. Note that symmetry has not | ||
# been checked. | ||
|
||
(u, s, v) = svd(cov) | ||
neg = (np.sum(u.T * v, axis=1) < 0) & (s > 0) | ||
if np.any(neg): | ||
s[neg] = 0. | ||
warnings.warn("covariance is not positive-semidefinite.", | ||
RuntimeWarning) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would be the entire next part under no reason to check if we |
||
if check_valid != 'ignore': | ||
if check_valid != 'warn' and check_valid != 'raise': | ||
raise ValueError("check_valid must equal 'warn', 'raise', or 'ignore'") | ||
|
||
psd = np.allclose(np.dot(v.T * s, v), cov, rtol=tol, atol=tol) | ||
if not psd: | ||
if check_valid == 'warn': | ||
warnings.warn("covariance is not positive-semidefinite.", | ||
RuntimeWarning) | ||
else: | ||
raise ValueError("covariance is not positive-semidefinite.") | ||
|
||
x = np.dot(x, np.sqrt(s)[:, None] * v) | ||
x += mean | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a check for
check_valid
being one of the 3 accepted strings?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, move it from line 4511 up here. It's misplaced there; only checked for when
not psd
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that looks good now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that the check is itself later in the code, when checking. However, it may go unchecked when the psd passes, so I have rearranged the logic so it is always checked.