Current state
Granite Switch appends control_dims extra dimensions to every Q, K, and V vector in every decoder layer to implement KV-hiding of control tokens. This creates two costs:
1. Excessive KV cache size
The KV cache stores vectors at expanded_head_dim = head_dim + control_dims per head, per layer.
| Model |
head_dim |
control_dims |
expanded_head_dim |
KV cache overhead |
| Granite 4.1-3b |
64 |
32 |
96 |
+50% |
| Granite 4.1-8b |
128 |
32 |
160 |
+25% |
| Granite 4.1-30b |
128 |
32 |
160 |
+25% |
For the 8B model, this means the KV cache is 25% larger than the base model's — directly reducing the number of concurrent sequences that fit in GPU memory.
2. Excessive compute
The attention dot product is computed over expanded_head_dim elements instead of head_dim in every layer. For the 8B model, FlashAttention pads 160→192 internally, meaning the kernel operates on 192-dim vectors when only 128 carry content.
Proposed Solutions
Solution A (Preferred): Token Exchange — Replace Embedding After Switch
Keep the control token in input_ids so the switch detects it via input_ids == adapter_token_ids, then replace its hidden-state representation with the embedding of the first token of the adapter's invocation sequence before decoder layers process it. The KV cache entry at that position becomes a natural content token — no hiding needed, control_dims set to 0.
Solution steps:
Solution B (Fallback): Minimize control_dims from 32 to 1
The current control_dims=32 allocates 32 extra dimensions, but only num_hiding_groups (typically 1) are ever non-zero. Set control_dims = num_hiding_groups (usually 1).
See the full design document in the issue discussion for detailed risk analysis and implementation plan.
Current state
Granite Switch appends control_dims extra dimensions to every Q, K, and V vector in every decoder layer to implement KV-hiding of control tokens. This creates two costs:
1. Excessive KV cache size
The KV cache stores vectors at expanded_head_dim = head_dim + control_dims per head, per layer.
For the 8B model, this means the KV cache is 25% larger than the base model's — directly reducing the number of concurrent sequences that fit in GPU memory.
2. Excessive compute
The attention dot product is computed over expanded_head_dim elements instead of head_dim in every layer. For the 8B model, FlashAttention pads 160→192 internally, meaning the kernel operates on 192-dim vectors when only 128 carry content.
Proposed Solutions
Solution A (Preferred): Token Exchange — Replace Embedding After Switch
Keep the control token in
input_idsso the switch detects it viainput_ids == adapter_token_ids, then replace its hidden-state representation with the embedding of the first token of the adapter's invocation sequence before decoder layers process it. The KV cache entry at that position becomes a natural content token — no hiding needed,control_dimsset to 0.Solution steps:
Solution B (Fallback): Minimize control_dims from 32 to 1
The current
control_dims=32allocates 32 extra dimensions, but onlynum_hiding_groups(typically 1) are ever non-zero. Setcontrol_dims = num_hiding_groups(usually 1).See the full design document in the issue discussion for detailed risk analysis and implementation plan.