Skip to content

Update optimization docs and add TPU v7x guide#3857

Open
jacoguzo wants to merge 1 commit into
mainfrom
jacoguzo_optimizing_v7x_blog
Open

Update optimization docs and add TPU v7x guide#3857
jacoguzo wants to merge 1 commit into
mainfrom
jacoguzo_optimizing_v7x_blog

Conversation

@jacoguzo
Copy link
Copy Markdown
Collaborator

@jacoguzo jacoguzo commented May 8, 2026

Description

Turning https://discuss.google.dev/t/optimizing-frontier-model-training-on-tpu-v7x-ironwood/336983 into a How To Guide

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

*For a comprehensive overview of how to apply these strategies in MaxText, refer to the [Sharding on TPUs](sharding.md) guide. Below are Ironwood-specific considerations:*

* **Fully Sharded Data Parallelism (FSDP):** This is the preferred strategy for large-scale model training that exceeds the memory capacity of a single chip. FSDP shards the model’s weights, gradients, and optimizer states. Increasing the per-device batch size improves efficiency by introducing more compute, which can hide the latency of the All-Gather operations it introduces.
* **Tensor Parallelism (TP):** TP shards individual tensors. Ironwood’s high AI (11.5k) requires an MLP dimension greater than 46k (for TP degree 4\) to be viable over ICI. Most open source models like Llama3 70B (MLP dimension 28,672) and Qwen 2.5 7B (MLP dimension 18,944) fall short, and using TP here would result in the system becoming communication-bound.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the forward slash "for TP degree 4" intentional - it is formatting?


* **Splash Attention:** Used as the primary attention implementation to eliminate the HBM bottleneck of standard attention and use the most efficient attention implementation on TPUs.
* **Megablox Grouped Matrix Multiplication (GMM):** For MoE workloads, Megablox efficiently handles grouped matrix multiplications by computing over the ragged activations representation.
* **Empirical tuning with tune-jax:** The Tokamax library has [utilities](https://github.com/openxla/tokamax/blob/main/tokamax/experimental/utils/tuning/tpu/README.md) that use `tune-jax` to perform empirical searches for optimal block sizes. Default kernel tile sizes are often suboptimal; tuning allows choosing hardware friendly VMEM tile sizes (as well as other hyperparameters) to maximize hardware utilization.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note the current vision of tokamax is an auto-tune cache

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can have someone from tokamax team review just this section

* **Megablox Grouped Matrix Multiplication (GMM):** For MoE workloads, Megablox efficiently handles grouped matrix multiplications by computing over the ragged activations representation.
* **Empirical tuning with tune-jax:** The Tokamax library has [utilities](https://github.com/openxla/tokamax/blob/main/tokamax/experimental/utils/tuning/tpu/README.md) that use `tune-jax` to perform empirical searches for optimal block sizes. Default kernel tile sizes are often suboptimal; tuning allows choosing hardware friendly VMEM tile sizes (as well as other hyperparameters) to maximize hardware utilization.

#### Memory pipeline tuning
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might just call this block size tuning and put the tune jax as a subsection of it


In this regime, the workload is primarily compute-bound. The objective is to keep the MXUs fully saturated and minimize TensorCore idle time.

* SparseCore offload: By offloading communication collectives to the SparseCore, we freed TensorCores to focus on MXU operations and achieved near-perfect overlap between communication and computation. *Result: 22% decrease in step time.*
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its strange to give such a precise result here without more info - I would cite e2e MFU and/or more info, with a specific model like llama2-70B which we can get ~50% MFU with


At a context length of 128k, activation memory grows with sequence length, making out-of-memory (OOM) errors the primary hurdle.

* SparseCore Offload: Offloading All-Gather and Reduce-Scatter operations ensured that the communication required for TP and CP did not stall the MXUs. *Result: 5% reduction in step time.*
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an optimal config probably doesn't use TP

At a context length of 128k, activation memory grows with sequence length, making out-of-memory (OOM) errors the primary hurdle.

* SparseCore Offload: Offloading All-Gather and Reduce-Scatter operations ensured that the communication required for TP and CP did not stall the MXUs. *Result: 5% reduction in step time.*
* Hybrid Parallelism (FSDP16 \+ TP2 \+ CP2): To handle a full batch, we utilized a hybrid approach of CP2 and TP2. We chose TP2 specifically to align the workload with Ironwood’s dual-chiplet architecture. This allows frequent communications to occur over the internal die-to-die (D2D) interface — which is 6x faster than the standard ICI. *Result: 4% performance improvement compared to using CP4 alone.*
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk May 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you link this run? both TP=2 and CP=2 in the same run is strange since at most one parallelism can be across cores


### Case study: Dense LLM (< 20B parameters) – long context (128k)

At a context length of 128k, activation memory grows with sequence length, making out-of-memory (OOM) errors the primary hurdle.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that splash becomes the majority of flops so the e2e MFU is roughly the same as the single spalsh kernel MFU.

memory does become a problem so we need CP for most 128k seq configs, but the performance of splash is critical for performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants