Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions src/maxtext/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,3 +1800,118 @@ def qwen3_omni_mrope_embedding_as_linen(
metadata_fn=variable_to_logically_partitioned,
name=name,
)


from typing import Any


class DeepSeekV4RotaryEmbedding(nnx.Module):
"""DeepSeek-V4 partial rotary embedding with interleaved frequencies.

DeepSeek-V4 uses an interleaved positional encoding where consecutive channels
are paired together. Unlike standard rotary models that split dimensions globally
into first and second halves, this implementation pairs each even channel 2i
with the corresponding odd channel 2i + 1.

This results in two specific mathematical properties:
1. Inverse frequencies are computed for (dim // 2) unique theta angles.
2. Sinusoidal components are expanded consecutively (e.g., [f0, f0, f1, f1])
prior to application.
"""

def __init__(
self,
head_dim: int,
partial_rotary_factor: float = 64.0 / 512.0,
rope_theta: float = 10000.0,
dtype: Any = jnp.float32,
):
self.head_dim = head_dim
self.partial_rotary_factor = partial_rotary_factor
self.rope_theta = rope_theta
self.dtype = dtype

# Compute the partial rotary dimension (rope_head_dim)
# e.g., 512 * (64 / 512) = 64 channels
self.dim = int(head_dim * partial_rotary_factor)

# Compute base inverse frequencies for half of self.dim (dim // 2 unique theta angles).
# Adjacent channels share the same base frequency, matching the reference sequence.
half_dim = self.dim // 2
fraction = 2 * jnp.arange(0, half_dim, dtype=jnp.float32) / self.dim
self.inv_freq = 1.0 / (self.rope_theta**fraction)

