Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions contracts/apr-distill-teacher-vocab-alignment-v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
metadata:
version: 1.0.0
created: '2026-05-22'
author: PAIML Engineering
description: |
When the distillation teacher and student share a tokenizer base
but the teacher's vocabulary is a strict superset (e.g. Qwen2.5-Coder-7B
vocab=152064 vs Qwen2.5-Coder-0.5B vocab=151936 — the 7B adds 128
code-specific tokens), the teacher's logits must be truncated to the
student's vocab before KD loss is computed. The truncation point IS
the new logit support; softmax acts on the truncated logits to produce
a renormalized teacher distribution P_t' over the shared vocab.

Surfaced post-PMAT-701: with the memory blockers cleared, dispatching
the MODEL-1 7B teacher (paiml/qwen2.5-coder-7b-apache-q4k-v1) against
the 0.5B student hung in the first KD step on the dimension mismatch
that `kd_logit_gradient`'s assert_eq! would have rejected. This
contract codifies the alignment.
references:
- 'PMAT-703 (this contract): teacher vocab > student vocab alignment'
- 'PMAT-701 cuda-q4k-frozen-teacher-v1.yaml — prerequisite (memory fixes)'
- 'Hinton et al. 2015 §2 — KD loss derivation; assumes same support'
- 'Qwen2.5 model card — vocab=152064 for 7B Coder, vocab=151936 for 0.5B/1.5B Coder'
- 'crates/aprender-train-distill/src/kd_step.rs:103-107 (assert_eq! that the alignment must satisfy)'
- 'crates/apr-cli/src/commands/distill_q4k_teacher.rs (fix site)'

kind: KernelContract
name: apr-distill-teacher-vocab-alignment
version: 1.0.0
scope: aprender-train-distill teacher-provider boundary — vocab alignment between teacher and student

equations:
vocab_alignment_dispatch:
formula: |
effective_teacher_vocab(native_t, target_s) =
target_s if Some(target_s) AND target_s <= native_t
native_t if None
Err(VocabAlignment::TargetTooLarge) if Some(target_s) AND target_s > native_t

teacher.vocab_size() returns effective_teacher_vocab
teacher.logits_for_batch returns vectors of length effective_teacher_vocab
(truncating native_t entries to the first effective_teacher_vocab if needed)
domain: native_t in Z+; target_s in Option<Z+>
codomain: effective_teacher_vocab in Z+ or VocabAlignmentError
invariants:
- 'Truncation happens at the teacher-provider boundary, before any softmax/KL'
- 'Softmax post-truncation renormalizes the teacher distribution over the shared support'
- 'The first effective_teacher_vocab tokens of teacher and student MUST refer to the same tokens (shared tokenizer prefix)'
- 'For Qwen2.5: 7B vocab[0..151936] == 0.5B/1.5B vocab[0..151936] (verified against tokenizer.ggml.tokens)'
notes: |
The "shared tokenizer prefix" invariant is the operator's responsibility
to verify when building a distillation pair. The contract assumes the
invariant holds; runtime checks would require comparing tokenizer.json
between teacher and student, which is out of scope here.
preconditions:
- native_t >= 1
- target_s >= 1 (if Some)
lean_theorem: Theorems.Vocab_Alignment_Dispatch

kd_loss_invariance_under_truncation:
formula: |
KL(softmax(l_t[0..N] / T) || softmax(l_s[0..N] / T))
where N = effective_teacher_vocab = student_vocab_size, l_t is native teacher logits
length native_t, l_s is student logits length N.
domain: l_t in R^{native_t}; l_s in R^N; T > 0; N <= native_t
codomain: KL value in R >= 0
invariants:
- 'Truncating before softmax (NOT after) is mandatory — post-softmax truncation loses normalization'
- 'The dropped tail (l_t[N..native_t]) contributes mass only to tokens the student cannot produce, so dropping them aligns the supports correctly'
- 'No renormalization scaling is applied beyond what softmax provides intrinsically'
preconditions:
- student and teacher tokenizers share a prefix of length N
lean_theorem: Theorems.Kd_Loss_Invariance_Under_Truncation

