10000 Expose multivariate normal `method` argument in post-estimation tasks by jessegrabowski · Pull Request #484 · pymc-devs/pymc-extras · GitHub
[go: up one dir, main page]

Skip to content

Expose multivariate normal method argument in post-estimation tasks #484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Expose method argument to MvNormals used in statespace distribution…
…s when doing post-estimation tasks
  • Loading branch information
jessegrabowski committed May 28, 2025
commit 4d7dabae5ecab097afff9a7e2bc8b1af93a1f4b0
64 changes: 57 additions & 7 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,7 @@ def _sample_conditional(
group: str,
random_seed: RandomState | None = None,
data: pt.TensorLike | None = None,
method: str = "svd",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to type hint the options and/or mention them on the docstrings? Potentially mention the reason for choosing some over the other?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed this without having read this comment, so yes :)

**kwargs,
):
"""
Expand All @@ -1130,6 +1131,11 @@ def _sample_conditional(
Observed data on which to condition the model. If not provided, the function will use the data that was
provided when the model was built.

method: str
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
is "svd".

kwargs:
Additional keyword arguments are passed to pymc.sample_posterior_predictive

Expand Down Expand Up @@ -1181,6 +1187,7 @@ def _sample_conditional(
covs=cov,
logp=dummy_ll,
dims=state_dims,
method=method,
)

obs_mu = (Z @ mu[..., None]).squeeze(-1)
Expand All @@ -1192,6 +1199,7 @@ def _sample_conditional(
covs=obs_cov,
logp=dummy_ll,
dims=obs_dims,
method=method,
)

# TODO: Remove this after pm.Flat initial values are fixed
Expand Down Expand Up @@ -1222,6 +1230,7 @@ def _sample_unconditional(
steps: int | None = None,
use_data_time_dim: bool = False,
random_seed: RandomState | None = None,
method: str = "svd",
**kwargs,
):
"""
Expand Down Expand Up @@ -1251,6 +1260,11 @@ def _sample_unconditional(
random_seed : int, RandomState or Generator, optional
Seed for the random number generator.

method: str
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
is "svd".

kwargs:
Additional keyword arguments are passed to pymc.sample_posterior_predictive

Expand Down Expand Up @@ -1309,6 +1323,7 @@ def _sample_unconditional(
steps=steps,
dims=dims,
mode=self._fit_mode,
method=method,
sequence_names=self.kalman_filter.seq_names,
k_endog=self.k_endog,
)
Expand All @@ -1331,7 +1346,7 @@ def _sample_unconditional(
return idata_unconditional.posterior_predictive

def sample_conditional_prior(
self, idata: InferenceData, random_seed: RandomState | None = None, **kwargs
self, idata: InferenceData, random_seed: RandomState | None = None, method="svd", **kwargs
) -> InferenceData:
"""
Sample from the conditional prior; that is, given parameter draws from the prior distribution,
Expand All @@ -1347,6 +1362,11 @@ def sample_conditional_prior(
random_seed : int, RandomState or Generator, optional
Seed for the random number generator.

method: str
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
is "svd".

kwargs:
Additional keyword arguments are passed to pymc.sample_posterior_predictive

Expand All @@ -1358,10 +1378,10 @@ def sample_conditional_prior(
"predicted_prior", and "smoothed_prior".
"""

return self._sample_conditional(idata, "prior", random_seed, **kwargs)
return self._sample_conditional(idata, "prior", random_seed, method, **kwargs)

def sample_conditional_posterior(
self, idata: InferenceData, random_seed: RandomState | None = None, **kwargs
self, idata: InferenceData, random_seed: RandomState | None = None, method="svd", **kwargs
):
"""
Sample from the conditional posterior; that is, given parameter draws from the posterior distribution,
Expand All @@ -1376,6 +1396,11 @@ def sample_conditional_posterior(
random_seed : int, RandomState or Generator, optional
Seed for the random number generator.

method: str
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
is "svd".

kwargs:
Additional keyword arguments are passed to pymc.sample_posterior_predictive

Expand All @@ -1387,14 +1412,15 @@ def sample_conditional_posterior(
"predicted_posterior", and "smoothed_posterior".
"""

return self._sample_conditional(idata, "posterior", random_seed, **kwargs)
return self._sample_conditional(idata, "posterior", random_seed, method, **kwargs)

def sample_unconditional_prior(
self,
idata: InferenceData,
steps: int | None = None,
use_data_time_dim: bool = False,
random_seed: RandomState | None = None,
method="svd",
**kwargs,
) -> InferenceData:
"""
Expand Down Expand Up @@ -1423,6 +1449,11 @@ def sample_unconditional_prior(
random_seed : int, RandomState or Generator, optional
Seed for the random number generator.

method: str
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
is "svd".

kwargs:
Additional keyword arguments are passed to pymc.sample_posterior_predictive

Expand All @@ -1439,7 +1470,7 @@ def sample_unconditional_prior(
"""

return self._sample_unconditional(
idata, "prior", steps, use_data_time_dim, random_seed, **kwargs
idata, "prior", steps, use_data_time_dim, random_seed, method, **kwargs
)

def sample_unconditional_posterior(
Expand All @@ -1448,6 +1479,7 @@ def sample_unconditional_posterior(
steps: int | None = None,
use_data_time_dim: bool = False,
random_seed: RandomState | None = None,
method="svd",
**kwargs,
) -> InferenceData:
"""
Expand Down Expand Up @@ -1477,6 +1509,11 @@ def sample_unconditional_posterior(
random_seed : int, RandomState or Generator, optional
Seed for the random number generator.

method: str
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
is "svd".

Returns
-------
InferenceData
Expand All @@ -1490,7 +1527,7 @@ def sample_unconditional_posterior(
"""

return self._sample_unconditional(
idata, "posterior", steps, use_data_time_dim, random_seed, **kwargs
idata, "posterior", steps, use_data_time_dim, random_seed, method, **kwargs
)

def sample_statespace_matrices(
Expand Down Expand Up @@ -1933,6 +1970,7 @@ def forecast(
filter_output="smoothed",
random_seed: RandomState | None = None,
verbose: bool = True,
method: str = "svd",
**kwargs,
) -> InferenceData:
"""
Expand Down Expand Up @@ -1989,6 +2027,11 @@ def forecast(
verbose: bool, default=True
Whether to print diagnostic information about forecasting.

method: str
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
is "svd".

kwargs:
Additional keyword arguments are passed to pymc.sample_posterior_predictive

Expand Down Expand Up @@ -2098,6 +2141,7 @@ def forecast(
sequence_names=self.kalman_filter.seq_names,
k_endog=self.k_endog,
append_x0=False,
method=method,
)

forecast_model.rvs_to_initial_values = {
Expand Down Expand Up @@ -2126,6 +2170,7 @@ def impulse_response_function(
shock_trajectory: np.ndarray | None = None,
orthogonalize_shocks: bool = False,
random_seed: RandomState | None = None,
method="svd",
**kwargs,
):
"""
Expand Down Expand Up @@ -2177,6 +2222,11 @@ def impulse_response_function(
random_seed : int, RandomState or Generator, optional
Seed for the random number generator.

method: str
Method used to compute draws from multivariate normal. One of "cholesky", "eig", or "svd". "cholesky" is
fastest, but least robust to ill-conditioned matrices, while "svd" is slow but extermely robust. Default
is "svd".

kwargs:
Additional keyword arguments are passed to pymc.sample_posterior_predictive

Expand Down Expand Up @@ -2236,7 +2286,7 @@ def impulse_response_function(
shock_trajectory = pt.zeros((n_steps, self.k_posdef))
if Q is not None:
init_shock = pm.MvNormal(
"initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method="svd"
"initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method=method
)
else:
init_shock = pm.Deterministic(
Expand Down
23 changes: 15 additions & 8 deletions pymc_extras/statespace/filters/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __new__(
mode=None,
sequence_names=None,
append_x0=True,
method="svd",
**kwargs,
):
# Ignore dims in support shape because they are just passed along to the "observed" and "latent" distributions
Expand Down Expand Up @@ -100,6 +101,7 @@ def __new__(
mode=mode,
sequence_names=sequence_names,
append_x0=append_x0,
method=method,
**kwargs,
)

Expand All @@ -119,6 +121,7 @@ def dist(
mode=None,
sequence_names=None,
append_x0=True,
method="svd",
**kwargs,
):
steps = get_support_shape_1d(
Expand All @@ -135,6 +138,7 @@ def dist(
mode=mode,
sequence_names=sequence_names,
append_x0=append_x0,
method=method,
**kwargs,
)

Expand All @@ -155,6 +159,7 @@ def rv_op(
mode=None,
sequence_names=None,
append_x0=True,
method="svd",
):
if sequence_names is None:
sequence_names = []
Expand Down Expand Up @@ -205,10 +210,10 @@ def step_fn(*args):
a = state[:k]

middle_rng, a_innovation = pm.MvNormal.dist(
mu=0, cov=Q, rng=rng, method="svd"
mu=0, cov=Q, rng=rng, method=method
).owner.outputs
next_rng, y_innovation = pm.MvNormal.dist(
mu=0, cov=H, rng=middle_rng, method="svd"
mu=0, cov=H, rng=middle_rng, method=method
).owner.outputs

a_mu = c + T @ a
Expand All @@ -224,8 +229,8 @@ def step_fn(*args):
Z_init = Z_ if Z_ in non_sequences else Z_[0]
H_init = H_ if H_ in non_sequences else H_[0]

init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method="svd")
init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method="svd")
init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method=method)
init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method=method)

init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)

Expand Down Expand Up @@ -281,6 +286,7 @@ def __new__(
sequence_names=None,
mode=None,
append_x0=True,
method="svd",
**kwargs,
):
dims = kwargs.pop("dims", None)
Expand Down Expand Up @@ -310,6 +316,7 @@ def __new__(
mode=mode,
sequence_names=sequence_names,
append_x0=append_x0,
method=method,
**kwargs,
)
latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None))
Expand Down Expand Up @@ -368,11 +375,11 @@ def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)

@classmethod
def dist(cls, mus, covs, logp, **kwargs):
return super().dist([mus, covs, logp], **kwargs)
def dist(cls, mus, covs, logp, method="svd", **kwargs):
return super().dist([mus, covs, logp], method=method, **kwargs)

@classmethod
def rv_op(cls, mus, covs, logp, size=None):
def rv_op(cls, mus, covs, logp, method="svd", size=None):
# Batch dimensions (if any) will be on the far left, but scan requires time to be there instead
if mus.ndim > 2:
mus = pt.moveaxis(mus, -2, 0)
Expand All @@ -385,7 +392,7 @@ def rv_op(cls, mus, covs, logp, size=None):
rng = pytensor.shared(np.random.default_rng())

def step(mu, cov, rng):
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method=method).owner.outputs
return mvn, {rng: new_rng}

mvn_seq, updates = pytensor.scan(
Expand Down
0