In the rapidly evolving landscape of large language models (LLMs), pre-training is only the first step. To transform a base model into a specialized assistant or a high-performing reasoning engine, post-training is essential. Today, we are excited to announce new features in MaxText that streamline this process: Supervised Fine-Tuning (SFT) and Reinforcement Learning (RL) now available on single-host TPU configurations (such as v5p-8 and v6e-8).
By leveraging the power of JAX and the efficiency of the Tunix library, MaxText provides a high-performance, scalable path for developers to refine their models using the latest post-training techniques. You can explore the full documentation for SFT and RL to start your post-training journey on TPUs today.
Supervised Fine-Tuning is the primary method for adapting a pre-trained model to follow specific instructions or excel at niche tasks. With the new single-host SFT support, users can now take an existing MaxText or Hugging Face checkpoint and fine-tune it on labeled datasets with minimal setup.
Key Highlights:
For tasks requiring complex logic and reasoning—such as math or coding—Reinforcement Learning is a game-changer. MaxText now supports several state-of-the-art RL algorithms on single-host TPUs, utilizing vLLM for high-throughput inference during the training loop. For example,
To begin using these new features, ensure you have the latest post-training dependencies installed:
uv pip install maxtext[tpu-post-train]==0.2.1 --resolution=lowest
install_maxtext_tpu_post_train_extra_deps
You can launch an SFT run using the train_sft module, specifying your model, dataset, and output directory:
python3 -m maxtext.trainers.post_train.sft.train_sft \
model_name=${MODEL?} \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
run_name=${RUN_NAME?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?}
For RL, the train_rl module handles the loading of policy and reference models, executes the training, and provides automated evaluation on reasoning benchmarks:
python3 -m maxtext.trainers.post_train.rl.train_rl \
model_name=${MODEL?} \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
run_name=${RUN_NAME?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
loss_algo=gspo-token \
chips_per_vm=${CHIPS_PER_VM?}
While single-host support provides a powerful entry point for many developers, MaxText is built for scale. These same workflows are designed to transition seamlessly to multi-host configurations for those training larger models and utilizing massive datasets. Please stay tuned for more updates in this direction from us in the future.