cli_dispatch_passes_student_vocab:
formula: |
run_cuda_backend reads student vocab_size from student.apr metadata.
For Q4K teachers: RealizarQ4KTeacher::from_apr_path_with_target_vocab(teacher_path, Some(student_vocab))
For F32 teachers: CudaTrainerTeacher path remains unaffected (no truncation supported yet)
domain: student_vocab in Z+
codomain: a constructed TeacherLogitsProvider with effective vocab matching student
invariants:
- 'The student vocab passed in MUST match the actual student logit output (verified by kd_step.rs:218 check)'
- 'When teacher native_vocab > student_vocab, truncation is automatic; no operator action needed'
- 'When teacher native_vocab == student_vocab, truncation is a no-op'
- 'When teacher native_vocab < student_vocab, construction fails (student cannot have MORE vocab than teacher in this design)'
preconditions:
- student.apr has stamped vocab_size metadata
lean_theorem: Theorems.Cli_Dispatch_Passes_Student_Vocab

proof_obligations:
- type: invariant
property: vocab_size() reports the effective (post-truncation) vocab
formal: |
For every RealizarQ4KTeacher t constructed with target_vocab = Some(N) where N <= native_t:
t.vocab_size() == N AND every Vec<f32> returned from t.logits_for_batch has length N.
- type: invariant
property: kd_step.rs assert_eq! always passes for vocab-aligned teacher+student
formal: |
For every (teacher = RealizarQ4KTeacher with target N, student emitting N logits):
kd_step.rs:103-107 assert_eq!(student_logits.len(), teacher_logits.len()) holds.
- type: bound
property: truncation never increases memory or compute beyond native
formal: |
For every native_t and N <= native_t:
truncated logit vector has length N <= native_t (memory bound).
Truncation is O(N) per logit vector (compute bound).
- type: classification
property: vocab-mismatch path is detectable from metadata alone
formal: |
For every (teacher.apr, student.apr) pair: comparing their metadata.vocab_size
fields suffices to decide whether truncation is needed. No tokenizer decode
or token-by-token comparison required at runtime.

falsification_tests:
- id: FT-VOCAB-ALIGN-001
rule: 7B teacher → 0.5B student dispatches without dimension mismatch
prediction: |
`apr distill <paiml/qwen2.5-coder-7b-apache-q4k-v1.apr> --student <0.5B.apr>
--epochs 1 --backend cuda` completes the first KD step without an assert
in kd_step.rs:103-107 firing. Pre-fix the same dispatch hangs or asserts
on dimension mismatch (teacher returns 152064 logits, student returns 151936).
if_fails: |
The CLI dispatch didn't pass the student vocab to RealizarQ4KTeacher.
Re-check run_cuda_backend's construction site. Make sure the student
metadata vocab_size is read BEFORE the teacher provider is built.
evidence: |
evidence/distill-7b-teacher-vocab-aligned/launch-after-fix.log shows
per-step loss output (not the silent hang seen pre-fix at Bug-B verification).
- id: FT-VOCAB-ALIGN-002
rule: vocab_size() reports the effective vocab, not native
prediction: |
For a 7B teacher constructed with target_vocab=151936:
teacher.vocab_size() == 151936 (not 152064).
if_fails: |
The vocab_size() impl returns native_t instead of effective. Likely a
`self.cuda_model.config().vocab_size` direct return that needs to be
replaced with `self.effective_vocab_size`.
evidence: cargo test -p apr-cli --features cuda,training,inference --lib distill_q4k_teacher::tests::vocab_size_reports_effective
- id: FT-VOCAB-ALIGN-003
rule: target_vocab > native_vocab returns Err at construction
prediction: |
Constructing a RealizarQ4KTeacher with target_vocab=200000 against the 7B
(native=152064) returns Err(EntrenarError::Internal) with a message that
names both the requested and native vocab sizes.
if_fails: |
The check is missing; truncation logic with a too-large target produces
unspecified behavior or silent index-out-of-bounds at runtime.
evidence: cargo test -p apr-cli --features cuda,training,inference --lib distill_q4k_teacher::tests::oversized_target_errors
- id: FT-VOCAB-ALIGN-004
rule: logits_for_batch returns vectors of effective_vocab_size length
prediction: |
A constructed RealizarQ4KTeacher with target=151936 against the 7B,
given a 4-element input_ids batch: returns Vec<Vec<f32>> where every inner
vec has len()==151936.
if_fails: |
The truncation step was skipped. Check that the `.take(effective_vocab_size).collect()`
or `logits.truncate(effective_vocab_size)` is in the forward path.
evidence: cargo test -p apr-cli --features cuda,training,inference --lib distill_q4k_teacher::tests::logits_truncated_to_target

