Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
bd0f893
Replace control-token KV hiding with token-exchange by default (#8)
AlonMalach May 12, 2026
d51c967
Add token-exchange parity eval harness (#8)
AlonMalach May 13, 2026
6f7ee71
Use <|start_of_role|> instead of BOS for LoRA/builtin substitute (#8)
AlonMalach May 13, 2026
52b357e
Template: drop the first role marker after each adapter control token…
AlonMalach May 13, 2026
96fe271
Template: drop first char of in-message ALoRA invocation (#8)
AlonMalach May 13, 2026
3384540
Derive LoRA substitute from the tokenizer's chat template (#8)
AlonMalach May 13, 2026
49a13a8
Move token-exchange rewrite into the switch (#8)
AlonMalach May 14, 2026
60920bb
Remove the legacy KV-hiding code path (#8)
AlonMalach May 14, 2026
2176e99
Fix base-weight validator rejecting control_to_substitute_lut buffer …
AlonMalach May 14, 2026
b41b2fb
Remove dead hiding-constant report from compose output (#8)
AlonMalach May 14, 2026
2b1fd8e
Update tests for removed legacy hiding fields (#8) — partial
AlonMalach May 14, 2026
116d21b
Strip remaining legacy hiding-field references from tests/docs (#8)
AlonMalach May 14, 2026
6d2d463
Drop tensor.any() gate from switch token-exchange rewrite (#8)
AlonMalach May 14, 2026
3f4d842
Fix vLLM test runners after switch tuple-return + dead-class purge (#8)
AlonMalach May 14, 2026
935259c
Document the LoRA substitute probe's data-independence assumption (#8)
AlonMalach May 16, 2026
284fad0
Remove residual position-correction references (#8)
AlonMalach May 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,9 @@ TEST_STATUS_REPORT.md
htmlcov/

# pyenv version file (local dev preference)
.python-version
.python-version

# Local design/planning doc (keep on disk, do not version)
docs/KV_CACHE_OVERHEAD_REMOVAL.md
docs/KV_CACHE_OVERHEAD_REMOVAL.html
docs/KV_CACHE_OVERHEAD_REMOVAL*.html
4 changes: 1 addition & 3 deletions src/granite_switch/composer/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ class ArchDescriptor:
default_factory=lambda: [
"adapter_token_ids",
"adapter_scalings",
"token_to_group_mask",
"adapter_hiding_matrix",
"all_hiding_group_token_ids",
"control_to_substitute_lut",
]
)

Expand Down
116 changes: 85 additions & 31 deletions src/granite_switch/composer/compose_granite_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from granite_switch.composer.tokenizer_setup import (
add_control_tokens,
configure_chat_template,
get_alora_first_invocation_token_id,
)
from granite_switch.composer.reporting import generate_compose_report, write_build_doc

Expand All @@ -76,6 +77,70 @@ def _load_tokenizer(model_name_or_path):
return AutoTokenizer.from_pretrained(model_name_or_path)


def _probe_lora_substitute_token_id(tokenizer) -> int:
"""Ask the tokenizer which token naturally appears at sequence position 0
of a rendered no-adapter chat.

The LoRA prefix insertion places the adapter control token at sequence
position 0 of the rendered output. Whatever token would otherwise have
occupied position 0 (in a no-adapter render) is the right substitute
whose embedding should land at the swap site so the post-swap sequence
is indistinguishable from a no-adapter render.

Assumption (Granite 4.x): the chat template emits a constant
``input_ids[0]`` regardless of message content, system prompt presence,
or generation-prompt flag. Empirically verified — every realistic render
of the Granite 4.1 template yields ``<|start_of_role|>`` (id 100264) at
position 0. The probe renders a single minimal chat to read that
constant out of the template.

A future model whose chat template branches on inputs at position 0
(e.g. emits BOS only when no system message is present) would break
this assumption: the probe would still return *some* valid id, but it
might not match position 0 in another render mode at runtime, leaving
the LoRA control token swapped to an embedding the model doesn't
expect at that position. ``tests/composer/test_lora_substitute_probe.py``
pins the Granite 4.x behavior; if you port to another base model with
a more dynamic template, extend the probe to render multiple shapes
and verify they all agree.

By deriving the substitute from the tokenizer's own chat template at
compose time we avoid hard-coding a Granite-specific token string.

Raises ``ValueError`` if the template is missing, fails to render, or
emits an unknown token.
"""
if tokenizer.chat_template is None:
raise ValueError(
"Tokenizer has no chat_template; cannot probe the LoRA "
"substitute token."
)
try:
probe_text = tokenizer.apply_chat_template(
[{"role": "user", "content": "probe"}],
tokenize=False,
add_generation_prompt=False,
)
except Exception as e:
raise ValueError(
"Failed to render a probe chat via tokenizer.apply_chat_template "
f"while detecting the LoRA substitute token: {e!r}."
) from e
ids = tokenizer(probe_text, add_special_tokens=False).input_ids
if not ids:
raise ValueError(
"Probe chat tokenized to an empty id list; cannot determine the "
"LoRA substitute token."
)
sub_id = ids[0]
if sub_id == tokenizer.unk_token_id:
raise ValueError(
"First token of the rendered probe chat is <unk>; the template "
"appears to emit content outside the tokenizer's vocabulary."
)
return sub_id


def _get_directory_size(directory):
"""Return ``(total_size in GBs, file_count)`` for *directory*."""
if Path(directory).exists():
Expand Down Expand Up @@ -449,12 +514,6 @@ def _compose_argparser():
default=None,
help="Dimension of Q/K/V vectors in switch attention",
)
parser.add_argument(
"--control-dims",
type=int,
default=None,
help="Extra dims for K/V to mask control tokens in decoder layers",
)
parser.add_argument(
"--built-in-adapters",
type=str,
Expand Down Expand Up @@ -678,9 +737,8 @@ def build():
has_external = len(external_discovered) > 0
has_built_in = len(built_in_discovered) > 0

# Mode detection:
# Mode A (native): built-in only → no hiding, control_dims=0
# Mode B (third-party): externals present → full hiding
# Mode detection (informational only — token-exchange handles both
# native and third-party adapter builds uniformly).
if has_built_in and not has_external:
build_mode = "native"
elif has_external:
Expand All @@ -692,7 +750,6 @@ def build():
# Extract fields from 4-tuples (path, name, tech, source)
adapter_paths = [t[0] for t in all_discovered if t[0] is not None]
adapter_names = [t[1] for t in all_discovered]
external_names = [t[1] for t in external_discovered]
built_in_names = [name for name in (args.built_in_adapters or [])]

print(f"\nBuild mode: {build_mode}")
Expand Down Expand Up @@ -747,33 +804,30 @@ def build():
optional_kwargs = {}
if args.switch_head_dim is not None:
optional_kwargs["switch_head_dim"] = args.switch_head_dim
if args.control_dims is not None:
optional_kwargs["control_dims"] = args.control_dims

# Per-mode hiding configuration
if build_mode == "native":
# Mode A (native): no hiding, control_dims=0 (unless overridden)
hiding_groups = None
hiding_policy = None
adapter_third_party = None
if "control_dims" not in optional_kwargs:
optional_kwargs["control_dims"] = 0
else:
# Mode B (third-party): full hiding for external adapters
hiding_groups = {"all_controls": list(adapter_names)}
hiding_policy = {name: ["all_controls"] for name in adapter_names}
hiding_policy["base"] = ["all_controls"]
# Only external adapters are third-party
adapter_third_party = list(external_names)

# Token-exchange substitute choice (must mirror the token that appears
# right after the control token in the rendered chat prompt, so the
# swap keeps the residual stream in-distribution):
# - ALoRA: first token of the adapter's alora_invocation_tokens.
# - LoRA/builtin: whatever the tokenizer's chat template emits at
# the very start of a no-adapter user turn. For Granite 4.x that's
# <|start_of_role|>; the probe derives this from the template at
# compose time so other base models work by construction.
lora_sub_id = _probe_lora_substitute_token_id(tokenizer)
adapter_substitute_token_ids = []
for adapter_path, _name, technology, _source in all_discovered:
if technology == "alora":
sub_id = get_alora_first_invocation_token_id(adapter_path)
else:
sub_id = lora_sub_id
adapter_substitute_token_ids.append(sub_id)

model = GraniteSwitchComposer.from_base_and_adapters(
base_model_name_or_path=base_model_local_path,
adapter_paths=adapter_paths,
adapter_token_ids=adapter_token_ids,
adapter_substitute_token_ids=adapter_substitute_token_ids,
adapter_names=adapter_names,
hiding_groups=hiding_groups,
hiding_policy=hiding_policy,
adapter_third_party=adapter_third_party,
built_in_adapter_names=built_in_names,
built_in_lora_rank=args.lora_rank,
built_in_lora_alpha=args.lora_alpha if args.lora_alpha is not None else float(args.lora_rank),
Expand Down
14 changes: 6 additions & 8 deletions src/granite_switch/composer/compose_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def from_base_and_adapters(
base_model_name_or_path: str,
adapter_paths: Optional[List[str]] = None,
adapter_token_ids: Optional[List[int]] = None,
adapter_substitute_token_ids: Optional[List[int]] = None,
adapter_names: Optional[List[str]] = None,
built_in_adapter_names: Optional[List[str]] = None,
built_in_lora_rank: int = 8,
Expand All @@ -48,6 +49,9 @@ def from_base_and_adapters(
empty for zero-adapter skinning (base model only).
adapter_token_ids: Token IDs for adapter control. Required when
``adapter_paths`` is non-empty.
adapter_substitute_token_ids: Token IDs whose embeddings replace
control-token embeddings at the switch. Required when
``adapter_paths`` is non-empty; one per adapter.
adapter_names: Display names for each adapter (external + built-in).
When ``None``, derived from the directory structure.
built_in_adapter_names: Names for built-in (empty LoRA) adapter slots.
Expand Down Expand Up @@ -112,10 +116,6 @@ def from_base_and_adapters(
source_analysis = {}

# --- Step 4: Build switch config from arch descriptor ---
hiding_groups = kwargs.pop("hiding_groups", None)
hiding_policy = kwargs.pop("hiding_policy", None)
adapter_third_party = kwargs.pop("adapter_third_party", None)

# Copy config fields driven by architecture descriptor
config_kwargs: Dict = {}

Expand Down Expand Up @@ -151,17 +151,15 @@ def from_base_and_adapters(
{
"num_adapters": num_total,
"adapter_token_ids": adapter_token_ids,
"adapter_substitute_token_ids": adapter_substitute_token_ids,
"adapter_names": adapter_names,
"hiding_groups": hiding_groups,
"hiding_policy": hiding_policy,
"adapter_third_party": adapter_third_party,
"max_lora_rank": lora_rank,
"adapter_ranks": adapter_ranks,
"lora_target_modules": lora_target_modules,
}
)

# Merge caller-provided overrides (switch_head_dim, control_dims, etc.)
# Merge caller-provided overrides (switch_head_dim, etc.)
config_kwargs.update(kwargs)

switch_config = GraniteSwitchConfig(**config_kwargs)
Expand Down
3 changes: 0 additions & 3 deletions src/granite_switch/composer/reporting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
from .population_table import generate_adapter_population_table, print_adapter_population_table
from .compose_report import generate_compose_report
from .adapter_analysis import print_source_adapter_analysis
from .hiding_constant_report import compute_hiding_constant_safety, print_hiding_constant_safety
from .model_card import render_model_card, write_model_card, write_build_doc

__all__ = [
'generate_adapter_population_table',
'print_adapter_population_table',
'generate_compose_report',
'print_source_adapter_analysis',
'compute_hiding_constant_safety',
'print_hiding_constant_safety',
'render_model_card',
'write_model_card',
'write_build_doc',
Expand Down
5 changes: 0 additions & 5 deletions src/granite_switch/composer/reporting/compose_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,6 @@ def _print_summary(
if len(base_source_not_connected) > 10:
print(f" ... and {len(base_source_not_connected) - 10} more")

# Hiding constant safety margin
if model is not None:
from .hiding_constant_report import print_hiding_constant_safety
print_hiding_constant_safety(model.dtype)

print(f"\nDetailed report saved to: {report_path}")
print("="*80)

Expand Down
58 changes: 0 additions & 58 deletions src/granite_switch/composer/reporting/hiding_constant_report.py

This file was deleted.

4 changes: 3 additions & 1 deletion src/granite_switch/composer/reporting/model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,9 @@ def _short_source(source):
"lora_rank": getattr(args, "lora_rank", None) if built_in else None,
"lora_alpha": getattr(args, "lora_alpha", None) if built_in else None,
"switch_head_dim": getattr(args, "switch_head_dim", None),
"control_dims": getattr(args, "control_dims", None),
"adapter_substitute_token_ids": getattr(
model.config, "adapter_substitute_token_ids", None
),
"target_model": getattr(args, "target_model", None),
}
# Parameter counts: base is captured during transfer (see
Expand Down
Loading