10000 improve docs, aesthetics and functionality by aloctavodia · Pull Request #198 · pymc-devs/pymc-bart · GitHub
[go: up one dir, main page]

Skip to content

improve docs, aesthetics and functionality #198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 70 additions & 17 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,10 +824,14 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
else:
shape = bartrv.eval().shape[0]

n_vars = X.shape[1]

if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns
X = X.to_numpy()
else:
labels = np.arange(n_vars).astype(str)

n_vars = X.shape[1]
r2_mean = np.zeros(n_vars)
r2_hdi = np.zeros((n_vars, 2))
preds = np.zeros((n_vars, samples, bartrv.eval().shape[0]))
Expand Down Expand Up @@ -947,6 +951,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

vi_results = {
"indices": indices,
"labels": labels[indices],
"r2_mean": r2_mean,
"r2_hdi": r2_hdi,
"preds": preds,
Expand All @@ -957,7 +962,6 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

def plot_variable_importance(
vi_results: dict,
X: npt.NDArray[np.float64],
labels=None,
figsize=None,
plot_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -1008,19 +1012,13 @@ def plot_variable_importance(
if figsize is None:
figsize = (8, 3)

if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns
X = X.to_numpy()

if ax is None:
_, ax = plt.subplots(1, 1, figsize=figsize)

if labels is None:
labels = np.arange(n_vars).astype(str)
else:
labels = np.asarray(labels)
labels = vi_results["labels"]

new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]

r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)])

Expand Down Expand Up @@ -1048,7 +1046,7 @@ def plot_variable_importance(
)
ax.set_xticks(
ticks,
new_labels,
labels,
rotation=plot_kwargs.get("rotation", 0),
)
ax.set_ylabel("R²", rotation=0, labelpad=12)
Expand All @@ -1058,25 +1056,80 @@ def plot_variable_importance(
return ax


def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None):
def plot_scatter_submodels(
vi_results: dict,
func: Optional[Callable] = None,
grid: str = "long",
labels=None,
figsize: Optional[Tuple[float, float]] = None,
plot_kwargs: Optional[Dict[str, Any]] = None,
axes: Optional[plt.Axes] = None,
):
"""
Plot submodel's predictions against reference-model's predictions.

Parameters
----------
vi_results: Dictionary
Dictionary computed with `compute_variable_importance`
func : Optional[Callable], by default None.
Arbitrary function to apply to the predictions. Defaults to the identity function.
grid : str or tuple
How to arrange the subplots. Defaults to "long", one subplot below the other.
Other options are "wide", one subplot next to each other or a tuple indicating the number
of rows and columns.
labels : Optional[List[str]]
List of the names of the covariates.
plot_kwargs : dict
Additional keyword arguments for the plot. Defaults to None.
Valid keys are:
- color_ref: matplotlib valid color for the 45 degree line
- color_scatter: matplotlib valid color for the scatter plot
axes : axes
Matplotlib axes.

Returns
-------
axes: matplotlib axes
"""
indices = vi_results["indices"]
preds = vi_results["preds"]
preds_all = vi_results["preds_all"]

if axes is None:
_, axes = _get_axes(grid, len(indices), False, True, None)
_, axes = _get_axes(grid, len(indices), True, True, figsize)

if plot_kwargs is None:
plot_kwargs = {}

if labels is None:
labels = vi_results["labels"]

labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]

func = None
if func is not None:
preds = func(preds)
preds_all = func(preds_all)

min_ = min(np.min(preds), np.min(preds_all))
max_ = max(np.max(preds), np.max(preds_all))

for pred, ax in zip(preds, axes.ravel()):
ax.plot(pred, preds_all, ".", color="C0", alpha=0.1)
ax.axline([min_, min_], [max_, max_], color="0.5")
for pred, x_label, ax in zip(preds, labels, axes.ravel()):
ax.plot(
pred,
preds_all,
marker=plot_kwargs.get("marker_scatter", "."),
ls="",
color=plot_kwargs.get("color_scatter", "C0"),
alpha=plot_kwargs.get("alpha_scatter", 0.1),
)
ax.set_xlabel(x_label)
ax.axline(
[min_, min_],
[max_, max_],
color=plot_kwargs.get("color_ref", "0.5"),
ls=plot_kwargs.get("ls_ref", "--"),
)


def generate_sequences(n_vars, i_var, include):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def test_vi(self, kwargs):
vi_results = pmb.compute_variable_importance(
self.idata, bartrv=self.mu, X=self.X, samples=samples
)
pmb.plot_variable_importance(vi_results, X=self.X, **kwargs)
pmb.plot_scatter_submodels(vi_results)
pmb.plot_variable_importance(vi_results, **kwargs)
pmb.plot_scatter_submodels(vi_results, **kwargs)

def test_pdp_pandas_labels(self):
pd = pytest.importorskip("pandas")
Expand Down
0