|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | +package org.pytorch.executorch |
| 9 | + |
| 10 | +import androidx.test.ext.junit.runners.AndroidJUnit4 |
| 11 | +import java.io.File |
| 12 | +import java.io.IOException |
| 13 | +import org.apache.commons.io.FileUtils |
| 14 | +import org.junit.After |
| 15 | +import org.junit.Assert.assertEquals |
| 16 | +import org.junit.Assert.assertNotEquals |
| 17 | +import org.junit.Assert.assertTrue |
| 18 | +import org.junit.Before |
| 19 | +import org.junit.Test |
| 20 | +import org.junit.runner.RunWith |
| 21 | +import org.pytorch.executorch.TestFileUtils.getTestFilePath |
| 22 | +import org.pytorch.executorch.extension.llm.LlmCallback |
| 23 | +import org.pytorch.executorch.extension.llm.LlmModule |
| 24 | + |
| 25 | +/** |
| 26 | + * Behavioral tests for multi-turn / conversation-history semantics on [LlmModule]. |
| 27 | + * |
| 28 | + * These tests run on the TinyStories-110M fixture pulled by `android_test_setup.sh`, which is too |
| 29 | + * small and not instruction-tuned, so we cannot assert anything about the *content* of generated |
| 30 | + * text (e.g. "did the model recall the user's name"). Instead, we assert structural invariants of |
| 31 | + * the KV-cache + reset plumbing that any conversation-history feature depends on: |
| 32 | + * 1. Determinism after [LlmModule.resetContext] at temperature=0 (greedy decode). |
| 33 | + * 2. State preservation across successive [LlmModule.generate] calls (no reset → output diverges). |
| 34 | + * 3. [LlmModule.prefillPrompt] influences the next [LlmModule.generate] call. |
| 35 | + * 4. [LlmModule.resetContext] fully clears prefilled state. |
| 36 | + * |
| 37 | + * All tests run on both internal (fbsource Sandcastle) and OSS (GitHub Actions) Android CI because |
| 38 | + * the fixture is fetched from the public `ossci-android` S3 bucket by `android_test_setup.sh` and |
| 39 | + * the test only depends on the public `LlmModule` API. |
| 40 | + */ |
| 41 | +@RunWith(AndroidJUnit4::class) |
| 42 | +class LlmModuleConversationHistoryTest { |
| 43 | + |
| 44 | + private lateinit var llmModule: LlmModule |
| 45 | + |
| 46 | + @Before |
| 47 | + @Throws(IOException::class) |
| 48 | + fun setUp() { |
| 49 | + val pteFile = File(getTestFilePath(TEST_FILE_NAME)) |
| 50 | + requireNotNull(javaClass.getResourceAsStream(TEST_FILE_NAME)) { |
| 51 | + "Test resource $TEST_FILE_NAME not found; did android_test_setup.sh run?" |
| 52 | + } |
| 53 | + .use { pteStream -> FileUtils.copyInputStreamToFile(pteStream, pteFile) } |
| 54 | + |
| 55 | + val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME)) |
| 56 | + requireNotNull(javaClass.getResourceAsStream(TOKENIZER_FILE_NAME)) { |
| 57 | + "Test resource $TOKENIZER_FILE_NAME not found; did android_test_setup.sh run?" |
| 58 | + } |
| 59 | + .use { tokenizerStream -> FileUtils.copyInputStreamToFile(tokenizerStream, tokenizerFile) } |
| 60 | + |
| 61 | + llmModule = |
| 62 | + LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f) |
| 63 | + llmModule.load() |
| 64 | + } |
| 65 | + |
| 66 | + @After |
| 67 | + fun tearDown() { |
| 68 | + if (::llmModule.isInitialized) { |
| 69 | + llmModule.close() |
| 70 | + } |
| 71 | + } |
| 72 | + |
| 73 | + /** |
| 74 | + * resetContext() + greedy decode (temperature=0) must produce identical output across two runs |
| 75 | + * with the same prompt. This is the foundational invariant any conversation-history feature |
| 76 | + * relies on: clearing the KV cache truly returns the model to a clean state. |
| 77 | + */ |
| 78 | + @Test |
| 79 | + @Throws(IOException::class) |
| 80 | + fun testResetContextProducesDeterministicOutput() { |
| 81 | + val firstRun = generateAndCollect(PROMPT_A) |
| 82 | + llmModule.resetContext() |
| 83 | + val secondRun = generateAndCollect(PROMPT_A) |
| 84 | + |
| 85 | + assertTrue("Expected non-empty generation on first run", firstRun.isNotEmpty()) |
| 86 | + assertTrue("Expected non-empty generation on second run", secondRun.isNotEmpty()) |
| 87 | + assertEquals( |
| 88 | + "Greedy generation after resetContext() must be deterministic for the same prompt.", |
| 89 | + firstRun, |
| 90 | + secondRun, |
| 91 | + ) |
| 92 | + } |
| 93 | + |
| 94 | + /** |
| 95 | + * Without resetContext() between calls, KV-cache state persists and influences subsequent |
| 96 | + * generation. Generating the same prompt twice in a row should produce different output the |
| 97 | + * second time (because the KV cache is no longer empty and start position is non-zero), or the |
| 98 | + * second call may throw because the runtime detects the stale KV state. |
| 99 | + * |
| 100 | + * Either outcome proves state persistence. If this test ever starts failing (i.e. both calls |
| 101 | + * succeed with equal output), the runtime is silently dropping state between generate() calls — |
| 102 | + * that would break multi-turn conversations. |
| 103 | + */ |
| 104 | + @Test |
| 105 | + @Throws(IOException::class) |
| 106 | + fun testKvCacheStatePersistsAcrossGenerateCalls() { |
| 107 | + val firstRun = generateAndCollect(PROMPT_A) |
| 108 | + assertTrue("Expected non-empty generation on first run", firstRun.isNotEmpty()) |
| 109 | + |
| 110 | + try { |
| 111 | + val secondRun = generateAndCollect(PROMPT_A) |
| 112 | + assertNotEquals( |
| 113 | + "Without resetContext(), repeated generate() calls must reflect persisted KV state.", |
| 114 | + firstRun, |
| 115 | + secondRun, |
| 116 | + ) |
| 117 | + } catch (_: ExecutorchRuntimeException) { |
| 118 | + // The second generate() threw because KV-cache state from the first call |
| 119 | + // affected execution — this also proves state persistence. |
| 120 | + } |
| 121 | + } |
| 122 | + |
| 123 | + /** |
| 124 | + * prefillPrompt() must influence the next generate() — i.e. prefilled tokens are part of the |
| 125 | + * conversation history. If prefilling has no effect, multi-turn flows that rely on injecting |
| 126 | + * prior turns via prefill are broken. |
| 127 | + */ |
| 128 | + @Test |
| 129 | + @Throws(IOException::class) |
| 130 | + fun testPrefillPromptInfluencesNextGeneration() { |
| 131 | + val baselineRun = generateAndCollect(PROMPT_A) |
| 132 | + |
| 133 | + llmModule.resetContext() |
| 134 | + llmModule.prefillPrompt(PREFILL_HISTORY) |
| 135 | + val withHistoryRun = generateAndCollect(PROMPT_A) |
| 136 | + |
| 137 | + assertTrue("Expected non-empty baseline generation", baselineRun.isNotEmpty()) |
| 138 | + assertTrue("Expected non-empty post-prefill generation", withHistoryRun.isNotEmpty()) |
| 139 | + assertNotEquals( |
| 140 | + "prefillPrompt() must alter the KV state seen by the next generate() call.", |
| 141 | + baselineRun, |
| 142 | + withHistoryRun, |
| 143 | + ) |
| 144 | + } |
| 145 | + |
| 146 | + /** |
| 147 | + * resetContext() must fully clear prefilled state — running prefill then resetting then |
| 148 | + * generating should match a clean-slate generation of the same prompt. |
| 149 | + */ |
| 150 | + @Test |
| 151 | + @Throws(IOException::class) |
| 152 | + fun testResetContextClearsPrefilledHistory() { |
| 153 | + val cleanRun = generateAndCollect(PROMPT_A) |
| 154 | + |
| 155 | + llmModule.resetContext() |
| 156 | + llmModule.prefillPrompt(PREFILL_HISTORY) |
| 157 | + llmModule.resetContext() |
| 158 | + val postResetRun = generateAndCollect(PROMPT_A) |
| 159 | + |
| 160 | + assertTrue("Expected non-empty clean run", cleanRun.isNotEmpty()) |
| 161 | + assertTrue("Expected non-empty post-reset run", postResetRun.isNotEmpty()) |
| 162 | + assertEquals( |
| 163 | + "resetContext() after a prefillPrompt() must fully clear KV state.", |
| 164 | + cleanRun, |
| 165 | + postResetRun, |
| 166 | + ) |
| 167 | + } |
| 168 | + |
| 169 | + private fun generateAndCollect(prompt: String): List<String> { |
| 170 | + val collector = CollectingCallback() |
| 171 | + llmModule.generate(prompt, SEQ_LEN, collector) |
| 172 | + return collector.tokens() |
| 173 | + } |
| 174 | + |
| 175 | + private class CollectingCallback : LlmCallback { |
| 176 | + private val tokens: MutableList<String> = ArrayList() |
| 177 | + |
| 178 | + override fun onResult(result: String) { |
| 179 | + tokens.add(result) |
| 180 | + } |
| 181 | + |
| 182 | + override fun onStats(stats: String) = Unit |
| 183 | + |
| 184 | + fun tokens(): List<String> = tokens.toList() |
| 185 | + } |
| 186 | + |
| 187 | + companion object { |
| 188 | + private const val TEST_FILE_NAME = "/stories.pte" |
| 189 | + private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" |
| 190 | + |
| 191 | + /** Short prompt; SEQ_LEN kept small to keep the test fast on CI emulators/devices. */ |
| 192 | + private const val PROMPT_A = "Once" |
| 193 | + private const val PREFILL_HISTORY = "Long ago, in a small village by the sea, " |
| 194 | + private const val SEQ_LEN = 24 |
| 195 | + } |
| 196 | +} |
0 commit comments