Fix Swift compiler warnings and refine MTP scatter logic#42
Open
solderzzc wants to merge 2 commits into
Open
Conversation
There was a problem hiding this comment.
Pull request overview
This PR addresses Swift compiler warnings and simplifies the SSD/MTP scatter path in the MLX language model code, while also adding two standalone scatter/array repro scripts.
Changes:
- Removed an unused
matchedCandidatevariable in weight loading. - Simplified the always-sorted empty-routing path in
SwitchGLU. - Changed a reshaped MLX output array binding from
vartoletin Gemma4 MTP scatter logic.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
Libraries/MLXLMCommon/Load.swift |
Removes unused candidate-tracking state during expert streaming setup. |
Libraries/MLXLMCommon/SwitchLayers.swift |
Simplifies empty-index handling in the stacked SSD fast path. |
Libraries/MLXLLM/Models/Gemma4Text.swift |
Uses an immutable binding for the reshaped scatter output array. |
test_array_init.swift |
Adds a standalone MLX array initialization repro script. |
test_scatter.swift |
Adds a standalone MLX scatter repro script. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+1
to
+7
| import Foundation | ||
| import MLX | ||
| MLX.GPU.set(cacheLimit: 10 * 1024 * 1024) | ||
|
|
||
| let size: Int = 10 | ||
| let arr = MLXArray(0 ..< size).asType(.int32) | ||
| print(arr) |
Comment on lines
+3
to
+13
|
|
||
| MLX.GPU.set(cacheLimit: 10 * 1024 * 1024) | ||
|
|
||
| var out = MLXArray.zeros([4, 10]) | ||
| let rows = MLXArray(0 ..< Int32(4)).reshaped([4, 1]) | ||
| let cols = MLXArray([1, 2, 0, 4, 3, 5, 2, 9]).reshaped([4, 2]) | ||
| let vals = MLXArray([10, 20, 30, 40, 50, 60, 70, 80]).reshaped([4, 2]) | ||
|
|
||
| out[rows, cols] = vals | ||
| MLX.eval(out) | ||
| print(out) |
- Add maxSharedKV=16 window in runMTPHead to limit cross-attention
to the most recent 16 backbone KV positions (was O(T), now O(16)).
Eliminates throughput regression at 40K-100K context lengths.
- Implement MTPPartialRollback protocol on Gemma4AssistantModel:
store lastBackboneHiddenStateAll for position-specific rollback
without re-running the main model on partial draft rejection.
- Add callMTPHeadOnly for re-seeding MTP head from cached backbone
state (rollback draft generation, no main-model forward pass).
- Add numMTPDraftTokens=2 to control assistant head depth per pass.
- Benchmarks (M5 Pro 64GB, gemma-4-26b-a4b-it-8bit):
8-bit + MTP at 40K: +20% TPS vs vanilla (38.8 vs 32.4)
8-bit + MTP at 100K: +51% TPS vs vanilla (22.5 vs 14.9)
4-bit MoE is compute-bound (FFN dominates); MTP neutral there.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
matchedCandidateunused variable warning inLoad.swift.if doSortbranch inSwitchLayers.swift(asdoSortis constantlytruein this SSD streaming context).var output2Dtolet output2DinGemma4Text.swiftto address the Swift mutation warning, correctly reflecting that the MLXArray reference subscript mutates the underlying C++ buffer directly.