def __call__(self, x: jnp.ndarray, position_ids: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
# position_ids: [B, S]
# Expand inverse frequencies for broadcasting: [1, 1, dim/2]
inv_freq_expanded = self.inv_freq[jnp.newaxis, jnp.newaxis, :]

# Expand position IDs: [B, S, 1]
position_ids_expanded = position_ids[:, :, jnp.newaxis].astype(jnp.float32)

# Compute outer product of positions and frequencies: [B, S, dim/2]
freqs = position_ids_expanded * inv_freq_expanded

cos = jnp.cos(freqs).astype(x.dtype) # [B, S, dim/2]
sin = jnp.sin(freqs).astype(x.dtype) # [B, S, dim/2]

return cos, sin


def _rotate_half(x: jax.Array) -> jax.Array:
"""Performs consecutive half-rotation to match DeepSeek-V4 interleaved layout.

Pairs adjacent elements: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2].

Operations:
1. Slice even indices: x1 = x[..., 0::2]
2. Slice odd indices: x2 = x[..., 1::2]
3. Stack (-x2, x1) along a new trailing dimension: [..., D/2, 2]
4. Reshape back to the original dimension: [..., D]
"""
x1 = x[..., 0::2] # [B, S, H, D_rope/2]
x2 = x[..., 1::2] # [B, S, H, D_rope/2]

# Interleave consecutive components: [-x2_0, x1_0, -x2_1, x1_1, ...]
stacked = jnp.stack((-x2, x1), axis=-1) # [B, S, H, D_rope/2, 2]
return stacked.reshape(x.shape) # [B, S, H, D_rope]


def apply_rotary_pos_emb(
x: jax.Array,
cos: jax.Array,
sin: jax.Array,
unsqueeze_dim: int = 2,
) -> jax.Array:
"""Applies DeepSeek-V4 interleaved RoPE to the trailing rotary slice of x.

1. Duplicates inverse frequencies consecutively using jnp.repeat along the
last dimension to match the full rotary dimension size.
2. Extracts the trailing 'rope_dim' channels of x to apply rotation, leaving
the leading 'nope' channels unmodified.
3. Computes the rotation using float32 precision for numerical stability,
casting the final rotated tensor back to the input data type.
"""
# cos/sin shape: [B, S, D_rope/2]
# Duplicate frequencies consecutively to build full D_rope dimension
cos = jnp.repeat(cos, 2, axis=-1) # [B, S, D_rope]
sin = jnp.repeat(sin, 2, axis=-1) # [B, S, D_rope]

# Expand dimensions for head broadcasting: [B, S, 1, D_rope]
cos = jnp.expand_dims(cos, axis=unsqueeze_dim)
sin = jnp.expand_dims(sin, axis=unsqueeze_dim)

rope_dim = cos.shape[-1]

# Separate features into unrotated (nope) and rotated (rope) slices
# x: [B, S, H, D] where D is the head dimension
nope = x[..., :-rope_dim] # [B, S, H, D - D_rope]
rope = x[..., -rope_dim:] # [B, S, H, D_rope]

# Cast to float32, compute rotation, and cast back to original data type
rope_f32 = rope.astype(jnp.float32)
rotated = (rope_f32 * cos) + (_rotate_half(rope_f32) * sin)
rotated = rotated.astype(x.dtype)

# Concatenate unrotated and rotated channels
return jnp.concatenate([nope, rotated], axis=-1) # [B, S, H, D]
63 changes: 63 additions & 0 deletions src/maxtext/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,66 @@ def mlp_block(
abstract_init=False,
)
return module


class DeepSeekGroupedLinear(nnx.Module):
"""Block-diagonal grouped linear projection layer.

This layer segments the trailing dimension of the input tensor into a specified
number of groups, and projects each group independently using a distinct weight
matrix block. It minimizes parameter counts and compute overhead in the
attention output projection.
"""

def __init__(
self,
in_features_per_group: int,
out_features: int,
n_groups: int,
weight_dtype: DType = jnp.float32,
dtype: DType = jnp.float32,
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
*,
rngs: nnx.Rngs,
):
self.in_features_per_group = in_features_per_group
self.out_features = out_features
self.n_groups = n_groups
self.weight_dtype = weight_dtype
self.dtype = dtype

# Validate divisibility of target output features by group count
if out_features % n_groups != 0:
raise ValueError(f"Output features ({out_features}) must be divisible by n_groups ({n_groups}).")
self.out_features_per_group = out_features // n_groups

# Grouped block-diagonal projection kernel parameters
# Kernels are stored as a 3D tensor: [n_groups, in_features_per_group, out_features_per_group]
kernel_shape = (n_groups, in_features_per_group, self.out_features_per_group)
self.weight = nnx.Param(
kernel_init(
rngs.params(),
kernel_shape,
self.weight_dtype,
in_axis=1,
out_axis=2,
)
)

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Projects segmented groups from the input tensor using block weight matrices.

Args:
x: Input tensor of shape [..., n_groups, in_features_per_group]

Returns:
Projected tensor of shape [..., n_groups, out_features_per_group]
"""
x = jnp.asarray(x, self.dtype)
weight = jnp.asarray(self.weight[...], self.dtype)

# Execute parallel group projection via optimized einsum broadcasting.
# x: [..., g, i]
# weight: [g, i, o]
# output: [..., g, o]
return jnp.einsum("...gi,gio->...go", x, weight)
59 changes: 59 additions & 0 deletions src/maxtext/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,62 @@ def l2norm(x: Array, dim: int = -1, eps: float = 1e-6) -> Array:
scale_init=linen_initializers.zeros,
scale_offset=1.0,
)


class DeepSeekV4RMSNorm(nnx.Module):
"""RMS normalization for DeepSeek-V4 (equivalent to T5LayerNorm)."""

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
dtype: Any = jnp.float32,
weight_dtype: Any = jnp.float32,
):
self.hidden_size = hidden_size
self.eps = eps
self.dtype = dtype
self.weight_dtype = weight_dtype

# Initialize learnable scale weight to ones matching T5LayerNorm behavior
self.weight = nnx.Param(jnp.ones((hidden_size,), dtype=weight_dtype))

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# [B, S, D] where D = hidden_size
# Convert inputs to float32 for numerical stability during variance pooling
x_f32 = jnp.asarray(x, jnp.float32) # [B, S, D] in float32

# Calculate variance across features axis
variance = jnp.mean(lax.square(x_f32), axis=-1, keepdims=True) # [B, S, 1]

# Apply reciprocal square root with epsilon offset
normalized = x_f32 * lax.rsqrt(variance + self.eps) # [B, S, D]

# Cast back to active precision and apply scaling weight
y = jnp.asarray(normalized, self.dtype) # [B, S, D]
weight = jnp.asarray(self.weight.get_value(), self.dtype) # [D]
return y * weight # [B, S, D]


class DeepSeekV4UnweightedRMSNorm(nnx.Module):
"""Unweighted RMS normalization for DeepSeek-V4."""

def __init__(
self,
eps: float = 1e-6,
dtype: Any = jnp.float32,
):
self.eps = eps
self.dtype = dtype

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# [..., D] where D is feature dimension
# Convert inputs to float32 for numerical stability during variance pooling
x_f32 = jnp.asarray(x, jnp.float32) # [..., D] in float32

# Calculate variance across features axis
variance = jnp.mean(lax.square(x_f32), axis=-1, keepdims=True) # [..., 1]

# Apply reciprocal square root and cast back to active precision
normalized = x_f32 * lax.rsqrt(variance + self.eps) # [..., D]
return jnp.asarray(normalized, self.dtype) # [..., D]
Loading
Loading