|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | +package org.pytorch.executorch |
| 9 | + |
| 10 | +import android.os.Bundle |
| 11 | +import androidx.test.ext.junit.runners.AndroidJUnit4 |
| 12 | +import androidx.test.platform.app.InstrumentationRegistry |
| 13 | +import java.io.File |
| 14 | +import java.io.IOException |
| 15 | +import java.util.Collections |
| 16 | +import org.apache.commons.io.FileUtils |
| 17 | +import org.json.JSONException |
| 18 | +import org.json.JSONObject |
| 19 | +import org.junit.After |
| 20 | +import org.junit.Assert.assertTrue |
| 21 | +import org.junit.Assert.fail |
| 22 | +import org.junit.Before |
| 23 | +import org.junit.Test |
| 24 | +import org.junit.runner.RunWith |
| 25 | +import org.pytorch.executorch.TestFileUtils.getTestFilePath |
| 26 | +import org.pytorch.executorch.extension.llm.LlmCallback |
| 27 | +import org.pytorch.executorch.extension.llm.LlmModule |
| 28 | + |
| 29 | +/** |
| 30 | + * Performance regression tests for LLM inference on ExecuTorch Android. |
| 31 | + * |
| 32 | + * Measures tokens-per-second (TPS), TPS stability, and time-to-first-token (TTFT). Results are |
| 33 | + * reported via [InstrumentationRegistry] so CI systems can capture and trend metrics over time. |
| 34 | + * |
| 35 | + * Uses the same TinyStories-110M fixture as [LlmModuleConversationHistoryTest], so no additional |
| 36 | + * test infrastructure is needed. Works on both OSS (GitHub Actions) and internal (Sandcastle) CI. |
| 37 | + * |
| 38 | + * To run locally: |
| 39 | + * ``` |
| 40 | + * ./gradlew :executorch_android:connectedAndroidTest \ |
| 41 | + * -Pandroid.testInstrumentationRunnerArguments.class=org.pytorch.executorch.LlmPerformanceTest |
| 42 | + * ``` |
| 43 | + * |
| 44 | + * To override the TPS threshold for physical devices: |
| 45 | + * ``` |
| 46 | + * -Pandroid.testInstrumentationRunnerArguments.minTps=10.0 |
| 47 | + * ``` |
| 48 | + */ |
| 49 | +@RunWith(AndroidJUnit4::class) |
| 50 | +class LlmPerformanceTest : LlmCallback { |
| 51 | + |
| 52 | + private lateinit var llmModule: LlmModule |
| 53 | + private val generatedTokens: MutableList<String> = |
| 54 | + Collections.synchronizedList(mutableListOf<String>()) |
| 55 | + private val tpsResults: MutableList<Float> = Collections.synchronizedList(mutableListOf<Float>()) |
| 56 | + @Volatile private var lastStatsJson: String? = null |
| 57 | + |
| 58 | + @Before |
| 59 | + @Throws(IOException::class) |
| 60 | + fun setUp() { |
| 61 | + val pteFile = File(getTestFilePath(TEST_FILE_NAME)) |
| 62 | + requireNotNull(javaClass.getResourceAsStream(TEST_FILE_NAME)) { |
| 63 | + "Test resource $TEST_FILE_NAME not found; did android_test_setup.sh run?" |
| 64 | + } |
| 65 | + .use { pteStream -> FileUtils.copyInputStreamToFile(pteStream, pteFile) } |
| 66 | + |
| 67 | + val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME)) |
| 68 | + requireNotNull(javaClass.getResourceAsStream(TOKENIZER_FILE_NAME)) { |
| 69 | + "Test resource $TOKENIZER_FILE_NAME not found; did android_test_setup.sh run?" |
| 70 | + } |
| 71 | + .use { tokenizerStream -> FileUtils.copyInputStreamToFile(tokenizerStream, tokenizerFile) } |
| 72 | + |
| 73 | + llmModule = |
| 74 | + LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f) |
| 75 | + } |
| 76 | + |
| 77 | + @After |
| 78 | + fun tearDown() { |
| 79 | + if (::llmModule.isInitialized) { |
| 80 | + llmModule.close() |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + /** |
| 85 | + * Measures TPS after a warm-up run and asserts it exceeds a minimum threshold. |
| 86 | + * |
| 87 | + * The warm-up is necessary because the first inference includes one-time costs (memory |
| 88 | + * allocation, kernel compilation on some backends) that would unfairly penalize the measurement. |
| 89 | + * |
| 90 | + * Default threshold is conservative (1.0 TPS) for emulator CI. Override with the `minTps` |
| 91 | + * instrumentation argument for physical device runs where 10-30+ TPS is expected. |
| 92 | + */ |
| 93 | + @Test(timeout = MAX_TEST_TIMEOUT_MS) |
| 94 | + fun testTpsAboveThreshold() { |
| 95 | + llmModule.load() |
| 96 | + |
| 97 | + // Warm-up: first inference includes one-time overhead |
| 98 | + resetState() |
| 99 | + llmModule.generate(TEST_PROMPT, SEQ_LEN, this) |
| 100 | + assertTrue("Warm-up produced no tokens — model may be broken", generatedTokens.isNotEmpty()) |
| 101 | + val warmupTps = tpsResults.lastOrNull() ?: 0f |
| 102 | + reportMetric("warmup_tps", warmupTps) |
| 103 | + |
| 104 | + // Measured run |
| 105 | + resetState() |
| 106 | + llmModule.generate(TEST_PROMPT, SEQ_LEN, this) |
| 107 | + |
| 108 | + assertTrue("Measured run produced no tokens", generatedTokens.isNotEmpty()) |
| 109 | + assertTrue("No TPS stats received from onStats callback", tpsResults.isNotEmpty()) |
| 110 | + |
| 111 | + val measuredTps = tpsResults.last() |
| 112 | + val minTps = getMinTpsThreshold() |
| 113 | + val statsTokenCount = |
| 114 | + try { |
| 115 | + JSONObject(lastStatsJson!!).getInt("generated_tokens") |
| 116 | + } catch (_: Exception) { |
| 117 | + -1 |
| 118 | + } |
| 119 | + |
| 120 | + reportMetric("measured_tps", measuredTps) |
| 121 | + reportMetric("measured_tokens", statsTokenCount.toFloat()) |
| 122 | + reportMetric("min_tps_threshold", minTps) |
| 123 | + |
| 124 | + assertTrue( |
| 125 | + "TPS regression detected! measured=${"%.2f".format(measuredTps)} " + |
| 126 | + "< threshold=${"%.2f".format(minTps)}. Raw stats: $lastStatsJson", |
| 127 | + measuredTps >= minTps, |
| 128 | + ) |
| 129 | + } |
| 130 | + |
| 131 | + /** |
| 132 | + * Validates that TPS is stable across multiple consecutive runs. |
| 133 | + * |
| 134 | + * Large variance in TPS (high coefficient of variation) may indicate thread contention, GC |
| 135 | + * pressure, thermal throttling, or non-deterministic scheduling — all of which degrade the user |
| 136 | + * experience even if average TPS is acceptable. |
| 137 | + */ |
| 138 | + @Test(timeout = MAX_TEST_TIMEOUT_MS) |
| 139 | + fun testTpsStability() { |
| 140 | + llmModule.load() |
| 141 | + |
| 142 | + // Warm-up |
| 143 | + resetState() |
| 144 | + llmModule.generate(TEST_PROMPT, SEQ_LEN, this) |
| 145 | + |
| 146 | + // Collect TPS over multiple runs |
| 147 | + val measurements = mutableListOf<Float>() |
| 148 | + for (i in 1..STABILITY_ITERATIONS) { |
| 149 | + resetState() |
| 150 | + llmModule.generate(TEST_PROMPT, SEQ_LEN, this) |
| 151 | + if (tpsResults.isNotEmpty()) { |
| 152 | + measurements.add(tpsResults.last()) |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + assertTrue( |
| 157 | + "Not enough TPS measurements (${measurements.size}/$STABILITY_ITERATIONS)", |
| 158 | + measurements.size >= STABILITY_ITERATIONS, |
| 159 | + ) |
| 160 | + |
| 161 | + val mean = measurements.average().toFloat() |
| 162 | + val variance = measurements.map { (it - mean) * (it - mean) }.average().toFloat() |
| 163 | + val stddev = Math.sqrt(variance.toDouble()).toFloat() |
| 164 | + val cv = if (mean > 0f) stddev / mean else Float.MAX_VALUE |
| 165 | + |
| 166 | + reportMetric("stability_mean_tps", mean) |
| 167 | + reportMetric("stability_stddev", stddev) |
| 168 | + reportMetric("stability_cv", cv) |
| 169 | + reportMetric("stability_min", measurements.minOrNull()!!) |
| 170 | + reportMetric("stability_max", measurements.maxOrNull()!!) |
| 171 | + |
| 172 | + assertTrue( |
| 173 | + "TPS too unstable! CV=${"%.3f".format(cv)} exceeds max $MAX_CV. " + |
| 174 | + "Measurements: $measurements", |
| 175 | + cv <= MAX_CV, |
| 176 | + ) |
| 177 | + } |
| 178 | + |
| 179 | + /** |
| 180 | + * Measures time-to-first-token (TTFT) — the delay from calling generate() until the first token |
| 181 | + * is produced (i.e., prompt evaluation / prefill time). |
| 182 | + * |
| 183 | + * High TTFT directly impacts perceived responsiveness: the user types a message and sees nothing |
| 184 | + * happen until prefill completes. |
| 185 | + */ |
| 186 | + @Test(timeout = MAX_TEST_TIMEOUT_MS) |
| 187 | + fun testTimeToFirstToken() { |
| 188 | + llmModule.load() |
| 189 | + |
| 190 | + // Warm-up |
| 191 | + resetState() |
| 192 | + llmModule.generate(TEST_PROMPT, SEQ_LEN, this) |
| 193 | + |
| 194 | + // Measured TTFT |
| 195 | + resetState() |
| 196 | + llmModule.generate(TEST_PROMPT, SEQ_LEN, this) |
| 197 | + |
| 198 | + val statsJson = lastStatsJson |
| 199 | + assertTrue("No stats JSON received from onStats callback", statsJson != null) |
| 200 | + |
| 201 | + try { |
| 202 | + val json = JSONObject(statsJson!!) |
| 203 | + val inferenceStartMs = json.getLong("inference_start_ms") |
| 204 | + val firstTokenMs = json.getLong("first_token_ms") |
| 205 | + val ttftMs = firstTokenMs - inferenceStartMs |
| 206 | + |
| 207 | + reportMetric("ttft_ms", ttftMs.toFloat()) |
| 208 | + |
| 209 | + assertTrue( |
| 210 | + "TTFT too slow: ${ttftMs}ms exceeds max ${MAX_TTFT_MS}ms. " + |
| 211 | + "First token latency is too high.", |
| 212 | + ttftMs <= MAX_TTFT_MS, |
| 213 | + ) |
| 214 | + } catch (e: JSONException) { |
| 215 | + fail("Failed to parse onStats JSON for TTFT: $statsJson. Error: ${e.message}") |
| 216 | + } |
| 217 | + } |
| 218 | + |
| 219 | + // ─── LlmCallback ────────────────────────────────────────────────────────────────── |
| 220 | + |
| 221 | + override fun onResult(result: String) { |
| 222 | + generatedTokens.add(result) |
| 223 | + } |
| 224 | + |
| 225 | + override fun onStats(stats: String) { |
| 226 | + lastStatsJson = stats |
| 227 | + try { |
| 228 | + val json = JSONObject(stats) |
| 229 | + val numTokens = json.getInt("generated_tokens") |
| 230 | + val inferenceEndMs = json.getLong("inference_end_ms") |
| 231 | + val promptEvalEndMs = json.getLong("prompt_eval_end_ms") |
| 232 | + val decodeTimeMs = inferenceEndMs - promptEvalEndMs |
| 233 | + if (decodeTimeMs > 0) { |
| 234 | + tpsResults.add(numTokens.toFloat() / decodeTimeMs.toFloat() * 1000f) |
| 235 | + } |
| 236 | + } catch (_: JSONException) { |
| 237 | + // Parsing failure — test will fail on assertion |
| 238 | + } |
| 239 | + } |
| 240 | + |
| 241 | + // ─── Helpers ───────────────────────────────────────────────────────────────────── |
| 242 | + |
| 243 | + private fun resetState() { |
| 244 | + generatedTokens.clear() |
| 245 | + tpsResults.clear() |
| 246 | + lastStatsJson = null |
| 247 | + llmModule.resetContext() |
| 248 | + } |
| 249 | + |
| 250 | + /** |
| 251 | + * Returns the minimum TPS threshold. Overridable via instrumentation arg `minTps` so the same |
| 252 | + * test binary can gate at different levels for emulator vs physical device CI. |
| 253 | + */ |
| 254 | + private fun getMinTpsThreshold(): Float { |
| 255 | + val override = |
| 256 | + InstrumentationRegistry.getArguments().getString("minTps") ?: return DEFAULT_MIN_TPS |
| 257 | + val parsed = override.toFloatOrNull() |
| 258 | + require(parsed != null && parsed.isFinite() && parsed > 0f) { |
| 259 | + "Invalid instrumentation arg minTps='$override'. Expected a finite, positive float." |
| 260 | + } |
| 261 | + return parsed |
| 262 | + } |
| 263 | + |
| 264 | + private fun reportMetric(key: String, value: Float) { |
| 265 | + val bundle = Bundle().apply { putFloat(key, value) } |
| 266 | + InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle) |
| 267 | + } |
| 268 | + |
| 269 | + companion object { |
| 270 | + private const val TEST_FILE_NAME = "/stories.pte" |
| 271 | + private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" |
| 272 | + |
| 273 | + /** Prompt for inference. Kept short to minimize test wall-time. */ |
| 274 | + private const val TEST_PROMPT = "Once upon a time" |
| 275 | + private const val SEQ_LEN = 64 |
| 276 | + |
| 277 | + /** |
| 278 | + * Minimum TPS for the test to pass. Conservative for x86_64 emulator (API 34). For physical |
| 279 | + * devices, override via: -Pandroid.testInstrumentationRunnerArguments.minTps=10.0 |
| 280 | + */ |
| 281 | + private const val DEFAULT_MIN_TPS = 1.0f |
| 282 | + |
| 283 | + /** Maximum time-to-first-token in milliseconds. 30s is generous for emulator. */ |
| 284 | + private const val MAX_TTFT_MS = 30_000 |
| 285 | + |
| 286 | + /** |
| 287 | + * Maximum coefficient of variation (stddev/mean) for TPS across runs. 0.5 = up to 50% relative |
| 288 | + * variance, which is generous for noisy emulator environments. Tighten for dedicated devices. |
| 289 | + */ |
| 290 | + private const val MAX_CV = 0.5f |
| 291 | + |
| 292 | + /** Number of runs for the stability test. */ |
| 293 | + private const val STABILITY_ITERATIONS = 3 |
| 294 | + |
| 295 | + /** Per-test timeout: 5 minutes to accommodate slow emulator environments. */ |
| 296 | + private const val MAX_TEST_TIMEOUT_MS = 300_000L |
| 297 | + } |
| 298 | +} |
0 commit comments