Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,71 @@
from maxtext.utils.globals import EPS


def _apply_rope(x, cos, sin, interleave=True):
"""Applies rotary positional embedding to the input.

Args:
x: Input tensor [B, S, N, H] or [B, S, H].
cos: Cosine component of RoPE, [B, S, 1, H/2] or [B, S, H/2].
sin: Sine component of RoPE, [B, S, 1, H/2] or [B, S, H/2].
interleave: Whether to use interleaved or concatenated layout.

Returns:
Rotated input.
"""
if interleave:
x1, x2 = x[..., ::2], x[..., 1::2]
else:
x1, x2 = jnp.split(x, 2, axis=-1)

# Handle cases with or without heads dimension
if x.ndim == 4:
cos = cos[:, :, None, :] if cos.ndim == 3 else cos
sin = sin[:, :, None, :] if sin.ndim == 3 else sin
elif x.ndim == 3:
cos = cos[:, :, :] if cos.ndim == 3 else cos
sin = sin[:, :, :] if sin.ndim == 3 else sin

y1 = x1 * cos - x2 * sin
y2 = x1 * sin + x2 * cos

if interleave:
rotated = jnp.stack([y1, y2], axis=-1)
return rotated.reshape(x.shape)
else:
return jnp.concatenate([y1, y2], axis=-1)


def _compute_rope(head_dim, positions, theta, dtype):
"""Computes RoPE frequencies on the fly for given positions."""
freqs = 1.0 / (
theta ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)
)
# positions shape [B, S], freqs shape [D/2] -> angles shape [B, S, D/2]
angles = positions[..., None] * freqs
return jnp.cos(angles).astype(dtype), jnp.sin(angles).astype(dtype)


def _get_cos_sin(rotary_embedding, positions, dtype):
"""Computes cos and sin embeddings from the rotary_embedding module."""
# Use optimized on-the-fly computation instead of table lookup
head_dim = rotary_embedding.embedding_dims
theta = rotary_embedding.rope_theta

cos, sin = _compute_rope(head_dim, positions, theta, dtype)

# Add heads dimension for broadcasting: [B, S, D/2] -> [B, S, 1, D/2]
cos = cos[:, :, jnp.newaxis, :]
sin = sin[:, :, jnp.newaxis, :]

if getattr(rotary_embedding, "attention_scaling", False):
rope_factor = getattr(rotary_embedding, "rope_factor", 1.0)
scaling = 1.0 if rope_factor <= 1 else (0.1 * math.log(rope_factor) + 1.0)
cos = cos * scaling
sin = sin * scaling
return cos, sin


class Indexer(nnx.Module):
"""Indexer for DeepSeek Sparse Attention (DSA).

Expand Down Expand Up @@ -189,9 +254,10 @@ def apply_partial_rope(
# indexer_head_dim -> [rope_head_dim, indexer_head_dim - rope_head_dim]
x_pe, x_nope = jnp.split(inputs, [self.rope_head_dim], axis=-1)
# x_pe [B, S, H, rope_head_dim], positions [B, S]
x_pe = self.rotary_embedding(x_pe, position=inputs_positions)
cos, sin = _get_cos_sin(self.rotary_embedding, inputs_positions, self.dtype)
x_pe = _apply_rope(x_pe, cos, sin, interleave=self.rotary_embedding.interleave)
x = jnp.concatenate([x_pe, x_nope], axis=-1)
return x
return checkpoint_name(x, "indexer_partial_rope")

def generate_mask(self, topk_indices, s):
"""
Expand Down Expand Up @@ -478,6 +544,17 @@ def mla_as_linen(
class MLA(Attention):
"""Multi-Head Latent Attention (MLA) layer."""

def apply_rotary_embedding(self, inputs: Array, inputs_positions: Optional[Array | None] = None, rope_kwargs: dict | None = None):
"""Overrides RoPE with optimized implementation for MLA."""
with jax.named_scope("mla_rope"):
if inputs_positions is None:
seq_length = inputs.shape[1]
inputs_positions = jnp.arange(seq_length, dtype=jnp.int32)[jnp.newaxis, :]

cos, sin = _get_cos_sin(self.rotary_embedding, inputs_positions, self.dtype)
x_out = _apply_rope(inputs, cos, sin, interleave=self.rotary_embedding.interleave)
return checkpoint_name(x_out, "mla_rope")

def __init__(
self,
config: Config,
Expand Down
Loading