Skip to content

feat(utils): add register_pytree_node and pytree-aware compile wrapper#3500

Open
st-adam wants to merge 1 commit intoml-explore:mainfrom
st-adam:feat/register-pytree-node
Open

feat(utils): add register_pytree_node and pytree-aware compile wrapper#3500
st-adam wants to merge 1 commit intoml-explore:mainfrom
st-adam:feat/register-pytree-node

Conversation

@st-adam
Copy link
Copy Markdown

@st-adam st-adam commented May 8, 2026

Fixes #3499.

Summary

  • Adds mlx.utils.register_pytree_node(cls, flatten_fn, unflatten_fn) — a JAX-style API for registering custom classes as pytree nodes
  • Extends tree_map and tree_flatten to traverse registered types
  • Adds mlx.utils.compile() — a Python-level pytree-aware wrapper around mlx.core.compile that flattens registered-type arguments to arrays before the C++ tracer sees them

Motivation

mx.compile rejects custom objects as function arguments even when they are simple array wrappers. This blocks JIT compilation of hybrid SSM+attention models (Qwen 3.5/3.6, Llama 4, Gemma 3n, etc.) whose forward pass takes ArraysCache (SSM state cache) arguments alongside regular KVCache objects.

Usage

import mlx.utils as mlxu

# Register ArraysCache as a pytree
mlxu.register_pytree_node(
    ArraysCache,
    lambda c: (list(c.state), None),         # flatten → list of mx.array
    lambda _, s: ArraysCache.from_state(s),  # unflatten
)

# Use mlx.utils.compile (not mlx.core.compile) to get pytree support
compiled_forward = mlxu.compile(model.__call__)
output = compiled_forward(x, cache=cache)   # cache may contain ArraysCache

Test plan

  • register_pytree_node + tree_flatten on custom class: returns correct flat array list
  • mlx.utils.compile with registered pytree arg: correct result (3 + 4 + 10 = 17 verified)
  • No-registry fast path: delegates directly to mlx.core.compile, zero overhead
  • Structure-change detection: triggers recompile on next call

Limitations

SSM caches with in-place state mutation require the cache to also flow through inputs/outputs for the mutations to be captured by the tracer. This PR addresses argument-position structured inputs; mutable SSM state is a follow-on.

Files changed

  • python/mlx/utils.py — all changes (pure Python, no C++ modifications)

🤖 Generated with Claude Code

Adds mlx.utils.register_pytree_node — a JAX-style API that lets
third-party classes participate in MLX tree utilities and mx.compile.

Problem
-------
mx.compile rejects any function argument that is not a plain array,
list, dict, tuple, or scalar constant:

  ValueError: [compile] Function arguments must be trees of arrays or
  constants (floats, ints, strings, or None), but received type
  mlx_lm.models.cache.ArraysCache.

Any model whose forward pass receives custom cache objects (e.g. SSM
state caches in hybrid attention+SSM architectures like Qwen 3.5/3.6)
cannot be compiled with mx.compile even though the underlying
computation is fully expressible as MLX ops.

Changes
-------
python/mlx/utils.py:

1. _pytree_registry dict — module-level registry mapping type → (flatten_fn, unflatten_fn)

2. register_pytree_node(cls, flatten_fn, unflatten_fn) — public API
   mirroring jax.tree_util.register_pytree_node:
     flatten_fn(obj)          -> (children: list, aux_data)
     unflatten_fn(aux, children) -> obj

3. tree_map / tree_flatten extended to traverse registered types
   (registered nodes are treated as interior nodes, not leaves)

4. _pytree_flatten_args / _pytree_unflatten_args — helpers that flatten
   a call-argument tuple, expanding registered pytrees to their array
   leaves, and reconstruct the originals from a flat list.

5. mlx.utils.compile(fun, inputs, outputs, shapeless) — Python-level
   pytree-aware wrapper around mlx.core.compile:
   - Wraps fun so registered-type args are flattened to arrays before
     the C++ tracer sees them, then unflattened when calling the
     original function.
   - Falls through to mlx.core.compile directly when no pytrees are
     registered (zero overhead for existing code).
   - Automatically recompiles when the pytree structure changes between
     calls.

Usage
-----
  import mlx.utils as mlxu
  from mlx_lm.models.cache import ArraysCache, KVCache

  mlxu.register_pytree_node(
      ArraysCache,
      lambda c: (c.state, None),          # flatten: expose state arrays
      lambda _, s: ArraysCache.from_state(s),  # unflatten
  )

  # Now mx.compile (via mlx.utils.compile) accepts ArraysCache args:
  compiled_step = mlxu.compile(model_step)

Limitations
-----------
SSM caches with in-place state mutation (GatedDeltaNet, Mamba) update
their arrays after every step. These mutations are not captured by the
C++ compile graph unless the cache is also passed via the inputs/outputs
mechanism. This PR unblocks read-only structured arguments; mutable SSM
state requires a separate inputs/outputs-aware solution.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the register_pytree_node API is a good thing to have, but utils.compile would be a bad addition.

The register_pytree_node API should be implemented in C++ layer so mx.compile can be made aware of it directly. And we should probably remove the python versions of tree utils and expose the C++ ones instead so we don't have to duplicate the register_pytree_node implementation in 2 languages.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

mx.compile rejects custom cache objects — no pytree registration mechanism

2 participants