feat(utils): add register_pytree_node and pytree-aware compile wrapper#3500
Open
st-adam wants to merge 1 commit intoml-explore:mainfrom
Open
feat(utils): add register_pytree_node and pytree-aware compile wrapper#3500st-adam wants to merge 1 commit intoml-explore:mainfrom
st-adam wants to merge 1 commit intoml-explore:mainfrom
Conversation
Adds mlx.utils.register_pytree_node — a JAX-style API that lets
third-party classes participate in MLX tree utilities and mx.compile.
Problem
-------
mx.compile rejects any function argument that is not a plain array,
list, dict, tuple, or scalar constant:
ValueError: [compile] Function arguments must be trees of arrays or
constants (floats, ints, strings, or None), but received type
mlx_lm.models.cache.ArraysCache.
Any model whose forward pass receives custom cache objects (e.g. SSM
state caches in hybrid attention+SSM architectures like Qwen 3.5/3.6)
cannot be compiled with mx.compile even though the underlying
computation is fully expressible as MLX ops.
Changes
-------
python/mlx/utils.py:
1. _pytree_registry dict — module-level registry mapping type → (flatten_fn, unflatten_fn)
2. register_pytree_node(cls, flatten_fn, unflatten_fn) — public API
mirroring jax.tree_util.register_pytree_node:
flatten_fn(obj) -> (children: list, aux_data)
unflatten_fn(aux, children) -> obj
3. tree_map / tree_flatten extended to traverse registered types
(registered nodes are treated as interior nodes, not leaves)
4. _pytree_flatten_args / _pytree_unflatten_args — helpers that flatten
a call-argument tuple, expanding registered pytrees to their array
leaves, and reconstruct the originals from a flat list.
5. mlx.utils.compile(fun, inputs, outputs, shapeless) — Python-level
pytree-aware wrapper around mlx.core.compile:
- Wraps fun so registered-type args are flattened to arrays before
the C++ tracer sees them, then unflattened when calling the
original function.
- Falls through to mlx.core.compile directly when no pytrees are
registered (zero overhead for existing code).
- Automatically recompiles when the pytree structure changes between
calls.
Usage
-----
import mlx.utils as mlxu
from mlx_lm.models.cache import ArraysCache, KVCache
mlxu.register_pytree_node(
ArraysCache,
lambda c: (c.state, None), # flatten: expose state arrays
lambda _, s: ArraysCache.from_state(s), # unflatten
)
# Now mx.compile (via mlx.utils.compile) accepts ArraysCache args:
compiled_step = mlxu.compile(model_step)
Limitations
-----------
SSM caches with in-place state mutation (GatedDeltaNet, Mamba) update
their arrays after every step. These mutations are not captured by the
C++ compile graph unless the cache is also passed via the inputs/outputs
mechanism. This PR unblocks read-only structured arguments; mutable SSM
state requires a separate inputs/outputs-aware solution.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
5 tasks
zcbenz
requested changes
May 10, 2026
Collaborator
zcbenz
left a comment
There was a problem hiding this comment.
I think the register_pytree_node API is a good thing to have, but utils.compile would be a bad addition.
The register_pytree_node API should be implemented in C++ layer so mx.compile can be made aware of it directly. And we should probably remove the python versions of tree utils and expose the C++ ones instead so we don't have to duplicate the register_pytree_node implementation in 2 languages.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #3499.
Summary
mlx.utils.register_pytree_node(cls, flatten_fn, unflatten_fn)— a JAX-style API for registering custom classes as pytree nodestree_mapandtree_flattento traverse registered typesmlx.utils.compile()— a Python-level pytree-aware wrapper aroundmlx.core.compilethat flattens registered-type arguments to arrays before the C++ tracer sees themMotivation
mx.compilerejects custom objects as function arguments even when they are simple array wrappers. This blocks JIT compilation of hybrid SSM+attention models (Qwen 3.5/3.6, Llama 4, Gemma 3n, etc.) whose forward pass takesArraysCache(SSM state cache) arguments alongside regularKVCacheobjects.Usage
Test plan
register_pytree_node+tree_flattenon custom class: returns correct flat array listmlx.utils.compilewith registered pytree arg: correct result (3 + 4 + 10 = 17 verified)mlx.core.compile, zero overheadLimitations
SSM caches with in-place state mutation require the cache to also flow through
inputs/outputsfor the mutations to be captured by the tracer. This PR addresses argument-position structured inputs; mutable SSM state is a follow-on.Files changed
python/mlx/utils.py— all changes (pure Python, no C++ modifications)🤖 Generated with Claude Code