Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,181 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.
return output, pre_bias_logits


def _sqrtsoftplus(x: jax.Array) -> jax.Array:
"""Computes sqrtsoftplus activation: sqrt(softplus(x))."""
# [Any] -> [Any]
return jnp.sqrt(jax.nn.softplus(x))


class DeepSeekV4TopKRouter(nnx.Module):
"""Top-K Router for DeepSeek-V4 MoE routing.

Computes logits, normalized routing weights, and expert indices.
"""

def __init__(
self,
config: ctypes.Config,
mesh: jax.sharding.Mesh,
rngs: nnx.Rngs,
kernel_axes: Tuple[Optional[str], ...] = (),
):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.hidden_dim = config.emb_dim if config.moe_expert_input_dim <= 0 else config.moe_expert_input_dim
self.routed_scaling_factor = config.routed_scaling_factor

# Initialize gate weight matrix.
# Shape: [hidden_dim, num_experts]
kernel_init = nd_dense_init(1.0, "fan_in", "truncated_normal")
kernel_shape = (self.hidden_dim, self.num_experts)
kernel_in_axis = np.arange(1)
kernel_out_axis = np.arange(1, 2)

self.kernel = nnx.Param(
kernel_init(
rngs.params(),
kernel_shape,
config.weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
out_sharding=kernel_axes,
)

# Load-balancing expert score correction bias.
# Shape: [num_experts]
self.e_score_correction_bias = nnx.Param(
jnp.zeros((self.num_experts,), dtype=jnp.float32),
out_sharding=(kernel_axes[-1] if kernel_axes else None,),
)

def __call__(self, hidden_states: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]:
# input hidden_states shape: [batch, seq_len, hidden_dim] or [tokens, hidden_dim]
inputs = jnp.asarray(hidden_states, dtype=self.config.dtype)
# [batch, seq_len, hidden_dim] -> [tokens, hidden_dim]
flat = inputs.reshape(-1, self.hidden_dim)

# Compute raw logits in float32.
# [tokens, hidden_dim] x [hidden_dim, num_experts] -> [tokens, num_experts]
kernel_f32 = jnp.asarray(self.kernel[...], dtype=jnp.float32)
logits = jnp.matmul(flat.astype(jnp.float32), kernel_f32)

# Apply custom scoring function (sqrtsoftplus).
# [tokens, num_experts] -> [tokens, num_experts]
scores = _sqrtsoftplus(logits)

# Add expert score correction bias and select top-k indices.
# [tokens, num_experts] + [num_experts] -> [tokens, num_experts]
scores_biased = scores + jnp.asarray(self.e_score_correction_bias[...], dtype=jnp.float32)
# [tokens, num_experts] -> [tokens, top_k]
_, indices = jax.lax.top_k(scores_biased, self.top_k)

# Gather corresponding scores for the selected top-k indices.
# [tokens, num_experts] gathered with [tokens, top_k] -> [tokens, top_k]
weights = jnp.take_along_axis(scores, indices, axis=-1)

# Normalize weights to sum to 1.0 per token.
# [tokens, top_k] -> [tokens, top_k]
weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20)

# Scale weights by routed scaling factor.
# [tokens, top_k] -> [tokens, top_k]
scaled_weights = weights * self.routed_scaling_factor

return (
logits.astype(self.config.dtype),
scaled_weights.astype(self.config.dtype),
indices,
)


class DeepSeekV4HashRouter(nnx.Module):
"""Hash Router for DeepSeek-V4 MoE routing.

Computes logits, static routing weights based on token IDs, and expert indices.
"""

def __init__(
self,
config: ctypes.Config,
mesh: jax.sharding.Mesh,
rngs: nnx.Rngs,
kernel_axes: Tuple[Optional[str], ...] = (),
):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.hidden_dim = config.emb_dim if config.moe_expert_input_dim <= 0 else config.moe_expert_input_dim
self.routed_scaling_factor = config.routed_scaling_factor

# Initialize gate weight matrix.
# Shape: [hidden_dim, num_experts]
kernel_init = nd_dense_init(1.0, "fan_in", "truncated_normal")
kernel_shape = (self.hidden_dim, self.num_experts)
kernel_in_axis = np.arange(1)
kernel_out_axis = np.arange(1, 2)

self.kernel = nnx.Param(
kernel_init(
rngs.params(),
kernel_shape,
config.weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
out_sharding=kernel_axes,
)

# Static token-to-expert mapping table.
# Shape: [vocab_size, top_k]
self.tid2eid = nnx.Param(
jnp.zeros((config.vocab_size, self.top_k), dtype=jnp.int32),
)

def __call__(self, hidden_states: jax.Array, input_ids: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]:
# input hidden_states shape: [batch, seq_len, hidden_dim] or [tokens, hidden_dim]
inputs = jnp.asarray(hidden_states, dtype=self.config.dtype)
# [batch, seq_len, hidden_dim] -> [tokens, hidden_dim]
flat = inputs.reshape(-1, self.hidden_dim)

# Compute raw logits in float32.
# [tokens, hidden_dim] x [hidden_dim, num_experts] -> [tokens, num_experts]
kernel_f32 = jnp.asarray(self.kernel[...], dtype=jnp.float32)
logits = jnp.matmul(flat.astype(jnp.float32), kernel_f32)

# Apply custom scoring function (sqrtsoftplus).
# [tokens, num_experts] -> [tokens, num_experts]
scores = _sqrtsoftplus(logits)

# Look up frozen expert routing indices from input_ids.
# [batch, seq_len] -> [tokens]
flat_input_ids = input_ids.reshape(-1)
# [vocab_size, top_k] sliced at [tokens] -> [tokens, top_k]
indices = self.tid2eid[...][flat_input_ids]

# Gather corresponding scores for the statically selected expert indices.
# [tokens, num_experts] gathered with [tokens, top_k] -> [tokens, top_k]
weights = jnp.take_along_axis(scores, indices, axis=-1)

# Normalize weights to sum to 1.0 per token.
# [tokens, top_k] -> [tokens, top_k]
weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20)

# Scale weights by routed scaling factor.
# [tokens, top_k] -> [tokens, top_k]
scaled_weights = weights * self.routed_scaling_factor

return (
logits.astype(self.config.dtype),
scaled_weights.astype(self.config.dtype),
indices,
)


class RoutedMoE(nnx.Module):
"""Implements a routed MoE block."""

Expand Down
Loading
Loading