Skip to content

Replace control-token KV hiding with token-exchange (#8)#34

Open
AlonMalach wants to merge 16 commits into
mainfrom
feature/reduce-control-token-kv-overhead
Open

Replace control-token KV hiding with token-exchange (#8)#34
AlonMalach wants to merge 16 commits into
mainfrom
feature/reduce-control-token-kv-overhead

Conversation

@AlonMalach
Copy link
Copy Markdown
Collaborator

@AlonMalach AlonMalach commented May 15, 2026

Summary

Closes #8. Replaces the legacy KV-hiding scheme (which padded every Q/K/V in every decoder layer with control_dims=32 extra dimensions and branded the control token's K with finfo.min) with token exchange: at the switch layer, each control token's id is rewritten in input_ids to a substitute id whose embedding the model was already trained to handle. The decoder embeds the rewritten ids once and runs natively — no Q/K/V padding, no KV cache expansion, no hidden-count position correction.

Outcome on Granite 4.1-8b: KV cache head_dim returns to native 128 (from expanded 160, padded to 192). FlashAttention runs on native vectors. ~20% less KV memory, ~33% less attention compute per layer. No retraining required.

How it works

Two paired layers:

  1. Chat template (compose-time, tokenizer_setup.py): a skip-once Jinja flag (ns.skip_next_start_of_role) suppresses the role-marker token that would normally follow each control token, and alora_pass2 drops the first character of in-message ALoRA invocation text (BPE-equivalent to dropping one tokenized piece). The rendered sequence is one token shorter than before.
  2. Switch (runtime, hf/switch/single.py + vllm/switch/single.py): when emitting adapter_indices, also rewrite each control token's id via a control_to_substitute_lut buffer. Returns (adapter_indices, modified_input_ids). The decoder unpacks and embeds modified_input_ids.

Substitute derivation:

  • ALoRA → first token of alora_invocation_tokens (read from adapter's adapter_config.json).
  • LoRA / built-in → whatever the tokenizer's chat template emits at the start of a no-adapter user turn. Derived at compose time by _probe_lora_substitute_token_id(tokenizer) — render a minimal probe chat, tokenize, take input_ids[0]. On Granite 4.x this resolves to <|start_of_role|> (id 100264).

What's removed

The legacy KV-hiding path is gone, not gated. _expand_with_control_dimensions deleted from both backends. control_dims, hiding_groups, hiding_policy, adapter_third_party, expanded_head_dim, num_hiding_groups, get_hiding_group_token_ids, get_third_party_adapter_mask, get_adapter_hiding_policy_matrix deleted from config. adapter_substitute_token_ids is required when num_adapters > 0. ~3000 LoC deleted net.

Backward compatibility

Breaking. Checkpoints composed under the legacy scheme cannot load — from_pretrained raises ValueError from the new adapter_substitute_token_ids is required validator. Users must recompose with the current compose_granite_switch.py against the same base + adapter sources (minutes per checkpoint).

AlonMalach added 14 commits May 15, 2026 10:35
Adapter control tokens were padded into every Q/K/V in every decoder layer
via `control_dims=32` and masked with `finfo.min`. This bloated the KV
cache head_dim by 25-50% and forced FlashAttention onto padded 160/192-wide
vectors when only `num_hiding_groups` (typically 1) of the 32 extra dims
were ever non-zero.

Switch to token-exchange: after the switch reads `input_ids` and detects
which adapter to activate, replace each control token's embedding with a
substitute real-token embedding before the decoder runs. Control tokens
become ordinary content tokens in the residual stream and `control_dims`
collapses to 0, dropping the expansion entirely.

Substitute ids are computed at compose time:
  - ALoRA adapters -> first token of alora_invocation_tokens
  - LoRA/builtin adapters -> tokenizer.bos_token_id

New config field `adapter_substitute_token_ids` is persisted in config.json
and drives a `use_token_exchange` property read by both backends. Default
`control_dims` flips from 32 to 0.

The legacy KV-hiding path is preserved as an opt-in escape hatch via the
new `--legacy-hiding` composer flag; any adapter that regresses under
token-exchange can be composed with the old semantics unchanged.

Key validation:
  - Reject num_adapters>0 with neither hiding nor substitute ids (would
    leak raw control-token embeddings into attention).
  - Reject duplicate adapter_token_ids (LUT collision).
  - Reject negative / wrong-length substitute ids.

Position correction via `hidden_count` is skipped in token-exchange mode
since control tokens are real positions.

Design: docs/KV_CACHE_OVERHEAD_REMOVAL.md
Tracks issue #8.
Measures four metrics per position, teacher-forced, across a list of
prompts to compare legacy KV-hiding (control_dims>0) vs. token-exchange
(control_dims=0 + substitute ids):

  1. KL(p_old || p_new) per position  (log_softmax based to avoid underflow)
  2. Top-1 agreement                    (tagged "(noisy)" on wide nuclei)
  3. Nucleus (top-p=0.9) Jaccard        (sampling-set overlap)
  4. Mass under old nucleus by new      (the actionable gate)

Results are partitioned into overall / pre-control / adapter-active buckets.
The pre-control bucket must be bit-for-bit identical (KL max == 0, top-1
agree == 1.0); any drift there signals a bug in the embedding-swap gating
rather than a mode trade-off.

Two modes:
  - Synthetic (CPU): builds two HF models with identical base weights, one
    in legacy hiding and one in token-exchange. Useful as a plumbing check
    and regression guard. Runs as a standard pytest.
  - Real-model (GPU, opt-in): set GRANITE_SWITCH_PARITY_MODELS='{"old":...,
    "new":...}'. Loads composed checkpoints and uses demo-script prompts
    (14 adapter-specific prompts from run_adapter_generation_direct.py)
    rendered through the composed tokenizer's chat template. Thresholds:
    top-1 >= 0.95, mean KL <= 0.02, mean mass-under-old-nucleus >= 0.88.

Also exposes build_demo_prompts() in the demo script. Short-circuits
_generate via a module-level capture flag so prompt text is collected
without touching model.generate. Used by the parity eval to pull
realistic adapter inputs without duplicating the demo prompt data.

CLI usage:
    python -m tests.integration.test_token_exchange_parity \
        --old /path/to/legacy_build --new /path/to/te_build --json-out report.json
Granite tokenizers alias bos_token_id to <|end_of_text|> (EOS), so the
previous BOS-based substitute for LoRA/builtin adapters would have
injected an end-of-text signal mid-prompt — a stop-generation marker
in a place the model was not trained to see it.

The chat template places the LoRA control token at sequence start,
immediately followed by <|start_of_role|>user<|end_of_role|>... — so
<|start_of_role|> is the deterministic "token that naturally follows"
for every LoRA adapter, and its embedding is well-trained in the base
model (part of the base vocab on Granite 4.0 and 4.1).

Parallels the ALoRA path (substitute = first invocation token).
Both paths now pick "the token that comes right after the control
token in the rendered chat prompt" — single principle, two sources.

Validated:
  - tokenizer.convert_tokens_to_ids('<|start_of_role|>') == 100264
    on ibm-granite/granite-4.1-3b and granite-4.0-micro (part of
    base vocab, not composer-added).
  - bos_token_id == eos_token_id == 100257 ('<|end_of_text|>') on
    all three Granite tokenizers tested — confirming the prior
    default was semantically wrong.
…#8)

The runtime embedding swap replaces each adapter control token's
embedding with a substitute token's embedding — for LoRA adapters this
is <|start_of_role|>, for assistant-boundary ALoRA adapters it's also
<|start_of_role|> (the first token of their invocation sequence). But
the chat template then emits a *real* <|start_of_role|> at the next
position: the user or assistant role marker that naturally follows the
control-token prefix.

Result before this change: two consecutive positions carrying
<|start_of_role|>'s embedding. The model has never seen that pattern
during pretraining — a duplicate-embedding OOD right at the start of
the decoder's residual stream.

Fix: add a skip-once Jinja flag (ns.skip_next_start_of_role). Arm it
when lora_prefix_insertion emits the LoRA control token, or when
alora_insertion fires the fallback path for assistant-boundary ALoRAs.
Wrap every <|start_of_role|> emission in the base Granite template
with a skip-once block that consumes the flag. The flag is single-shot
— only the very first <|start_of_role|> after the control token is
suppressed; all later role markers emit normally.

Not addressed in this PR: ALoRAs whose invocation text is in-message
content text (<requirements>, <guardian>, <certainty>). The first
token of these invocations is the single character '<', and the rest
of the invocation text cannot be cleanly sliced at the template level
without changing what 'requirements>' (or 'guardian>', etc.)
tokenizes to. Those adapters retain the duplicate-embedding pattern
until a runtime-level drop lands in a follow-up.

Backward compatibility: old checkpoints (composed before this change)
load unchanged — the template edit only runs at compose time and
affects only newly-composed models. Their rendered output for LoRA
and assistant-boundary ALoRA is now one token shorter than before
(the suppressed <|start_of_role|>). Update the three
test_chat_template tests whose assertions encoded the old contract.
Closes the remaining duplicate-embedding OOD at the swap site. Complements
the skip-once <|start_of_role|> edit from the previous commit by extending
the same principle to ALoRA adapters whose invocation text lives inside a
user message (<requirements>, <certainty>, <guardian>, <context>, etc.).

Change: in alora_pass2, after inserting the control token before the
invocation text, also drop the first CHARACTER of the invocation text.
Example: "Please <|req_check|><requirements>" becomes
"Please <|req_check|>requirements>". At runtime the embedding-swap
replaces the control token's embedding with the first invocation token's
embedding — the embedding of '<'. The decoder then sees
[<|req_check|>→e_<, requirements, >] — exactly what "<requirements>"
tokenizes to in isolation, with no duplicate.

Why this is safe on the Granite tokenizer: verified empirically via a
new property test (test_first_char_drop_equals_first_token_drop). For
every ALoRA invocation in the standard library, tokenizing the full
invocation and dropping the first token ID yields the same sequence as
tokenizing the string with its first character removed. BPE's greedy-
merge would break this property if the second-byte merges depended on
the leading '<'; it doesn't, because '<' tokenizes as its own single-
character token in every case.

The accompanying test test_first_token_is_single_character asserts the
complementary invariant: the first token of each invocation decodes to
exactly one character. If a future invocation text starts with a
multi-character first token, that test catches it — the Jinja edit
(invocation_text[1:] drops one character) would otherwise silently
produce a wrong-length drop.

Combined with the previous commit (skip-once <|start_of_role|>), the
duplicate-embedding pattern is now eliminated across all adapter types
in the Granite adapter library: LoRA, assistant-boundary ALoRA, and
in-message ALoRA.
Previously the composer hardcoded _LORA_SUBSTITUTE_TOKEN =
"<|start_of_role|>". That's the right answer for Granite 4.x but it
ties the default-path composer to a Granite-specific token name. Any
base model with a different chat template (different role marker,
different turn-open convention) would silently get the wrong
substitute — a token the base model knows, but not the one sitting
at position 1 of its rendered prompt.

Replace the hardcode with a compose-time probe: render a minimal
no-adapter user turn through tokenizer.apply_chat_template, tokenize,
and read input_ids[0]. That's by construction whatever the template
emits at the start of a normal turn, which is exactly what sits at
position 1 after a LoRA-prepended control token. The substitute and
the template's own behavior are now derived from the same source of
truth.

Verified: the probe returns 100264 (<|start_of_role|>) on
granite-4.1-3b, granite-4.0-micro, and granite-switch-4.1-3b-preview
— identical to the previous hardcoded value. Behavior on Granite is
unchanged; the door is open for non-Granite base models.

Error paths give actionable messages:
  - Tokenizer has no chat_template → suggest --legacy-hiding
  - Template render fails → report the Jinja error, suggest
    --legacy-hiding
  - First token is <unk> → report that the template emits something
    outside the vocab
  - Probe returns an empty id list → same

Tests:
  - tests/composer/test_lora_substitute_probe.py (7 cases):
    * Real tokenizer round-trip on granite-4.1-3b and 4.0-micro
    * Synthetic tokenizer with a non-Granite template returns
      the custom template's first-token id
    * All four error paths raise ValueError with matching messages
Refactor: the runtime substitution LUT and the embedding-swap step
move out of each backend's decoder and into SingleSwitch (HF + vLLM).
The switch now performs both halves of token-exchange:

  1. Adapter selection — read input_ids, detect control tokens via
     input_ids == adapter_token_ids, emit per-token adapter_indices
     (unchanged).
  2. Token rewrite — replace each control token's id in input_ids
     with its substitute id (from a switch-owned LUT). New.

The switch's forward signature changes from
  -> adapter_indices
to
  -> (adapter_indices, modified_input_ids)

The decoder consumes both: adapter_indices feeds the LoRA layers as
before, modified_input_ids feeds embed_tokens / get_input_embeddings
exactly once. There is no longer a decoder-side LUT, no scatter, no
clone-guard, no use_token_exchange branch in the embedding path.

Why this is cleaner:

- Single source of truth for the substitution. The switch already
  knows which positions are control tokens; rewriting input_ids at
  those positions is a natural extension of "decide which adapter is
  active." The decoder is genuinely token-exchange-agnostic — it
  just embeds whatever input_ids it receives.

- HF and vLLM converge to the same control flow. Both backends now
  call switch(...), unpack two outputs, embed once. Previously each
  backend had a near-identical but layout-specific (B,S,H vs T,H)
  embedding-swap block + clone-guard that needed to be maintained
  separately.

- Smaller diff for any future change to the substitution logic.
  Whether to ship a different substitute strategy (e.g. learned
  embedding, per-adapter rules) becomes a one-place change in the
  switch instead of a two-place change across both decoders.

HF model forward also reorders slightly: switch runs before
embed_tokens, so we embed exactly once on modified_input_ids.
create_causal_mask now receives a stub embedding tensor of the right
shape and dtype (it only uses the tensor for batch/query/dtype
inference per the upstream docstring), since the real embedding
hasn't been computed yet.

Tests:
- tests/hf/test_single_switch.py: _run helper unpacks the new tuple
  return; TestBatchProcessing similarly.
- tests/hf/test_token_exchange.py: LUT presence assertion now reads
  model.model.switch.control_to_substitute_lut instead of
  model.model.control_to_substitute_lut.

No behavior change verified by 756 passing tests (= same count as
before the refactor; +0 -0 after fixture updates).
Token-exchange has been the default for several commits. This change
deletes the dead-but-still-callable KV-hiding code path entirely:

Config:
- Drop control_dims, hiding_groups, hiding_policy, adapter_third_party
  parameters and the corresponding state.
- Drop expanded_head_dim, num_hiding_groups, hiding_group_names,
  use_token_exchange properties (token-exchange is now always on when
  num_adapters > 0).
- Drop get_hiding_group_token_ids, get_third_party_adapter_mask,
  get_adapter_hiding_policy_matrix methods.
- adapter_substitute_token_ids becomes required when num_adapters > 0.
- Net: -150 LoC (config.py 345 → 195).

Models:
- HF and vLLM both drop token_to_group_mask / adapter_hiding_matrix
  buffers, hidden_count / adjusted_position_ids logic, and the
  token_group_membership / query_group_suppression plumbing through
  decoder layers.
- The HF decoder layer's forward signature drops two kwargs.

Attention layers (hf/core/lora.py, vllm/core/decoder.py):
- Drop expand_control_dims / control_dims / expanded_head_dim fields.
- Delete _expand_with_control_dimensions method entirely (~85 LoC each).
- Delete the expansion / trim-back branches in forward.
- vllm/core/decoder.py: attn_head_dim is unconditionally head_dim.

Switches:
- Drop config.expanded_head_dim references; head_dim is
  config.projection_head_dim everywhere.

vllm/__init__.py:
- ModelArchConfigConvertor.get_head_size() returns
  config.projection_head_dim (no expansion logic).

Composer:
- compose_granite_switch.py: drop --control-dims and --legacy-hiding
  CLI flags. Delete the legacy-hiding branch in build(); always
  token-exchange.
- compose_utils.py: drop hiding_groups / hiding_policy /
  adapter_third_party kwargs.
- model_card.py: drop control_dims / legacy_hiding /
  use_token_exchange reporting fields.

Tests deleted entirely:
- tests/unit/test_hiding_constant.py
- tests/hf/test_kv_hiding_gap_equivalence.py
- tests/vllm/test_kv_hiding_gap_equivalence.py
- tests/vllm/_kv_hiding_gap_tests.py
- tests/hf/test_position_zero_nan.py
- tests/vllm/_position_zero_nan_tests.py
- tests/integration/test_token_exchange_parity.py (compared old vs new
  modes; with no old mode, nothing to compare).
- tests/composer/test_built_in_adapters.py (entire file tested removed
  Mode A / Mode B distinction).

Tests rewritten:
- tests/conftest.py, tests/unit/test_config{,_edge_cases}.py,
  tests/unit/test_token_exchange.py, tests/hf/test_model_forward.py,
  tests/hf/test_token_exchange.py, tests/hf/test_qk_norm.py,
  tests/shared/granite4_equivalence.py, tests/shared/generation_models.py:
  fixtures and assertions updated for the simpler config surface.

Net diff: ~3000 LoC deleted, ~200 LoC added (test rewrites). 643
tests pass on CPU after the refactor (was 756; the difference is
parameterized hiding-equivalence tests + the parity harness, all
deleted).

Breaking change for any externally-composed checkpoint that was using
control_dims > 0: those checkpoints are unloadable under this version.
The token-exchange path has been the documented default since #8 and
the only path that received the chat-template drops, so any in-flight
build should already be on it.
)

The new switch buffer was failing compose-pipeline validation because
buffer_keywords still listed the deleted legacy buffer names instead of
the new one. Replace token_to_group_mask / adapter_hiding_matrix /
all_hiding_group_token_ids with control_to_substitute_lut in arch.py
and in the two test_granite4_mini parameter-allowlist assertions.
The report described safety margins for the finfo.min K-side hiding
constant. Hiding is gone, so the section is meaningless. Drop the
module, the call site in compose_report.py, and the package re-exports.
Replace control_dims / hiding_groups / hiding_policy / adapter_third_party
references with adapter_substitute_token_ids in test fixtures, and drop
TestControlTokenKVInvisibility (tested the deleted hiding mechanism).

This is a partial sweep — vLLM workers, hf/test_single_switch_e2e.py,
and shared/granite4_equivalence.py still need follow-up edits.
- tests/hf/test_single_switch_e2e.py: drop CONTROL_DIMS_MODES axis; one
  parametrization on attention_multiplier only. Fixture returns a 3-tuple.
- tests/vllm/_generation_equivalence_worker.py and _tp_integration_worker.py:
  remove control_dims/hiding_groups/hiding_policy/adapter_third_party from
  composer calls; pass adapter_substitute_token_ids instead.
- tests/vllm/_single_switch_worker.py: mock_config uses projection_head_dim.
- tests/vllm/test_generation_equivalence.py: docstring updated.
- tests/shared/granite4_equivalence.py: rationale comments updated for
  token-exchange (no behavior change).
- src/granite_switch/composer/compose_utils.py: docstring/comment cleanup.
The vLLM decoder is wrapped in @support_torch_compile; Dynamo cannot
trace data-dependent branching like ``if is_control.any()``. The gate
broke engine init on GPU runs.

Replace it with an unconditional torch.where in both backends — keeps
HF and vLLM symmetric, costs one indexed gather + one elementwise
select per forward, and makes the switch compile-safe.
Three fixes uncovered by GPU run:

1. tests/vllm/_single_switch_worker.py: switch.forward now returns
   (adapter_indices, modified_input_ids); unpack and return only the
   indices. Worker was calling .cpu() on a tuple → every parametrized
   test in tests/vllm/test_single_switch.py failed at the same point.

2. tests/vllm/test_model_forward.py: drop the TestControlTokenKVInvisibility
   class stub. The inner class was deleted with the legacy hiding tests
   in 0ddaf0e, but the parametrized runner still referenced it.

3. tests/vllm/test_position_zero_nan.py: deleted. The inner
   _position_zero_nan_tests.py was removed (only existed for the legacy
   hiding path); the runner became orphan and pytest reported "file or
   directory not found" on every parametrized variant.

The flash_api.cpp:697 "no kernel image" failures in test_model_forward
are pre-existing GPU/FlashAttention environment issues, not branch bugs.
@AlonMalach
Copy link
Copy Markdown
Collaborator Author

A guide for this PR changes:
KV_CACHE_OVERHEAD_REMOVAL.html

@lastras
Copy link
Copy Markdown
Contributor

lastras commented May 16, 2026

The comment

LoRA / built-in → whatever the tokenizer's chat template emits at the start of a no-adapter user turn

appears to be incorrect even though the subsequent algorithm is "more correct" (looking at position 0 of a tokenized sequence). Note nonetheless than in principle a chat template can render different input_ids[0] depending on the data passed to it, so technically speaking, an actual analysis of the chat template needs to be done to figure out if there is only a single token id that is used in all circumstances (which I believe is the case in Granite 4.1). So the code is probably doing the right thing, and this will work in most circumstances but if there is data dependency on the first rendered token id then this wouldn't work. At least we should have a comment that this is specific to Granite 4.1

Per review feedback: tighten the probe's docstring to state the
assumption it relies on — that the chat template emits a constant
input_ids[0] regardless of message content, system-prompt presence,
or generation-prompt flag — and call out that this is verified
empirically for Granite 4.x (every realistic render shape produces
<|start_of_role|>). Note what would need to change if a future base
model's template breaks the assumption.

No behavior change.
@AlonMalach
Copy link
Copy Markdown
Collaborator Author

AlonMalach commented May 16, 2026

You're right on both counts — apologies for the loose wording. Two clarifications:

Position, not turn. The LoRA control token sits at sequence position 0 of the rendered prompt, not at "the start of a user turn." Verified empirically against the composed model:

no-adapter:    [<|start_of_role|>, user, <|end_of_role|>, hi, ...]
LoRA active:   [<|citations|>,     user, <|end_of_role|>, hi, ...]   ← <|citations|> at index 0

The skip-once flag drops the <|start_of_role|> that would otherwise follow at index 1, and the switch's runtime LUT places <|start_of_role|>'s embedding at index 0, so the post-swap sequence has the same length and same content as the no-adapter render — by construction.

Data-dependence concern. Real and worth documenting. The probe assumes the chat template emits a constant input_ids[0] regardless of inputs. I checked this empirically for Granite 4.1 across several render shapes (plain user, with system prompt, with generation prompt, multi-turn, etc.) — all return <|start_of_role|> (100264). A future base model whose template branches at position 0 (e.g. emits BOS only when no system message is present) would need a multi-shape probe.

@lastras
Copy link
Copy Markdown
Contributor

lastras commented May 16, 2026

I think static analysis of the chat template would reveal which situation we are in. I agree that for Granite 4.1 we have it down, and a comment would help further work in this area to be warned ahead of time.

@lastras
Copy link
Copy Markdown
Contributor

lastras commented May 16, 2026

I am also aware that we have (currently) two different kinds of aLoRA adapters: ones that activate with the rightmost appearance of the assistant role token (I think all the RAG ones are like this), and ones that activate with a custom token sequence in a user turn preceding the assistant reply (I think Guardian and Core are like this). Please confirm you dealt with these two cases explicitly.

@AlonMalach
Copy link
Copy Markdown
Collaborator Author

Yes — both ALoRA styles are handled, with the right path selected per adapter at compose time.

Style A — assistant-role-boundary (RAG: answerability, query_rewrite, query_clarification). The chat template scans user messages for the adapter's invocation text; when nothing matches, it falls back to placing the control token right before the generation prompt, in the slot the assistant's <|start_of_role|> would normally occupy, and suppresses that next <|start_of_role|> so we don't end up with two identical embeddings back-to-back after the swap. The substitute id is alora_invocation_tokens[0] (= <|start_of_role|>) read from the adapter's adapter_config.json.

Style B — in-user-message (Guardian/Core: uncertainty, requirement-check, factuality-{detection,correction}, guardian-core, policy-guardrails). The template finds the invocation text (<certainty>, <requirements>, <guardian>, <context>, …) in the last user message, splices the control token in immediately before it, and drops the first character of the invocation text — the BPE leading < is its own token, and dropping it on the string side is equivalent to dropping exactly one tokenized piece, which keeps the rendered sequence the same length. The substitute id is alora_invocation_tokens[0] (= <) read from the adapter's adapter_config.json.

Comment thread tests/hf/test_token_exchange.py Outdated
assert adapter_indices[0, 4].item() == 1


class TestPositionCorrectionSkipped:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Position correction should be fully removed, including comments and tests

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 284fad0. Audited the codebase end-to-end grep -rni 'position correction|position_correction|hidden_count|position[ _]offset|position[ _]shift' now returns zero hits. Removed:

  • tests/hf/test_token_exchange.py — entire TestPositionCorrectionSkipped class (lines 112–126) gone.
  • src/granite_switch/hf/modeling_granite_switch.py:295 — dropped the (BEFORE RoPE for position correction) parenthetical from the switch-call comment; kept the rest of the comment that still describes what the switch returns.
  • src/granite_switch/hf/modeling_granite_switch.py:323–325 — removed the "control tokens count as real positions, no hidden_count subtraction" comment block above position_embeddings = None. The line stands on its own without the rationale for absent code.

No behavioural change.

@antonpibm
Copy link
Copy Markdown
Collaborator

@AlonMalach What was the result of running the tests on vllm and composition?

@AlonMalach
Copy link
Copy Markdown
Collaborator Author

AlonMalach commented May 18, 2026

@antonpibm Composer ✅, vLLM ❌ but for an environment reason, not a code one. Details:

Composer (make test-composer) — all green.

vLLM (make test-vllm) — every test fails on this branch and on main with the same crash on the first forward pass:

INFO [cuda.py:334] Using FLASH_ATTN attention backend out of potential backends:
                  ['FLASH_ATTN', 'FLASHINFER', 'TRITON_ATTN', 'FLEX_ATTENTION'].
INFO [flash_attn.py:596] Using FlashAttention version 2
CUDA error (.../vllm-flash-attn-src/hopper/flash_api.cpp:697):
  no kernel image is available for execution on the device

The vllm-flash-attn build in this environment ships only Hopper (SM 9.0) kernels, but the GPU here is a different arch — the worker dies on its first forward, and every subsequent test in the same worker cascades into BrokenPipeError. Running make test-all:

Branch Result
main 377 failed, 50 passed
feature/reduce-control-token-kv-overhead 374 failed, 47 passed, 1 error

The 50-vs-47 / 377-vs-374 deltas are entirely test-collection differences (this branch deletes ~734 lines of KV-hiding tests that no longer have a code path to test — test_position_zero_nan.py, test_kv_hiding_gap_equivalence.py, TestControlTokenKVInvisibility). They are not tests that pass on main and fail on this branch. No vLLM forward pass actually executes on either branch, so the run can't establish or rule out behavioural regressions either way.

Tracked in #36 with full details.

The position-correction code path was removed when token-exchange
became the default. Drop the dead test class and the two stale
parenthetical comments that still mentioned it. No behavioural change.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Reduce KV Cache Overhead from Control Dimension Expansion

3 participants