8000 GitHub - liesel-devs/liesel: A probabilistic programming framework
[go: up one dir, main page]

Skip to content

liesel-devs/liesel

Repository files navigation

Liesel: A Probabilistic Programming Framework

pypi readthedocs pre-commit pytest pytest-cov

logo

logo

Liesel is a probabilistic programming framework with a focus on semi-parametric regression. It includes:

  • Liesel, a library to express statistical models as Probabilistic Graphical Models (PGMs). Through the PGM representation, the user can build and update models in a natural way.
  • Goose, a library to build custom MCMC algorithms with several parameter blocks and MCMC kernels such as the No U-Turn Sampler (NUTS), the Iteratively Weighted Least Squares (IWLS) sampler, or different Gibbs samplers. Goose also takes care of the MCMC bookkeeping and the chain post-processing.
  • RLiesel, an R interface for Liesel which assists the user with the configuration of semi-parametric regression models such as Generalized Additive Models for Location, Scale and Shape (GAMLSS) with different response distributions, spline-based smooth terms and shrinkage priors.

The name “Liesel” is an homage to the Gänseliesel fountain, landmark of Liesel’s birth city Göttingen.

Resources

Usage

The following example shows how to build a simple i.i.d. normal model with Liesel. We set up two parameters and one observed variable, and combine them in a model.

import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd
import liesel.model as lsl

loc = lsl.Var.new_param(0.0, name="loc")
scale = lsl.Var.new_param(1.0, name="scale")

y = lsl.Var.new_obs(
    value=jnp.array([1.314, 0.861, -1.813, 0.587, -1.408]),
    distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
    name="y",
)

model = lsl.Model([y])

The model allows us to evaluate the log-probability through a property, which is updated automatically if the value of a node is modified.

model.log_prob
Array(-8.635652, dtype=float32)
model.vars["loc"].value = -0.5
model.log_prob
Array(-9.031153, dtype=float32)

We can estimate the mean parameter with Goose and a NUTS sampler. Goose’s workhorse to run an MCMC algorithm is the Engine, which can be constructed with the EngineBuilder. The builder allows us to assign different MCMC kernels to one or more parameters. We also need to specify the model, the initial values, and the sampling duration, before we can run the sampler.

import liesel.goose as gs

builder = gs.EngineBuilder(seed=42, num_chains=4)

builder.add_kernel(gs.NUTSKernel(["loc"]))
builder.set_model(gs.LieselInterface(model))
builder.set_initial_values(model.state)

 # we disable the progress bar for a nicer display here in the readme
builder.show_progress = False

builder.set_duration(warmup_duration=1000, posterior_duration=1000)

engine = builder.build()
engine.sample_all_epochs()

Finally, we can print a summary table and view some diagnostic plots.

results = engine.get_results()
gs.Summary(results)

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
loc () kernel_00 -0.109 0.448 -0.829 -0.119 0.645 4000 1416.594 2268.975 1.006

Error summary:

count relative
kernel error_code error_msg phase
kernel_00 1 divergent transition warmup 53 0.013
posterior 0 0.000
gs.plot_param(results, param="loc")

Installation

Liesel requires Python ≥ 3.13. Create and activate a virtual environment, and install the latest release from PyPI:

pip install liesel

You can also install the development version from GitHub. Liesel uses uv for the project management so make sure the tool is installed.

git clone https://github.com/liesel-devs/liesel.git
cd liesel
uv sync
# or `uv sync --dev` for an editable install including the dev dependencies
# or `uv sync --locked` for an installation using the dependencies locked in the uv.lock file

Liesel depends on JAX and jaxlib. As of now, there are no official jaxlib wheels for Windows. If you are on Windows, the JAX developers recommend using the Windows Subsystem for Linux. Alternatively, you can build jaxlib from source or try the unofficial jaxlib wheels from https://github.com/cloudhan/jax-windows-builder.

If you are using the lsl.plot_model() function, installing Graphviz will greatly improve the layout of the model graphs.

Development

Please run

  1. pre-commit run -a before committing your work,
  2. make sure the tests don’t fail with pytest --run-mcmc, and
  3. make sure the examples in your docstrings are up-to-date with pytest --doctest-modules liesel.

when you are using uv to manage the project, you can run these commends in the virtual environment by prepending uv run to the commands, e.g., uv run pre-commit run -a.

Acknowledgements

Liesel is being developed by Paul Wiemann, Hannes Riebl, Johannes Brachem and Gianmarco Callegher with support from Thomas Kneib. Important contributions were made by Joel Beck and Alex Afanasev. We are grateful to the German Research Foundation (DFG) for funding the development through grant KN 922/11-1.

Contributors 9

Languages

0