Skip to content

Conversation

@Anri-Lombard
Copy link
Contributor

Summary

  • Document that MLX's mask="causal" uses lower-right alignment
  • Clarify the difference from PyTorch's default is_causal=True (upper-left)

When T_q != T_kv, this distinction matters:

  • MLX (lower-right): Last query aligns with last key
  • PyTorch default (upper-left): First query aligns with first key

References:

Relates to #2835

Clarify that MLX uses lower-right alignment for causal masks when
T_q != T_kv, which differs from PyTorch's default upper-left alignment.

Relates to ml-explore#2835
Copy link
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think PyTorch has a causal_lower_right option for SDPA and the description is not really right.

@Anri-Lombard
Copy link
Contributor Author

Hey @zcbenz, it does have causal_lower_right since 2.3 and can be used with SDPA via the attn_mask parameter. I ran a script with:

from torch.nn.attention.bias import causal_lower_right
bias = causal_lower_right(T_q, T_kv)
F.scaled_dot_product_attention(q, k, v, attn_mask=bias)

to verify.

Here is the tutorial that documents this explicitly: https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html.

I also verified masks are mathematically identical. For example with T_q=2, T_kv=4:

  MLX's mask (using q_off = max(0, kL - qL)):
        k0  k1  k2  k3
  q0 [  1   1   1   0  ]
  q1 [  1   1   1   1  ]

  PyTorch's causal_lower_right(2, 4):
        k0  k1  k2  k3
  q0 [  1   1   1   0  ]
  q1 [  1   1   1   1  ]

  PyTorch's is_causal=True (upper_left):
        k0  k1  k2  k3
  q0 [  1   0   0   0  ]
  q1 [  1   1   0   0  ]
  

The first two are identical; the third is different. This is also consistent with MLX's CUDA backend which uses cuDNN's set_causal_mask_bottom_right.

Is there something specific about the description you think is incorrect? if your concern is that causal_lower_right isn't a direct SDPA parameter (like is_causal=True) but rather a separate utility class, I could clarify the wording to use the full module path torch.nn.attention.bias.causal_lower_right.

@zcbenz
Copy link
Collaborator

zcbenz commented Jan 18, 2026

Thanks for linking the docs, this is a new learn for me. On the behavior, it actually depends on whether T_q is larger or smaller than T_kv:

if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants