Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions einx/backend/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import einx
import types
from functools import partial
import inspect
from frozendict import frozendict


def create():
Expand Down Expand Up @@ -70,6 +72,24 @@ def to_tensor(tensor, shape):
rsqrt = op.elementwise(tjax.lax.rsqrt)
square = op.elementwise(tjnp.square)

@staticmethod
def tracing_cache_key(args, kwargs):
def process_arg(arg):
if not jax_.tree_util.treedef_is_leaf(tree := jax_.tree_util.tree_structure(arg)):
# A pytree, probably an equinox module
# We want to cache by the shape of any contained
# arrays instead of the arrays themselves
return "_EQUINOX_PYTREE", jax_.tree_util.tree_map(lambda x: x.shape if isinstance(x, jnp.ndarray) else x, tree)
elif inspect.ismethod(arg) and not jax_.tree_util.treedef_is_leaf(tree := jax_.tree_util.tree_structure(arg.__self__)):
# Bound method of a pytree, probably an equinox module
return "_EQUINOX_METHOD", arg.__func__, jax_.tree.map(lambda x: x.shape if isinstance(x, jnp.ndarray) else x, tree)

# Not a bound method of a pytree, just return the arg
return arg

ret = tuple(process_arg(arg) for arg in args), frozendict({k: process_arg(v) for k, v in kwargs.items()})
return ret

@staticmethod
@einx.trace
def get_at(tensor, coordinates):
Expand Down
4 changes: 4 additions & 0 deletions einx/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def __exit__(backend, *args):
def _decorate_construct_graph(f):
return f

@staticmethod
def tracing_cache_key(args, kwargs):
return args, kwargs

@classmethod
@einx.trace
def all_to_tensor(backend, tensors, convert_scalars=False):
Expand Down
36 changes: 29 additions & 7 deletions einx/tracer/decorator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import functools
import os

import cachetools

import einx
import threading
import frozendict
Expand Down Expand Up @@ -95,20 +98,39 @@ def func_with_warn(*args, **kwargs):
return func


def lru_cache(func):
def _cache_impl(func, cache_provider):
func = _with_retrace_warning(func)

max_cache_size = int(os.environ.get("EINX_CACHE_SIZE", -1))
func = cache_provider(func, max_cache_size)
func = freeze(func)

return func


def _lru_cache_provider(func, max_cache_size):
if max_cache_size > 0:
func = functools.lru_cache(maxsize=max_cache_size if max_cache_size > 0 else None)(func)
return functools.lru_cache(maxsize=max_cache_size if max_cache_size > 0 else None)(func)
elif max_cache_size < 0:
if "cache" in vars(functools):
func = functools.cache(func)
return functools.cache(func)
else:
func = functools.lru_cache(maxsize=None)(func)
func = freeze(func)
return functools.lru_cache(maxsize=None)(func)

return func

def lru_cache(func):
return _cache_impl(func, _lru_cache_provider)


def _tracing_cache_provider(func, max_cache_size):
return cachetools.cached(
cache = cachetools.LRUCache(max_cache_size) if max_cache_size > 0 else {},
key = lambda args, kwargs, backend: cachetools.keys.hashkey(*backend.tracing_cache_key(args, kwargs), backend),
)(func)


def _tracing_cache(func):
return _cache_impl(func, _tracing_cache_provider)


_thread_local = threading.local()
Expand Down Expand Up @@ -146,7 +168,7 @@ def jit(func=None, trace=trace_all):
if func is None:
return partial(jit, trace=trace)

@lru_cache
@_tracing_cache
def construct_graph(args, kwargs, backend):
with _trace_context(backend):
# Replace input keys with tracers and retrieve list of traced arguments
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"numpy",
"sympy",
"frozendict",
"cachetools",
]

[project.optional-dependencies]
Expand Down
32 changes: 32 additions & 0 deletions test/test_equinox_vmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import importlib
from typing import Any

if importlib.util.find_spec("equinox"):
import equinox as eqx
import jax.numpy as jnp
import einx.nn.equinox as einn
import einx
import jax

def test_equinox_vmap():
class DummyModule(eqx.Module):
w: jnp.ndarray
linear: einn.Linear
key_linear: Any = eqx.field(static = True)

def __init__( self, size, key ):
kw, kl = jax.random.split(key, 2)
self.key_linear = kl
self.w = jax.random.normal(kw, (size, size))
self.linear = einn.Linear("[c -> s]", s = size)

def __call__( self, x ):
x = self.linear(x, rng = self.key_linear)
return self.w @ x

dummy = DummyModule(100, jax.random.PRNGKey(0))

arr = einx.add("a, b -> a b", jnp.arange(10), 10 * jnp.arange(10))
dummy(jnp.arange(10))
ret = einx.vmap("a [b] -> a [s]", arr, s=100, op=dummy)
assert ret.shape == (10, 100)