Update optimization docs and add TPU v7x guide#3857
Conversation
| *For a comprehensive overview of how to apply these strategies in MaxText, refer to the [Sharding on TPUs](sharding.md) guide. Below are Ironwood-specific considerations:* | ||
|
|
||
| * **Fully Sharded Data Parallelism (FSDP):** This is the preferred strategy for large-scale model training that exceeds the memory capacity of a single chip. FSDP shards the model’s weights, gradients, and optimizer states. Increasing the per-device batch size improves efficiency by introducing more compute, which can hide the latency of the All-Gather operations it introduces. | ||
| * **Tensor Parallelism (TP):** TP shards individual tensors. Ironwood’s high AI (11.5k) requires an MLP dimension greater than 46k (for TP degree 4\) to be viable over ICI. Most open source models like Llama3 70B (MLP dimension 28,672) and Qwen 2.5 7B (MLP dimension 18,944) fall short, and using TP here would result in the system becoming communication-bound. |
There was a problem hiding this comment.
is the forward slash "for TP degree 4" intentional - it is formatting?
|
|
||
| * **Splash Attention:** Used as the primary attention implementation to eliminate the HBM bottleneck of standard attention and use the most efficient attention implementation on TPUs. | ||
| * **Megablox Grouped Matrix Multiplication (GMM):** For MoE workloads, Megablox efficiently handles grouped matrix multiplications by computing over the ragged activations representation. | ||
| * **Empirical tuning with tune-jax:** The Tokamax library has [utilities](https://github.com/openxla/tokamax/blob/main/tokamax/experimental/utils/tuning/tpu/README.md) that use `tune-jax` to perform empirical searches for optimal block sizes. Default kernel tile sizes are often suboptimal; tuning allows choosing hardware friendly VMEM tile sizes (as well as other hyperparameters) to maximize hardware utilization. |
There was a problem hiding this comment.
note the current vision of tokamax is an auto-tune cache
There was a problem hiding this comment.
can have someone from tokamax team review just this section
| * **Megablox Grouped Matrix Multiplication (GMM):** For MoE workloads, Megablox efficiently handles grouped matrix multiplications by computing over the ragged activations representation. | ||
| * **Empirical tuning with tune-jax:** The Tokamax library has [utilities](https://github.com/openxla/tokamax/blob/main/tokamax/experimental/utils/tuning/tpu/README.md) that use `tune-jax` to perform empirical searches for optimal block sizes. Default kernel tile sizes are often suboptimal; tuning allows choosing hardware friendly VMEM tile sizes (as well as other hyperparameters) to maximize hardware utilization. | ||
|
|
||
| #### Memory pipeline tuning |
There was a problem hiding this comment.
I might just call this block size tuning and put the tune jax as a subsection of it
|
|
||
| In this regime, the workload is primarily compute-bound. The objective is to keep the MXUs fully saturated and minimize TensorCore idle time. | ||
|
|
||
| * SparseCore offload: By offloading communication collectives to the SparseCore, we freed TensorCores to focus on MXU operations and achieved near-perfect overlap between communication and computation. *Result: 22% decrease in step time.* |
There was a problem hiding this comment.
I think its strange to give such a precise result here without more info - I would cite e2e MFU and/or more info, with a specific model like llama2-70B which we can get ~50% MFU with
|
|
||
| At a context length of 128k, activation memory grows with sequence length, making out-of-memory (OOM) errors the primary hurdle. | ||
|
|
||
| * SparseCore Offload: Offloading All-Gather and Reduce-Scatter operations ensured that the communication required for TP and CP did not stall the MXUs. *Result: 5% reduction in step time.* |
There was a problem hiding this comment.
an optimal config probably doesn't use TP
| At a context length of 128k, activation memory grows with sequence length, making out-of-memory (OOM) errors the primary hurdle. | ||
|
|
||
| * SparseCore Offload: Offloading All-Gather and Reduce-Scatter operations ensured that the communication required for TP and CP did not stall the MXUs. *Result: 5% reduction in step time.* | ||
| * Hybrid Parallelism (FSDP16 \+ TP2 \+ CP2): To handle a full batch, we utilized a hybrid approach of CP2 and TP2. We chose TP2 specifically to align the workload with Ironwood’s dual-chiplet architecture. This allows frequent communications to occur over the internal die-to-die (D2D) interface — which is 6x faster than the standard ICI. *Result: 4% performance improvement compared to using CP4 alone.* |
There was a problem hiding this comment.
can you link this run? both TP=2 and CP=2 in the same run is strange since at most one parallelism can be across cores
|
|
||
| ### Case study: Dense LLM (< 20B parameters) – long context (128k) | ||
|
|
||
| At a context length of 128k, activation memory grows with sequence length, making out-of-memory (OOM) errors the primary hurdle. |
There was a problem hiding this comment.
I would say that splash becomes the majority of flops so the e2e MFU is roughly the same as the single spalsh kernel MFU.
memory does become a problem so we need CP for most 128k seq configs, but the performance of splash is critical for performance.
Description
Turning https://discuss.google.dev/t/optimizing-frontier-model-training-on-tpu-v7x-ironwood/336983 into a How To Guide
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.