MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and targeting Google Cloud TPUs and GPUs for training and inference. MaxText achieves high MFUs and scales from single host to very large clusters while staying simple and "optimization-free" thanks to the power of Jax and the XLA compiler.
MaxText aims to be a launching off point for ambitious LLM projects both in research and production. We encourage users to start by experimenting with MaxText out of the box and then fork and modify MaxText to meet their needs.
We have used MaxText to demonstrate high-performance, well-converging training in int8 and scale training to ~51K chips.
Key supported features:
- TPUs and GPUs
- Training and Inference
- Models: Llama2, Mistral and Gemma
- Getting Started
- Runtime Performance Results
- Comparison To Alternatives
- Development
- Features and Diagnostics
For your first time running MaxText, we provide specific instructions.
MaxText supports training and inference of various open models. Follow user guides in the getting started folder to know more.
Some extra helpful guides:
- Gemma: a family of open-weights Large Language Model (LLM) by Google DeepMind, based on Gemini research and technology. You can run decode and finetuning using these instructions.
- Llama2: a family of open-weights Large Language Model (LLM) by Meta. You can run decode and finetuning using these instructions.
- Mixtral: a family of open-weights sparse mixture-of-experts (MoE) model by Mistral AI. You can run decode and finetuning using these instructions
In addition to the getting started guides, there are always other MaxText capabilities that are being constantly being added! The full suite of end-to-end tests is in end_to_end. We run them with a nightly cadence. They can be a good source for understanding MaxText Alternatively you can see the continuous unit tests which are run almost continuously.
More details on reproducing these results can be found in MaxText/configs/README.md.
No. of params | Accelerator Type | TFLOP/chip/sec | Model flops utilization (MFU) |
---|---|---|---|
32B | v5p-128 | 3.28e+02 | 67.76% |
64B | v5p-128 | 3.23e+02 | 70.31% |
128B | v5p-256 | 3.15e+02 | 68.68% |
128B | v5p-512 | 3.15e+02 | 68.53% |
256B | v5p-1024 | 3.16e+02 | 68.82% |
512B | v5p-1024 | 2.94e+02 | 63.99% |
1024B | v5p-2048 | 2.49e+02 | 64.05% |
1024B | v5p-4096 | 2.97e+02 | 64.80% |
1160B | v5p-7680 | 2.95e+02 | 64.27% |
1160B | v5p-12288 | 3.04e+02 | 66.23% |
For 16B, 32B, 64B, and 128B models. See full run configs in MaxText/configs/v5e/ as 16b.sh
, 32b.sh
, 64b.sh
, 128b.sh
.
Hardware | 16B TFLOP/sec/chip | 16B MFU | 32B TFLOP/sec/chip | 32B MFU | 64B TFLOP/sec/chip | 64B MFU | 128B TFLOP/sec/chip | 128B MFU |
---|---|---|---|---|---|---|---|---|
1x v5e-256 | 120 | 61.10% | 132 | 66.86% | 118 | 59.90% | 110 | 56.06% |
2x v5e-256 | 117 | 59.37% | 128 | 64.81% | 112 | 56.66% | 110 | 55.82% |
4x v5e-256 | 117 | 59.14% | 126 | 64.10% | 110 | 55.85% | 108 | 54.93% |
8x v5e-256 | 115 | 58.27% | 125 | 63.67% | 108 | 54.96% | 104 | 52.93% |
16x v5e-256 | 111 | 56.56% | 123 | 62.26% | 105 | 53.29% | 100 | 50.86% |
32x v5e-256 | 108 | 54.65% | 119 | 60.40% | 99 | 50.18% | 91 | 46.25% |