Skip to content

Commit dea4980

Browse files
authored
feat: add spectrum caching method (#1322)
1 parent c8fb3d2 commit dea4980

6 files changed

Lines changed: 325 additions & 7 deletions

File tree

docs/caching.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Caching methods accelerate diffusion inference by reusing intermediate computati
1111
| `dbcache` | DiT models | Block-level L1 residual threshold |
1212
| `taylorseer` | DiT models | Taylor series approximation |
1313
| `cache-dit` | DiT models | Combined DBCache + TaylorSeer |
14+
| `spectrum` | UNET models | Chebyshev + Taylor output forecasting |
1415

1516
### UCache (UNET Models)
1617

@@ -118,6 +119,28 @@ Mask values: `1` = compute, `0` = can cache.
118119
--scm-policy dynamic
119120
```
120121

122+
### Spectrum (UNET Models)
123+
124+
Spectrum uses Chebyshev polynomial fitting blended with Taylor extrapolation to predict denoised outputs, skipping entire UNet forward passes. Based on the paper [Spectrum: Adaptive Spectral Feature Forecasting for Efficient Diffusion Sampling](https://github.com/tingyu215/Spectrum).
125+
126+
```bash
127+
sd-cli -m model.safetensors -p "a cat" --cache-mode spectrum
128+
```
129+
130+
#### Parameters
131+
132+
| Parameter | Description | Default |
133+
|-----------|-------------|---------|
134+
| `w` | Chebyshev vs Taylor blend weight (0=Taylor, 1=Chebyshev) | 0.40 |
135+
| `m` | Chebyshev polynomial degree | 3 |
136+
| `lam` | Ridge regression regularization | 1.0 |
137+
| `window` | Initial window size (compute every N steps) | 2 |
138+
| `flex` | Window growth per computed step after warmup | 0.50 |
139+
| `warmup` | Steps to always compute before caching starts | 4 |
140+
| `stop` | Stop caching at this fraction of total steps | 0.9 |
141+
142+
```
143+
121144
### Performance Tips
122145
123146
- Start with default thresholds and adjust based on output quality

examples/cli/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,12 @@ Generation Options:
138138
--skip-layers layers to skip for SLG steps (default: [7,8,9])
139139
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
140140
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
141-
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)
141+
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level),
142+
'spectrum' (UNET Chebyshev+Taylor forecasting)
142143
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
143-
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples:
144-
"threshold=0.25" or "threshold=1.5,reset=0"
144+
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=;
145+
spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples:
146+
"threshold=0.25" or "threshold=1.5,reset=0" or "w=0.4,window=2"
145147
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
146148
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
147149
--scm-policy SCM policy: 'dynamic' (default) or 'static'

