diff --git a/.github/workflows/post-release.yml b/.github/workflows/post-release.yml
new file mode 100644
index 0000000..5526a27
--- /dev/null
+++ b/.github/workflows/post-release.yml
@@ -0,0 +1,19 @@
+name: Post-release
+on:
+ release:
+ types: [published, released]
+ workflow_dispatch:
+
+jobs:
+ changelog:
+ name: Update changelog
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ ref: main
+ - uses: rhysd/changelog-from-release/action@v3
+ with:
+ file: CHANGELOG.md
+ github_token: ${{ secrets.GITHUB_TOKEN }}
+ commit_summary_template: 'update changelog for %s changes'
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8274275..7bd307d 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -12,21 +12,21 @@ ci:
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.5.7
+ rev: v0.11.11
hooks:
- id: ruff
args: ["--fix", "--output-format=full"]
- id: ruff-format
args: ["--line-length=100"]
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v1.11.1
+ rev: v1.15.0
hooks:
- id: mypy
args: [--ignore-missing-imports]
files: ^pymc_bart/
additional_dependencies: [numpy, pandas-stubs]
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.6.0
+ rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 6e5cef0..691fce7 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -1,11 +1,14 @@
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
version: 2
+sphinx:
+ # Path to your Sphinx configuration file.
+ configuration: docs/conf.py
build:
- os: ubuntu-20.04
+ os: ubuntu-24.04
tools:
- python: "3.10"
+ python: "3.12"
python:
install:
@@ -13,3 +16,15 @@ python:
- requirements: requirements.txt
- method: pip
path: .
+
+search:
+ ranking:
+ _sources/*: -10
+ _modules/*: -5
+ genindex.html: -9
+
+ ignore:
+ - 404.html
+ - search.html
+ - index.html
+ - 'examples/*'
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..99d410f
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,465 @@
+
+# [0.9.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.9.1) - 2025-05-19
+
+## What's Changed
+* misc doc improvements and theme update by [@OriolAbril](https://github.com/OriolAbril) in [#225](https://github.com/pymc-devs/pymc-bart/pull/225)
+* Use last pymc version by [@aloctavodia](https://github.com/aloctavodia) in [#227](https://github.com/pymc-devs/pymc-bart/pull/227)
+
+## New Contributors
+* [@OriolAbril](https://github.com/OriolAbril) made their first contribution in [#225](https://github.com/pymc-devs/pymc-bart/pull/225)
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.9.0...0.9.1
+
+[Changes][0.9.1]
+
+
+
+# [0.9.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.9.0) - 2025-03-10
+
+## What's Changed
+* Update MyPy 14 by [@juanitorduz](https://github.com/juanitorduz) in [#210](https://github.com/pymc-devs/pymc-bart/pull/210)
+* Automatic Changelog by [@aloctavodia](https://github.com/aloctavodia) in [#213](https://github.com/pymc-devs/pymc-bart/pull/213)
+* Adds get_variable_inclusion function by [@aloctavodia](https://github.com/aloctavodia) in [#214](https://github.com/pymc-devs/pymc-bart/pull/214)
+* Refactor rng_fn method by [@aloctavodia](https://github.com/aloctavodia) in [#212](https://github.com/pymc-devs/pymc-bart/pull/212)
+* Fix docs by adding path of config by [@juanitorduz](https://github.com/juanitorduz) in [#217](https://github.com/pymc-devs/pymc-bart/pull/217)
+* Enhance `plot_pdp` and fix `plot_scatter_submodels` by [@AlexAndorra](https://github.com/AlexAndorra) in [#218](https://github.com/pymc-devs/pymc-bart/pull/218)
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.8.2...0.9.0
+
+[Changes][0.9.0]
+
+
+
+# [0.8.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.8.2) - 2024-12-23
+
+## What's Changed
+* Compute_variable_importance: fix bug with non-default shapes by [@aloctavodia](https://github.com/aloctavodia) in [#208](https://github.com/pymc-devs/pymc-bart/pull/208)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.8.1...0.8.2
+
+[Changes][0.8.2]
+
+
+
+# [0.8.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.8.1) - 2024-12-20
+
+## What's Changed
+* Patch for case when Y is a TensorVariable by [@AlexAndorra](https://github.com/AlexAndorra) in [#206](https://github.com/pymc-devs/pymc-bart/pull/206)
+* Fix bug with labels in variable importance, add reference line, remove deprecation warning by [@aloctavodia](https://github.com/aloctavodia) in [#207](https://github.com/pymc-devs/pymc-bart/pull/207)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.8.0...0.8.1
+
+[Changes][0.8.1]
+
+
+
+# [0.8.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.8.0) - 2024-12-17
+
+## What's Changed
+
+* Add new vi plots by [@aloctavodia](https://github.com/aloctavodia) in [#196](https://github.com/pymc-devs/pymc-bart/pull/196)
+* Allows plotting a subset of the variables once the variable's importance has been computed by [@aloctavodia](https://github.com/aloctavodia) in [#200](https://github.com/pymc-devs/pymc-bart/pull/200)
+* Enable passing `Y` as a `SharedVariable` to `pm.Bart` by [@AlexAndorra](https://github.com/AlexAndorra) in [#202](https://github.com/pymc-devs/pymc-bart/pull/202)
+* Improve docs, aesthetics and functionality by [@aloctavodia](https://github.com/aloctavodia) in [#198](https://github.com/pymc-devs/pymc-bart/pull/198)
+
+
+## New Contributors
+* [@AlexAndorra](https://github.com/AlexAndorra) made their first contribution in [#202](https://github.com/pymc-devs/pymc-bart/pull/202)
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.7.1...0.8.0
+
+[Changes][0.8.0]
+
+
+
+# [0.7.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.7.1) - 2024-11-07
+
+## What's Changed
+* Conform to recent changes in pymc by [@aloctavodia](https://github.com/aloctavodia) in [#194](https://github.com/pymc-devs/pymc-bart/pull/194)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.7.0...0.7.1
+
+[Changes][0.7.1]
+
+
+
+# [0.7.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.7.0) - 2024-09-05
+
+## What's Changed
+* Allow Y to be a tensor by [@aloctavodia](https://github.com/aloctavodia) in [#180](https://github.com/pymc-devs/pymc-bart/pull/180)
+* improve plot_variable_importance by [@aloctavodia](https://github.com/aloctavodia) in [#182](https://github.com/pymc-devs/pymc-bart/pull/182)
+* move x_angle to plot_kwargs by [@aloctavodia](https://github.com/aloctavodia) in [#185](https://github.com/pymc-devs/pymc-bart/pull/185)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.6.0...0.7.0
+
+[Changes][0.7.0]
+
+
+
+# [0.6.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.6.0) - 2024-08-16
+
+## What's Changed
+* Add categorical example by [@PabloGGaray](https://github.com/PabloGGaray) in [#167](https://github.com/pymc-devs/pymc-bart/pull/167)
+* Fix np.float_ type by [@juanitorduz](https://github.com/juanitorduz) in [#171](https://github.com/pymc-devs/pymc-bart/pull/171)
+* Support Polars by [@aloctavodia](https://github.com/aloctavodia) in [#179](https://github.com/pymc-devs/pymc-bart/pull/179)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.14...0.6.0
+
+[Changes][0.6.0]
+
+
+
+# [0.5.14](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.14) - 2024-05-14
+
+## What's Changed
+* Less than equal PyMC Version by [@juanitorduz](https://github.com/juanitorduz) in [#164](https://github.com/pymc-devs/pymc-bart/pull/164)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.13...0.5.14
+
+[Changes][0.5.14]
+
+
+
+# [0.5.13](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.13) - 2024-05-13
+
+## What's Changed
+* Update pymc version requirements.txt by [@juanitorduz](https://github.com/juanitorduz) in [#163](https://github.com/pymc-devs/pymc-bart/pull/163)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.12...0.5.13
+
+[Changes][0.5.13]
+
+
+
+# [0.5.12](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.12) - 2024-04-18
+
+## What's Changed
+* Unpin numpy by [@maresb](https://github.com/maresb) in [#156](https://github.com/pymc-devs/pymc-bart/pull/156)
+* Resolve deprecation warning for `pytensor`'s `Variable` by [@RyanAugust](https://github.com/RyanAugust) in [#159](https://github.com/pymc-devs/pymc-bart/pull/159)
+
+## New Contributors
+* [@RyanAugust](https://github.com/RyanAugust) made their first contribution in [#159](https://github.com/pymc-devs/pymc-bart/pull/159)
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.11...0.5.12
+
+[Changes][0.5.12]
+
+
+
+# [0.5.11](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.11) - 2024-03-15
+
+## What's Changed
+* Add citation file by [@PabloGGaray](https://github.com/PabloGGaray) in [#151](https://github.com/pymc-devs/pymc-bart/pull/151)
+* Rename moment to support_point by [@PabloGGaray](https://github.com/PabloGGaray) in [#154](https://github.com/pymc-devs/pymc-bart/pull/154)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.10...0.5.11
+
+[Changes][0.5.11]
+
+
+
+# [0.5.10](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.10) - 2024-03-14
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.8...0.5.10
+
+[Changes][0.5.10]
+
+
+
+# [0.5.9](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.9) - 2024-03-14
+
+## What's Changed
+* Ruff linter + pre-commit integration by [@juanitorduz](https://github.com/juanitorduz) in [#140](https://github.com/pymc-devs/pymc-bart/pull/140)
+* Improve CONTRIBUTING guidelines by [@juanitorduz](https://github.com/juanitorduz) in [#141](https://github.com/pymc-devs/pymc-bart/pull/141)
+* Add Usage and Table of Contents, to the README file, enhance Installation section, and fix top header by [@NicholasLindner](https://github.com/NicholasLindner) in [#143](https://github.com/pymc-devs/pymc-bart/pull/143)
+
+## New Contributors
+* [@NicholasLindner](https://github.com/NicholasLindner) made their first contribution in [#143](https://github.com/pymc-devs/pymc-bart/pull/143)
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.7...0.5.9
+
+[Changes][0.5.9]
+
+
+
+# [0.5.8](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.8) - 2024-03-14
+
+## What's Changed
+* Ruff linter + pre-commit integration by [@juanitorduz](https://github.com/juanitorduz) in [#140](https://github.com/pymc-devs/pymc-bart/pull/140)
+* Improve CONTRIBUTING guidelines by [@juanitorduz](https://github.com/juanitorduz) in [#141](https://github.com/pymc-devs/pymc-bart/pull/141)
+* Add Usage and Table of Contents, to the README file, enhance Installation section, and fix top header by [@NicholasLindner](https://github.com/NicholasLindner) in [#143](https://github.com/pymc-devs/pymc-bart/pull/143)
+
+
+## New Contributors
+* [@NicholasLindner](https://github.com/NicholasLindner) made their first contribution in [#143](https://github.com/pymc-devs/pymc-bart/pull/143)
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.7...0.5.8
+
+[Changes][0.5.8]
+
+
+
+# [0.5.7](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.7) - 2023-12-29
+
+## What's Changed
+* Properly handle nans when jittering by [@aloctavodia](https://github.com/aloctavodia) in [#136](https://github.com/pymc-devs/pymc-bart/pull/136)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.6...0.5.7
+
+[Changes][0.5.7]
+
+
+
+# [0.5.6](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.6) - 2023-12-23
+
+## What's Changed
+* Fix bug in plot_ice, and clean docstring of plot_ice and plot_pdp by [@aloctavodia](https://github.com/aloctavodia) in [#135](https://github.com/pymc-devs/pymc-bart/pull/135)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.5...0.5.6
+
+[Changes][0.5.6]
+
+
+
+# [0.5.5](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.5) - 2023-12-22
+
+## What's Changed
+* add jitter to duplicated values for continuous splitting rule by [@aloctavodia](https://github.com/aloctavodia) in [#129](https://github.com/pymc-devs/pymc-bart/pull/129)
+* link GitHub icon to pymc-bart repo by [@aloctavodia](https://github.com/aloctavodia) in [#131](https://github.com/pymc-devs/pymc-bart/pull/131)
+* VI remove unnecessary evaluations for the backward method by [@aloctavodia](https://github.com/aloctavodia) in [#132](https://github.com/pymc-devs/pymc-bart/pull/132)
+* jitter only arrays of whole numbers by [@aloctavodia](https://github.com/aloctavodia) in [#133](https://github.com/pymc-devs/pymc-bart/pull/133)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.3...0.5.5
+
+[Changes][0.5.5]
+
+
+
+# [0.5.4](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.4) - 2023-11-21
+
+## What's Changed
+* add jitter to duplicated values for continuous splitting rule by [@aloctavodia](https://github.com/aloctavodia) in [#129](https://github.com/pymc-devs/pymc-bart/pull/129)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.3...0.5.4
+
+[Changes][0.5.4]
+
+
+
+# [0.5.3](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.3) - 2023-11-18
+
+## What's Changed
+* improve variable importance computation by adding backward method by [@aloctavodia](https://github.com/aloctavodia) in [#125](https://github.com/pymc-devs/pymc-bart/pull/125)
+* set new paths to notebooks by [@aloctavodia](https://github.com/aloctavodia) in [#126](https://github.com/pymc-devs/pymc-bart/pull/126)
+* fix case examples by [@aloctavodia](https://github.com/aloctavodia) in [#127](https://github.com/pymc-devs/pymc-bart/pull/127)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.2...0.5.3
+
+[Changes][0.5.3]
+
+
+
+# [0.5.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.2) - 2023-10-27
+
+## What's Changed
+* Minor doctrings and types improvements by [@juanitorduz](https://github.com/juanitorduz) in [#108](https://github.com/pymc-devs/pymc-bart/pull/108)
+* Fix ICE plot when there is a discrete variable by [@juanitorduz](https://github.com/juanitorduz) in [#107](https://github.com/pymc-devs/pymc-bart/pull/107)
+* Add support python 3.11 by [@juanitorduz](https://github.com/juanitorduz) in [#109](https://github.com/pymc-devs/pymc-bart/pull/109)
+* Add issue templates by [@PabloGGaray](https://github.com/PabloGGaray) in [#113](https://github.com/pymc-devs/pymc-bart/pull/113)
+* Add conda option by [@PabloGGaray](https://github.com/PabloGGaray) in [#114](https://github.com/pymc-devs/pymc-bart/pull/114)
+* fix split_prior bug by [@aloctavodia](https://github.com/aloctavodia) in [#115](https://github.com/pymc-devs/pymc-bart/pull/115)
+* Add logo by [@aloctavodia](https://github.com/aloctavodia) in [#116](https://github.com/pymc-devs/pymc-bart/pull/116)
+* clean logo by [@aloctavodia](https://github.com/aloctavodia) in [#117](https://github.com/pymc-devs/pymc-bart/pull/117)
+* Add plot_ice to API description on the webpage by [@PabloGGaray](https://github.com/PabloGGaray) in [#119](https://github.com/pymc-devs/pymc-bart/pull/119)
+* Better handling of discrete variables and other minor fixes by [@aloctavodia](https://github.com/aloctavodia) in [#121](https://github.com/pymc-devs/pymc-bart/pull/121)
+
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.0...0.5.2
+
+[Changes][0.5.2]
+
+
+
+# [O.5.1](https://github.com/pymc-devs/pymc-bart/releases/tag/O.5.1) - 2023-07-12
+
+## What's Changed
+* Minor doctrings and types improvements by [@juanitorduz](https://github.com/juanitorduz) in [#108](https://github.com/pymc-devs/pymc-bart/pull/108)
+* Fix ICE plot when there is a discrete variable by [@juanitorduz](https://github.com/juanitorduz) in [#107](https://github.com/pymc-devs/pymc-bart/pull/107)
+* Add support python 3.11 by [@juanitorduz](https://github.com/juanitorduz) in [#109](https://github.com/pymc-devs/pymc-bart/pull/109)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.5.0...O.5.1
+
+[Changes][O.5.1]
+
+
+
+# [0.5.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.5.0) - 2023-07-10
+
+## What's Changed
+* Add pre-commit hooks by [@juanitorduz](https://github.com/juanitorduz) in [#75](https://github.com/pymc-devs/pymc-bart/pull/75)
+* Add mypy init by [@juanitorduz](https://github.com/juanitorduz) in [#78](https://github.com/pymc-devs/pymc-bart/pull/78)
+* Do not store index at each node. by [@howsiyu](https://github.com/howsiyu) in [#80](https://github.com/pymc-devs/pymc-bart/pull/80)
+* Add linear response [@juanitorduz](https://github.com/juanitorduz) in [#79](https://github.com/pymc-devs/pymc-bart/pull/79)
+* Do weighted mean when pruning by [@aloctavodia](https://github.com/aloctavodia) in [#83](https://github.com/pymc-devs/pymc-bart/pull/83)
+* Implement fast version of pdp by [@aloctavodia](https://github.com/aloctavodia) in [#85](https://github.com/pymc-devs/pymc-bart/pull/85)
+* Add error bars to variable importance by [@aloctavodia](https://github.com/aloctavodia) in [#90](https://github.com/pymc-devs/pymc-bart/pull/90)
+* Compute running variance for leaf nodes by [@aloctavodia](https://github.com/aloctavodia) in [#91](https://github.com/pymc-devs/pymc-bart/pull/91)
+* Improve doc style and add missing examples by [@aloctavodia](https://github.com/aloctavodia) in [#92](https://github.com/pymc-devs/pymc-bart/pull/92)
+* Make the Repo more welcoming with a clear title by [@juanitorduz](https://github.com/juanitorduz) in [#94](https://github.com/pymc-devs/pymc-bart/pull/94)
+* Improve docstrings new alpha and beta parameters by [@juanitorduz](https://github.com/juanitorduz) in [#95](https://github.com/pymc-devs/pymc-bart/pull/95)
+* Allow different splitting rules by [@velochy](https://github.com/velochy) in [#96](https://github.com/pymc-devs/pymc-bart/pull/96)
+* Allow training separate tree structures if training multiple trees by [@velochy](https://github.com/velochy) in [#98](https://github.com/pymc-devs/pymc-bart/pull/98)
+
+## New Contributors
+* [@howsiyu](https://github.com/howsiyu) made their first contribution in [#80](https://github.com/pymc-devs/pymc-bart/pull/80)
+* [@velochy](https://github.com/velochy) made their first contribution in [#96](https://github.com/pymc-devs/pymc-bart/pull/96)
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.4.0...0.5.0
+
+[Changes][0.5.0]
+
+
+
+# [0.4.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.4.0) - 2023-03-17
+
+## What's Changed
+* fig bug systematic resampling and add func argument by [@aloctavodia](https://github.com/aloctavodia) in [#61](https://github.com/pymc-devs/pymc-bart/pull/61) and [#66](https://github.com/pymc-devs/pymc-bart/pull/66)
+* add tests for individual functions/methods in PGBART by [@aloctavodia](https://github.com/aloctavodia) in [#64](https://github.com/pymc-devs/pymc-bart/pull/64)
+* Modify resampling schema and refactor by [@aloctavodia](https://github.com/aloctavodia) in [#65](https://github.com/pymc-devs/pymc-bart/pull/65)
+* add plot_convergence by [@aloctavodia](https://github.com/aloctavodia) in [#67](https://github.com/pymc-devs/pymc-bart/pull/67) and [@aloctavodia](https://github.com/aloctavodia) in [#68](https://github.com/pymc-devs/pymc-bart/pull/68)
+* Improve plot_dependence by [@PabloGGaray](https://github.com/PabloGGaray) in [#70](https://github.com/pymc-devs/pymc-bart/pull/70) and [@aloctavodia](https://github.com/aloctavodia) in [#71](https://github.com/pymc-devs/pymc-bart/pull/71) and in [#73](https://github.com/pymc-devs/pymc-bart/pull/73)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.3.2...0.4.0
+
+[Changes][0.4.0]
+
+
+
+# [0.3.2](https://github.com/pymc-devs/pymc-bart/releases/tag/0.3.2) - 2023-02-03
+
+## What's Changed
+* Refactor and [@njit](https://github.com/njit) on methods by [@fjloyola](https://github.com/fjloyola) in [#54](https://github.com/pymc-devs/pymc-bart/pull/54)
+* Fix shape error [@aloctavodia](https://github.com/aloctavodia) in [#57](https://github.com/pymc-devs/pymc-bart/pull/57) and [#59](https://github.com/pymc-devs/pymc-bart/pull/59)
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.3.1...0.3.2
+
+[Changes][0.3.2]
+
+
+
+# [0.3.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.3.1) - 2023-01-26
+
+## What's Changed
+* Fix Url pymc-bart on documentation by [@fjloyola](https://github.com/fjloyola) in [#34](https://github.com/pymc-devs/pymc-bart/pull/34)
+* Fixing issue ThemeError for read the docs by [@fjloyola](https://github.com/fjloyola) in [#37](https://github.com/pymc-devs/pymc-bart/pull/37)
+* Refactor to avoid inheritance in BaseNode by [@fjloyola](https://github.com/fjloyola) in [#35](https://github.com/pymc-devs/pymc-bart/pull/35)
+* Add link to license by [@PabloGGaray](https://github.com/PabloGGaray) in [#39](https://github.com/pymc-devs/pymc-bart/pull/39)
+* Improvements over Tree implementation by [@fjloyola](https://github.com/fjloyola) in [#40](https://github.com/pymc-devs/pymc-bart/pull/40)
+* fix import error from pymc 5.0.2 by [@juanitorduz](https://github.com/juanitorduz) in [#43](https://github.com/pymc-devs/pymc-bart/pull/43)
+* Update pymc minimum version by [@aloctavodia](https://github.com/aloctavodia) in [#45](https://github.com/pymc-devs/pymc-bart/pull/45)
+* Avoid Deepcopy on Tree and ParticleTree by [@fjloyola](https://github.com/fjloyola) in [#47](https://github.com/pymc-devs/pymc-bart/pull/47)
+
+## New Contributors
+* [@fjloyola](https://github.com/fjloyola) made their first contribution in [#34](https://github.com/pymc-devs/pymc-bart/pull/34)
+* [@juanitorduz](https://github.com/juanitorduz) made their first contribution in [#43](https://github.com/pymc-devs/pymc-bart/pull/43)
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.3.0...0.3.1
+
+[Changes][0.3.1]
+
+
+
+# [0.3.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.3.0) - 2022-12-22
+
+## What's Changed
+* Update README with conda installation by [@maresb](https://github.com/maresb) in [#26](https://github.com/pymc-devs/pymc-bart/pull/26)
+* Fix broken URL by [@maresb](https://github.com/maresb) in [#27](https://github.com/pymc-devs/pymc-bart/pull/27)
+* Update to PyMC 5 and PyTensor by [@aloctavodia](https://github.com/aloctavodia) in [#29](https://github.com/pymc-devs/pymc-bart/pull/29)
+
+
+**Full Changelog**: https://github.com/pymc-devs/pymc-bart/compare/0.2.1...0.3.0
+
+[Changes][0.3.0]
+
+
+
+# [0.2.1](https://github.com/pymc-devs/pymc-bart/releases/tag/0.2.1) - 2022-11-07
+
+
+
+[Changes][0.2.1]
+
+
+
+# [0.2.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.2.0) - 2022-11-03
+
+
+
+[Changes][0.2.0]
+
+
+
+# [0.1.0](https://github.com/pymc-devs/pymc-bart/releases/tag/0.1.0) - 2022-10-26
+
+
+
+[Changes][0.1.0]
+
+
+
+# [0.0.3](https://github.com/pymc-devs/pymc-bart/releases/tag/0.0.3) - 2022-09-13
+
+
+
+[Changes][0.0.3]
+
+
+[0.9.1]: https://github.com/pymc-devs/pymc-bart/compare/0.9.0...0.9.1
+[0.9.0]: https://github.com/pymc-devs/pymc-bart/compare/0.8.2...0.9.0
+[0.8.2]: https://github.com/pymc-devs/pymc-bart/compare/0.8.1...0.8.2
+[0.8.1]: https://github.com/pymc-devs/pymc-bart/compare/0.8.0...0.8.1
+[0.8.0]: https://github.com/pymc-devs/pymc-bart/compare/0.7.1...0.8.0
+[0.7.1]: https://github.com/pymc-devs/pymc-bart/compare/0.7.0...0.7.1
+[0.7.0]: https://github.com/pymc-devs/pymc-bart/compare/0.6.0...0.7.0
+[0.6.0]: https://github.com/pymc-devs/pymc-bart/compare/0.5.14...0.6.0
+[0.5.14]: https://github.com/pymc-devs/pymc-bart/compare/0.5.13...0.5.14
+[0.5.13]: https://github.com/pymc-devs/pymc-bart/compare/0.5.12...0.5.13
+[0.5.12]: https://github.com/pymc-devs/pymc-bart/compare/0.5.11...0.5.12
+[0.5.11]: https://github.com/pymc-devs/pymc-bart/compare/0.5.10...0.5.11
+[0.5.10]: https://github.com/pymc-devs/pymc-bart/compare/0.5.9...0.5.10
+[0.5.9]: https://github.com/pymc-devs/pymc-bart/compare/0.5.8...0.5.9
+[0.5.8]: https://github.com/pymc-devs/pymc-bart/compare/0.5.7...0.5.8
+[0.5.7]: https://github.com/pymc-devs/pymc-bart/compare/0.5.6...0.5.7
+[0.5.6]: https://github.com/pymc-devs/pymc-bart/compare/0.5.5...0.5.6
+[0.5.5]: https://github.com/pymc-devs/pymc-bart/compare/0.5.4...0.5.5
+[0.5.4]: https://github.com/pymc-devs/pymc-bart/compare/0.5.3...0.5.4
+[0.5.3]: https://github.com/pymc-devs/pymc-bart/compare/0.5.2...0.5.3
+[0.5.2]: https://github.com/pymc-devs/pymc-bart/compare/O.5.1...0.5.2
+[O.5.1]: https://github.com/pymc-devs/pymc-bart/compare/0.5.0...O.5.1
+[0.5.0]: https://github.com/pymc-devs/pymc-bart/compare/0.4.0...0.5.0
+[0.4.0]: https://github.com/pymc-devs/pymc-bart/compare/0.3.2...0.4.0
+[0.3.2]: https://github.com/pymc-devs/pymc-bart/compare/0.3.1...0.3.2
+[0.3.1]: https://github.com/pymc-devs/pymc-bart/compare/0.3.0...0.3.1
+[0.3.0]: https://github.com/pymc-devs/pymc-bart/compare/0.2.1...0.3.0
+[0.2.1]: https://github.com/pymc-devs/pymc-bart/compare/0.2.0...0.2.1
+[0.2.0]: https://github.com/pymc-devs/pymc-bart/compare/0.1.0...0.2.0
+[0.1.0]: https://github.com/pymc-devs/pymc-bart/compare/0.0.3...0.1.0
+[0.0.3]: https://github.com/pymc-devs/pymc-bart/tree/0.0.3
+
+
diff --git a/docs/api_reference.rst b/docs/api_reference.rst
index 93afde1..b6fb8a5 100644
--- a/docs/api_reference.rst
+++ b/docs/api_reference.rst
@@ -13,4 +13,4 @@ methods in the current release of PyMC-BART.
=============================
.. automodule:: pymc_bart
- :members: BART, PGBART, plot_pdp, plot_ice, plot_variable_importance, plot_convergence, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
+ :members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
diff --git a/docs/changelog.rst b/docs/changelog.rst
new file mode 100644
index 0000000..f83d445
--- /dev/null
+++ b/docs/changelog.rst
@@ -0,0 +1,5 @@
+Changelog
+*********
+
+.. include:: ../CHANGELOG.md
+ :parser: myst_parser.sphinx_
diff --git a/docs/conf.py b/docs/conf.py
index ba89cb1..8945cef 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -21,7 +21,6 @@
"sphinx_design",
"sphinxcontrib.bibtex",
"sphinx_codeautolink",
- "sphinx_remove_toctrees",
]
# List of patterns, relative to source directory, that match files and
@@ -73,6 +72,7 @@
html_theme = "pymc_sphinx_theme"
html_theme_options = {
"secondary_sidebar_items": ["page-toc", "edit-this-page", "sourcelink", "donate"],
+ "search_bar_text": "Search within PyMC-BART...",
"navbar_start": ["navbar-logo"],
"icon_links": [
{
@@ -80,17 +80,6 @@
"icon": "fa-brands fa-github",
"name": "GitHub",
},
- {
- "url": "https://twitter.com/pymc_devs/",
- "icon": "fa-brands fa-twitter",
- "name": "Twitter",
- },
- {
- "url": "https://www.youtube.com/c/PyMCDevelopers",
- "icon": "fa-brands fa-youtube",
- "name": "YouTube",
- },
- {"url": "https://discourse.pymc.io", "icon": "fa-brands fa-discourse", "name": "Discourse"},
],
}
@@ -144,23 +133,6 @@
nb_execution_mode = "off"
-remove_from_toctrees = [
- "BART/*",
- "case_studies/*",
- "causal_inference/*",
- "diagnostics_and_criticism/*",
- "gaussian_processes/*",
- "generalized_linear_models/*",
- "mixture_models/*",
- "ode_models/*",
- "howto/*",
- "samplers/*",
- "splines/*",
- "survival_analysis/*",
- "time_series/*",
- "variational_inference/*",
-]
-
# bibtex config
bibtex_bibfiles = ["references.bib"]
bibtex_default_style = "unsrt"
diff --git a/docs/index.rst b/docs/index.rst
index 4b1dd0e..78a59fb 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -29,7 +29,7 @@ interpretation of those models and perform variable selection.
Installation
============
-PyMC-BART requires a working Python interpreter (3.8+). We recommend installing Python and key numerical libraries using the `Anaconda distribution `_, which has one-click installers available on all major platforms.
+PyMC-BART requires a working Python interpreter (3.10+). We recommend installing Python and key numerical libraries using the `Anaconda distribution `_, which has one-click installers available on all major platforms.
Assuming a standard Python environment is installed on your machine, PyMC-BART itself can be installed either using pip or conda-forge.
@@ -93,10 +93,12 @@ Contents
:maxdepth: 2
examples
- api_reference
-Indices
-=======
+References
+==========
+
+.. toctree::
+ :maxdepth: 1
-* :ref:`genindex`
-* :ref:`modindex`
+ api_reference
+ changelog
diff --git a/env-dev.yml b/env-dev.yml
new file mode 100644
index 0000000..1e28429
--- /dev/null
+++ b/env-dev.yml
@@ -0,0 +1,23 @@
+name: pymc-bart-dev
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - pymc>=5.16.2,<=5.19.1
+ - arviz>=0.18.0
+ - numba
+ - matplotlib
+ - numpy
+ - pytensor
+ # Development dependencies
+ - pytest>=4.4.0
+ - pytest-cov>=2.6.1
+ - click==8.0.4
+ - pylint==2.17.4
+ - pre-commit
+ - black
+ - isort
+ - flake8
+ - pip
+ - pip:
+ - -e .
diff --git a/env.yml b/env.yml
new file mode 100644
index 0000000..bd814ae
--- /dev/null
+++ b/env.yml
@@ -0,0 +1,14 @@
+name: pymc-bart
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - pymc>=5.16.2,<=5.19.1
+ - arviz>=0.18.0
+ - numba
+ - matplotlib
+ - numpy
+ - pytensor
+ - pip
+ - pip:
+ - pymc-bart
diff --git a/mypy.ini b/mypy.ini
deleted file mode 100644
index 56088d7..0000000
--- a/mypy.ini
+++ /dev/null
@@ -1,15 +0,0 @@
-[mypy]
-files = pymc_bart/*.py
-plugins = numpy.typing.mypy_plugin
-
-[mypy-matplotlib.*]
-ignore_missing_imports = True
-
-[mypy-numba.*]
-ignore_missing_imports = True
-
-[mypy-pymc.*]
-ignore_missing_imports = True
-
-[mypy-scipy.*]
-ignore_missing_imports = True
diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py
index a7b4cb5..36972b3 100644
--- a/pymc_bart/__init__.py
+++ b/pymc_bart/__init__.py
@@ -17,11 +17,14 @@
from pymc_bart.pgbart import PGBART
from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
from pymc_bart.utils import (
+ compute_variable_importance,
+ get_variable_inclusion,
plot_convergence,
- plot_dependence,
plot_ice,
plot_pdp,
+ plot_scatter_submodels,
plot_variable_importance,
+ plot_variable_inclusion,
)
__all__ = [
@@ -30,13 +33,16 @@
"ContinuousSplitRule",
"OneHotSplitRule",
"SubsetSplitRule",
+ "compute_variable_importance",
+ "get_variable_inclusion",
"plot_convergence",
- "plot_dependence",
"plot_ice",
"plot_pdp",
+ "plot_scatter_submodels",
"plot_variable_importance",
+ "plot_variable_inclusion",
]
-__version__ = "0.6.0"
+__version__ = "0.9.1"
pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]
diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py
index 969baf4..5114b6e 100644
--- a/pymc_bart/bart.py
+++ b/pymc_bart/bart.py
@@ -16,7 +16,7 @@
import warnings
from multiprocessing import Manager
-from typing import List, Optional, Tuple
+from typing import Optional
import numpy as np
import numpy.typing as npt
@@ -25,6 +25,8 @@
from pymc.distributions.distribution import Distribution, _support_point
from pymc.logprob.abstract import _logprob
from pytensor.tensor.random.op import RandomVariable
+from pytensor.tensor.sharedvar import TensorSharedVariable
+from pytensor.tensor.variable import TensorVariable
from .split_rules import SplitRule
from .tree import Tree
@@ -37,24 +39,32 @@ class BARTRV(RandomVariable):
"""Base class for BART."""
name: str = "BART"
- ndim_supp = 1
- ndims_params: List[int] = [2, 1, 0, 0, 0, 1]
+ signature = "(m,n),(m),(),(),() -> (m)"
dtype: str = "floatX"
- _print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
- all_trees = List[List[List[Tree]]]
+ _print_name: tuple[str, str] = ("BART", "\\operatorname{BART}")
+ all_trees = list[list[list[Tree]]]
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed
- return dist_params[0].shape[:1]
+ idx = dist_params[0].ndim - 2
+ return [dist_params[0].shape[idx]]
@classmethod
def rng_fn( # pylint: disable=W0237
- cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, split_prior=None, size=None
+ cls, rng=None, X=None, Y=None, m=None, alpha=None, beta=None, size=None
):
+ if not size:
+ size = None
+
if not cls.all_trees:
+ if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)):
+ Y = cls.Y.eval()
+ else:
+ Y = cls.Y
+
if size is not None:
- return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
+ return np.full((size[0], Y.shape[0]), Y.mean())
else:
- return np.full(cls.Y.shape[0], cls.Y.mean())
+ return np.full(Y.shape[0], Y.mean())
else:
if size is not None:
shape = size[0]
@@ -89,16 +99,13 @@ class BART(Distribution):
beta : float
Controls the prior probability over the number of leaves of the trees.
Should be positive.
- split_prior : Optional[List[float]], default None.
+ split_prior : Optional[list[float]], default None.
List of positive numbers, one per column in input data.
Defaults to None, all covariates have the same prior probability to be selected.
- split_rules : Optional[List[SplitRule]], default None
+ split_rules : Optional[list[SplitRule]], default None
List of SplitRule objects, one per column in input data.
Allows using different split rules for different columns. Default is ContinuousSplitRule.
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
- shape: : Optional[Tuple], default None
- Specify the output shape. If shape is different from (len(X)) (the default), train a
- separate tree for each value in other dimensions.
separate_trees : Optional[bool], default False
When training multiple trees (by setting a shape parameter), the default behavior is to
learn a joint tree structure and only have different leaf values for each.
@@ -125,8 +132,8 @@ def __new__(
alpha: float = 0.95,
beta: float = 2.0,
response: str = "constant",
- split_prior: Optional[npt.NDArray[np.float64]] = None,
- split_rules: Optional[List[SplitRule]] = None,
+ split_prior: Optional[npt.NDArray] = None,
+ split_rules: Optional[list[SplitRule]] = None,
separate_trees: Optional[bool] = False,
**kwargs,
):
@@ -169,7 +176,7 @@ def get_moment(rv, size, *rv_inputs):
return cls.get_moment(rv, size, *rv_inputs)
cls.rv_op = bart_op
- params = [X, Y, m, alpha, beta, split_prior]
+ params = [X, Y, m, alpha, beta]
return super().__new__(cls, name, *params, **kwargs)
@classmethod
@@ -196,9 +203,7 @@ def get_moment(cls, rv, size, *rv_inputs):
return mean
-def preprocess_xy(
- X: TensorLike, Y: TensorLike
-) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
+def preprocess_xy(X: TensorLike, Y: TensorLike) -> tuple[npt.NDArray, npt.NDArray]:
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py
index be4a8e8..014313a 100644
--- a/pymc_bart/pgbart.py
+++ b/pymc_bart/pgbart.py
@@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple, Union
+from typing import Optional, Union
import numpy as np
import numpy.typing as npt
+import pymc as pm
+import pytensor.tensor as pt
from numba import njit
+from pymc.initial_point import PointType
from pymc.model import Model, modelcontext
from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements
from pymc.step_methods.arraystep import ArrayStepShared
@@ -43,7 +46,7 @@ class ParticleTree:
def __init__(self, tree: Tree):
self.tree: Tree = tree.copy()
- self.expansion_nodes: List[int] = [0]
+ self.expansion_nodes: list[int] = [0]
self.log_weight: float = 0
def copy(self) -> "ParticleTree":
@@ -114,22 +117,32 @@ class PGBART(ArrayStepShared):
name = "pgbart"
default_blocked = False
generates_stats = True
- stats_dtypes = [{"variable_inclusion": object, "tune": bool}]
+ stats_dtypes_shapes: dict[str, tuple[type, list]] = {
+ "variable_inclusion": (object, []),
+ "tune": (bool, []),
+ }
- def __init__( # noqa: PLR0915
+ def __init__( # noqa: PLR0912, PLR0915
self,
- vars=None, # pylint: disable=redefined-builtin
+ vars: list[pm.Distribution] | None = None,
num_particles: int = 10,
- batch: Tuple[float, float] = (0.1, 0.1),
+ batch: tuple[float, float] = (0.1, 0.1),
model: Optional[Model] = None,
- ):
+ initial_point: PointType | None = None,
+ compile_kwargs: dict | None = None,
+ ) -> None:
model = modelcontext(model)
- initial_values = model.initial_point()
+ if initial_point is None:
+ initial_point = model.initial_point()
if vars is None:
vars = model.value_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = inputvars(vars)
+
+ if vars is None:
+ raise ValueError("Unable to find variables to sample")
+
value_bart = vars[0]
self.bart = model.values_to_rvs[value_bart].owner.op
@@ -138,11 +151,16 @@ def __init__( # noqa: PLR0915
else:
self.X = self.bart.X
+ if isinstance(self.bart.Y, Variable):
+ self.Y = self.bart.Y.eval()
+ else:
+ self.Y = self.bart.Y
+
self.missing_data = np.any(np.isnan(self.X))
self.m = self.bart.m
self.response = self.bart.response
- shape = initial_values[value_bart.name].shape
+ shape = initial_point[value_bart.name].shape
self.shape = 1 if len(shape) == 1 else shape[0]
@@ -166,7 +184,7 @@ def __init__( # noqa: PLR0915
if rule is ContinuousSplitRule:
self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.nanstd(self.X[:, idx]))
- init_mean = self.bart.Y.mean()
+ init_mean = self.Y.mean()
self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
self.available_predictors = list(range(self.num_variates))
@@ -174,18 +192,18 @@ def __init__( # noqa: PLR0915
# if data is binary
self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape))
- y_unique = np.unique(self.bart.Y)
+ y_unique = np.unique(self.Y)
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
self.leaf_sd *= 3 / self.m**0.5
else:
- self.leaf_sd *= self.bart.Y.std() / self.m**0.5
+ self.leaf_sd *= self.Y.std() / self.m**0.5
self.running_sd = [
RunningSd((self.leaves_shape, self.num_observations)) for _ in range(self.trees_shape)
]
self.sum_trees = np.full(
- (self.trees_shape, self.leaves_shape, self.bart.Y.shape[0]), init_mean
+ (self.trees_shape, self.leaves_shape, self.Y.shape[0]), init_mean
).astype(config.floatX)
self.sum_trees_noi = self.sum_trees - init_mean
self.a_tree = Tree.new_tree(
@@ -209,8 +227,8 @@ def __init__( # noqa: PLR0915
self.num_particles = num_particles
self.indices = list(range(1, num_particles))
- shared = make_shared_replacements(initial_values, vars, model)
- self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
+ shared = make_shared_replacements(initial_point, vars, model)
+ self.likelihood_logp = logp(initial_point, [model.datalogp], vars, shared)
self.all_particles = [
[ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape)
]
@@ -222,7 +240,7 @@ def __init__( # noqa: PLR0915
def astep(self, _):
variable_inclusion = np.zeros(self.num_variates, dtype="int")
- upper = min(self.lower + self.batch[~self.tune], self.m)
+ upper = min(self.lower + self.batch[not self.tune], self.m)
tree_ids = range(self.lower, upper)
self.lower = upper if upper < self.m else 0
@@ -302,7 +320,7 @@ def astep(self, _):
stats = {"variable_inclusion": variable_inclusion, "tune": self.tune}
return self.sum_trees, [stats]
- def normalize(self, particles: List[ParticleTree]) -> float:
+ def normalize(self, particles: list[ParticleTree]) -> float:
"""
Use softmax to get normalized_weights.
"""
@@ -313,16 +331,16 @@ def normalize(self, particles: List[ParticleTree]) -> float:
return wei / wei.sum()
def resample(
- self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
- ) -> List[ParticleTree]:
+ self, particles: list[ParticleTree], normalized_weights: npt.NDArray
+ ) -> list[ParticleTree]:
"""
Use systematic resample for all but the first particle
Ensure particles are copied only if needed.
"""
new_indices = self.systematic(normalized_weights) + 1
- seen: List[int] = []
- new_particles: List[ParticleTree] = []
+ seen: list[int] = []
+ new_particles: list[ParticleTree] = []
for idx in new_indices:
if idx in seen:
new_particles.append(particles[idx].copy())
@@ -335,8 +353,8 @@ def resample(
return particles
def get_particle_tree(
- self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
- ) -> Tuple[ParticleTree, Tree]:
+ self, particles: list[ParticleTree], normalized_weights: npt.NDArray
+ ) -> tuple[ParticleTree, Tree]:
"""
Sample a new particle and associated tree
"""
@@ -347,7 +365,7 @@ def get_particle_tree(
return new_particle, new_particle.tree
- def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]:
+ def systematic(self, normalized_weights: npt.NDArray) -> npt.NDArray[np.int_]:
"""
Systematic resampling.
@@ -359,12 +377,12 @@ def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray
single_uniform = (self.uniform.rvs() + np.arange(lnw)) / lnw
return inverse_cdf(single_uniform, normalized_weights)
- def init_particles(self, tree_id: int, odim: int) -> List[ParticleTree]:
+ def init_particles(self, tree_id: int, odim: int) -> list[ParticleTree]:
"""Initialize particles."""
p0: ParticleTree = self.all_particles[odim][tree_id]
# The old tree does not grow so we update the weight only once
self.update_weight(p0, odim)
- particles: List[ParticleTree] = [p0]
+ particles: list[ParticleTree] = [p0]
particles.extend(ParticleTree(self.a_tree) for _ in self.indices)
return particles
@@ -383,7 +401,7 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None:
particle.log_weight = new_likelihood
@staticmethod
- def competence(var, has_grad):
+ def competence(var: pm.Distribution, has_grad: bool) -> Competence:
"""PGBART is only suitable for BART distributions."""
dist = getattr(var.owner, "op", None)
if isinstance(dist, BARTRV):
@@ -394,12 +412,12 @@ def competence(var, has_grad):
class RunningSd:
"""Welford's online algorithm for computing the variance/standard deviation"""
- def __init__(self, shape: tuple) -> None:
+ def __init__(self, shape: tuple[int, ...]) -> None:
self.count = 0 # number of data points
self.mean = np.zeros(shape) # running mean
self.m_2 = np.zeros(shape) # running second moment
- def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
+ def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]:
self.count = self.count + 1
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
return fast_mean(std)
@@ -408,10 +426,10 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray
@njit
def _update(
count: int,
- mean: npt.NDArray[np.float64],
- m_2: npt.NDArray[np.float64],
- new_value: npt.NDArray[np.float64],
-) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
+ mean: npt.NDArray,
+ m_2: npt.NDArray,
+ new_value: npt.NDArray,
+) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]:
delta = new_value - mean
mean += delta / count
delta2 = new_value - mean
@@ -422,7 +440,7 @@ def _update(
class SampleSplittingVariable:
- def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
+ def __init__(self, alpha_vec: npt.NDArray) -> None:
"""
Sample splitting variables proportional to `alpha_vec`.
@@ -431,7 +449,7 @@ def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
"""
self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum())))
- def rvs(self) -> Union[int, Tuple[int, float]]:
+ def rvs(self) -> Union[int, tuple[int, float]]:
rnd: float = np.random.random()
for i, val in self.enu:
if rnd <= val:
@@ -439,7 +457,7 @@ def rvs(self) -> Union[int, Tuple[int, float]]:
return self.enu[-1]
-def compute_prior_probability(alpha: int, beta: int) -> List[float]:
+def compute_prior_probability(alpha: int, beta: int) -> list[float]:
"""
Calculate the probability of the node being a leaf node (1 - p(being split node)).
@@ -452,7 +470,7 @@ def compute_prior_probability(alpha: int, beta: int) -> List[float]:
-------
list with probabilities for leaf nodes
"""
- prior_leaf_prob: List[float] = [0]
+ prior_leaf_prob: list[float] = [0]
depth = 0
while prior_leaf_prob[-1] < 0.9999:
prior_leaf_prob.append(1 - (alpha * ((1 + depth) ** (-beta))))
@@ -535,16 +553,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d
def draw_leaf_value(
- y_mu_pred: npt.NDArray[np.float64],
- x_mu: npt.NDArray[np.float64],
+ y_mu_pred: npt.NDArray,
+ x_mu: npt.NDArray,
m: int,
- norm: npt.NDArray[np.float64],
+ norm: npt.NDArray,
shape: int,
response: str,
-) -> Tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
+) -> tuple[npt.NDArray, Optional[npt.NDArray]]:
"""Draw Gaussian distributed leaf values."""
linear_params = None
- mu_mean = np.empty(shape)
+ mu_mean: npt.NDArray
if y_mu_pred.size == 0:
return np.zeros(shape), linear_params
@@ -559,7 +577,7 @@ def draw_leaf_value(
@njit
-def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
+def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]:
"""Use Numba to speed up the computation of the mean."""
if ari.ndim == 1:
count = ari.shape[0]
@@ -578,11 +596,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float
@njit
def fast_linear_fit(
- x: npt.NDArray[np.float64],
- y: npt.NDArray[np.float64],
+ x: npt.NDArray,
+ y: npt.NDArray,
m: int,
- norm: npt.NDArray[np.float64],
-) -> Tuple[npt.NDArray[np.float64], List[npt.NDArray[np.float64]]]:
+ norm: npt.NDArray,
+) -> tuple[npt.NDArray, list[npt.NDArray]]:
n = len(x)
y = y / m + np.expand_dims(norm, axis=1)
@@ -666,17 +684,17 @@ def update(self):
@njit
def inverse_cdf(
- single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64]
+ single_uniform: npt.NDArray, normalized_weights: npt.NDArray
) -> npt.NDArray[np.int_]:
"""
Inverse CDF algorithm for a finite distribution.
Parameters
----------
- single_uniform: npt.NDArray[np.float64]
+ single_uniform: npt.NDArray
Ordered points in [0,1]
- normalized_weights: npt.NDArray[np.float64])
+ normalized_weights: npt.NDArray)
Normalized weights
Returns
@@ -699,7 +717,7 @@ def inverse_cdf(
@njit
-def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]:
+def jitter_duplicated(array: npt.NDArray, std: float) -> npt.NDArray:
"""
Jitter duplicated values.
"""
@@ -715,12 +733,17 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray
@njit
-def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_:
+def are_whole_number(array: npt.NDArray) -> np.bool_:
"""Check if all values in array are whole numbers"""
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)
-def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
+def logp(
+ point,
+ out_vars: list[pm.Distribution],
+ vars: list[pm.Distribution],
+ shared: list[pt.TensorVariable],
+):
"""Compile PyTensor function of the model and the input and output variables.
Parameters
diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py
index 0e0a35c..61e5050 100644
--- a/pymc_bart/tree.py
+++ b/pymc_bart/tree.py
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from collections.abc import Generator
from functools import lru_cache
-from typing import Dict, Generator, List, Optional, Tuple, Union
+from typing import Optional, Union
import numpy as np
import numpy.typing as npt
@@ -27,21 +28,21 @@ class Node:
Attributes
----------
- value : npt.NDArray[np.float64]
+ value : npt.NDArray
idx_data_points : Optional[npt.NDArray[np.int_]]
idx_split_variable : int
- linear_params: Optional[List[float]] = None
+ linear_params: Optional[list[float]] = None
"""
__slots__ = "value", "nvalue", "idx_split_variable", "idx_data_points", "linear_params"
def __init__(
self,
- value: npt.NDArray[np.float64] = np.array([-1.0]),
+ value: npt.NDArray = np.array([-1.0]),
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_split_variable: int = -1,
- linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
+ linear_params: Optional[list[npt.NDArray]] = None,
) -> None:
self.value = value
self.nvalue = nvalue
@@ -52,11 +53,11 @@ def __init__(
@classmethod
def new_leaf_node(
cls,
- value: npt.NDArray[np.float64],
+ value: npt.NDArray,
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_split_variable: int = -1,
- linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
+ linear_params: Optional[list[npt.NDArray]] = None,
) -> "Node":
return cls(
value=value,
@@ -94,19 +95,19 @@ class Tree:
Attributes
----------
- tree_structure : Dict[int, Node]
+ tree_structure : dict[int, Node]
A dictionary that represents the nodes stored in breadth-first order, based in the array
method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays).
The dictionary's keys are integers that represent the nodes position.
The dictionary's values are objects of type Node that represent the split and leaf nodes
of the tree itself.
- output: Optional[npt.NDArray[np.float64]]
+ output: Optional[npt.NDArray]
Array of shape number of observations, shape
- split_rules : List[SplitRule]
+ split_rules : list[SplitRule]
List of SplitRule objects, one per column in input data.
Allows using different split rules for different columns. Default is ContinuousSplitRule.
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
- idx_leaf_nodes : Optional[List[int]], by default None.
+ idx_leaf_nodes : Optional[list[int]], by default None.
Array with the index of the leaf nodes of the tree.
Parameters
@@ -120,10 +121,10 @@ class Tree:
def __init__(
self,
- tree_structure: Dict[int, Node],
- output: npt.NDArray[np.float64],
- split_rules: List[SplitRule],
- idx_leaf_nodes: Optional[List[int]] = None,
+ tree_structure: dict[int, Node],
+ output: npt.NDArray,
+ split_rules: list[SplitRule],
+ idx_leaf_nodes: Optional[list[int]] = None,
) -> None:
self.tree_structure = tree_structure
self.idx_leaf_nodes = idx_leaf_nodes
@@ -133,11 +134,11 @@ def __init__(
@classmethod
def new_tree(
cls,
- leaf_node_value: npt.NDArray[np.float64],
+ leaf_node_value: npt.NDArray,
idx_data_points: Optional[npt.NDArray[np.int_]],
num_observations: int,
shape: int,
- split_rules: List[SplitRule],
+ split_rules: list[SplitRule],
) -> "Tree":
return cls(
tree_structure={
@@ -159,7 +160,7 @@ def __setitem__(self, index, node) -> None:
self.set_node(index, node)
def copy(self) -> "Tree":
- tree: Dict[int, Node] = {
+ tree: dict[int, Node] = {
k: Node(
value=v.value,
nvalue=v.nvalue,
@@ -189,7 +190,7 @@ def grow_leaf_node(
self,
current_node: Node,
selected_predictor: int,
- split_value: npt.NDArray[np.float64],
+ split_value: npt.NDArray,
index_leaf_node: int,
) -> None:
current_node.value = split_value
@@ -199,7 +200,7 @@ def grow_leaf_node(
self.idx_leaf_nodes.remove(index_leaf_node)
def trim(self) -> "Tree":
- tree: Dict[int, Node] = {
+ tree: dict[int, Node] = {
k: Node(
value=v.value,
nvalue=v.nvalue,
@@ -221,7 +222,7 @@ def get_split_variables(self) -> Generator[int, None, None]:
if node.is_split_node():
yield node.idx_split_variable
- def _predict(self) -> npt.NDArray[np.float64]:
+ def _predict(self) -> npt.NDArray:
output = self.output
if self.idx_leaf_nodes is not None:
@@ -232,23 +233,23 @@ def _predict(self) -> npt.NDArray[np.float64]:
def predict(
self,
- x: npt.NDArray[np.float64],
- excluded: Optional[List[int]] = None,
+ x: npt.NDArray,
+ excluded: Optional[list[int]] = None,
shape: int = 1,
- ) -> npt.NDArray[np.float64]:
+ ) -> npt.NDArray:
"""
Predict output of tree for an (un)observed point x.
Parameters
----------
- x : npt.NDArray[np.float64]
+ x : npt.NDArray
Unobserved point
- excluded: Optional[List[int]]
+ excluded: Optional[list[int]]
Indexes of the variables to exclude when computing predictions
Returns
-------
- npt.NDArray[np.float64]
+ npt.NDArray
Value of the leaf value where the unobserved point lies.
"""
if excluded is None:
@@ -258,34 +259,36 @@ def predict(
def _traverse_tree(
self,
- X: npt.NDArray[np.float64],
- excluded: Optional[List[int]] = None,
- shape: Union[int, Tuple[int, ...]] = 1,
- ) -> npt.NDArray[np.float64]:
+ X: npt.NDArray,
+ excluded: Optional[list[int]] = None,
+ shape: Union[int, tuple[int, ...]] = 1,
+ ) -> npt.NDArray:
"""
Traverse the tree starting from the root node given an (un)observed point.
Parameters
----------
- X : npt.NDArray[np.float64]
+ X : npt.NDArray
(Un)observed point(s)
node_index : int
Index of the node to start the traversal from
split_variable : int
Index of the variable used to split the node
- excluded: Optional[List[int]]
+ excluded: Optional[list[int]]
Indexes of the variables to exclude when computing predictions
Returns
-------
- npt.NDArray[np.float64]
+ npt.NDArray
Leaf node value or mean of leaf node values
"""
x_shape = (1,) if len(X.shape) == 1 else X.shape[:-1]
nd_dims = (...,) + (None,) * len(x_shape)
- stack = [(0, np.ones(x_shape), 0)] # (node_index, weight, idx_split_variable) initial state
+ stack: list[tuple[int, npt.NDArray, int]] = [
+ (0, np.ones(x_shape), 0)
+ ] # (node_index, weight, idx_split_variable) initial state
p_d = (
np.zeros(shape + x_shape) if isinstance(shape, tuple) else np.zeros((shape,) + x_shape)
)
@@ -308,9 +311,19 @@ def _traverse_tree(
)
if excluded is not None and idx_split_variable in excluded:
prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue
- stack.append((left_node_index, weights * prop_nvalue_left, idx_split_variable))
stack.append(
- (right_node_index, weights * (1 - prop_nvalue_left), idx_split_variable)
+ (
+ left_node_index,
+ weights * prop_nvalue_left,
+ idx_split_variable,
+ )
+ )
+ stack.append(
+ (
+ right_node_index,
+ weights * (1 - prop_nvalue_left),
+ idx_split_variable,
+ )
)
else:
to_left = (
@@ -327,14 +340,14 @@ def _traverse_tree(
return p_d
def _traverse_leaf_values(
- self, leaf_values: List[npt.NDArray[np.float64]], leaf_n_values: List[int], node_index: int
+ self, leaf_values: list[npt.NDArray], leaf_n_values: list[int], node_index: int
) -> None:
"""
Traverse the tree appending leaf values starting from a particular node.
Parameters
----------
- leaf_values : List[npt.NDArray[np.float64]]
+ leaf_values : list[npt.NDArray]
node_index : int
"""
node = self.get_node(node_index)
diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py
index 9eee3b4..3ba6e58 100644
--- a/pymc_bart/utils.py
+++ b/pymc_bart/utils.py
@@ -1,31 +1,33 @@
+# pylint: disable=too-many-branches
"""Utility function for variable selection and bart interpretability."""
import warnings
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Optional, Union
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pytensor.tensor as pt
+from numba import jit
from pytensor.tensor.variable import Variable
from scipy.interpolate import griddata
from scipy.signal import savgol_filter
-from scipy.stats import norm, pearsonr
+from scipy.stats import norm
from .tree import Tree
-TensorLike = Union[npt.NDArray[np.float64], pt.TensorVariable]
+TensorLike = Union[npt.NDArray, pt.TensorVariable]
def _sample_posterior(
- all_trees: List[List[Tree]],
+ all_trees: list[list[Tree]],
X: TensorLike,
rng: np.random.Generator,
- size: Optional[Union[int, Tuple[int, ...]]] = None,
- excluded: Optional[List[int]] = None,
+ size: Optional[Union[int, tuple[int, ...]]] = None,
+ excluded: Optional[list[int]] = None,
shape: int = 1,
-) -> npt.NDArray[np.float64]:
+) -> npt.NDArray:
"""
Generate samples from the BART-posterior.
@@ -48,7 +50,7 @@ def _sample_posterior(
X = X.eval()
if size is None:
- size_iter: Union[List, Tuple] = (1,)
+ size_iter: Union[list, tuple] = (1,)
elif isinstance(size, int):
size_iter = [size]
else:
@@ -77,9 +79,9 @@ def plot_convergence(
idata: az.InferenceData,
var_name: Optional[str] = None,
kind: str = "ecdf",
- figsize: Optional[Tuple[float, float]] = None,
+ figsize: Optional[tuple[float, float]] = None,
ax=None,
-) -> List[plt.Axes]:
+) -> list[plt.Axes]:
"""
Plot convergence diagnostics.
@@ -91,14 +93,14 @@ def plot_convergence(
Name of the BART variable to plot. Defaults to None.
kind : str
Type of plot to display. Options are "ecdf" (default) and "kde".
- figsize : Optional[Tuple[float, float]], by default None.
+ figsize : Optional[tuple[float, float]], by default None.
Figure size. Defaults to None.
ax : matplotlib axes
Axes on which to plot. Defaults to None.
Returns
-------
- List[ax] : matplotlib axes
+ list[ax] : matplotlib axes
"""
ess_threshold = idata["posterior"]["chain"].size * 100
ess = np.atleast_2d(az.ess(idata, method="bulk", var_names=var_name)[var_name].values)
@@ -135,28 +137,12 @@ def plot_convergence(
return ax
-def plot_dependence(*args, kind="pdp", **kwargs): # pylint: disable=unused-argument
- """
- Partial dependence or individual conditional expectation plot.
- """
- if kind == "pdp":
- warnings.warn(
- "This function has been deprecated. Use plot_pdp instead.",
- FutureWarning,
- )
- elif kind == "ice":
- warnings.warn(
- "This function has been deprecated. Use plot_ice instead.",
- FutureWarning,
- )
-
-
def plot_ice(
bartrv: Variable,
- X: npt.NDArray[np.float64],
- Y: Optional[npt.NDArray[np.float64]] = None,
- var_idx: Optional[List[int]] = None,
- var_discrete: Optional[List[int]] = None,
+ X: npt.NDArray,
+ Y: Optional[npt.NDArray] = None,
+ var_idx: Optional[list[int]] = None,
+ var_discrete: Optional[list[int]] = None,
func: Optional[Callable] = None,
centered: Optional[bool] = True,
samples: int = 100,
@@ -168,10 +154,10 @@ def plot_ice(
color="C0",
color_mean: str = "C0",
alpha: float = 0.1,
- figsize: Optional[Tuple[float, float]] = None,
- smooth_kwargs: Optional[Dict[str, Any]] = None,
+ figsize: Optional[tuple[float, float]] = None,
+ smooth_kwargs: Optional[dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
-) -> List[plt.Axes]:
+) -> list[plt.Axes]:
"""
Individual conditional expectation plot.
@@ -179,13 +165,13 @@ def plot_ice(
----------
bartrv : BART Random Variable
BART variable once the model that include it has been fitted.
- X : npt.NDArray[np.float64]
+ X : npt.NDArray
The covariate matrix.
- Y : Optional[npt.NDArray[np.float64]], by default None.
+ Y : Optional[npt.NDArray], by default None.
The response vector.
- var_idx : Optional[List[int]], by default None.
+ var_idx : Optional[list[int]], by default None.
List of the indices of the covariate for which to compute the pdp or ice.
- var_discrete : Optional[List[int]], by default None.
+ var_discrete : Optional[list[int]], by default None.
List of the indices of the covariate treated as discrete.
func : Optional[Callable], by default None.
Arbitrary function to apply to the predictions. Defaults to the identity function.
@@ -247,7 +233,7 @@ def identity(x):
_,
) = _prepare_plot_data(X, Y, "linear", None, var_idx, var_discrete)
- fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax)
+ fig, axes, shape = _create_figure_axes(bartrv, var_idx, grid, sharey, figsize, ax)
instances_ary = rng.choice(range(X.shape[0]), replace=False, size=instances)
idx_s = list(range(X.shape[0]))
@@ -268,14 +254,13 @@ def identity(x):
)
new_x = fake_X[:, var]
- p_d = np.array(y_pred)
- print(p_d.shape)
+ p_d = func(np.array(y_pred))
for s_i in range(shape):
if centered:
- p_di = func(p_d[:, :, s_i]) - func(p_d[:, :, s_i][:, 0][:, None])
+ p_di = p_d[:, :, s_i] - p_d[:, :, s_i][:, 0][:, None]
else:
- p_di = func(p_d[:, :, s_i])
+ p_di = p_d[:, :, s_i]
if var in var_discrete:
axes[count].plot(new_x, p_di.mean(0), "o", color=color_mean)
axes[count].plot(new_x, p_di.T, ".", color=color, alpha=alpha)
@@ -298,14 +283,15 @@ def identity(x):
def plot_pdp(
bartrv: Variable,
- X: npt.NDArray[np.float64],
- Y: Optional[npt.NDArray[np.float64]] = None,
+ X: npt.NDArray,
+ Y: Optional[npt.NDArray] = None,
xs_interval: str = "quantiles",
- xs_values: Optional[Union[int, List[float]]] = None,
- var_idx: Optional[List[int]] = None,
- var_discrete: Optional[List[int]] = None,
+ xs_values: Optional[Union[int, list[float]]] = None,
+ var_idx: Optional[list[int]] = None,
+ var_discrete: Optional[list[int]] = None,
func: Optional[Callable] = None,
samples: int = 200,
+ ref_line: bool = True,
random_seed: Optional[int] = None,
sharey: bool = True,
smooth: bool = True,
@@ -313,10 +299,10 @@ def plot_pdp(
color="C0",
color_mean: str = "C0",
alpha: float = 0.1,
- figsize: Optional[Tuple[float, float]] = None,
- smooth_kwargs: Optional[Dict[str, Any]] = None,
+ figsize: Optional[tuple[float, float]] = None,
+ smooth_kwargs: Optional[dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
-) -> List[plt.Axes]:
+) -> list[plt.Axes]:
"""
Partial dependence plot.
@@ -324,28 +310,30 @@ def plot_pdp(
----------
bartrv : BART Random Variable
BART variable once the model that include it has been fitted.
- X : npt.NDArray[np.float64]
+ X : npt.NDArray
The covariate matrix.
- Y : Optional[npt.NDArray[np.float64]], by default None.
+ Y : Optional[npt.NDArray], by default None.
The response vector.
xs_interval : str
Method used to compute the values X used to evaluate the predicted function. "linear",
evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified
quantiles of X. "insample", the evaluation is done at the values of X.
For discrete variables these options are ommited.
- xs_values : Optional[Union[int, List[float]]], by default None.
+ xs_values : Optional[Union[int, list[float]]], by default None.
Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of
points in the evenly spaced grid. If ``xs_interval="quantiles"`` quantile or sequence of
quantiles to compute, which must be between 0 and 1 inclusive.
Ignored when ``xs_interval="insample"``.
- var_idx : Optional[List[int]], by default None.
+ var_idx : Optional[list[int]], by default None.
List of the indices of the covariate for which to compute the pdp or ice.
- var_discrete : Optional[List[int]], by default None.
+ var_discrete : Optional[list[int]], by default None.
List of the indices of the covariate treated as discrete.
func : Optional[Callable], by default None.
Arbitrary function to apply to the predictions. Defaults to the identity function.
samples : int
Number of posterior samples used in the predictions. Defaults to 200
+ ref_line : bool
+ If True a reference line is plotted at the mean of the partial dependence. Defaults to True.
random_seed : Optional[int], by default None.
Seed used to sample from the posterior. Defaults to None.
sharey : bool
@@ -397,21 +385,26 @@ def identity(x):
xs_values,
) = _prepare_plot_data(X, Y, xs_interval, xs_values, var_idx, var_discrete)
- fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax)
+ fig, axes, shape = _create_figure_axes(bartrv, var_idx, grid, sharey, figsize, ax)
count = 0
fake_X = _create_pdp_data(X, xs_interval, xs_values)
+ null_pd = []
for var in range(len(var_idx)):
excluded = indices[:]
excluded.remove(var)
- p_d = _sample_posterior(
- all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
+ p_d = func(
+ _sample_posterior(
+ all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
+ )
)
+
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="hdi currently interprets 2d data")
new_x = fake_X[:, var]
for s_i in range(shape):
- p_di = func(p_d[:, :, s_i])
+ p_di = p_d[:, :, s_i]
+ null_pd.append(p_di.mean())
if var in var_discrete:
_, idx_uni = np.unique(new_x, return_index=True)
y_means = p_di.mean(0)[idx_uni]
@@ -441,19 +434,24 @@ def identity(x):
count += 1
+ if ref_line:
+ ref_val = sum(null_pd) / len(null_pd)
+ for ax_ in np.ravel(axes):
+ ax_.axhline(ref_val, color="0.7", linestyle="--")
+
fig.text(-0.05, 0.5, y_label, va="center", rotation="vertical", fontsize=15)
return axes
-def _get_axes(
+def _create_figure_axes(
bartrv: Variable,
- var_idx: List[int],
+ var_idx: list[int],
grid: str = "long",
sharey: bool = True,
- figsize: Optional[Tuple[float, float]] = None,
+ figsize: Optional[tuple[float, float]] = None,
ax: Optional[plt.Axes] = None,
-) -> Tuple[plt.Figure, List[plt.Axes], int]:
+) -> tuple[plt.Figure, list[plt.Axes], int]:
"""
Create and return the figure and axes objects for plotting the variables.
@@ -463,9 +461,9 @@ def _get_axes(
----------
bartrv : BART Random Variable
BART variable once the model that include it has been fitted.
- var_idx : Optional[List[int]], by default None.
+ var_idx : Optional[list[int]], by default None.
List of the indices of the covariate for which to compute the pdp or ice.
- var_discrete : Optional[List[int]], by default None.
+ var_discrete : Optional[list[int]], by default None.
grid : str or tuple
How to arrange the subplots. Defaults to "long", one subplot below the other.
Other options are "wide", one subplot next to each other or a tuple indicating the number of
@@ -480,7 +478,7 @@ def _get_axes(
Returns
-------
- Tuple[plt.Figure, List[plt.Axes], int]
+ tuple[plt.Figure, list[plt.Axes], int]
A tuple containing the figure object, list of axes objects, and the shape value.
"""
if bartrv.ndim == 1: # type: ignore
@@ -491,29 +489,8 @@ def _get_axes(
n_plots = len(var_idx) * shape
if ax is None:
- if grid == "long":
- fig, axes = plt.subplots(n_plots, sharey=sharey, figsize=figsize)
- if n_plots == 1:
- axes = [axes]
- elif grid == "wide":
- fig, axes = plt.subplots(1, n_plots, sharey=sharey, figsize=figsize)
- if n_plots == 1:
- axes = [axes]
- elif isinstance(grid, tuple):
- grid_size = grid[0] * grid[1]
- if n_plots > grid_size:
- warnings.warn(
- """The grid is smaller than the number of available variables to plot.
- Automatically adjusting the grid size."""
- )
- grid = (n_plots // grid[1] + (n_plots % grid[1] > 0), grid[1])
+ fig, axes = _get_axes(grid, n_plots, False, sharey, figsize)
- fig, axes = plt.subplots(*grid, sharey=sharey, figsize=figsize)
- axes = np.ravel(axes)
-
- for i in range(n_plots, len(axes)):
- fig.delaxes(axes[i])
- axes = axes[:n_plots]
elif isinstance(ax, np.ndarray):
axes = ax
fig = ax[0].get_figure()
@@ -524,22 +501,49 @@ def _get_axes(
return fig, axes, shape
+def _get_axes(grid, n_plots, sharex, sharey, figsize):
+ if grid == "long":
+ fig, axes = plt.subplots(n_plots, sharex=sharex, sharey=sharey, figsize=figsize)
+ if n_plots == 1:
+ axes = [axes]
+ elif grid == "wide":
+ fig, axes = plt.subplots(1, n_plots, sharex=sharex, sharey=sharey, figsize=figsize)
+ if n_plots == 1:
+ axes = [axes]
+ elif isinstance(grid, tuple):
+ grid_size = grid[0] * grid[1]
+ if n_plots > grid_size:
+ warnings.warn(
+ """The grid is smaller than the number of available variables to plot.
+ Automatically adjusting the grid size."""
+ )
+ grid = (n_plots // grid[1] + (n_plots % grid[1] > 0), grid[1])
+
+ fig, axes = plt.subplots(*grid, sharey=sharey, figsize=figsize)
+ axes = np.ravel(axes)
+
+ for i in range(n_plots, len(axes)):
+ fig.delaxes(axes[i])
+ axes = axes[:n_plots]
+ return fig, axes
+
+
def _prepare_plot_data(
- X: npt.NDArray[np.float64],
- Y: Optional[npt.NDArray[np.float64]] = None,
+ X: npt.NDArray,
+ Y: Optional[npt.NDArray] = None,
xs_interval: str = "quantiles",
- xs_values: Optional[Union[int, List[float]]] = None,
- var_idx: Optional[List[int]] = None,
- var_discrete: Optional[List[int]] = None,
-) -> Tuple[
- npt.NDArray[np.float64],
- List[str],
+ xs_values: Optional[Union[int, list[float]]] = None,
+ var_idx: Optional[list[int]] = None,
+ var_discrete: Optional[list[int]] = None,
+) -> tuple[
+ npt.NDArray,
+ list[str],
str,
- List[int],
- List[int],
- List[int],
+ list[int],
+ list[int],
+ list[int],
str,
- Union[int, None, List[float]],
+ Union[int, None, list[float]],
]:
"""
Prepare data for plotting.
@@ -618,10 +622,10 @@ def _prepare_plot_data(
def _create_pdp_data(
- X: npt.NDArray[np.float64],
+ X: npt.NDArray,
xs_interval: str,
- xs_values: Optional[Union[int, List[float]]] = None,
-) -> npt.NDArray[np.float64]:
+ xs_values: Optional[Union[int, list[float]]] = None,
+) -> npt.NDArray:
"""
Create data for partial dependence plot.
@@ -636,7 +640,7 @@ def _create_pdp_data(
Returns
-------
- npt.NDArray[np.float64]
+ npt.NDArray
A 2D array for the fake_X data.
"""
if xs_interval == "insample":
@@ -653,11 +657,11 @@ def _create_pdp_data(
def _smooth_mean(
- new_x: npt.NDArray[np.float64],
- p_di: npt.NDArray[np.float64],
+ new_x: npt.NDArray,
+ p_di: npt.NDArray,
kind: str = "pdp",
- smooth_kwargs: Optional[Dict[str, Any]] = None,
-) -> Tuple[np.ndarray, np.ndarray]:
+ smooth_kwargs: Optional[dict[str, Any]] = None,
+) -> tuple[np.ndarray, np.ndarray]:
"""
Smooth the mean data for plotting.
@@ -669,12 +673,12 @@ def _smooth_mean(
The distribution of partial dependence from which to comptue the smoothed mean.
kind : str, optional
The type of plot. Possible values are "pdp" or "ice".
- smooth_kwargs : Optional[Dict[str, Any]], optional
+ smooth_kwargs : Optional[dict[str, Any]], optional
Additional keyword arguments for the smoothing function. Defaults to None.
Returns
-------
- Tuple[np.ndarray, np.ndarray]
+ tuple[np.ndarray, np.ndarray]
A tuple containing a grid for the x-axis data and the corresponding smoothed y-axis data.
"""
@@ -692,55 +696,154 @@ def _smooth_mean(
return x_data, y_data
-def plot_variable_importance( # noqa: PLR0915
+def get_variable_inclusion(idata, X, labels=None, to_kulprit=False):
+ """
+ Get the normalized variable inclusion from BART model.
+
+ Parameters
+ ----------
+ idata : InferenceData
+ InferenceData containing a collection of BART_trees in sample_stats group
+ X : npt.NDArray
+ The covariate matrix.
+ labels : Optional[list[str]]
+ List of the names of the covariates. If X is a DataFrame the names of the covariables will
+ be taken from it and this argument will be ignored.
+ to_kulprit : bool
+ If True, the function will return a list of list with the variables names.
+ This list can be passed as a path to Kulprit's project method. Defaults to False.
+ Returns
+ -------
+ VI_norm : npt.NDArray
+ Normalized variable inclusion.
+ labels : list[str]
+ List of the names of the covariates.
+ """
+ VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
+ VI_norm = VIs / VIs.sum()
+ idxs = np.argsort(VI_norm)
+
+ indices = idxs[::-1]
+ n_vars = len(indices)
+
+ if hasattr(X, "columns") and hasattr(X, "to_numpy"):
+ labels = X.columns
+
+ if labels is None:
+ labels = np.arange(n_vars).astype(str)
+
+ label_list = labels.to_list()
+
+ if to_kulprit:
+ return [label_list[:idx] for idx in range(n_vars)]
+ else:
+ return VI_norm[indices], label_list
+
+
+def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None):
+ """
+ Plot normalized variable inclusion from BART model.
+
+ Parameters
+ ----------
+ idata : InferenceData
+ InferenceData containing a collection of BART_trees in sample_stats group
+ X : npt.NDArray
+ The covariate matrix.
+ labels : Optional[list[str]]
+ List of the names of the covariates. If X is a DataFrame the names of the covariables will
+ be taken from it and this argument will be ignored.
+ figsize : tuple
+ Figure size. If None it will be defined automatically.
+ plot_kwargs : dict
+ Additional keyword arguments for the plot. Defaults to None.
+ Valid keys are:
+ - color: matplotlib valid color for VI
+ - marker: matplotlib valid marker for VI
+ - ls: matplotlib valid linestyle for the VI line
+ - rotation: float, rotation of the x-axis labels
+ ax : axes
+ Matplotlib axes.
+
+ Returns
+ -------
+ axes: matplotlib axes
+ """
+ if plot_kwargs is None:
+ plot_kwargs = {}
+
+ VI_norm, labels = get_variable_inclusion(idata, X, labels)
+ n_vars = len(labels)
+
+ new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
+
+ ticks = np.arange(n_vars, dtype=int)
+
+ if figsize is None:
+ figsize = (8, 3)
+
+ if ax is None:
+ _, ax = plt.subplots(1, 1, figsize=figsize)
+
+ ax.axhline(1 / n_vars, color="0.5", linestyle="--")
+ ax.plot(
+ VI_norm,
+ color=plot_kwargs.get("color", "k"),
+ marker=plot_kwargs.get("marker", "o"),
+ ls=plot_kwargs.get("ls", "-"),
+ )
+
+ ax.set_xticks(ticks, new_labels, rotation=plot_kwargs.get("rotation", 0))
+ ax.set_ylim(0, 1)
+
+ return ax
+
+
+def compute_variable_importance( # noqa: PLR0915 PLR0912
idata: az.InferenceData,
bartrv: Variable,
- X: npt.NDArray[np.float64],
- labels: Optional[List[str]] = None,
+ X: npt.NDArray,
method: str = "VI",
- figsize: Optional[Tuple[float, float]] = None,
- xlabel_angle: float = 0,
- samples: int = 100,
+ fixed: int = 0,
+ samples: int = 50,
random_seed: Optional[int] = None,
- ax: Optional[plt.Axes] = None,
-) -> Tuple[List[int], Union[List[plt.Axes], Any]]:
+) -> dict[str, object]:
"""
Estimates variable importance from the BART-posterior.
Parameters
----------
- idata: InferenceData
+ idata : InferenceData
InferenceData containing a collection of BART_trees in sample_stats group
bartrv : BART Random Variable
BART variable once the model that include it has been fitted.
- X : npt.NDArray[np.float64]
+ X : npt.NDArray
The covariate matrix.
- labels : Optional[List[str]]
- List of the names of the covariates. If X is a DataFrame the names of the covariables will
- be taken from it and this argument will be ignored.
method : str
- Method used to rank variables. Available options are "VI" (default) and "backward".
+ Method used to rank variables. Available options are "VI" (default), "backward"
+ and "backward_VI".
The R squared will be computed following this ranking.
"VI" counts how many times each variable is included in the posterior distribution
of trees. "backward" uses a backward search based on the R squared.
- VI requieres less computation time.
- figsize : tuple
- Figure size. If None it will be defined automatically.
- xlabel_angle : float
- rotation angle of the x-axis labels. Defaults to 0. Use values like 45 for
- long labels and/or many variables.
+ "backward_VI" combines both methods with the backward search excluding
+ the ``fixed`` number of variables with the lowest variable inclusion.
+ "VI" is the fastest method, while "backward" is the slowest.
+ fixed : Optional[int]
+ Number of variables to fix in the backward search. Defaults to None.
+ Must be greater than 0 and less than the number of variables.
+ Ignored if method is "VI" or "backward".
samples : int
- Number of predictions used to compute correlation for subsets of variables. Defaults to 100
+ Number of predictions used to compute correlation for subsets of variables. Defaults to 50
random_seed : Optional[int]
random_seed used to sample from the posterior. Defaults to None.
- ax : axes
- Matplotlib axes.
Returns
-------
- idxs: indexes of the covariates from higher to lower relative importance
- axes: matplotlib axes
+ vi_results: dictionary
"""
+ if method not in ["VI", "backward", "backward_VI"]:
+ raise ValueError("method must be 'VI', 'backward' or 'backward_VI'")
+
rng = np.random.default_rng(random_seed)
all_trees = bartrv.owner.op.all_trees
@@ -750,40 +853,44 @@ def plot_variable_importance( # noqa: PLR0915
else:
shape = bartrv.eval().shape[0]
+ n_vars = X.shape[1]
+
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns
X = X.to_numpy()
-
- n_vars = X.shape[1]
-
- if figsize is None:
- figsize = (8, 3)
-
- if ax is None:
- _, ax = plt.subplots(1, 1, figsize=figsize)
-
- if labels is None:
- labels_ary = np.arange(n_vars).astype(str)
else:
- labels_ary = np.array(labels)
-
- ticks = np.arange(n_vars, dtype=int)
+ labels = np.arange(n_vars).astype(str)
+
+ r2_mean: npt.NDArray = np.zeros(n_vars)
+ r2_hdi: npt.NDArray = np.zeros((n_vars, 2))
+ preds: npt.NDArray = np.zeros((n_vars, samples, *bartrv.eval().T.shape))
+
+ if method == "backward_VI":
+ if fixed >= n_vars:
+ raise ValueError("fixed must be less than the number of variables")
+ elif fixed < 1:
+ raise ValueError("fixed must be greater than 0")
+ init = fixed + 1
+ else:
+ fixed = 0
+ init = 0
predicted_all = _sample_posterior(
all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape
)
- if method == "VI":
+ if method in ["VI", "backward_VI"]:
idxs = np.argsort(
idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
)
- subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))]
+ subsets: list[list[int]] = [list(idxs[:-i]) for i in range(1, len(idxs))]
subsets.append(None) # type: ignore
- indices: List[int] = list(idxs[::-1])
+ if method == "backward_VI":
+ subsets = subsets[-init:]
+
+ indices: list[int] = list(idxs[::-1])
- r2_mean = np.zeros(n_vars)
- r2_hdi = np.zeros((n_vars, 2))
for idx, subset in enumerate(subsets):
predicted_subset = _sample_posterior(
all_trees=all_trees,
@@ -794,26 +901,28 @@ def plot_variable_importance( # noqa: PLR0915
shape=shape,
)
r_2 = np.array(
- [
- pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] ** 2
- for j in range(samples)
- ]
+ [pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)]
)
r2_mean[idx] = np.mean(r_2)
r2_hdi[idx] = az.hdi(r_2)
-
- elif method == "backward":
- r2_mean = np.zeros(n_vars)
- r2_hdi = np.zeros((n_vars, 2))
-
- variables = set(range(n_vars))
- least_important_vars: List[int] = []
- indices = []
+ preds[idx] = predicted_subset.squeeze()
+
+ if method in ["backward", "backward_VI"]:
+ if method == "backward_VI":
+ least_important_vars: list[int] = indices[-fixed:]
+ r2_mean_vi = r2_mean[:init]
+ r2_hdi_vi = r2_hdi[:init]
+ preds_vi = preds[:init]
+ r2_mean = np.zeros(n_vars - fixed - 1)
+ r2_hdi = np.zeros((n_vars - fixed - 1, 2))
+ preds = np.zeros((n_vars - fixed - 1, samples, bartrv.eval().shape[0]))
+ else:
+ least_important_vars = []
# Iterate over each variable to determine its contribution
# least_important_vars tracks the variable with the lowest contribution
- # at the current stage. One new varible is added at each iteration.
- for i_var in range(n_vars):
+ # at the current stage. One new variable is added at each iteration.
+ for i_var in range(init, n_vars):
# Generate all possible subsets by adding one variable at a time to
# least_important_vars
subsets = generate_sequences(n_vars, i_var, least_important_vars)
@@ -833,10 +942,7 @@ def plot_variable_importance( # noqa: PLR0915
# Calculate Pearson correlation for each sample and find the mean
r_2 = np.zeros(samples)
for j in range(samples):
- r_2[j] = (
- (pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0])
- ** 2
- )
+ r_2[j] = pearsonr2(predicted_all[j], predicted_subset[j])
mean_r_2 = np.mean(r_2, dtype=float)
# Identify the least important combination of variables
# based on the maximum mean squared Pearson correlation
@@ -844,43 +950,266 @@ def plot_variable_importance( # noqa: PLR0915
max_r_2 = mean_r_2
least_important_subset = subset
r_2_without_least_important_vars = r_2
+ least_important_samples = predicted_subset
# Save values for plotting later
- r2_mean[i_var] = max_r_2
- r2_hdi[i_var] = az.hdi(r_2_without_least_important_vars)
+ r2_mean[i_var - init] = max_r_2
+ r2_hdi[i_var - init] = az.hdi(r_2_without_least_important_vars)
+ preds[i_var - init] = least_important_samples.squeeze()
# extend current list of least important variable
- least_important_vars += least_important_subset
+ for var_i in least_important_subset:
+ if var_i not in least_important_vars:
+ least_important_vars.append(var_i)
+
+ # Add the remaining variables to the list of least important variables
+ for var_i in range(n_vars):
+ if var_i not in least_important_vars:
+ least_important_vars.append(var_i)
+
+ if method == "backward_VI":
+ r2_mean = np.concatenate((r2_mean[::-1], r2_mean_vi))
+ r2_hdi = np.concatenate((r2_hdi[::-1], r2_hdi_vi))
+ preds = np.concatenate((preds[::-1], preds_vi))
+ else:
+ r2_mean = r2_mean[::-1]
+ r2_hdi = r2_hdi[::-1]
+ preds = preds[::-1]
+
+ indices = least_important_vars[::-1]
+
+ labels = np.array(
+ ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
+ )
- # add index of removed variable
- indices += list(set(least_important_subset) - set(indices))
+ vi_results = {
+ "indices": np.asarray(indices),
+ "labels": labels,
+ "r2_mean": r2_mean,
+ "r2_hdi": r2_hdi,
+ "preds": preds,
+ "preds_all": predicted_all.squeeze(),
+ }
+ return vi_results
+
+
+def plot_variable_importance(
+ vi_results: dict,
+ submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None,
+ labels: Optional[list[str]] = None,
+ figsize: Optional[tuple[float, float]] = None,
+ plot_kwargs: Optional[dict[str, Any]] = None,
+ ax: Optional[plt.Axes] = None,
+):
+ """
+ Estimates variable importance from the BART-posterior.
+
+ Parameters
+ ----------
+ vi_results: Dictionary
+ Dictionary computed with `compute_variable_importance`
+ submodels : Optional[Union[list[int], np.ndarray]]
+ List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
+ The indices correspond to order computed by `compute_variable_importance`.
+ For example `submodels=[0,1]` will plot the two most important variables.
+ `submodels=[1,0]` is equivalent as values are sorted before use.
+ labels : Optional[list[str]]
+ List of the names of the covariates. If X is a DataFrame the names of the covariables will
+ be taken from it and this argument will be ignored.
+ plot_kwargs : dict
+ Additional keyword arguments for the plot. Defaults to None.
+ Valid keys are:
+ - color_r2: matplotlib valid color for error bars
+ - marker_r2: matplotlib valid marker for the mean R squared
+ - marker_fc_r2: matplotlib valid marker face color for the mean R squared
+ - ls_ref: matplotlib valid linestyle for the reference line
+ - color_ref: matplotlib valid color for the reference line
+ - rotation: float, rotation angle of the x-axis labels. Defaults to 0.
+ ax : axes
+ Matplotlib axes.
+
+ Returns
+ -------
+ axes: matplotlib axes
+ """
+ if submodels is None:
+ submodels = np.sort(vi_results["indices"])
+ else:
+ submodels = np.sort(submodels)
+
+ indices = vi_results["indices"][submodels]
+ r2_mean = vi_results["r2_mean"][submodels]
+ r2_hdi = vi_results["r2_hdi"][submodels]
+ preds = vi_results["preds"][submodels]
+ preds_all = vi_results["preds_all"]
+ samples = preds.shape[1]
+
+ n_vars = len(indices)
+ ticks = np.arange(n_vars, dtype=int)
+
+ if plot_kwargs is None:
+ plot_kwargs = {}
+
+ if figsize is None:
+ figsize = (8, 3)
- # add remaining index
- indices += list(set(variables) - set(least_important_vars))
+ if ax is None:
+ _, ax = plt.subplots(1, 1, figsize=figsize)
- indices = indices[::-1]
- r2_mean = r2_mean[::-1]
- r2_hdi = r2_hdi[::-1]
+ if labels is None:
+ labels = vi_results["labels"][submodels]
- new_labels = [
- "+ " + ele if index != 0 else ele for index, ele in enumerate(labels_ary[indices])
- ]
+ r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)])
r2_yerr_min = np.clip(r2_mean - r2_hdi[:, 0], 0, None)
r2_yerr_max = np.clip(r2_hdi[:, 1] - r2_mean, 0, None)
+
ax.errorbar(
ticks,
r2_mean,
np.array((r2_yerr_min, r2_yerr_max)),
- color="C0",
+ color=plot_kwargs.get("color_r2", "k"),
+ fmt=plot_kwargs.get("marker_r2", "o"),
+ mfc=plot_kwargs.get("marker_fc_r2", "white"),
+ )
+ ax.axhline(
+ np.mean(r_2_ref),
+ ls=plot_kwargs.get("ls_ref", "--"),
+ color=plot_kwargs.get("color_ref", "grey"),
+ )
+ ax.fill_between(
+ [-0.5, n_vars - 0.5],
+ *az.hdi(r_2_ref),
+ alpha=0.1,
+ color=plot_kwargs.get("color_ref", "grey"),
+ )
+ ax.set_xticks(
+ ticks,
+ labels,
+ rotation=plot_kwargs.get("rotation", 0),
)
- ax.axhline(r2_mean[-1], ls="--", color="0.5")
- ax.set_xticks(ticks, new_labels, rotation=xlabel_angle)
ax.set_ylabel("R²", rotation=0, labelpad=12)
ax.set_ylim(0, 1)
ax.set_xlim(-0.5, n_vars - 0.5)
- return indices, ax
+ return ax
+
+
+def plot_scatter_submodels(
+ vi_results: dict,
+ func: Optional[Callable] = None,
+ submodels: Optional[Union[list[int], np.ndarray]] = None,
+ grid: str = "long",
+ labels: Optional[list[str]] = None,
+ figsize: Optional[tuple[float, float]] = None,
+ plot_kwargs: Optional[dict[str, Any]] = None,
+ ax: Optional[plt.Axes] = None,
+) -> list[plt.Axes]:
+ """
+ Plot submodel's predictions against reference-model's predictions.
+
+ Parameters
+ ----------
+ vi_results : Dictionary
+ Dictionary computed with `compute_variable_importance`
+ func : Optional[Callable], by default None.
+ Arbitrary function to apply to the predictions. Defaults to the identity function.
+ submodels : Optional[Union[list[int], np.ndarray]]
+ List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
+ The indices correspond to order computed by `compute_variable_importance`.
+ For example `submodels=[0,1]` will plot the two most important variables.
+ `submodels=[1,0]` is equivalent as values are sorted before use.
+ grid : str or tuple
+ How to arrange the subplots. Defaults to "long", one subplot below the other.
+ Other options are "wide", one subplot next to each other or a tuple indicating the number
+ of rows and columns.
+ labels : Optional[list[str]]
+ List of the names of the covariates.
+ plot_kwargs : dict
+ Additional keyword arguments for the plot. Defaults to None.
+ Valid keys are:
+ - marker_scatter: matplotlib valid marker for the scatter plot
+ - color_scatter: matplotlib valid color for the scatter plot
+ - alpha_scatter: matplotlib valid alpha for the scatter plot
+ - color_ref: matplotlib valid color for the 45 degree line
+ - ls_ref: matplotlib valid linestyle for the reference line
+ axes : axes
+ Matplotlib axes.
+
+ Returns
+ -------
+ axes: matplotlib axes
+ """
+ if submodels is None:
+ submodels = np.sort(vi_results["indices"])
+ else:
+ submodels = np.sort(submodels)
+
+ indices = vi_results["indices"][submodels]
+ preds_sub = vi_results["preds"][submodels]
+ preds_all = vi_results["preds_all"]
+
+ if labels is None:
+ labels = vi_results["labels"][submodels]
+
+ # handle categorical regression case:
+ n_cats = None
+ if preds_all.ndim > 2:
+ n_cats = preds_all.shape[-1]
+ indices = np.tile(indices, n_cats)
+
+ if ax is None:
+ _, ax = _get_axes(grid, len(indices), True, True, figsize)
+
+ if plot_kwargs is None:
+ plot_kwargs = {}
+
+ if func is not None:
+ preds_sub = func(preds_sub)
+ preds_all = func(preds_all)
+
+ min_ = min(np.min(preds_sub), np.min(preds_all))
+ max_ = max(np.max(preds_sub), np.max(preds_all))
+
+ # handle categorical regression case:
+ if n_cats is not None:
+ i = 0
+ for cat in range(n_cats):
+ for pred_sub, x_label in zip(preds_sub, labels):
+ ax[i].plot(
+ pred_sub[..., cat],
+ preds_all[..., cat],
+ marker=plot_kwargs.get("marker_scatter", "."),
+ ls="",
+ color=plot_kwargs.get("color_scatter", f"C{cat}"),
+ alpha=plot_kwargs.get("alpha_scatter", 0.1),
+ )
+ ax[i].set(xlabel=x_label, ylabel="ref model", title=f"Category {cat}")
+ ax[i].axline(
+ [min_, min_],
+ [max_, max_],
+ color=plot_kwargs.get("color_ref", "0.5"),
+ ls=plot_kwargs.get("ls_ref", "--"),
+ )
+ i += 1
+ else:
+ for pred_sub, x_label, axi in zip(preds_sub, labels, ax.ravel()):
+ axi.plot(
+ pred_sub,
+ preds_all,
+ marker=plot_kwargs.get("marker_scatter", "."),
+ ls="",
+ color=plot_kwargs.get("color_scatter", "C0"),
+ alpha=plot_kwargs.get("alpha_scatter", 0.1),
+ )
+ axi.set(xlabel=x_label, ylabel="ref model")
+ axi.axline(
+ [min_, min_],
+ [max_, max_],
+ color=plot_kwargs.get("color_ref", "0.5"),
+ ls=plot_kwargs.get("ls_ref", "--"),
+ )
+ return ax
def generate_sequences(n_vars, i_var, include):
@@ -890,3 +1219,13 @@ def generate_sequences(n_vars, i_var, include):
else:
sequences = [()]
return sequences
+
+
+@jit(nopython=True)
+def pearsonr2(A, B):
+ """Compute the squared Pearson correlation coefficient"""
+ A = A.flatten()
+ B = B.flatten()
+ am = A - np.mean(A)
+ bm = B - np.mean(B)
+ return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2))
diff --git a/pyproject.toml b/pyproject.toml
index 165ed67..4a2273d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,16 +8,17 @@ line-length = 100
[tool.ruff.lint]
select = ["E", "F", "I", "PL", "UP", "W"]
-ignore-init-module-imports = true
ignore = [
"PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons.
+ "PLR0913", #Too many arguments in function definition
+
]
[tool.ruff.lint.pylint]
max-args = 19
max-branches = 15
-[tool.ruff.extend-per-file-ignores]
+[tool.ruff.lint.extend-per-file-ignores]
"docs/conf.py" = ["E501", "F541"]
"tests/test_*.py" = ["F841"]
@@ -32,3 +33,20 @@ exclude_lines = [
isort = 1
black = 1
pyupgrade = 1
+
+
+[tool.mypy]
+files = "pymc_bart/*.py"
+plugins = "numpy.typing.mypy_plugin"
+
+[tool.mypy-matplotlib]
+ignore_missing_imports = true
+
+[tool.mypy-numba]
+ignore_missing_imports = true
+
+[tool.mypy-pymc]
+ignore_missing_imports = true
+
+[tool.mypy-scipy]
+ignore_missing_imports = true
diff --git a/requirements-docs.txt b/requirements-docs.txt
index 5074a06..214c399 100644
--- a/requirements-docs.txt
+++ b/requirements-docs.txt
@@ -1,8 +1,6 @@
myst-nb
-sphinx==5.0.2 # see https://github.com/pymc-devs/pymc-examples/issues/409
-git+https://github.com/pymc-devs/pymc-sphinx-theme
+sphinx
+pymc-sphinx-theme>=0.16
sphinxcontrib-bibtex
-nbsphinx
sphinx_design
sphinx_codeautolink
-sphinx_remove_toctrees
diff --git a/requirements.txt b/requirements.txt
index 23641cb..e3a38da 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-pymc<5.16.0
+pymc>=5.16.2, <=5.22.0
arviz>=0.18.0
numba
matplotlib
diff --git a/tests/test_bart.py b/tests/test_bart.py
index dfbd86f..226d938 100644
--- a/tests/test_bart.py
+++ b/tests/test_bart.py
@@ -3,7 +3,7 @@
import pytest
from numpy.testing import assert_almost_equal, assert_array_equal
from pymc.initial_point import make_initial_point_fn
-from pymc.logprob.basic import joint_logp
+from pymc.logprob.basic import transformed_conditional_logp
import pymc_bart as pmb
@@ -12,7 +12,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
fn = make_initial_point_fn(
model=model,
return_transformed=False,
- default_strategy="moment",
+ default_strategy="support_point",
)
moment = fn(0)["x"]
expected = np.asarray(expected)
@@ -27,7 +27,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
if check_finite_logp:
logp_moment = (
- joint_logp(
+ transformed_conditional_logp(
(model["x"],),
rvs_to_values={model["x"]: pm.math.constant(moment)},
rvs_to_transforms={},
@@ -53,7 +53,7 @@ def test_bart_vi(response):
mu = pmb.BART("mu", X, Y, m=10, response=response)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
- idata = pm.sample(random_seed=3415)
+ idata = pm.sample(tune=200, draws=200, random_seed=3415)
var_imp = (
idata.sample_stats["variable_inclusion"]
.stack(samples=("chain", "draw"))
@@ -77,8 +77,8 @@ def test_missing_data(response):
with pm.Model() as model:
mu = pmb.BART("mu", X, Y, m=10, response=response)
sigma = pm.HalfNormal("sigma", 1)
- y = pm.Normal("y", mu, sigma, observed=Y)
- idata = pm.sample(tune=100, draws=100, chains=1, random_seed=3415)
+ pm.Normal("y", mu, sigma, observed=Y)
+ pm.sample(tune=100, draws=100, chains=1, random_seed=3415)
@pytest.mark.parametrize(
@@ -91,7 +91,7 @@ def test_shared_variable(response):
Y = np.random.normal(0, 1, size=50)
with pm.Model() as model:
- data_X = pm.MutableData("data_X", X)
+ data_X = pm.Data("data_X", X)
mu = pmb.BART("mu", data_X, Y, m=2, response=response)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape)
@@ -116,7 +116,7 @@ def test_shape(response):
with pm.Model() as model:
w = pmb.BART("w", X, Y, m=2, response=response, shape=(2, 250))
y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y)
- idata = pm.sample(random_seed=3415)
+ idata = pm.sample(tune=50, draws=10, random_seed=3415)
assert model.initial_point()["w"].shape == (2, 250)
assert idata.posterior.coords["w_dim_0"].data.size == 2
@@ -133,7 +133,7 @@ class TestUtils:
mu = pmb.BART("mu", X, Y, m=10)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
- idata = pm.sample(random_seed=3415)
+ idata = pm.sample(tune=200, draws=200, random_seed=3415)
def test_sample_posterior(self):
all_trees = self.mu.owner.op.all_trees
@@ -184,12 +184,17 @@ def test_pdp(self, kwargs):
@pytest.mark.parametrize(
"kwargs",
[
- {},
+ {"samples": 50},
{"labels": ["A", "B", "C"], "samples": 2, "figsize": (6, 6)},
],
)
def test_vi(self, kwargs):
- pmb.plot_variable_importance(self.idata, X=self.X, bartrv=self.mu, **kwargs)
+ samples = kwargs.pop("samples")
+ vi_results = pmb.compute_variable_importance(
+ self.idata, bartrv=self.mu, X=self.X, samples=samples
+ )
+ pmb.plot_variable_importance(vi_results, **kwargs)
+ pmb.plot_scatter_submodels(vi_results, **kwargs)
def test_pdp_pandas_labels(self):
pd = pytest.importorskip("pandas")
@@ -243,8 +248,11 @@ def test_categorical_model(separate_trees, split_rule):
separate_trees=separate_trees,
)
y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y)
- idata = pm.sample(random_seed=3415, tune=300, draws=300)
- idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)
+ idata = pm.sample(tune=300, draws=300, random_seed=3415)
+ idata = pm.sample_posterior_predictive(
+ idata, predictions=True, extend_inferencedata=True, random_seed=3415
+ )
# Fit should be good enough so right category is selected over 50% of time
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()
+ assert pmb.compute_variable_importance(idata, bartrv=lo, X=X)["preds"].shape == (5, 50, 9, 3)