8000 Model marginalization doesn't work through Minibatch nodes · Issue #492 · pymc-devs/pymc-extras · GitHub
[go: up one dir, main page]

Skip to content
Model marginalization doesn't work through Minibatch nodes #492
@zaxtax

Description

@zaxtax

When I use Minibatch in the follow model:

import pymc as pm
import numpy as np
from pymc_extras.model.marginal.marginal_model import marginalize

data = np.random.normal(size=10_000)

with pm.Model() as model:
    d = pm.Data("data", data)
    batched_data = pm.Minibatch(d, batch_size=100)
    x = pm.Normal("x", 0., 1.)
    b = pm.Bernoulli("b", 0.5, shape=(100,))
    y = pm.Normal("y", b*x, total_size=len(data), observed=batched_data)

model2 = marginalize(model, [b])

I get the following error:

NotImplementedError                       Traceback (most recent call last)
File ~/upstream/pymc-extras/pymc_extras/model/marginal/marginal_model.py:561, in replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, dependent_rvs, input_rvs)
    560 try:
--> 561     dependent_rvs_dim_connections = subgraph_batch_dim_connection(
    562         rv_to_marginalize, dependent_rvs
    563     )
    564 except (ValueError, NotImplementedError) as e:
    565     # For the perspective of the user this is a NotImplementedError

File ~/upstream/pymc-extras/pymc_extras/model/marginal/graph_analysis.py:365, in subgraph_batch_dim_connection(input_var, output_vars)
    364 var_dims = {input_var: tuple(range(input_var.type.ndim))}
--> 365 var_dims = _subgraph_batch_dim_connection(var_dims, [input_var], output_vars)
    366 ret = []

File ~/upstream/pymc-extras/pymc_extras/model/marginal/graph_analysis.py:317, in _subgraph_batch_dim_connection(var_dims, input_vars, output_vars)
    316     else:
--> 317         raise NotImplementedError(f"Marginalization through operation {node} not supported.")
    319 return var_dims

NotImplementedError: Marginalization through operation MinibatchRandomVariable(y, 10000) not supported.

The above exception was the direct cause of the following exception:

NotImplementedError                       Traceback (most recent call last)
Cell In[21], line 2
      1 from pymc_extras.model.marginal.marginal_model import marginalize
----> 2 model2 = marginalize(model, [b])

File ~/upstream/pymc-extras/pymc_extras/model/marginal/marginal_model.py:244, in marginalize(model, rvs_to_marginalize)
    237     other_direct_rv_ancestors = [
    238         rv
    239         for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
    240         if rv is not rv_to_marginalize
    241     ]
    242     input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))
--> 244     replace_finite_discrete_marginal_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs)
    246 return model_from_fgraph(fg, mutate_fgraph=True)

File ~/upstream/pymc-extras/pymc_extras/model/marginal/marginal_model.py:566, in replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, dependent_rvs, input_rvs)
    561     dependent_rvs_dim_connections = subgraph_batch_dim_connection(
    562         rv_to_marginalize, dependent_rvs
    563     )
    564 except (ValueError, NotImplementedError) as e:
    565     # For the perspective of the user this is a NotImplementedError
--> 566     raise NotImplementedError(
    567         "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. "
    568         "You can try splitting the marginalized RV into separate components and marginalizing them separately."
    569     ) from e
    571 output_rvs = [rv_to_marginalize, *dependent_rvs]
    572 rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)

NotImplementedError: The graph between the marginalized and dependent RVs cannot be marginalized efficiently. You can try splitting the marginalized RV into separate components and marginalizing them separately.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0