8000 Conform to recent changes in pymc (#194) · pymc-devs/pymc-bart@b9f4567 · GitHub
[go: up one dir, main page]

Skip to content

Commit b9f4567

Browse files
authored
Conform to recent changes in pymc (#194)
* conform to recent changes in pymc * update version * fix shapes
1 parent 1741d7d commit b9f4567

File tree

5 files changed

+23
-20
lines changed

5 files changed

+23
-20
lines changed

pymc_bart/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"plot_pdp",
3737
"plot_variable_importance",
3838
]
39-
__version__ = "0.7.0"
39+
__version__ = "0.7.1"
4040

4141

4242
pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]

pymc_bart/bart.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,22 @@ class BARTRV(RandomVariable):
3737
"""Base class for BART."""
3838

3939
name: str = "BART"
40-
ndim_supp = 1
41-
ndims_params: List[int] = [2, 1, 0, 0, 0, 1]
40+
signature = "(m,n),(m),(),(),() -> (m)"
4241
dtype: str = "floatX"
4342
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
4443
all_trees = List[List[List[Tree]]]
4544

4645
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed
47-
return dist_params[0].shape[:1]
46+
idx = dist_params[0].ndim - 2
47+
return [dist_params[0].shape[idx]]
4848

4949
@classmethod
5050
def rng_fn( # pylint: disable=W0237
51-
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None
51+
cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, size=None
5252
):
53+
if not size:
54+
size = None
55+
5356
if not cls.all_trees:
5457
if size is not None:
5558
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
@@ -96,9 +99,6 @@ class BART(Distribution):
9699
List of SplitRule objects, one per column in input data.
97100
Allows using different split rules for different columns. Default is ContinuousSplitRule.
98101
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
99-
shape: : Optional[Tuple], default None
100-
Specify the output shape. If shape is different from (len(X)) (the default), train a
101-
separate tree for each value in other dimensions.
102102
separate_trees : Optional[bool], default False
103103
When training multiple trees (by setting a shape parameter), the default behavior is to
104104
learn a joint tree structure and only have different leaf values for each.

pymc_bart/pgbart.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ class PGBART(ArrayStepShared):
114114
name = "pgbart"
115115
default_blocked = False
116116
generates_stats = True
117-
stats_dtypes = [{"variable_inclusion": object, "tune": bool}]
117+
stats_dtypes_shapes: dict[str, tuple[type, list]] = {
118+
"variable_inclusion": (object, []),
119+
"tune": (bool, []),
120+
}
118121

119122
def __init__( # noqa: PLR0915
120123
self,
@@ -227,7 +230,7 @@ def __init__( # noqa: PLR0915
227230
def astep(self, _):
228231
variable_inclusion = np.zeros(self.num_variates, dtype="int")
229232

230-
upper = min(self.lower + self.batch[~self.tune], self.m)
233+
upper = min(self.lower + self.batch[not self.tune], self.m)
231234
tree_ids = range(self.lower, upper)
232235
self.lower = upper if upper < self.m else 0
233236

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
pymc<=5.16.2
1+
pymc>=5.16.2, <=5.18
22
arviz>=0.18.0
33
numba
44
matplotlib

tests/test_bart.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from numpy.testing import assert_almost_equal, assert_array_equal
55
from pymc.initial_point import make_initial_point_fn
6-
from pymc.logprob.basic import joint_logp
6+
from pymc.logprob.basic import transformed_conditional_logp
77

88
import pymc_bart as pmb
99

@@ -12,7 +12,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
1212
fn = make_initial_point_fn(
1313
model=model,
1414
return_transformed=False,
15-
default_strategy="moment",
15+
default_strategy="support_point",
1616
)
1717
moment = fn(0)["x"]
1818
expected = np.asarray(expected)
@@ -27,7 +27,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
2727

2828
if check_finite_logp:
2929
logp_moment = (
30-
joint_logp(
30+
transformed_conditional_logp(
3131
(model["x"],),
3232
rvs_to_values={model["x"]: pm.math.constant(moment)},
3333
rvs_to_transforms={},
@@ -53,7 +53,7 @@ def test_bart_vi(response):
5353
mu = pmb.BART("mu", X, Y, m=10, response=response)
5454
sigma = pm.HalfNormal("sigma", 1)
5555
y = pm.Normal("y", mu, sigma, observed=Y)
56-
idata = pm.sample(random_seed=3415)
56+
idata = pm.sample(tune=200, draws=200, random_seed=3415)
5757
var_imp = (
5858
idata.sample_stats["variable_inclusion"]
5959
.stack(samples=("chain", "draw"))
@@ -77,8 +77,8 @@ def test_missing_data(response):
7777
with pm.Model() as model:
7878
mu = pmb.BART("mu", X, Y, m=10, response=response)
7979
sigma = pm.HalfNormal("sigma", 1)
80-
y = pm.Normal("y", mu, sigma, observed=Y)
81-
idata = pm.sample(tune=100, draws=100, chains=1, random_seed=3415)
80+
pm.Normal("y", mu, sigma, observed=Y)
81+
pm.sample(tune=100, draws=100, chains=1, random_seed=3415)
8282

8383

8484
@pytest.mark.parametrize(
@@ -91,7 +91,7 @@ def test_shared_variable(response):
9191
Y = np.random.normal(0, 1, size=50)
9292

9393
with pm.Model() as model:
94-
data_X = pm.MutableData("data_X", X)
94+
data_X = pm.Data("data_X", X)
9595
mu = pmb.BART("mu", data_X, Y, m=2, response=response)
9696
sigma = pm.HalfNormal("sigma", 1)
9797
y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape)
@@ -116,7 +116,7 @@ def test_shape(response):
116116
with pm.Model() as model:
117117
w = pmb.BART("w", X, Y, m=2, response=response, shape=(2, 250))
118118
y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y)
119-
idata = pm.sample(random_seed=3415)
119+
idata = pm.sample(tune=50, draws=10, random_seed=3415)
120120

121121
assert model.initial_point()["w"].shape == (2, 250)
122122
assert idata.posterior.coords["w_dim_0"].data.size == 2
@@ -133,7 +133,7 @@ class TestUtils:
133133
mu = pmb.BART("mu", X, Y, m=10)
134134
sigma = pm.HalfNormal("sigma", 1)
135135
y = pm.Normal("y", mu, sigma, observed=Y)
136-
idata = pm.sample(random_seed=3415)
136+
idata = pm.sample(tune=200, draws=200, random_seed=3415)
137137

138138
def test_sample_posterior(self):
139139
all_trees = self.mu.owner.op.all_trees

0 commit comments

Comments
 (0)
0