Skip to content

einx.{max,min} are broken with pytorch #26

@falkaer

Description

@falkaer

Hi,

Using einx.{max,min} with pytorch tensors return a namedtuple of type torch.return_types.{min,max} due to directly calling torch.{min,max} with a dim argument. This is an unexpected result, and also breaks some otherwise valid reduce/rearrange combinations such as:

import einx
import torch

x = torch.randn(5, 3)
print(einx.max("a [b] -> a 1", x))
Traceback (most recent call last):
  File "/home/falkaer/projects/latent-features/einx_mwe.py", line 5, in <module>
    print(einx.max("a [b] -> a 1", x))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/falkaer/projects/latent-features/.devenv/state/venv/lib/python3.11/site-packages/einx/traceback_util.py", line 71, in func_with_reraise
    raise e.with_traceback(tb) from None
  File "<string>", line 4, in op0
TypeError: reshape(): argument 'input' (position 1) must be Tensor, not torch.return_types.max

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions