diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8274275..1a21dcf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,14 +12,14 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.7 + rev: v0.6.3 hooks: - id: ruff args: ["--fix", "--output-format=full"] - id: ruff-format args: ["--line-length=100"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.1 + rev: v1.11.2 hooks: - id: mypy args: [--ignore-missing-imports] diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index a7b4cb5..c10b8f8 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -36,7 +36,7 @@ "plot_pdp", "plot_variable_importance", ] -__version__ = "0.6.0" +__version__ = "0.7.0" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index be4a8e8..91a9beb 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -138,6 +138,11 @@ def __init__( # noqa: PLR0915 else: self.X = self.bart.X + if isinstance(self.bart.Y, Variable): + self.Y = self.bart.Y.eval() + else: + self.Y = self.bart.Y + self.missing_data = np.any(np.isnan(self.X)) self.m = self.bart.m self.response = self.bart.response @@ -166,7 +171,7 @@ def __init__( # noqa: PLR0915 if rule is ContinuousSplitRule: self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.nanstd(self.X[:, idx])) - init_mean = self.bart.Y.mean() + init_mean = self.Y.mean() self.num_observations = self.X.shape[0] self.num_variates = self.X.shape[1] self.available_predictors = list(range(self.num_variates)) @@ -174,18 +179,18 @@ def __init__( # noqa: PLR0915 # if data is binary self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape)) - y_unique = np.unique(self.bart.Y) + y_unique = np.unique(self.Y) if y_unique.size == 2 and np.all(y_unique == [0, 1]): self.leaf_sd *= 3 / self.m**0.5 else: - self.leaf_sd *= self.bart.Y.std() / self.m**0.5 + self.leaf_sd *= self.Y.std() / self.m**0.5 self.running_sd = [ RunningSd((self.leaves_shape, self.num_observations)) for _ in range(self.trees_shape) ] self.sum_trees = np.full( - (self.trees_shape, self.leaves_shape, self.bart.Y.shape[0]), init_mean + (self.trees_shape, self.leaves_shape, self.Y.shape[0]), init_mean ).astype(config.floatX) self.sum_trees_noi = self.sum_trees - init_mean self.a_tree = Tree.new_tree( diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 9eee3b4..a50f2d9 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -8,10 +8,11 @@ import numpy as np import numpy.typing as npt import pytensor.tensor as pt +from numba import jit from pytensor.tensor.variable import Variable from scipy.interpolate import griddata from scipy.signal import savgol_filter -from scipy.stats import norm, pearsonr +from scipy.stats import norm from .tree import Tree @@ -699,9 +700,9 @@ def plot_variable_importance( # noqa: PLR0915 labels: Optional[List[str]] = None, method: str = "VI", figsize: Optional[Tuple[float, float]] = None, - xlabel_angle: float = 0, - samples: int = 100, + samples: int = 50, random_seed: Optional[int] = None, + plot_kwargs: Optional[Dict[str, Any]] = None, ax: Optional[plt.Axes] = None, ) -> Tuple[List[int], Union[List[plt.Axes], Any]]: """ @@ -726,13 +727,18 @@ def plot_variable_importance( # noqa: PLR0915 VI requieres less computation time. figsize : tuple Figure size. If None it will be defined automatically. - xlabel_angle : float - rotation angle of the x-axis labels. Defaults to 0. Use values like 45 for - long labels and/or many variables. samples : int Number of predictions used to compute correlation for subsets of variables. Defaults to 100 random_seed : Optional[int] random_seed used to sample from the posterior. Defaults to None. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color_r2: matplotlib valid color for error bars + - marker_r2: matplotlib valid marker for the mean R squared + - marker_fc_r2: matplotlib valid marker face color for the mean R squared + - ls_ref: matplotlib valid linestyle for the reference line + - color_ref: matplotlib valid color for the reference line ax : axes Matplotlib axes. @@ -745,6 +751,9 @@ def plot_variable_importance( # noqa: PLR0915 all_trees = bartrv.owner.op.all_trees + if plot_kwargs is None: + plot_kwargs = {} + if bartrv.ndim == 1: # type: ignore shape = 1 else: @@ -773,6 +782,10 @@ def plot_variable_importance( # noqa: PLR0915 all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape ) + r_2_ref = np.array( + [pearsonr2(predicted_all[j], predicted_all[j + 1]) for j in range(samples - 1)] + ) + if method == "VI": idxs = np.argsort( idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values @@ -794,10 +807,7 @@ def plot_variable_importance( # noqa: PLR0915 shape=shape, ) r_2 = np.array( - [ - pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] ** 2 - for j in range(samples) - ] + [pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)] ) r2_mean[idx] = np.mean(r_2) r2_hdi[idx] = az.hdi(r_2) @@ -833,10 +843,7 @@ def plot_variable_importance( # noqa: PLR0915 # Calculate Pearson correlation for each sample and find the mean r_2 = np.zeros(samples) for j in range(samples): - r_2[j] = ( - (pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]) - ** 2 - ) + r_2[j] = pearsonr2(predicted_all[j], predicted_subset[j]) mean_r_2 = np.mean(r_2, dtype=float) # Identify the least important combination of variables # based on the maximum mean squared Pearson correlation @@ -872,10 +879,26 @@ def plot_variable_importance( # noqa: PLR0915 ticks, r2_mean, np.array((r2_yerr_min, r2_yerr_max)), - color="C0", + color=plot_kwargs.get("color_r2", "k"), + fmt=plot_kwargs.get("marker_r2", "o"), + mfc=plot_kwargs.get("marker_fc_r2", "white"), + ) + ax.axhline( + np.mean(r_2_ref), + ls=plot_kwargs.get("ls_ref", "--"), + color=plot_kwargs.get("color_ref", "grey"), + ) + ax.fill_between( + [-0.5, n_vars - 0.5], + *az.hdi(r_2_ref), + alpha=0.1, + color=plot_kwargs.get("color_ref", "grey"), + ) + ax.set_xticks( + ticks, + new_labels, + rotation=plot_kwargs.get("rotation", 0), ) - ax.axhline(r2_mean[-1], ls="--", color="0.5") - ax.set_xticks(ticks, new_labels, rotation=xlabel_angle) ax.set_ylabel("R²", rotation=0, labelpad=12) ax.set_ylim(0, 1) ax.set_xlim(-0.5, n_vars - 0.5) @@ -890,3 +913,13 @@ def generate_sequences(n_vars, i_var, include): else: sequences = [()] return sequences + + +@jit(nopython=True) +def pearsonr2(A, B): + """Compute the squared Pearson correlation coefficient""" + A = A.flatten() + B = B.flatten() + am = A - np.mean(A) + bm = B - np.mean(B) + return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2)) diff --git a/requirements.txt b/requirements.txt index 23641cb..e741cef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc<5.16.0 +pymc<=5.16.2 arviz>=0.18.0 numba matplotlib