-
Notifications
You must be signed in to change notification settings - Fork 24
Open
Labels
enhancementNew feature or requestNew feature or request
Description
By replacing explicite tensor operations with torch.einsum() in the Zero-Order-Hold transformation, performance and readability can be improved.
Replacing the original Zero-Order-Hold transformation in line 518 of mamba_arch.py
deltaA = torch.exp(delta.unsqueeze(-1) * A)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
BX = deltaB * (x.unsqueeze(-1))with:
deltaA = torch.einsum('bld,dn->bldn', dt, A)
BX = torch.einsum('bld,bld,bln->bldn', dt, u, B) can improve execution time by up to ~40% while requiring the same number of FLOPS. (See attached plot)
Moreover, vectorization of the loop does not further improve execution time.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request
