From 9f0f76d0b630e9808d25bb414b34116888bf134a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 15 Sep 2024 11:13:43 +0200 Subject: [PATCH 1/5] Do not autofix PRs --- .pre-commit-config.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 78c4a366b..66956b23b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,6 @@ +ci: + autofix_prs: false + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 From c0c493142a9fb5120929cbff691e58ce35dafb00 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 11 Jul 2024 13:32:47 +0200 Subject: [PATCH 2/5] Expose distributions at root level of pymc_experimental Also `marginalize` --- pymc_experimental/__init__.py | 15 ++------------- pyproject.toml | 1 + 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index d519ff3db..77dd3b22e 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -13,7 +13,8 @@ # limitations under the License. import logging -from pymc_experimental import distributions, gp, statespace, utils +from pymc_experimental import gp, statespace, utils +from pymc_experimental.distributions import * from pymc_experimental.inference.fit import fit from pymc_experimental.model.marginal_model import MarginalModel from pymc_experimental.model.model_api import as_model @@ -26,15 +27,3 @@ if len(_log.handlers) == 0: handler = logging.StreamHandler() _log.addHandler(handler) - - -__all__ = [ - "distributions", - "gp", - "statespace", - "utils", - "fit", - "MarginalModel", - "as_model", - "__version__", -] diff --git a/pyproject.toml b/pyproject.toml index 1fb70104c..d9bee4d68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,3 +76,4 @@ lines-between-types = 1 'F401', # Unused import warning for test files -- this check removes imports of fixtures 'F811' # Redefine while unused -- this check fails on imported fixtures ] +'pymc_experimental/__init__.py' = ['F401', 'F403'] From f61e161754ff5b9ce9751ba51079340fced0ad76 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 11 Jul 2024 13:33:13 +0200 Subject: [PATCH 3/5] Implement step method sampler for DiscreteMarkovChain --- pymc_experimental/distributions/timeseries.py | 86 +++++++++++++++++-- .../test_discrete_markov_chain.py | 40 ++++++++- 2 files changed, 120 insertions(+), 6 deletions(-) diff --git a/pymc_experimental/distributions/timeseries.py b/pymc_experimental/distributions/timeseries.py index 0e8659915..d4cd94356 100644 --- a/pymc_experimental/distributions/timeseries.py +++ b/pymc_experimental/distributions/timeseries.py @@ -20,7 +20,12 @@ from pymc.logprob.abstract import _logprob from pymc.logprob.basic import logp from pymc.pytensorf import constant_fold, intX -from pymc.util import check_dist_not_registered +from pymc.step_methods import STEP_METHODS +from pymc.step_methods.arraystep import ArrayStep +from pymc.step_methods.compound import Competence +from pymc.step_methods.metropolis import CategoricalGibbsMetropolis +from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars +from pytensor import Mode from pytensor.graph.basic import Node from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable @@ -101,10 +106,15 @@ class DiscreteMarkovChain(Distribution): Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P, 3 in this case. - >>> with pm.Model() as markov_chain: - >>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,)) - >>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3)) - >>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,)) + .. code-block:: python + + import pymc as pm + import pymc_experimental as pmx + + with pm.Model() as markov_chain: + P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,)) + init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3)) + markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,)) """ @@ -266,3 +276,69 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs): "P must sum to 1 along the last axis, " "First dimension of init_dist must be n_lags", ) + + +class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis): + name = "discrete_markov_chain_gibbs_metropolis" + + def __init__(self, vars, proposal="uniform", order="random", model=None): + model = pm.modelcontext(model) + vars = get_value_vars_from_user_vars(vars, model) + initial_point = model.initial_point() + + dimcats = [] + # The above variable is a list of pairs (aggregate dimension, number + # of categories). For example, if vars = [x, y] with x being a 2-D + # variable with M categories and y being a 3-D variable with N + # categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)]. + for v in vars: + v_init_val = initial_point[v.name] + rv_var = model.values_to_rvs[v] + rv_op = rv_var.owner.op + + if not isinstance(rv_op, DiscreteMarkovChainRV): + raise TypeError("All variables must be DiscreteMarkovChainRV") + + k_graph = rv_var.owner.inputs[0].shape[-1] + (k_graph,) = model.replace_rvs_by_values((k_graph,)) + k = model.compile_fn( + k_graph, + inputs=model.value_vars, + on_unused_input="ignore", + mode=Mode(linker="py", optimizer=None), + )(initial_point) + start = len(dimcats) + dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)] + + if order == "random": + self.shuffle_dims = True + self.dimcats = dimcats + else: + if sorted(order) != list(range(len(dimcats))): + raise ValueError("Argument 'order' has to be a permutation") + self.shuffle_dims = False + self.dimcats = [dimcats[j] for j in order] + + if proposal == "uniform": + self.astep = self.astep_unif + elif proposal == "proportional": + # Use the optimized "Metropolized Gibbs Sampler" described in Liu96. + self.astep = self.astep_prop + else: + raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'") + + # Doesn't actually tune, but it's required to emit a sampler stat + # that indicates whether a draw was done in a tuning phase. + self.tune = True + + # We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic + ArrayStep.__init__(self, vars, [model.compile_logp()]) + + @staticmethod + def competence(var): + if isinstance(var.owner.op, DiscreteMarkovChainRV): + return Competence.IDEAL + return Competence.INCOMPATIBLE + + +STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis) diff --git a/tests/distributions/test_discrete_markov_chain.py b/tests/distributions/test_discrete_markov_chain.py index b2b1d9796..0d855ef44 100644 --- a/tests/distributions/test_discrete_markov_chain.py +++ b/tests/distributions/test_discrete_markov_chain.py @@ -5,10 +5,15 @@ import pytensor.tensor as pt import pytest +from pymc.distributions import Categorical from pymc.distributions.shape_utils import change_dist_size from pymc.logprob.utils import ParameterValueError +from pymc.sampling.mcmc import assign_step_methods -from pymc_experimental.distributions.timeseries import DiscreteMarkovChain +from pymc_experimental.distributions.timeseries import ( + DiscreteMarkovChain, + DiscreteMarkovChainGibbsMetropolis, +) def transition_probability_tests(steps, n_states, n_lags, n_draws, atol): @@ -216,3 +221,36 @@ def test_change_size_univariate(self): new_rw = change_dist_size(chain, new_size=(4, 3), expand=True) assert tuple(new_rw.shape.eval()) == (4, 3, 100, 5) + + def test_mcmc_sampling(self): + with pm.Model(coords={"step": range(100)}) as model: + init_dist = Categorical.dist(p=[0.5, 0.5]) + DiscreteMarkovChain( + "markov_chain", + P=[[0.1, 0.9], [0.1, 0.9]], + init_dist=init_dist, + shape=(100,), + dims="step", + ) + + step_method = assign_step_methods(model) + assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis) + + # Sampler needs no tuning + idata = pm.sample( + tune=0, chains=4, draws=250, progressbar=False, compute_convergence_checks=False + ) + + np.testing.assert_allclose( + idata.posterior["markov_chain"].isel(step=0).mean(("chain", "draw")), + 0.5, + atol=0.05, + ) + + np.testing.assert_allclose( + idata.posterior["markov_chain"].isel(step=slice(1, None)).mean(("chain", "draw")), + 0.9, + atol=0.05, + ) + + assert pm.stats.ess(idata, method="tail").min() > 950 From 57e3e5dd95dd4af26527a43a5da7318c742aea90 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 4 Oct 2024 09:19:26 +0200 Subject: [PATCH 4/5] Pin micromamba < 2.0 in Windows job --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 37616dd38..f96e5b45e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -73,6 +73,7 @@ jobs: - uses: mamba-org/setup-micromamba@v1 with: environment-file: conda-envs/windows-environment-test.yml + micromamba-version: "1.5.10-0" # Until https://github.com/mamba-org/mamba/issues/3467 is not fixed create-args: >- python=${{matrix.python-version}} environment-name: pymc-experimental-test From db36290f971ed73ed694ac16b74eaf6fb8419c79 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Date: Thu, 3 Oct 2024 12:09:13 +0200 Subject: [PATCH 5/5] Update version.txt --- pymc_experimental/version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/version.txt b/pymc_experimental/version.txt index d917d3e26..b1e80bb24 100644 --- a/pymc_experimental/version.txt +++ b/pymc_experimental/version.txt @@ -1 +1 @@ -0.1.2 +0.1.3