Skip to content

Commit 92f7be4

Browse files
Fix retained x_attn capture without enabling Qwen QKV Freivalds
- Populate RetainedLayerState.x_attn_i8 with captured GPU x_attn (was always None — prover used captured x_attn for QKV accumulators but verifier fell back to bridge-derived x_attn, causing mismatch) - Keep Qwen supports_qkv_freivalds=false until GPU validation confirms the fix makes Freivalds pass - Add skipped field to V4VerifyReport for explicit unsupported reporting - Expose skipped in Python verify dict and Display impl - Update roadmap: QKV Freivalds gated by profile
1 parent f397405 commit 92f7be4

2 files changed

Lines changed: 10 additions & 6 deletions

File tree

crates/verilm-core/src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ impl VerificationProfile {
9191
max_validated_context: 1164,
9292
requires_score_anchoring: false,
9393
score_anchor_threshold: None, // anchor gap ~14, too loose for strong tier
94-
supports_qkv_freivalds: false, // bridge replay can't match GPU INT8 GEMM for Qwen
94+
supports_qkv_freivalds: false, // pending: prover now populates x_attn_i8 in retained state, needs GPU validation
9595
}
9696
}
9797

crates/verilm-prover/src/lib.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,16 +148,20 @@ pub fn build_retained_from_captures(
148148
let a_dim = entry.a_i8.len() / batch_size;
149149
let a = entry.a_i8[b * a_dim..(b + 1) * a_dim].to_vec();
150150

151-
if let Some(ref xa) = entry.x_attn_i8 {
151+
let (retained_xa, retained_scale_xa) = if let Some(ref xa) = entry.x_attn_i8 {
152152
let x_dim = xa.len() / batch_size;
153-
token_x_attn.push(xa[b * x_dim..(b + 1) * x_dim].to_vec());
154-
}
153+
let slice = xa[b * x_dim..(b + 1) * x_dim].to_vec();
154+
token_x_attn.push(slice.clone());
155+
(Some(slice), Some(entry.scale_x_attn[b]))
156+
} else {
157+
(None, None)
158+
};
155159

156160
layers.push(RetainedLayerState {
157161
a,
158162
scale_a: entry.scale_a[b],
159-
x_attn_i8: None,
160-
scale_x_attn: None,
163+
x_attn_i8: retained_xa,
164+
scale_x_attn: retained_scale_xa,
161165
});
162166
token_scales.push(CapturedLayerScales {
163167
scale_x_attn: entry.scale_x_attn[b],

0 commit comments

Comments
 (0)