Hi,
Using einx.{max,min} with pytorch tensors return a namedtuple of type torch.return_types.{min,max} due to directly calling torch.{min,max} with a dim argument. This is an unexpected result, and also breaks some otherwise valid reduce/rearrange combinations such as:
import einx
import torch
x = torch.randn(5, 3)
print(einx.max("a [b] -> a 1", x))
Traceback (most recent call last):
File "/home/falkaer/projects/latent-features/einx_mwe.py", line 5, in <module>
print(einx.max("a [b] -> a 1", x))
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/falkaer/projects/latent-features/.devenv/state/venv/lib/python3.11/site-packages/einx/traceback_util.py", line 71, in func_with_reraise
raise e.with_traceback(tb) from None
File "<string>", line 4, in op0
TypeError: reshape(): argument 'input' (position 1) must be Tensor, not torch.return_types.max
Hi,
Using
einx.{max,min}with pytorch tensors return a namedtuple of typetorch.return_types.{min,max}due to directly callingtorch.{min,max}with adimargument. This is an unexpected result, and also breaks some otherwise valid reduce/rearrange combinations such as: