Skip to content

Commit d950627

Browse files
authored
fix: ucache: normalize reuse error (#1313)
1 parent 7c880f8 commit d950627

1 file changed

Lines changed: 43 additions & 13 deletions

File tree

src/ucache.hpp

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct UCacheConfig {
1919
bool adaptive_threshold = true;
2020
float early_step_multiplier = 0.5f;
2121
float late_step_multiplier = 1.5f;
22+
float relative_norm_gain = 1.6f;
2223
bool reset_error_on_compute = true;
2324
};
2425

@@ -45,14 +46,16 @@ struct UCacheState {
4546
bool has_output_prev_norm = false;
4647
bool has_relative_transformation_rate = false;
4748
float relative_transformation_rate = 0.0f;
48-
float cumulative_change_rate = 0.0f;
4949
float last_input_change = 0.0f;
5050
bool has_last_input_change = false;
51+
float output_change_ema = 0.0f;
52+
bool has_output_change_ema = false;
5153
int total_steps_skipped = 0;
5254
int current_step_index = -1;
5355
int steps_computed_since_active = 0;
56+
int expected_total_steps = 0;
57+
int consecutive_skipped_steps = 0;
5458
float accumulated_error = 0.0f;
55-
float reference_output_norm = 0.0f;
5659

5760
struct BlockMetrics {
5861
float sum_transformation_rate = 0.0f;
@@ -106,14 +109,16 @@ struct UCacheState {
106109
has_output_prev_norm = false;
107110
has_relative_transformation_rate = false;
108111
relative_transformation_rate = 0.0f;
109-
cumulative_change_rate = 0.0f;
110112
last_input_change = 0.0f;
111113
has_last_input_change = false;
114+
output_change_ema = 0.0f;
115+
has_output_change_ema = false;
112116
total_steps_skipped = 0;
113117
current_step_index = -1;
114118
steps_computed_since_active = 0;
119+
expected_total_steps = 0;
120+
consecutive_skipped_steps = 0;
115121
accumulated_error = 0.0f;
116-
reference_output_norm = 0.0f;
117122
block_metrics.reset();
118123
total_active_steps = 0;
119124
}
@@ -133,7 +138,8 @@ struct UCacheState {
133138
if (!initialized || sigmas.size() < 2) {
134139
return;
135140
}
136-
size_t n_steps = sigmas.size() - 1;
141+
size_t n_steps = sigmas.size() - 1;
142+
expected_total_steps = static_cast<int>(n_steps);
137143

138144
size_t start_step = static_cast<size_t>(config.start_percent * n_steps);
139145
size_t end_step = static_cast<size_t>(config.end_percent * n_steps);
@@ -207,11 +213,15 @@ struct UCacheState {
207213
}
208214

209215
int effective_total = estimated_total_steps;
216+
if (effective_total <= 0) {
217+
effective_total = expected_total_steps;
218+
}
210219
if (effective_total <= 0) {
211220
effective_total = std::max(20, steps_computed_since_active * 2);
212221
}
213222

214223
float progress = (effective_total > 0) ? (static_cast<float>(steps_computed_since_active) / effective_total) : 0.0f;
224+
progress = std::max(0.0f, std::min(1.0f, progress));
215225

216226
float multiplier = 1.0f;
217227
if (progress < 0.2f) {
@@ -309,17 +319,31 @@ struct UCacheState {
309319

310320
if (has_output_prev_norm && has_relative_transformation_rate &&
311321
last_input_change > 0.0f && output_prev_norm > 0.0f) {
312-
float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm;
313-
accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
322+
float approx_output_change = relative_transformation_rate * last_input_change;
323+
float approx_output_change_rate;
324+
if (config.use_relative_threshold) {
325+
float base_scale = std::max(output_prev_norm, 1e-6f);
326+
float dyn_scale = has_output_change_ema
327+
? std::max(output_change_ema * std::max(1.0f, config.relative_norm_gain), 1e-6f)
328+
: base_scale;
329+
float scale = std::sqrt(base_scale * dyn_scale);
330+
approx_output_change_rate = approx_output_change / scale;
331+
} else {
332+
approx_output_change_rate = approx_output_change;
333+
}
334+
// Increase estimated error with skip horizon to avoid long extrapolation streaks
335+
approx_output_change_rate *= (1.0f + 0.50f * consecutive_skipped_steps);
336+
accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
314337

315338
float effective_threshold = get_adaptive_threshold();
316-
if (config.use_relative_threshold && reference_output_norm > 0.0f) {
317-
effective_threshold = effective_threshold * reference_output_norm;
339+
if (!config.use_relative_threshold && output_prev_norm > 0.0f) {
340+
effective_threshold = effective_threshold * output_prev_norm;
318341
}
319342

320343
if (accumulated_error < effective_threshold) {
321344
skip_current_step = true;
322345
total_steps_skipped++;
346+
consecutive_skipped_steps++;
323347
apply_cache(cond, input, output);
324348
return true;
325349
} else if (config.reset_error_on_compute) {
@@ -340,6 +364,8 @@ struct UCacheState {
340364
if (cond != anchor_condition) {
341365
return;
342366
}
367+
steps_computed_since_active++;
368+
consecutive_skipped_steps = 0;
343369

344370
size_t ne = static_cast<size_t>(ggml_nelements(input));
345371
float* in_data = (float*)input->data;
@@ -359,6 +385,14 @@ struct UCacheState {
359385
output_change /= static_cast<float>(ne);
360386
}
361387
}
388+
if (std::isfinite(output_change) && output_change > 0.0f) {
389+
if (!has_output_change_ema) {
390+
output_change_ema = output_change;
391+
has_output_change_ema = true;
392+
} else {
393+
output_change_ema = 0.8f * output_change_ema + 0.2f * output_change;
394+
}
395+
}
362396

363397
prev_output.resize(ne);
364398
for (size_t i = 0; i < ne; ++i) {
@@ -373,10 +407,6 @@ struct UCacheState {
373407
output_prev_norm = (ne > 0) ? (mean_abs / static_cast<float>(ne)) : 0.0f;
374408
has_output_prev_norm = output_prev_norm > 0.0f;
375409

376-
if (reference_output_norm == 0.0f) {
377-
reference_output_norm = output_prev_norm;
378-
}
379-
380410
if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) {
381411
float rate = output_change / last_input_change;
382412
if (std::isfinite(rate)) {

0 commit comments

Comments
 (0)