Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions effectful/handlers/jax/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
deffn,
defop,
syntactic_eq,
syntactic_hash,
)
from effectful.ops.types import Expr, NotHandled, Operation, Term

Expand Down Expand Up @@ -86,6 +87,12 @@ def _partial_eval(t: Expr[jax.Array]) -> Expr[jax.Array]:
if not sized_fvs:
return t

# if any dimension is zero sized, the result is empty
if any(size == 0 for size in sized_fvs.values()):
key = tuple(sized_fvs.keys())
shape = tuple(sized_fvs[k] for k in key)
return jax_getitem(jnp.empty(shape), key)

def _is_eager(t):
return not isinstance(t, Term) or t.op in sized_fvs or is_eager_array(t)

Expand Down Expand Up @@ -277,3 +284,9 @@ def _(x: jax.Array, other) -> bool:
and x.shape == other.shape
and bool((jnp.asarray(x) == jnp.asarray(other)).all())
)


@syntactic_hash.register(jax.Array)
def _(x: jax.Array) -> int:
# Concrete arrays aren't hashable; hash by shape, dtype, and bytes.
return hash(("jax.Array", x.shape, str(x.dtype), bytes(jax.numpy.asarray(x))))
20 changes: 3 additions & 17 deletions effectful/handlers/jax/_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import effectful.handlers.jax.numpy as jnp
from effectful.handlers.jax._handlers import (
IndexElement,
_partial_eval,
_register_jax_op,
bind_dims,
jax_getitem,
Expand Down Expand Up @@ -451,28 +450,15 @@ def _bind_dims_array(t: jax.Array, *args: Operation[[], jax.Array]) -> jax.Array
>>> bind_dims(t, b, a).shape
(3, 2)
"""

def _evaluate(expr):
if isinstance(expr, Term):
(args, kwargs) = jax.tree.map(_evaluate, (expr.args, expr.kwargs))
return _partial_eval(expr)
if not jax.tree_util.treedef_is_leaf(jax.tree.structure(expr)):
return jax.tree.map(_evaluate, expr)
return expr

if not isinstance(t, Term):
return t

result = _evaluate(t)
if not isinstance(result, Term) or not args:
return result

# ensure that the result is a jax_getitem with an array as the first argument
if not (result.op is jax_getitem and isinstance(result.args[0], jax.Array)):
if not (t.op is jax_getitem and isinstance(t.args[0], jax.Array)):
raise NotHandled

array = result.args[0]
dims = result.args[1]
array = t.args[0]
dims = t.args[1]
assert isinstance(dims, Sequence)

# ensure that the order is a subset of the named dimensions
Expand Down
Loading
Loading