Skip to content

Conversation

@tanmaysachan
Copy link

@tanmaysachan tanmaysachan commented Jan 16, 2026

[WIP]
Addresses #865

  • Model outline from pytorch -> jax
  • parity checks
  • Infer/benchmark

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the JAX implementation for the DeepseekV3 model. The implementation is comprehensive and covers the model's unique features like Multi-Head Latent Attention and Mixture of Experts with shared experts. The code is well-structured.

My review focuses on a critical bug that will prevent the model from running, along with some suggestions to improve maintainability by reducing code duplication and avoiding magic numbers. Addressing these points will make the implementation more robust and easier to maintain.

Comment on lines +527 to +543
# Precompute RoPE frequencies
# qk_rope_head_dim = config.qk_rope_head_dim
# original_seq_len = getattr(config, "original_seq_len", config.max_position_embeddings)
# rope_factor = getattr(config, "rope_factor", 1.0)
# beta_fast = getattr(config, "beta_fast", 32)
# beta_slow = getattr(config, "beta_slow", 1)

# TODO: Swap out like llama's rope?
# self.freqs_cis = precompute_freqs_cis(
# dim=qk_rope_head_dim,
# max_seq_len=config.max_position_embeddings,
# rope_theta=config.rope_theta,
# original_seq_len=original_seq_len,
# rope_factor=rope_factor,
# beta_fast=beta_fast,
# beta_slow=beta_slow,
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block for precomputing RoPE frequencies is commented out, but self.freqs_cis is used in DeepseekV3Model.__call__ at line 571. This will raise an AttributeError at runtime.

Looking at the DeepseekV3MLA implementation, the freqs_cis parameter is not used. Instead, apply_rope is called, which computes the frequencies on the fly.

To fix this, you should remove the freqs_cis parameter from the entire call chain, as it appears to be unused. This involves:

  1. Removing freqs_cis: jax.Array from the signature of DeepseekV3MLA.__call__.
  2. Removing freqs_cis: jax.Array from the signature of DeepseekV3DecoderLayer.__call__.
  3. Removing the freqs_cis=self.freqs_cis argument from the layer() call within DeepseekV3Model.__call__.

This will resolve the crash and align the code with the current apply_rope implementation. You can then address the TODO about swapping the RoPE implementation in a separate change.

)

# Bias only for specific model sizes (7168 hidden_size in original)
self.use_bias = config.hidden_size == 7168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding the magic number 7168 to determine self.use_bias makes the code brittle and less maintainable. If a new model variant is introduced that also requires this bias, this line would need to be updated. A better approach would be to introduce a dedicated boolean flag in the DeepseekV3Config, such as use_router_bias, to control this behavior explicitly.

# Bias only for specific model sizes (7168 hidden_size in original)
self.use_bias = config.hidden_size == 7168
if self.use_bias:
from tx.layers.util import Param
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This import statement is located inside a conditional block within __init__. According to PEP 8, all imports should be at the top of the file. This improves code readability and avoids potential circular import issues or unexpected behavior. Please move from tx.layers.util import Param to the top of the file with the other imports.

Comment on lines 413 to 415
class DeepseekV3SharedMLP(nnx.Module):
"""Always active shared experts."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The DeepseekV3SharedMLP class is nearly identical to DeepseekV3MLP, with the only significant difference being the intermediate_size. This creates code duplication, which can make maintenance harder.

To improve this, consider refactoring them into a single, more generic MLP class (e.g., SwiGLU) that accepts intermediate_size as a parameter in its __init__ method. You could then instantiate this class with config.intermediate_size for the standard MLP and with the calculated shared_inter_dim for the shared MLP part.

@pcmoritz pcmoritz added the tx label Jan 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants