-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
BUG, API: np.random.multivariate_normal behavior with bad covariance matrix #5726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
fc57915
f555826
4c93e28
fde2617
6d7f14f
c85d199
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
In np.random.multivariate_normal allow the covariance matrix to have small floating point errors. And allow control over what to do if the PSD check fails.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4240,7 +4240,8 @@ cdef class RandomState: | |
self.lock) | ||
|
||
# Multivariate distributions: | ||
def multivariate_normal(self, mean, cov, size=None): | ||
def multivariate_normal(self, mean, cov, size=None, check_valid='warn', | ||
tol=1e-8): | ||
""" | ||
multivariate_normal(mean, cov[, size]) | ||
|
||
|
@@ -4265,6 +4266,10 @@ cdef class RandomState: | |
generated, and packed in an `m`-by-`n`-by-`k` arrangement. Because | ||
each sample is `N`-dimensional, the output shape is ``(m,n,k,N)``. | ||
If no shape is specified, a single (`N`-D) sample is returned. | ||
check_valid : 'warn', 'raise', 'ignore' | ||
Behavior when the covariance matrix is not Positive Semi-definite. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style nitpick: use lower case for positive semidefinite |
||
tol : float | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
Tolerance of the singular values in covariance matrix. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. absolute or relative tolerance? |
||
|
||
Returns | ||
------- | ||
|
@@ -4349,11 +4354,11 @@ cdef class RandomState: | |
shape = size | ||
|
||
if len(mean.shape) != 1: | ||
raise ValueError("mean must be 1 dimensional") | ||
raise ValueError("mean must be 1 dimensional") | ||
if (len(cov.shape) != 2) or (cov.shape[0] != cov.shape[1]): | ||
raise ValueError("cov must be 2 dimensional and square") | ||
raise ValueError("cov must be 2 dimensional and square") | ||
if mean.shape[0] != cov.shape[0]: | ||
raise ValueError("mean and cov must have same length") | ||
raise ValueError("mean and cov must have same length") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually, move it from line 4511 up here. It's misplaced there; only checked for when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that looks good now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note that the check is itself later in the code, when checking. However, it may go unchecked when the psd passes, so I have rearranged the logic so it is always checked. |
||
# Compute shape of output and create a matrix of independent | ||
# standard normally distributed random numbers. The matrix has rows | ||
|
@@ -4376,12 +4381,19 @@ cdef class RandomState: | |
# not zero. We continue to use the SVD rather than Cholesky in | ||
# order to preserve current outputs. Note that symmetry has not | ||
# been checked. | ||
|
||
(u, s, v) = svd(cov) | ||
neg = (np.sum(u.T * v, axis=1) < 0) & (s > 0) | ||
if np.any(neg): | ||
s[neg] = 0. | ||
warnings.warn("covariance is not positive-semidefinite.", | ||
RuntimeWarning) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would be the entire next part under no reason to check if we |
||
if check_valid != 'ignore': | ||
psd = np.allclose(np.dot(v.T * s, v), cov) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test with non-default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realize. I wonder if tol should be changed to atol and rtol and just forward them in here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You'll have to specify both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry my comment wasn't a reply to yours, non-updating interface.
|
||
if not psd: | ||
if check_valid == 'warn': | ||
warnings.warn("covariance is not positive-semidefinite.", | ||
RuntimeWarning) | ||
elif check_valid == 'raise': | ||
raise ValueError("covariance is not positive-semidefinite.") | ||
else: | ||
raise ValueError("check_valid must equal 'warn', 'raise', or 'ignore'") | ||
|
||
x = np.dot(x, np.sqrt(s)[:, None] * v) | ||
x += mean | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
formatting: change to
{'warn', 'raise', 'ignore'}, optional