@@ -19,6 +19,7 @@ struct UCacheConfig {
1919 bool adaptive_threshold = true ;
2020 float early_step_multiplier = 0 .5f ;
2121 float late_step_multiplier = 1 .5f ;
22+ float relative_norm_gain = 1 .6f ;
2223 bool reset_error_on_compute = true ;
2324};
2425
@@ -45,14 +46,16 @@ struct UCacheState {
4546 bool has_output_prev_norm = false ;
4647 bool has_relative_transformation_rate = false ;
4748 float relative_transformation_rate = 0 .0f ;
48- float cumulative_change_rate = 0 .0f ;
4949 float last_input_change = 0 .0f ;
5050 bool has_last_input_change = false ;
51+ float output_change_ema = 0 .0f ;
52+ bool has_output_change_ema = false ;
5153 int total_steps_skipped = 0 ;
5254 int current_step_index = -1 ;
5355 int steps_computed_since_active = 0 ;
56+ int expected_total_steps = 0 ;
57+ int consecutive_skipped_steps = 0 ;
5458 float accumulated_error = 0 .0f ;
55- float reference_output_norm = 0 .0f ;
5659
5760 struct BlockMetrics {
5861 float sum_transformation_rate = 0 .0f ;
@@ -106,14 +109,16 @@ struct UCacheState {
106109 has_output_prev_norm = false ;
107110 has_relative_transformation_rate = false ;
108111 relative_transformation_rate = 0 .0f ;
109- cumulative_change_rate = 0 .0f ;
110112 last_input_change = 0 .0f ;
111113 has_last_input_change = false ;
114+ output_change_ema = 0 .0f ;
115+ has_output_change_ema = false ;
112116 total_steps_skipped = 0 ;
113117 current_step_index = -1 ;
114118 steps_computed_since_active = 0 ;
119+ expected_total_steps = 0 ;
120+ consecutive_skipped_steps = 0 ;
115121 accumulated_error = 0 .0f ;
116- reference_output_norm = 0 .0f ;
117122 block_metrics.reset ();
118123 total_active_steps = 0 ;
119124 }
@@ -133,7 +138,8 @@ struct UCacheState {
133138 if (!initialized || sigmas.size () < 2 ) {
134139 return ;
135140 }
136- size_t n_steps = sigmas.size () - 1 ;
141+ size_t n_steps = sigmas.size () - 1 ;
142+ expected_total_steps = static_cast <int >(n_steps);
137143
138144 size_t start_step = static_cast <size_t >(config.start_percent * n_steps);
139145 size_t end_step = static_cast <size_t >(config.end_percent * n_steps);
@@ -207,11 +213,15 @@ struct UCacheState {
207213 }
208214
209215 int effective_total = estimated_total_steps;
216+ if (effective_total <= 0 ) {
217+ effective_total = expected_total_steps;
218+ }
210219 if (effective_total <= 0 ) {
211220 effective_total = std::max (20 , steps_computed_since_active * 2 );
212221 }
213222
214223 float progress = (effective_total > 0 ) ? (static_cast <float >(steps_computed_since_active) / effective_total) : 0 .0f ;
224+ progress = std::max (0 .0f , std::min (1 .0f , progress));
215225
216226 float multiplier = 1 .0f ;
217227 if (progress < 0 .2f ) {
@@ -309,17 +319,31 @@ struct UCacheState {
309319
310320 if (has_output_prev_norm && has_relative_transformation_rate &&
311321 last_input_change > 0 .0f && output_prev_norm > 0 .0f ) {
312- float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm;
313- accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
322+ float approx_output_change = relative_transformation_rate * last_input_change;
323+ float approx_output_change_rate;
324+ if (config.use_relative_threshold ) {
325+ float base_scale = std::max (output_prev_norm, 1e-6f );
326+ float dyn_scale = has_output_change_ema
327+ ? std::max (output_change_ema * std::max (1 .0f , config.relative_norm_gain ), 1e-6f )
328+ : base_scale;
329+ float scale = std::sqrt (base_scale * dyn_scale);
330+ approx_output_change_rate = approx_output_change / scale;
331+ } else {
332+ approx_output_change_rate = approx_output_change;
333+ }
334+ // Increase estimated error with skip horizon to avoid long extrapolation streaks
335+ approx_output_change_rate *= (1 .0f + 0 .50f * consecutive_skipped_steps);
336+ accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
314337
315338 float effective_threshold = get_adaptive_threshold ();
316- if (config.use_relative_threshold && reference_output_norm > 0 .0f ) {
317- effective_threshold = effective_threshold * reference_output_norm ;
339+ if (! config.use_relative_threshold && output_prev_norm > 0 .0f ) {
340+ effective_threshold = effective_threshold * output_prev_norm ;
318341 }
319342
320343 if (accumulated_error < effective_threshold) {
321344 skip_current_step = true ;
322345 total_steps_skipped++;
346+ consecutive_skipped_steps++;
323347 apply_cache (cond, input, output);
324348 return true ;
325349 } else if (config.reset_error_on_compute ) {
@@ -340,6 +364,8 @@ struct UCacheState {
340364 if (cond != anchor_condition) {
341365 return ;
342366 }
367+ steps_computed_since_active++;
368+ consecutive_skipped_steps = 0 ;
343369
344370 size_t ne = static_cast <size_t >(ggml_nelements (input));
345371 float * in_data = (float *)input->data ;
@@ -359,6 +385,14 @@ struct UCacheState {
359385 output_change /= static_cast <float >(ne);
360386 }
361387 }
388+ if (std::isfinite (output_change) && output_change > 0 .0f ) {
389+ if (!has_output_change_ema) {
390+ output_change_ema = output_change;
391+ has_output_change_ema = true ;
392+ } else {
393+ output_change_ema = 0 .8f * output_change_ema + 0 .2f * output_change;
394+ }
395+ }
362396
363397 prev_output.resize (ne);
364398 for (size_t i = 0 ; i < ne; ++i) {
@@ -373,10 +407,6 @@ struct UCacheState {
373407 output_prev_norm = (ne > 0 ) ? (mean_abs / static_cast <float >(ne)) : 0 .0f ;
374408 has_output_prev_norm = output_prev_norm > 0 .0f ;
375409
376- if (reference_output_norm == 0 .0f ) {
377- reference_output_norm = output_prev_norm;
378- }
379-
380410 if (has_last_input_change && last_input_change > 0 .0f && output_change > 0 .0f ) {
381411 float rate = output_change / last_input_change;
382412 if (std::isfinite (rate)) {
0 commit comments