Skip to content

Commit d87890d

Browse files
psiddhCopilotclaude
authored
Add conversation-history instrumentation tests for LlmModule (pytorch#19679) (pytorch#19679)
Summary: Adds `LlmModuleConversationHistoryTest`, an Android instrumentation test that exercises the multi-turn / KV-cache plumbing on `LlmModule`. The OKR theme this enables is "Feature testing → conversation history" (3.2), which depends on `prefillPrompt` + `resetContext` semantics being correct. The test runs on the existing TinyStories-110M fixture pulled by `android_test_setup.sh` from the public `ossci-android` S3 bucket, so it works on **both** internal fbsource Android CI and OSS GitHub Actions Android CI without any new fixture infrastructure. Because TinyStories is too small and not instruction-tuned, content-level assertions (e.g. "did the model recall the user's name") are not reliable. Instead, the test asserts four behavioral invariants of the conversation-history surface that any production multi-turn flow depends on: 1. `testResetContextProducesDeterministicOutput` — at temperature=0 (greedy decode), running the same prompt twice with `resetContext()` between yields identical token streams. This is the foundational invariant: clearing the KV cache truly returns the model to a clean state. 2. `testKvCacheStatePersistsAcrossGenerateCalls` — without `resetContext()` between calls, two `generate()` calls with the same prompt diverge, proving the KV cache is preserved across turns. If this ever fails, multi-turn conversation is silently broken. 3. `testPrefillPromptInfluencesNextGeneration` — `prefillPrompt(history)` followed by `generate(prompt)` differs from a clean-context `generate(prompt)`, proving the prefilled context actually reaches the decoder. 4. `testResetContextClearsPrefilledHistory` — `prefillPrompt + resetContext + generate` matches a clean-slate `generate`, proving reset fully clears prefilled state. Reviewed By: GregoryComer, kirklandsign Differential Revision: D105741356 Pulled By: psiddh --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent 7724fd7 commit d87890d

1 file changed

Lines changed: 196 additions & 0 deletions

File tree

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
package org.pytorch.executorch
9+
10+
import androidx.test.ext.junit.runners.AndroidJUnit4
11+
import java.io.File
12+
import java.io.IOException
13+
import org.apache.commons.io.FileUtils
14+
import org.junit.After
15+
import org.junit.Assert.assertEquals
16+
import org.junit.Assert.assertNotEquals
17+
import org.junit.Assert.assertTrue
18+
import org.junit.Before
19+
import org.junit.Test
20+
import org.junit.runner.RunWith
21+
import org.pytorch.executorch.TestFileUtils.getTestFilePath
22+
import org.pytorch.executorch.extension.llm.LlmCallback
23+
import org.pytorch.executorch.extension.llm.LlmModule
24+
25+
/**
26+
* Behavioral tests for multi-turn / conversation-history semantics on [LlmModule].
27+
*
28+
* These tests run on the TinyStories-110M fixture pulled by `android_test_setup.sh`, which is too
29+
* small and not instruction-tuned, so we cannot assert anything about the *content* of generated
30+
* text (e.g. "did the model recall the user's name"). Instead, we assert structural invariants of
31+
* the KV-cache + reset plumbing that any conversation-history feature depends on:
32+
* 1. Determinism after [LlmModule.resetContext] at temperature=0 (greedy decode).
33+
* 2. State preservation across successive [LlmModule.generate] calls (no reset → output diverges).
34+
* 3. [LlmModule.prefillPrompt] influences the next [LlmModule.generate] call.
35+
* 4. [LlmModule.resetContext] fully clears prefilled state.
36+
*
37+
* All tests run on both internal (fbsource Sandcastle) and OSS (GitHub Actions) Android CI because
38+
* the fixture is fetched from the public `ossci-android` S3 bucket by `android_test_setup.sh` and
39+
* the test only depends on the public `LlmModule` API.
40+
*/
41+
@RunWith(AndroidJUnit4::class)
42+
class LlmModuleConversationHistoryTest {
43+
44+
private lateinit var llmModule: LlmModule
45+
46+
@Before
47+
@Throws(IOException::class)
48+
fun setUp() {
49+
val pteFile = File(getTestFilePath(TEST_FILE_NAME))
50+
requireNotNull(javaClass.getResourceAsStream(TEST_FILE_NAME)) {
51+
"Test resource $TEST_FILE_NAME not found; did android_test_setup.sh run?"
52+
}
53+
.use { pteStream -> FileUtils.copyInputStreamToFile(pteStream, pteFile) }
54+
55+
val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME))
56+
requireNotNull(javaClass.getResourceAsStream(TOKENIZER_FILE_NAME)) {
57+
"Test resource $TOKENIZER_FILE_NAME not found; did android_test_setup.sh run?"
58+
}
59+
.use { tokenizerStream -> FileUtils.copyInputStreamToFile(tokenizerStream, tokenizerFile) }
60+
61+
llmModule =
62+
LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f)
63+
llmModule.load()
64+
}
65+
66+
@After
67+
fun tearDown() {
68+
if (::llmModule.isInitialized) {
69+
llmModule.close()
70+
}
71+
}
72+
73+
/**
74+
* resetContext() + greedy decode (temperature=0) must produce identical output across two runs
75+
* with the same prompt. This is the foundational invariant any conversation-history feature
76+
* relies on: clearing the KV cache truly returns the model to a clean state.
77+
*/
78+
@Test
79+
@Throws(IOException::class)
80+
fun testResetContextProducesDeterministicOutput() {
81+
val firstRun = generateAndCollect(PROMPT_A)
82+
llmModule.resetContext()
83+
val secondRun = generateAndCollect(PROMPT_A)
84+
85+
assertTrue("Expected non-empty generation on first run", firstRun.isNotEmpty())
86+
assertTrue("Expected non-empty generation on second run", secondRun.isNotEmpty())
87+
assertEquals(
88+
"Greedy generation after resetContext() must be deterministic for the same prompt.",
89+
firstRun,
90+
secondRun,
91+
)
92+
}
93+
94+
/**
95+
* Without resetContext() between calls, KV-cache state persists and influences subsequent
96+
* generation. Generating the same prompt twice in a row should produce different output the
97+
* second time (because the KV cache is no longer empty and start position is non-zero), or the
98+
* second call may throw because the runtime detects the stale KV state.
99+
*
100+
* Either outcome proves state persistence. If this test ever starts failing (i.e. both calls
101+
* succeed with equal output), the runtime is silently dropping state between generate() calls —
102+
* that would break multi-turn conversations.
103+
*/
104+
@Test
105+
@Throws(IOException::class)
106+
fun testKvCacheStatePersistsAcrossGenerateCalls() {
107+
val firstRun = generateAndCollect(PROMPT_A)
108+
assertTrue("Expected non-empty generation on first run", firstRun.isNotEmpty())
109+
110+
try {
111+
val secondRun = generateAndCollect(PROMPT_A)
112+
assertNotEquals(
113+
"Without resetContext(), repeated generate() calls must reflect persisted KV state.",
114+
firstRun,
115+
secondRun,
116+
)
117+
} catch (_: ExecutorchRuntimeException) {
118+
// The second generate() threw because KV-cache state from the first call
119+
// affected execution — this also proves state persistence.
120+
}
121+
}
122+
123+
/**
124+
* prefillPrompt() must influence the next generate() — i.e. prefilled tokens are part of the
125+
* conversation history. If prefilling has no effect, multi-turn flows that rely on injecting
126+
* prior turns via prefill are broken.
127+
*/
128+
@Test
129+
@Throws(IOException::class)
130+
fun testPrefillPromptInfluencesNextGeneration() {
131+
val baselineRun = generateAndCollect(PROMPT_A)
132+
133+
llmModule.resetContext()
134+
llmModule.prefillPrompt(PREFILL_HISTORY)
135+
val withHistoryRun = generateAndCollect(PROMPT_A)
136+
137+
assertTrue("Expected non-empty baseline generation", baselineRun.isNotEmpty())
138+
assertTrue("Expected non-empty post-prefill generation", withHistoryRun.isNotEmpty())
139+
assertNotEquals(
140+
"prefillPrompt() must alter the KV state seen by the next generate() call.",
141+
baselineRun,
142+
withHistoryRun,
143+
)
144+
}
145+
146+
/**
147+
* resetContext() must fully clear prefilled state — running prefill then resetting then
148+
* generating should match a clean-slate generation of the same prompt.
149+
*/
150+
@Test
151+
@Throws(IOException::class)
152+
fun testResetContextClearsPrefilledHistory() {
153+
val cleanRun = generateAndCollect(PROMPT_A)
154+
155+
llmModule.resetContext()
156+
llmModule.prefillPrompt(PREFILL_HISTORY)
157+
llmModule.resetContext()
158+
val postResetRun = generateAndCollect(PROMPT_A)
159+
160+
assertTrue("Expected non-empty clean run", cleanRun.isNotEmpty())
161+
assertTrue("Expected non-empty post-reset run", postResetRun.isNotEmpty())
162+
assertEquals(
163+
"resetContext() after a prefillPrompt() must fully clear KV state.",
164+
cleanRun,
165+
postResetRun,
166+
)
167+
}
168+
169+
private fun generateAndCollect(prompt: String): List<String> {
170+
val collector = CollectingCallback()
171+
llmModule.generate(prompt, SEQ_LEN, collector)
172+
return collector.tokens()
173+
}
174+
175+
private class CollectingCallback : LlmCallback {
176+
private val tokens: MutableList<String> = ArrayList()
177+
178+
override fun onResult(result: String) {
179+
tokens.add(result)
180+
}
181+
182+
override fun onStats(stats: String) = Unit
183+
184+
fun tokens(): List<String> = tokens.toList()
185+
}
186+
187+
companion object {
188+
private const val TEST_FILE_NAME = "/stories.pte"
189+
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
190+
191+
/** Short prompt; SEQ_LEN kept small to keep the test fast on CI emulators/devices. */
192+
private const val PROMPT_A = "Once"
193+
private const val PREFILL_HISTORY = "Long ago, in a small village by the sea, "
194+
private const val SEQ_LEN = 24
195+
}
196+
}

0 commit comments

Comments
 (0)