diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 37616dd38..f96e5b45e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -73,6 +73,7 @@ jobs: - uses: mamba-org/setup-micromamba@v1 with: environment-file: conda-envs/windows-environment-test.yml + micromamba-version: "1.5.10-0" # Until https://github.com/mamba-org/mamba/issues/3467 is not fixed create-args: >- python=${{matrix.python-version}} environment-name: pymc-experimental-test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 78c4a366b..66956b23b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,6 @@ +ci: + autofix_prs: false + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index d519ff3db..77dd3b22e 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -13,7 +13,8 @@ # limitations under the License. import logging -from pymc_experimental import distributions, gp, statespace, utils +from pymc_experimental import gp, statespace, utils +from pymc_experimental.distributions import * from pymc_experimental.inference.fit import fit from pymc_experimental.model.marginal_model import MarginalModel from pymc_experimental.model.model_api import as_model @@ -26,15 +27,3 @@ if len(_log.handlers) == 0: handler = logging.StreamHandler() _log.addHandler(handler) - - -__all__ = [ - "distributions", - "gp", - "statespace", - "utils", - "fit", - "MarginalModel", - "as_model", - "__version__", -] diff --git a/pymc_experimental/distributions/timeseries.py b/pymc_experimental/distributions/timeseries.py index 0e8659915..d4cd94356 100644 --- a/pymc_experimental/distributions/timeseries.py +++ b/pymc_experimental/distributions/timeseries.py @@ -20,7 +20,12 @@ from pymc.logprob.abstract import _logprob from pymc.logprob.basic import logp from pymc.pytensorf import constant_fold, intX -from pymc.util import check_dist_not_registered +from pymc.step_methods import STEP_METHODS +from pymc.step_methods.arraystep import ArrayStep +from pymc.step_methods.compound import Competence +from pymc.step_methods.metropolis import CategoricalGibbsMetropolis +from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars +from pytensor import Mode from pytensor.graph.basic import Node from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable @@ -101,10 +106,15 @@ class DiscreteMarkovChain(Distribution): Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P, 3 in this case. - >>> with pm.Model() as markov_chain: - >>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,)) - >>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3)) - >>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,)) + .. code-block:: python + + import pymc as pm + import pymc_experimental as pmx + + with pm.Model() as markov_chain: + P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,)) + init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3)) + markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,)) """ @@ -266,3 +276,69 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs): "P must sum to 1 along the last axis, " "First dimension of init_dist must be n_lags", ) + + +class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis): + name = "discrete_markov_chain_gibbs_metropolis" + + def __init__(self, vars, proposal="uniform", order="random", model=None): + model = pm.modelcontext(model) + vars = get_value_vars_from_user_vars(vars, model) + initial_point = model.initial_point() + + dimcats = [] + # The above variable is a list of pairs (aggregate dimension, number + # of categories). For example, if vars = [x, y] with x being a 2-D + # variable with M categories and y being a 3-D variable with N + # categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)]. + for v in vars: + v_init_val = initial_point[v.name] + rv_var = model.values_to_rvs[v] + rv_op = rv_var.owner.op + + if not isinstance(rv_op, DiscreteMarkovChainRV): + raise TypeError("All variables must be DiscreteMarkovChainRV") + + k_graph = rv_var.owner.inputs[0].shape[-1] + (k_graph,) = model.replace_rvs_by_values((k_graph,)) + k = model.compile_fn( + k_graph, + inputs=model.value_vars, + on_unused_input="ignore", + mode=Mode(linker="py", optimizer=None), + )(initial_point) + start = len(dimcats) + dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)] + + if order == "random": + self.shuffle_dims = True + self.dimcats = dimcats + else: + if sorted(order) != list(range(len(dimcats))): + raise ValueError("Argument 'order' has to be a permutation") + self.shuffle_dims = False + self.dimcats = [dimcats[j] for j in order] + + if proposal == "uniform": + self.astep = self.astep_unif + elif proposal == "proportional": + # Use the optimized "Metropolized Gibbs Sampler" described in Liu96. + self.astep = self.astep_prop + else: + raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'") + + # Doesn't actually tune, but it's required to emit a sampler stat + # that indicates whether a draw was done in a tuning phase. + self.tune = True + + # We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic + ArrayStep.__init__(self, vars, [model.compile_logp()]) + + @staticmethod + def competence(var): + if isinstance(var.owner.op, DiscreteMarkovChainRV): + return Competence.IDEAL + return Competence.INCOMPATIBLE + + +STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis) diff --git a/pymc_experimental/version.txt b/pymc_experimental/version.txt index d917d3e26..b1e80bb24 100644 --- a/pymc_experimental/version.txt +++ b/pymc_experimental/version.txt @@ -1 +1 @@ -0.1.2 +0.1.3 diff --git a/pyproject.toml b/pyproject.toml index 1fb70104c..d9bee4d68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,3 +76,4 @@ lines-between-types = 1 'F401', # Unused import warning for test files -- this check removes imports of fixtures 'F811' # Redefine while unused -- this check fails on imported fixtures ] +'pymc_experimental/__init__.py' = ['F401', 'F403'] diff --git a/tests/distributions/test_discrete_markov_chain.py b/tests/distributions/test_discrete_markov_chain.py index b2b1d9796..0d855ef44 100644 --- a/tests/distributions/test_discrete_markov_chain.py +++ b/tests/distributions/test_discrete_markov_chain.py @@ -5,10 +5,15 @@ import pytensor.tensor as pt import pytest +from pymc.distributions import Categorical from pymc.distributions.shape_utils import change_dist_size from pymc.logprob.utils import ParameterValueError +from pymc.sampling.mcmc import assign_step_methods -from pymc_experimental.distributions.timeseries import DiscreteMarkovChain +from pymc_experimental.distributions.timeseries import ( + DiscreteMarkovChain, + DiscreteMarkovChainGibbsMetropolis, +) def transition_probability_tests(steps, n_states, n_lags, n_draws, atol): @@ -216,3 +221,36 @@ def test_change_size_univariate(self): new_rw = change_dist_size(chain, new_size=(4, 3), expand=True) assert tuple(new_rw.shape.eval()) == (4, 3, 100, 5) + + def test_mcmc_sampling(self): + with pm.Model(coords={"step": range(100)}) as model: + init_dist = Categorical.dist(p=[0.5, 0.5]) + DiscreteMarkovChain( + "markov_chain", + P=[[0.1, 0.9], [0.1, 0.9]], + init_dist=init_dist, + shape=(100,), + dims="step", + ) + + step_method = assign_step_methods(model) + assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis) + + # Sampler needs no tuning + idata = pm.sample( + tune=0, chains=4, draws=250, progressbar=False, compute_convergence_checks=False + ) + + np.testing.assert_allclose( + idata.posterior["markov_chain"].isel(step=0).mean(("chain", "draw")), + 0.5, + atol=0.05, + ) + + np.testing.assert_allclose( + idata.posterior["markov_chain"].isel(step=slice(1, None)).mean(("chain", "draw")), + 0.9, + atol=0.05, + ) + + assert pm.stats.ess(idata, method="tail").min() > 950