Skip to content

Crash using torch.compile with torch.inference_mode. #33

@marcelroed

Description

@marcelroed

This seems to break all my models using einx with this kind of rearrange under inference mode.

# /// script
# requires-python = ">=3.13"
# dependencies = [
#     "einx==0.4.3",
#     "torch==2.11.0",
# ]
# ///

import einx
import torch


@torch.compile
def f(x):
    return einx.id("i -> b... i", x, b=[1, 1])


with torch.inference_mode():
    x = torch.randn(2)
    print(f(x))
$ uv run scripts/einx_inference_mode.py
Traceback (most recent call last):
  File "/home/marcel/projects/cs336/assignments/systems/scripts/einx_inference_mode.py", line 20, in <module>
    print(f(x))
          ~^^^
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 1024, in compile_wrapper
    return fn(*args, **kwargs)
  File "/home/marcel/projects/cs336/assignments/systems/scripts/einx_inference_mode.py", line 15, in f
    return einx.id("i -> b... i", x, b=[1, 2])
           ~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/einx/_src/frontend/api.py", line 185, in inner
    function, code = construct_graph_with_cache(args=args, kwargs=kwargs | {"backend": backend})
                     ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/einx/_src/util/lru_cache.py", line 37, in func_frozen
    return func(*args, **kwargs)
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/einx/_src/frontend/api.py", line 131, in _construct_graph
    output_tracer = func(*args, **kwargs)
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/einx/_src/frontend/ops.py", line 29, in id
    return backend.id(description, *tensors, **parameters)
           ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/einx/_src/frontend/backend.py", line 54, in op
    return self.ops[name](*args, **kwargs)
           ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/einx/_src/adapter/einx_from_namedtensor.py", line 591, in inner
    exprs_in, exprs_out = solve(
                          ~~~~~^
        exprs_in,
        ^^^^^^^^^
    ...<5 lines>...
        equations_stage3=partial(equations_stage3, invocation=invocation) if equations_stage3 is not None else None,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/einx/_src/namedtensor/solve.py", line 86, in solve
    desc2=f"constraint ({_arr_to_str(np.asarray(v))})",
                         ~~~~~~~~~~~^^^^^^^^^^^^^^^
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 2316, in __call__
    result = self._torchdynamo_orig_backend(
        frame, cache_entry, self.hooks, frame_state, skip=1
    )
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 2052, in __call__
    result = self._inner_convert(
        frame, cache_entry, hooks, frame_state, skip=skip + 1
    )
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 729, in __call__
    result = _compile(
        frame.f_code,
    ...<16 lines>...
        convert_frame_box=self._box,
    )
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 1827, in _compile
    guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
                                  ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 1500, in compile_inner
    return _compile_inner(code, one_graph, hooks)
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 1634, in _compile_inner
    check_fn = dynamo_output.build_guards(
        code,
    ...<2 lines>...
        cache_entry=cache_entry,
    )
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 904, in build_guards
    return CheckFunctionManager(
        code,
    ...<5 lines>...
        strict_error=strict_error,
    )
  File "/home/marcel/.cache/uv/environments-v2/einx-inference-mode-1655de02df5016a3/lib/python3.13/site-packages/torch/_dynamo/guards.py", line 3928, in __init__
    raise AssertionError(
    ...<2 lines>...
    )
AssertionError: Guard failed on the same frame it was created. This is a bug - please create an issue.Guard fail reason: 23/0: tensor '___from_numpy(l)' dispatch key set mismatch. expected DispatchKeySet(CPU, BackendSelect, ADInplaceOrView), actual DispatchKeySet(CPU, BackendSelect)

This can be reproduced by running uv run <script_name>.py.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions