8000 Patch for case when Y is a TensorVariable (#206) · pymc-devs/pymc-bart@77116d1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 77116d1

Browse files
Patch for case when Y is a TensorVariable (#206)
* add case tensor var for Y * Improve `isinstance` statement Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com> --------- Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com>
1 parent d4e8cad commit 77116d1

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pymc_bart/bart.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymc.logprob.abstract import _logprob
2727
from pytensor.tensor.random.op import RandomVariable
2828
from pytensor.tensor.sharedvar import TensorSharedVariable
29+
from pytensor.tensor.variable import TensorVariable
2930

3031
from .split_rules import SplitRule
3132
from .tree import Tree
@@ -54,7 +55,7 @@ def rng_fn( # pylint: disable=W0237
5455
if not size:
5556
size = None
5657

57-
if isinstance(cls.Y, TensorSharedVariable):
58+
if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)):
5859
Y = cls.Y.eval()
5960
else:
6061
Y = cls.Y

0 commit comments

Comments
 (0)
0