examples/common/common.hpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,8 +1422,8 @@ struct SDGenerationParams {
14221422
}
14231423
cache_mode = argv_to_utf8(index, argv);
14241424
if (cache_mode != "easycache" && cache_mode != "ucache" &&
1425-
cache_mode != "dbcache" && cache_mode != "taylorseer" && cache_mode != "cache-dit") {
1426-
fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache', 'ucache', 'dbcache', 'taylorseer', or 'cache-dit'\n", cache_mode.c_str());
1425+
cache_mode != "dbcache" && cache_mode != "taylorseer" && cache_mode != "cache-dit" && cache_mode != "spectrum") {
1426+
fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache', 'ucache', 'dbcache', 'taylorseer', 'cache-dit', or 'spectrum'\n", cache_mode.c_str());
14271427
return -1;
14281428
}
14291429
return 1;
@@ -1779,7 +1779,23 @@ struct SDGenerationParams {
17791779
} else if (key == "Bn" || key == "bn") {
17801780
cache_params.Bn_compute_blocks = std::stoi(val);
17811781
} else if (key == "warmup") {
1782-
cache_params.max_warmup_steps = std::stoi(val);
1782+
if (cache_mode == "spectrum") {
1783+
cache_params.spectrum_warmup_steps = std::stoi(val);
1784+
} else {
1785+
cache_params.max_warmup_steps = std::stoi(val);
1786+
}
1787+
} else if (key == "w") {
1788+
cache_params.spectrum_w = std::stof(val);
1789+
} else if (key == "m") {
1790+
cache_params.spectrum_m = std::stoi(val);
1791+
} else if (key == "lam") {
1792+
cache_params.spectrum_lam = std::stof(val);
1793+
} else if (key == "window") {
1794+
cache_params.spectrum_window_size = std::stoi(val);
1795+
} else if (key == "flex") {
1796+
cache_params.spectrum_flex_window = std::stof(val);
1797+
} else if (key == "stop") {
1798+
cache_params.spectrum_stop_percent = std::stof(val);
17831799
} else {
17841800
LOG_ERROR("error: unknown cache parameter '%s'", key.c_str());
17851801
return false;
@@ -1827,6 +1843,15 @@ struct SDGenerationParams {
18271843
cache_params.Bn_compute_blocks = 0;
18281844
cache_params.residual_diff_threshold = 0.08f;
18291845
cache_params.max_warmup_steps = 8;
1846+
} else if (cache_mode == "spectrum") {
1847+
cache_params.mode = SD_CACHE_SPECTRUM;
1848+
cache_params.spectrum_w = 0.40f;
1849+
cache_params.spectrum_m = 3;
1850+
cache_params.spectrum_lam = 1.0f;
1851+
cache_params.spectrum_window_size = 2;
1852+
cache_params.spectrum_flex_window = 0.50f;
1853+
cache_params.spectrum_warmup_steps = 4;
1854+
cache_params.spectrum_stop_percent = 0.9f;
18301855
}
18311856

18321857
if (!cache_option.empty()) {

include/stable-diffusion.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ enum sd_cache_mode_t {
251251
SD_CACHE_DBCACHE,
252252
SD_CACHE_TAYLORSEER,
253253
SD_CACHE_CACHE_DIT,
254+
SD_CACHE_SPECTRUM,
254255
};
255256

256257
typedef struct {
@@ -271,6 +272,13 @@ typedef struct {
271272
int taylorseer_skip_interval;
272273
const char* scm_mask;
273274
bool scm_policy_dynamic;
275+
float spectrum_w;
276+
int spectrum_m;
277+
float spectrum_lam;
278+
int spectrum_window_size;
279+
float spectrum_flex_window;
280+
int spectrum_warmup_steps;
281+
float spectrum_stop_percent;
274282
} sd_cache_params_t;
275283

276284
typedef struct {

src/spectrum.hpp

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
#ifndef __SPECTRUM_HPP__
2+
#define __SPECTRUM_HPP__
3+
4+
#include <cmath>
5+
#include <cstring>
6+
#include <vector>
7+
8+
#include "ggml_extend.hpp"
9+
10+
struct SpectrumConfig {
11+
float w = 0.40f;
12+
int m = 3;
13+
float lam = 1.0f;
14+
int window_size = 2;
15+
float flex_window = 0.50f;
16+
int warmup_steps = 4;
17+
float stop_percent = 0.9f;
18+
};
19+
20+
struct SpectrumState {
21+
SpectrumConfig config;
22+
int cnt = 0;
23+
int num_cached = 0;
24+
float curr_ws = 2.0f;
25+
int K = 6;
26+
int stop_step = 0;
27+
int total_steps_skipped = 0;
28+
29+
std::vector<std::vector<float>> H_buf;
30+
std::vector<float> T_buf;
31+
32+
void init(const SpectrumConfig& cfg, size_t total_steps) {
33+
config = cfg;
34+
cnt = 0;
35+
num_cached = 0;
36+
curr_ws = (float)cfg.window_size;
37+
K = std::max(cfg.m + 1, 6);
38+
stop_step = (int)(cfg.stop_percent * (float)total_steps);
39+
total_steps_skipped = 0;
40+
H_buf.clear();
41+
T_buf.clear();
42+
}
43+
44+
float taus(int step_cnt) const {
45+
return (step_cnt / 50.0f) * 2.0f - 1.0f;
46+
}
47+
48+
bool should_predict() {
49+
if (cnt < config.warmup_steps)
50+
return false;
51+
if (stop_step > 0 && cnt >= stop_step)
52+
return false;
53+
if ((int)H_buf.size() < 2)
54+
return false;
55+
56+
int ws = std::max(1, (int)std::floor(curr_ws));
57+
return (num_cached + 1) % ws != 0;
58+
}
59+
60+
void update(const struct ggml_tensor* denoised) {
61+
int64_t ne = ggml_nelements(denoised);
62+
const float* data = (const float*)denoised->data;
63+
64+
H_buf.emplace_back(data, data + ne);
65+
T_buf.push_back(taus(cnt));
66+
67+
while ((int)H_buf.size() > K) {
68+
H_buf.erase(H_buf.begin());
69+
T_buf.erase(T_buf.begin());
70+
}
71+
72+
if (cnt >= config.warmup_steps)
73+
curr_ws += config.flex_window;
74+
75+
num_cached = 0;
76+
cnt++;
77+
}
78+
79+
void predict(struct ggml_tensor* denoised) {
80+
int64_t F = (int64_t)H_buf[0].size();
81+
int K_curr = (int)H_buf.size();
82+
int M1 = config.m + 1;
83+
float tau_at = taus(cnt);
84+
85+
// Design matrix X: K_curr x M1 (Chebyshev basis)
86+
std::vector<float> X(K_curr * M1);
87+
for (int i = 0; i < K_curr; i++) {
88+
X[i * M1] = 1.0f;
89+
if (M1 > 1)
90+
X[i * M1 + 1] = T_buf[i];
91+
for (int j = 2; j < M1; j++)
92+
X[i * M1 + j] = 2.0f * T_buf[i] * X[i * M1 + j - 1] - X[i * M1 + j - 2];
93+
}
94+
95+
// x_star: Chebyshev basis at current tau
96+
std::vector<float> x_star(M1);
97+
x_star[0] = 1.0f;
98+
if (M1 > 1)
99+
x_star[1] = tau_at;
100+
for (int j = 2; j < M1; j++)
101+
x_star[j] = 2.0f * tau_at * x_star[j - 1] - x_star[j - 2];
102+
103+
// XtX = X^T X + lambda I
104+
std::vector<float> XtX(M1 * M1, 0.0f);
105+
for (int i = 0; i < M1; i++) {
106+
for (int j = 0; j < M1; j++) {
107+
float sum = 0.0f;
108+
for (int k = 0; k < K_curr; k++)
109+
sum += X[k * M1 + i] * X[k * M1 + j];
110+
XtX[i * M1 + j] = sum + (i == j ? config.lam : 0.0f);
111+
}
112+
}
113+
114+
// Cholesky decomposition
115+
std::vector<float> L(M1 * M1, 0.0f);
116+
if (!cholesky_decompose(XtX.data(), L.data(), M1)) {
117+
float trace = 0.0f;
118+
for (int i = 0; i < M1; i++)
119+
trace += XtX[i * M1 + i];
120+
for (int i = 0; i < M1; i++)
121+
XtX[i * M1 + i] += 1e-4f * trace / M1;
122+
cholesky_decompose(XtX.data(), L.data(), M1);
123+
}
124+
125+
// Solve XtX v = x_star
126+
std::vector<float> v(M1);
127+
cholesky_solve(L.data(), x_star.data(), v.data(), M1);
128+
129+
// Prediction weights per history entry
130+
std::vector<float> weights(K_curr, 0.0f);
131+
for (int k = 0; k < K_curr; k++)
132+
for (int j = 0; j < M1; j++)
133+
weights[k] += X[k * M1 + j] * v[j];
134+
135+
// Blend Chebyshev and Taylor predictions
136+
float* out = (float*)denoised->data;
137+
float w_cheb = config.w;
138+
float w_taylor = 1.0f - w_cheb;
139+
const float* h_last = H_buf.back().data();
140+
const float* h_prev = H_buf[H_buf.size() - 2].data();
141+
142+
for (int64_t f = 0; f < F; f++) {
143+
float pred_cheb = 0.0f;
144+
for (int k = 0; k < K_curr; k++)
145+
pred_cheb += weights[k] * H_buf[k][f];
146+
147+
float pred_taylor = h_last[f] + 0.5f * (h_last[f] - h_prev[f]);
148+
149+
out[f] = w_taylor * pred_taylor + w_cheb * pred_cheb;
150+
}
151+
152+
num_cached++;
153+
total_steps_skipped++;
154+
cnt++;
155+
}
156+
157+
private:
158+
static bool cholesky_decompose(const float* A, float* L, int n) {
159+
std::memset(L, 0, n * n * sizeof(float));
160+
for (int i = 0; i < n; i++) {
161+
for (int j = 0; j <= i; j++) {
162+
float sum = 0.0f;
163+
for (int k = 0; k < j; k++)
164+
sum += L[i * n + k] * L[j * n + k];
165+
if (i == j) {
166+
float diag = A[i * n + i] - sum;
167+
if (diag <= 0.0f)
168+
return false;
169+
L[i * n + j] = std::sqrt(diag);
170+
} else {
171+
L[i * n + j] = (A[i * n + j] - sum) / L[j * n + j];
172+
}
173+
}
174+
}
175+
return true;
176+
}
177+
178+
static void cholesky_solve(const float* L, const float* b, float* x, int n) {
179+
std::vector<float> y(n);
180+
for (int i = 0; i < n; i++) {
181+
float sum = 0.0f;
182+
for (int j = 0; j < i; j++)
183+
sum += L[i * n + j] * y[j];
184+
y[i] = (b[i] - sum) / L[i * n + i];
185+
}
186+
for (int i = n - 1; i >= 0; i--) {
187+
float sum = 0.0f;
188+
for (int j = i + 1; j < n; j++)
189+
sum += L[j * n + i] * x[j];
190+
x[i] = (y[i] - sum) / L[i * n + i];
191+
}
192+
}
193+
};
194+
195+
#endif // __SPECTRUM_HPP__

0 commit comments

Comments
 (0)