Skip to content

An Elemwise feeding dest[idx].set(...) / .inc(...) isn't fused into the write (extra temp + copy) #2192

@ricardoV94

Description

@ricardoV94

Description

When an Elemwise result is written into a subtensor via .set(...)/.inc(...), the Elemwise output is materialized in its own buffer and then copied into the destination, even when the write itself is in-place. The Elemwise could instead write directly into the destination region.

Repro (Numba backend; destination intermediate so the write is in-place):

import pytensor, pytensor.tensor as pt
x  = pt.vector("x")
o0 = pt.vector("o0")
o  = o0 + 1.0
res = o[1:].set(pt.exp(x))
fn = pytensor.function([x, o0], res, mode="NUMBA")
pytensor.dprint(fn, print_memory_map=True)
SetSubtensor{start:}  d={0: [0]}    # in-place on o
 ├─ Add (o0 + 1)
 ├─ Exp                             # own buffer (inplace_pattern={}, destroy_map={})
 │  └─ x
 └─ 1

So per call we allocate exp(x) in its own buffer and then copy it into o[1:]. The in-place SetSubtensor only avoids copying the whole o, not the Elemwise temp.

Elemwise inplace can only destroy one of its own inputs; the write destination o[1:] is not an input to Exp, so there's no path for the Elemwise to write its result directly into the destination.

The advanced-indexing fusion (IndexedElemwise) already avoids this (after #2015) — its inner graph is AdvancedIncSubtensor1{inplace,set}(buffer, Exp(...), idx), so the Elemwise is fused into the scatter and the result lands in the destination with no temp. The basic Subtensor/IncSubtensor path has no equivalent rewrite.

We should generalize the IndexedElemwise fuse-elemwise-into-write machinery to also absorb basic Subtensor/IncSubtensor, unifying the slice and advanced paths.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions