-
-
Notifications
You must be signed in to change notification settings - Fork 69
Closed
Description
When I use Minibatch
in the follow model:
import pymc as pm
import numpy as np
from pymc_extras.model.marginal.marginal_model import marginalize
data = np.random.normal(size=10_000)
with pm.Model() as model:
d = pm.Data("data", data)
batched_data = pm.Minibatch(d, batch_size=100)
x = pm.Normal("x", 0., 1.)
b = pm.Bernoulli("b", 0.5, shape=(100,))
y = pm.Normal("y", b*x, total_size=len(data), observed=batched_data)
model2 = marginalize(model, [b])
I get the following error:
NotImplementedError Traceback (most recent call last)
File ~/upstream/pymc-extras/pymc_extras/model/marginal/marginal_model.py:561, in replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, dependent_rvs, input_rvs)
560 try:
--> 561 dependent_rvs_dim_connections = subgraph_batch_dim_connection(
562 rv_to_marginalize, dependent_rvs
563 )
564 except (ValueError, NotImplementedError) as e:
565 # For the perspective of the user this is a NotImplementedError
File ~/upstream/pymc-extras/pymc_extras/model/marginal/graph_analysis.py:365, in subgraph_batch_dim_connection(input_var, output_vars)
364 var_dims = {input_var: tuple(range(input_var.type.ndim))}
--> 365 var_dims = _subgraph_batch_dim_connection(var_dims, [input_var], output_vars)
366 ret = []
File ~/upstream/pymc-extras/pymc_extras/model/marginal/graph_analysis.py:317, in _subgraph_batch_dim_connection(var_dims, input_vars, output_vars)
316 else:
--> 317 raise NotImplementedError(f"Marginalization through operation {node} not supported.")
319 return var_dims
NotImplementedError: Marginalization through operation MinibatchRandomVariable(y, 10000) not supported.
The above exception was the direct cause of the following exception:
NotImplementedError Traceback (most recent call last)
Cell In[21], line 2
1 from pymc_extras.model.marginal.marginal_model import marginalize
----> 2 model2 = marginalize(model, [b])
File ~/upstream/pymc-extras/pymc_extras/model/marginal/marginal_model.py:244, in marginalize(model, rvs_to_marginalize)
237 other_direct_rv_ancestors = [
238 rv
239 for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
240 if rv is not rv_to_marginalize
241 ]
242 input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))
--> 244 replace_finite_discrete_marginal_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs)
246 return model_from_fgraph(fg, mutate_fgraph=True)
File ~/upstream/pymc-extras/pymc_extras/model/marginal/marginal_model.py:566, in replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, dependent_rvs, input_rvs)
561 dependent_rvs_dim_connections = subgraph_batch_dim_connection(
562 rv_to_marginalize, dependent_rvs
563 )
564 except (ValueError, NotImplementedError) as e:
565 # For the perspective of the user this is a NotImplementedError
--> 566 raise NotImplementedError(
567 "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. "
568 "You can try splitting the marginalized RV into separate components and marginalizing them separately."
569 ) from e
571 output_rvs = [rv_to_marginalize, *dependent_rvs]
572 rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)
NotImplementedError: The graph between the marginalized and dependent RVs cannot be marginalized efficiently. You can try splitting the marginalized RV into separate components and marginalizing them separately.
Metadata
Metadata
Assignees
Labels
No labels