kani_harnesses:
- id: KANI-VOCAB-ALIGN-001
obligation: vocab_alignment_dispatch — effective_vocab is min(native, target_if_set)
property: |
For all (native in [1..200000], target in [None, Some(s) for s in [1..200000]]):
effective_vocab(native, target) is either Ok(N) with N == native (if target None) or N == target (if target<=native), or Err otherwise.
bound: 4
strategy: bounded_int
solver: cadical
harness: verify_effective_vocab_computation
- id: KANI-VOCAB-ALIGN-002
obligation: truncation preserves rank ordering of top-K logits
property: |
For every (l_t in R^native, N <= native): the top-K argmax over l_t[0..N]
is the intersection of the top-K argmax over l_t and the set [0..N).
(i.e., truncation doesn't perturb the relative ordering of in-range logits.)
bound: 8
strategy: bounded_int
solver: cadical
harness: verify_truncation_preserves_top_k

qa_gate:
id: F-VOCAB-ALIGN-001
name: apr-distill-teacher-vocab-alignment-v1 Contract
description: Teacher logits must be truncated to the student's vocab size when teacher_vocab > student_vocab, before any KD/softmax computation.
checks:
- validation
- falsification
33 changes: 32 additions & 1 deletion crates/apr-cli/src/commands/distill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -783,11 +783,42 @@ fn run_cuda_backend(
.is_some_and(|t| matches!(t.dtype, TensorDType::Q4K | TensorDType::Q6K))
});
let teacher_provider: Box<dyn TeacherLogitsProvider> = if teacher_uses_quantized_weights {
// PMAT-703: if teacher's native vocab is larger than the student's,
// we must truncate teacher logits to the student vocab before KD loss
// (see contracts/apr-distill-teacher-vocab-alignment-v1.yaml). The
// canonical example is Qwen2.5-Coder-7B (vocab=152064) → 0.5B
// (vocab=151936). Pass the student vocab as target so the teacher
// truncates at the boundary.
let student_vocab = student_meta.vocab_size.ok_or_else(|| {
CliError::ValidationFailed(
"student .apr metadata missing vocab_size — required for PMAT-703 vocab alignment"
.into(),
)
})?;
let teacher_native_vocab = teacher_meta.vocab_size.ok_or_else(|| {
CliError::ValidationFailed(
"teacher .apr metadata missing vocab_size — required for PMAT-703 vocab alignment"
.into(),
)
})?;
let target_vocab = if teacher_native_vocab > student_vocab {
eprintln!(
"[PMAT-703] vocab alignment: teacher native={teacher_native_vocab}, student={student_vocab} → truncating teacher logits to student vocab"
);
Some(student_vocab)
} else if teacher_native_vocab < student_vocab {
return Err(CliError::ValidationFailed(format!(
"vocab alignment: teacher vocab ({teacher_native_vocab}) < student vocab ({student_vocab}); \
student would need to predict tokens the teacher has no embeddings for"
)));
} else {
None
};
eprintln!(
"[PMAT-701] Q4K/Q6K teacher detected → RealizarQ4KTeacher (Q4K-native forward, no F32 dequant)"
);
Box::new(
RealizarQ4KTeacher::from_apr_path(teacher_path)
RealizarQ4KTeacher::from_apr_path_with_target_vocab(teacher_path, target_vocab)
.map_err(|e| CliError::ValidationFailed(format!("RealizarQ4KTeacher load: {e}")))?,
)
} else {
Expand Down
Loading
Loading