Skip to content

Commit 2b7a5a2

Browse files
authored
Add LLM performance regression instrumentation tests (pytorch#19700) (pytorch#19700)
Summary: Adds `LlmPerformanceTest`, an Android instrumentation test that measures inference performance metrics (TPS, TPS stability, TTFT) for ExecuTorch LLM on the stories110M fixture and asserts they meet minimum thresholds. This enables OKR 3.3 (Performance Testing: TPS/latency regression detection) using the same zero-infra approach as D105741356 — same fixture, same CI paths, no new dependencies. Three performance aspects are tested: 1. `testTpsAboveThreshold` — decode speed regression gate. A warm-up run is excluded from measurement. Threshold is configurable via instrumentation arg (`minTps`) so the same APK works on emulator (1.0 TPS) and device (10+ TPS). 2. `testTpsStability` — checks coefficient of variation across 3 runs is below 0.5. Catches thread contention, GC pressure, or scheduling instability that causes inconsistent user experience. 3. `testTimeToFirstToken` — measures prompt evaluation latency (prefill time). Asserts TTFT < 30s. Catches regressions in the prefill/KV-cache-fill path that make the app feel unresponsive before generation starts. All metrics are reported via InstrumentationRegistry.sendStatus() for CI metric capture and future dashboarding. Differential Revision: D105840841 Pulled By: psiddh
1 parent 576ed77 commit 2b7a5a2

1 file changed

Lines changed: 298 additions & 0 deletions

File tree

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
package org.pytorch.executorch
9+
10+
import android.os.Bundle
11+
import androidx.test.ext.junit.runners.AndroidJUnit4
12+
import androidx.test.platform.app.InstrumentationRegistry
13+
import java.io.File
14+
import java.io.IOException
15+
import java.util.Collections
16+
import org.apache.commons.io.FileUtils
17+
import org.json.JSONException
18+
import org.json.JSONObject
19+
import org.junit.After
20+
import org.junit.Assert.assertTrue
21+
import org.junit.Assert.fail
22+
import org.junit.Before
23+
import org.junit.Test
24+
import org.junit.runner.RunWith
25+
import org.pytorch.executorch.TestFileUtils.getTestFilePath
26+
import org.pytorch.executorch.extension.llm.LlmCallback
27+
import org.pytorch.executorch.extension.llm.LlmModule
28+
29+
/**
30+
* Performance regression tests for LLM inference on ExecuTorch Android.
31+
*
32+
* Measures tokens-per-second (TPS), TPS stability, and time-to-first-token (TTFT). Results are
33+
* reported via [InstrumentationRegistry] so CI systems can capture and trend metrics over time.
34+
*
35+
* Uses the same TinyStories-110M fixture as [LlmModuleConversationHistoryTest], so no additional
36+
* test infrastructure is needed. Works on both OSS (GitHub Actions) and internal (Sandcastle) CI.
37+
*
38+
* To run locally:
39+
* ```
40+
* ./gradlew :executorch_android:connectedAndroidTest \
41+
* -Pandroid.testInstrumentationRunnerArguments.class=org.pytorch.executorch.LlmPerformanceTest
42+
* ```
43+
*
44+
* To override the TPS threshold for physical devices:
45+
* ```
46+
* -Pandroid.testInstrumentationRunnerArguments.minTps=10.0
47+
* ```
48+
*/
49+
@RunWith(AndroidJUnit4::class)
50+
class LlmPerformanceTest : LlmCallback {
51+
52+
private lateinit var llmModule: LlmModule
53+
private val generatedTokens: MutableList<String> =
54+
Collections.synchronizedList(mutableListOf<String>())
55+
private val tpsResults: MutableList<Float> = Collections.synchronizedList(mutableListOf<Float>())
56+
@Volatile private var lastStatsJson: String? = null
57+
58+
@Before
59+
@Throws(IOException::class)
60+
fun setUp() {
61+
val pteFile = File(getTestFilePath(TEST_FILE_NAME))
62+
requireNotNull(javaClass.getResourceAsStream(TEST_FILE_NAME)) {
63+
"Test resource $TEST_FILE_NAME not found; did android_test_setup.sh run?"
64+
}
65+
.use { pteStream -> FileUtils.copyInputStreamToFile(pteStream, pteFile) }
66+
67+
val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME))
68+
requireNotNull(javaClass.getResourceAsStream(TOKENIZER_FILE_NAME)) {
69+
"Test resource $TOKENIZER_FILE_NAME not found; did android_test_setup.sh run?"
70+
}
71+
.use { tokenizerStream -> FileUtils.copyInputStreamToFile(tokenizerStream, tokenizerFile) }
72+
73+
llmModule =
74+
LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f)
75+
}
76+
77+
@After
78+
fun tearDown() {
79+
if (::llmModule.isInitialized) {
80+
llmModule.close()
81+
}
82+
}
83+
84+
/**
85+
* Measures TPS after a warm-up run and asserts it exceeds a minimum threshold.
86+
*
87+
* The warm-up is necessary because the first inference includes one-time costs (memory
88+
* allocation, kernel compilation on some backends) that would unfairly penalize the measurement.
89+
*
90+
* Default threshold is conservative (1.0 TPS) for emulator CI. Override with the `minTps`
91+
* instrumentation argument for physical device runs where 10-30+ TPS is expected.
92+
*/
93+
@Test(timeout = MAX_TEST_TIMEOUT_MS)
94+
fun testTpsAboveThreshold() {
95+
llmModule.load()
96+
97+
// Warm-up: first inference includes one-time overhead
98+
resetState()
99+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
100+
assertTrue("Warm-up produced no tokens — model may be broken", generatedTokens.isNotEmpty())
101+
val warmupTps = tpsResults.lastOrNull() ?: 0f
102+
reportMetric("warmup_tps", warmupTps)
103+
104+
// Measured run
105+
resetState()
106+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
107+
108+
assertTrue("Measured run produced no tokens", generatedTokens.isNotEmpty())
109+
assertTrue("No TPS stats received from onStats callback", tpsResults.isNotEmpty())
110+
111+
val measuredTps = tpsResults.last()
112+
val minTps = getMinTpsThreshold()
113+
val statsTokenCount =
114+
try {
115+
JSONObject(lastStatsJson!!).getInt("generated_tokens")
116+
} catch (_: Exception) {
117+
-1
118+
}
119+
120+
reportMetric("measured_tps", measuredTps)
121+
reportMetric("measured_tokens", statsTokenCount.toFloat())
122+
reportMetric("min_tps_threshold", minTps)
123+
124+
assertTrue(
125+
"TPS regression detected! measured=${"%.2f".format(measuredTps)} " +
126+
"< threshold=${"%.2f".format(minTps)}. Raw stats: $lastStatsJson",
127+
measuredTps >= minTps,
128+
)
129+
}
130+
131+
/**
132+
* Validates that TPS is stable across multiple consecutive runs.
133+
*
134+
* Large variance in TPS (high coefficient of variation) may indicate thread contention, GC
135+
* pressure, thermal throttling, or non-deterministic scheduling — all of which degrade the user
136+
* experience even if average TPS is acceptable.
137+
*/
138+
@Test(timeout = MAX_TEST_TIMEOUT_MS)
139+
fun testTpsStability() {
140+
llmModule.load()
141+
142+
// Warm-up
143+
resetState()
144+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
145+
146+
// Collect TPS over multiple runs
147+
val measurements = mutableListOf<Float>()
148+
for (i in 1..STABILITY_ITERATIONS) {
149+
resetState()
150+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
151+
if (tpsResults.isNotEmpty()) {
152+
measurements.add(tpsResults.last())
153+
}
154+
}
155+
156+
assertTrue(
157+
"Not enough TPS measurements (${measurements.size}/$STABILITY_ITERATIONS)",
158+
measurements.size >= STABILITY_ITERATIONS,
159+
)
160+
161+
val mean = measurements.average().toFloat()
162+
val variance = measurements.map { (it - mean) * (it - mean) }.average().toFloat()
163+
val stddev = Math.sqrt(variance.toDouble()).toFloat()
164+
val cv = if (mean > 0f) stddev / mean else Float.MAX_VALUE
165+
166+
reportMetric("stability_mean_tps", mean)
167+
reportMetric("stability_stddev", stddev)
168+
reportMetric("stability_cv", cv)
169+
reportMetric("stability_min", measurements.minOrNull()!!)
170+
reportMetric("stability_max", measurements.maxOrNull()!!)
171+
172+
assertTrue(
173+
"TPS too unstable! CV=${"%.3f".format(cv)} exceeds max $MAX_CV. " +
174+
"Measurements: $measurements",
175+
cv <= MAX_CV,
176+
)
177+
}
178+
179+
/**
180+
* Measures time-to-first-token (TTFT) — the delay from calling generate() until the first token
181+
* is produced (i.e., prompt evaluation / prefill time).
182+
*
183+
* High TTFT directly impacts perceived responsiveness: the user types a message and sees nothing
184+
* happen until prefill completes.
185+
*/
186+
@Test(timeout = MAX_TEST_TIMEOUT_MS)
187+
fun testTimeToFirstToken() {
188+
llmModule.load()
189+
190+
// Warm-up
191+
resetState()
192+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
193+
194+
// Measured TTFT
195+
resetState()
196+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
197+
198+
val statsJson = lastStatsJson
199+
assertTrue("No stats JSON received from onStats callback", statsJson != null)
200+
201+
try {
202+
val json = JSONObject(statsJson!!)
203+
val inferenceStartMs = json.getLong("inference_start_ms")
204+
val firstTokenMs = json.getLong("first_token_ms")
205+
val ttftMs = firstTokenMs - inferenceStartMs
206+
207+
reportMetric("ttft_ms", ttftMs.toFloat())
208+
209+
assertTrue(
210+
"TTFT too slow: ${ttftMs}ms exceeds max ${MAX_TTFT_MS}ms. " +
211+
"First token latency is too high.",
212+
ttftMs <= MAX_TTFT_MS,
213+
)
214+
} catch (e: JSONException) {
215+
fail("Failed to parse onStats JSON for TTFT: $statsJson. Error: ${e.message}")
216+
}
217+
}
218+
219+
// ─── LlmCallback ──────────────────────────────────────────────────────────────────
220+
221+
override fun onResult(result: String) {
222+
generatedTokens.add(result)
223+
}
224+
225+
override fun onStats(stats: String) {
226+
lastStatsJson = stats
227+
try {
228+
val json = JSONObject(stats)
229+
val numTokens = json.getInt("generated_tokens")
230+
val inferenceEndMs = json.getLong("inference_end_ms")
231+
val promptEvalEndMs = json.getLong("prompt_eval_end_ms")
232+
val decodeTimeMs = inferenceEndMs - promptEvalEndMs
233+
if (decodeTimeMs > 0) {
234+
tpsResults.add(numTokens.toFloat() / decodeTimeMs.toFloat() * 1000f)
235+
}
236+
} catch (_: JSONException) {
237+
// Parsing failure — test will fail on assertion
238+
}
239+
}
240+
241+
// ─── Helpers ─────────────────────────────────────────────────────────────────────
242+
243+
private fun resetState() {
244+
generatedTokens.clear()
245+
tpsResults.clear()
246+
lastStatsJson = null
247+
llmModule.resetContext()
248+
}
249+
250+
/**
251+
* Returns the minimum TPS threshold. Overridable via instrumentation arg `minTps` so the same
252+
* test binary can gate at different levels for emulator vs physical device CI.
253+
*/
254+
private fun getMinTpsThreshold(): Float {
255+
val override =
256+
InstrumentationRegistry.getArguments().getString("minTps") ?: return DEFAULT_MIN_TPS
257+
val parsed = override.toFloatOrNull()
258+
require(parsed != null && parsed.isFinite() && parsed > 0f) {
259+
"Invalid instrumentation arg minTps='$override'. Expected a finite, positive float."
260+
}
261+
return parsed
262+
}
263+
264+
private fun reportMetric(key: String, value: Float) {
265+
val bundle = Bundle().apply { putFloat(key, value) }
266+
InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle)
267+
}
268+
269+
companion object {
270+
private const val TEST_FILE_NAME = "/stories.pte"
271+
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
272+
273+
/** Prompt for inference. Kept short to minimize test wall-time. */
274+
private const val TEST_PROMPT = "Once upon a time"
275+
private const val SEQ_LEN = 64
276+
277+
/**
278+
* Minimum TPS for the test to pass. Conservative for x86_64 emulator (API 34). For physical
279+
* devices, override via: -Pandroid.testInstrumentationRunnerArguments.minTps=10.0
280+
*/
281+
private const val DEFAULT_MIN_TPS = 1.0f
282+
283+
/** Maximum time-to-first-token in milliseconds. 30s is generous for emulator. */
284+
private const val MAX_TTFT_MS = 30_000
285+
286+
/**
287+
* Maximum coefficient of variation (stddev/mean) for TPS across runs. 0.5 = up to 50% relative
288+
* variance, which is generous for noisy emulator environments. Tighten for dedicated devices.
289+
*/
290+
private const val MAX_CV = 0.5f
291+
292+
/** Number of runs for the stability test. */
293+
private const val STABILITY_ITERATIONS = 3
294+
295+
/** Per-test timeout: 5 minutes to accommodate slow emulator environments. */
296+
private const val MAX_TEST_TIMEOUT_MS = 300_000L
297+
}
298+
}

0 commit comments

Comments
 (0)