3
3
import pytest
4
4
from numpy .testing import assert_almost_equal , assert_array_equal
5
5
from pymc .initial_point import make_initial_point_fn
6
- from pymc .logprob .basic import joint_logp
6
+ from pymc .logprob .basic import transformed_conditional_logp
7
7
8
8
import pymc_bart as pmb
9
9
@@ -12,7 +12,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
12
12
fn = make_initial_point_fn (
13
13
model = model ,
14
14
return_transformed = False ,
15
- default_strategy = "moment " ,
15
+ default_strategy = "support_point " ,
16
16
)
17
17
moment = fn (0 )["x" ]
18
18
expected = np .asarray (expected )
@@ -27,7 +27,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
27
27
28
28
if check_finite_logp :
29
29
logp_moment = (
30
- joint_logp (
30
+ transformed_conditional_logp (
31
31
(model ["x" ],),
32
32
rvs_to_values = {model ["x" ]: pm .math .constant (moment )},
33
33
rvs_to_transforms = {},
@@ -53,7 +53,7 @@ def test_bart_vi(response):
53
53
mu = pmb .BART ("mu" , X , Y , m = 10 , response = response )
54
54
sigma = pm .HalfNormal ("sigma" , 1 )
55
55
y = pm .Normal ("y" , mu , sigma , observed = Y )
56
- idata = pm .sample (random_seed = 3415 )
56
+ idata = pm .sample (tune = 200 , draws = 200 , random_seed = 3415 )
57
57
var_imp = (
58
58
idata .sample_stats ["variable_inclusion" ]
59
59
.stack (samples = ("chain" , "draw" ))
@@ -77,8 +77,8 @@ def test_missing_data(response):
77
77
with pm .Model () as model :
78
78
mu = pmb .BART ("mu" , X , Y , m = 10 , response = response )
79
79
sigma = pm .HalfNormal ("sigma" , 1 )
80
- y = pm .Normal ("y" , mu , sigma , observed = Y )
81
- idata = pm .sample (tune = 100 , draws = 100 , chains = 1 , random_seed = 3415 )
80
+ pm .Normal ("y" , mu , sigma , observed = Y )
81
+ pm .sample (tune = 100 , draws = 100 , chains = 1 , random_seed = 3415 )
82
82
83
83
84
84
@pytest .mark .parametrize (
@@ -91,7 +91,7 @@ def test_shared_variable(response):
91
91
Y = np .random .normal (0 , 1 , size = 50 )
92
92
93
93
with pm .Model () as model :
94
- data_X = pm .MutableData ("data_X" , X )
94
+ data_X = pm .Data ("data_X" , X )
95
95
mu = pmb .BART ("mu" , data_X , Y , m = 2 , response = response )
96
96
sigma = pm .HalfNormal ("sigma" , 1 )
97
97
y = pm .Normal ("y" , mu , sigma , observed = Y , shape = mu .shape )
@@ -116,7 +116,7 @@ def test_shape(response):
116
116
with pm .Model () as model :
117
117
w = pmb .BART ("w" , X , Y , m = 2 , response = response , shape = (2 , 250 ))
118
118
y = pm .Normal ("y" , w [0 ], pm .math .abs (w [1 ]), observed = Y )
119
- idata = pm .sample (random_seed = 3415 )
119
+ idata = pm .sample (tune = 50 , draws = 10 , random_seed = 3415 )
120
120
121
121
assert model .initial_point ()["w" ].shape == (2 , 250 )
122
122
assert idata .posterior .coords ["w_dim_0" ].data .size == 2
@@ -133,7 +133,7 @@ class TestUtils:
133
133
mu = pmb .BART ("mu" , X , Y , m = 10 )
134
134
sigma = pm .HalfNormal ("sigma" , 1 )
135
135
y = pm .Normal ("y" , mu , sigma , observed = Y )
136
- idata = pm .sample (random_seed = 3415 )
136
+ idata = pm .sample (tune = 200 , draws = 200 , random_seed = 3415 )
137
137
138
138
def test_sample_posterior (self ):
139
139
all_trees = self .mu .owner .op .all_trees
0 commit comments