Introducing Metrax: performant, efficient, and robust model evaluation metrics in JAX

2025년 11월 13일
Yufeng Guo Developer Advocate Core ML Frameworks
Jiwon Shin Software Engineer Core ML Frameworks
Jeff Carpenter Software Engineer Core ML Frameworks

At Google, as teams were migrating from TensorFlow to JAX, teams were manually reimplementing metrics that were previously provided by TensorFlow, because JAX did not have a built-in metrics library. So each team using JAX was implementing its own version of accuracy, F1, RMS error, etc. While creating metrics may seem, to some, like a fairly simple and straightforward topic, when considering large scale training and evaluation across datacenter-sized distributed compute environments, it becomes somewhat less trivial.

And thus the idea for Metrax was born: to bring a high-performance library for efficient and robust model evaluation metrics in JAX. Metrax currently provides predefined metrics used to evaluate various types of machine learning models (classification, regression, recommendation, vision, audio, and language), and provides compatibility and consistency in distributed and scaled training environments. This allows you to focus on the model evaluation results, rather than (re)implementing various metrics definitions. Metrax adds to the ever-evolving ecosystem of JAX-based tooling, integrating well with the JAX AI Stack, a suite of tools that are designed to work together to power your AI tooling needs. Today, Metrax is already used by some of the largest software stacks at Google, including teams in Google Search, YouTube, and Google’s own post-training library, Tunix.

Link to Youtube Video (visible only when JS is disabled)

Strengths of Metrax

Particularly noteworthy is the inclusion of the ability to compute "at K" metrics for multiple values of K, in parallel, which allows you to more comprehensively evaluate model performance, more quickly. For example, you can use PrecisionAtK to determine the precision of your model for multiple values of K (say, at K=1, K=8, and K=20), all in one forward pass through your model, rather than needing to call PrecisionAtK multiple times with each of these arguments. There are several "at K" metrics available for you to try out, including RecallAtK and NDCGAtK. All the metrics, along with their definitions, can be found at the documentation located here.

The last thing you want to worry about when working on your machine learning research project is whether your metrics are implemented correctly across your system, so having a well-tested metrics library will help the community create less error prone code and model evaluations.

Performance

Metrax leverages some of the core strengths of JAX, including vmap and jit, to enable it to do things like multiple "at K" operations, and to do so in a highly performant manner. While not every metric offered is "jit-able" due to the nature of the metric, the goal is to ensure all metrics are well-written and demonstrate best practices. Beyond the classic metrics such as accuracy, precision, and recall, the library also features a robust set of NLP-related metrics, including Perplexity, BLEU, and ROUGE, as well as metrics for vision models, such as Intersection over Union (IoU), Signal-to-Noise Ratio (SNR), and Structural Similarity Index (SSIM). There's no need to vibe code your metrics implementations anymore, just use Metrax!

Metrax in action

Let's see how to use Metrax with your code. This is what it looks like to compute precision metrics from your model's output. Notice that we pass in the predictions and labels, along with a threshold value, and then to compute the metric's value, we need to call compute().

import metrax

# Directly compute the metric state.
metric_state = metrax.Precision.from_model_output(
    predictions=predictions,
    labels=labels,
    threshold=0.5
)

# The result is then readily available by calling compute().
result = metric_state.compute()
result
Python

Oftentimes, we do evaluations in batches, so we want to be able to iteratively add more information to our collection of metrics. Metrax supports this workflow with a function called merge(). This is a great function to use inside your evaluation loop as you're aggregating your metrics over the course of your training run. Notice we still call compute() when we're ready to get a final value.

# Iteratively merging precision metrics
for labels_b, predictions_b, weights_b in zip(labels_batched, predictions_batched, sample_weights_batched):
    batch_metric_state = metrax.Precision.from_model_output(
        predictions=predictions_b,
        labels=labels_b
    )
    metric_state = metric_state.merge(batch_metric_state)

result = metric_state.compute()
result
Python

For a full set of examples check out this notebook, which demonstrates more ways you can use Metrax, including scaling to multiple devices and integrations with Flax NNX, a modeling library that abstracts away some of the implementation details of building AI models.

Contribute

Metrax is developed on GitHub, and happy to accept community contributions. Some of the metrics available today were in fact added by community contributors; big shout out to GitHub users @nikolasavic3 and @Mrigankkh for their efforts! So if there are more metrics you'd like to see added to it, submit a pull request and work with the development team to include it into Metrax. You can learn more at github.com/google/metrax.

Also, be sure to check out the other libraries in the JAX ecosystem, at jaxstack.ai. There you can find more libraries that integrate well with Metrax, and additional content about building machine learning models.