|
1 | 1 | import { describe, test, expect, beforeEach } from 'bun:test' |
| 2 | +import { readFileSync } from 'fs' |
| 3 | +import { join } from 'path' |
2 | 4 |
|
3 | 5 | import contextPruner from '../context-pruner' |
4 | 6 |
|
5 | 7 | import type { JSONValue, Message, ToolMessage } from '../types/util-types' |
| 8 | +import { AgentState } from 'types/agent-definition' |
6 | 9 | const createMessage = ( |
7 | 10 | role: 'user' | 'assistant', |
8 | 11 | content: string, |
@@ -62,11 +65,16 @@ describe('context-pruner handleSteps', () => { |
62 | 65 | output: string, |
63 | 66 | exitCode?: number, |
64 | 67 | ): [Message, ToolMessage] => |
65 | | - createToolCallPair(toolCallId, 'run_terminal_command', { command }, { |
66 | | - command, |
67 | | - stdout: output, |
68 | | - ...(exitCode !== undefined && { exitCode }), |
69 | | - }) |
| 68 | + createToolCallPair( |
| 69 | + toolCallId, |
| 70 | + 'run_terminal_command', |
| 71 | + { command }, |
| 72 | + { |
| 73 | + command, |
| 74 | + stdout: output, |
| 75 | + ...(exitCode !== undefined && { exitCode }), |
| 76 | + }, |
| 77 | + ) |
70 | 78 |
|
71 | 79 | const createLargeToolPair = ( |
72 | 80 | toolCallId: string, |
@@ -790,6 +798,164 @@ describe('context-pruner image token counting', () => { |
790 | 798 | }) |
791 | 799 | }) |
792 | 800 |
|
| 801 | +describe('context-pruner saved run state overflow', () => { |
| 802 | + test('prunes message history from saved run state with large token count', () => { |
| 803 | + // Load the saved run state file with ~194k tokens in message history |
| 804 | + const runStatePath = join( |
| 805 | + __dirname, |
| 806 | + 'data', |
| 807 | + 'run-state-context-overflow.json', |
| 808 | + ) |
| 809 | + const savedRunState = JSON.parse(readFileSync(runStatePath, 'utf-8')) |
| 810 | + const initialMessages = |
| 811 | + savedRunState.sessionState?.mainAgentState?.messageHistory ?? [] |
| 812 | + |
| 813 | + // Calculate initial token count |
| 814 | + const countTokens = (msgs: any[]) => { |
| 815 | + return msgs.reduce( |
| 816 | + (sum, msg) => sum + Math.ceil(JSON.stringify(msg).length / 3), |
| 817 | + 0, |
| 818 | + ) |
| 819 | + } |
| 820 | + const initialTokens = countTokens(initialMessages) |
| 821 | + console.log('Initial message count:', initialMessages.length) |
| 822 | + console.log('Initial tokens (approx):', initialTokens) |
| 823 | + |
| 824 | + // Run context-pruner with 100k limit |
| 825 | + const mockAgentState = { |
| 826 | + messageHistory: initialMessages, |
| 827 | + } as AgentState |
| 828 | + const mockLogger = { |
| 829 | + debug: () => {}, |
| 830 | + info: () => {}, |
| 831 | + warn: () => {}, |
| 832 | + error: () => {}, |
| 833 | + } |
| 834 | + |
| 835 | + const maxContextLength = 190_000 |
| 836 | + |
| 837 | + // Override maxMessageTokens via params |
| 838 | + const generator = contextPruner.handleSteps!({ |
| 839 | + agentState: mockAgentState, |
| 840 | + logger: mockLogger, |
| 841 | + params: { maxContextLength }, |
| 842 | + }) |
| 843 | + |
| 844 | + const results: any[] = [] |
| 845 | + let result = generator.next() |
| 846 | + while (!result.done) { |
| 847 | + if (typeof result.value === 'object') { |
| 848 | + results.push(result.value) |
| 849 | + } |
| 850 | + result = generator.next() |
| 851 | + } |
| 852 | + |
| 853 | + expect(results).toHaveLength(1) |
| 854 | + const prunedMessages = results[0].input.messages |
| 855 | + const finalTokens = countTokens(prunedMessages) |
| 856 | + |
| 857 | + console.log('Final message count:', prunedMessages.length) |
| 858 | + console.log('Final tokens (approx):', finalTokens) |
| 859 | + console.log('Token reduction:', initialTokens - finalTokens) |
| 860 | + |
| 861 | + // The context-pruner should have actually pruned the token count. |
| 862 | + // With a 100k limit and ~194k tokens, the pruner targets: |
| 863 | + // targetTokens = maxContextLength * shortenedMessageTokenFactor = 100k * 0.5 = 50k |
| 864 | + // So final tokens should be around 50k. |
| 865 | + const shortenedMessageTokenFactor = 0.5 |
| 866 | + const targetTokens = maxContextLength * shortenedMessageTokenFactor |
| 867 | + // Allow 500 tokens overhead |
| 868 | + const maxAllowedTokens = targetTokens + 500 |
| 869 | + |
| 870 | + expect(finalTokens).toBeLessThan(maxAllowedTokens) |
| 871 | + }) |
| 872 | + |
| 873 | + test('accounts for system prompt and tool definitions when pruning with default 200k limit', () => { |
| 874 | + // Load the saved run state file with ~194k tokens in message history |
| 875 | + const runStatePath = join( |
| 876 | + __dirname, |
| 877 | + 'data', |
| 878 | + 'run-state-context-overflow.json', |
| 879 | + ) |
| 880 | + const savedRunState = JSON.parse(readFileSync(runStatePath, 'utf-8')) |
| 881 | + const initialMessages = |
| 882 | + savedRunState.sessionState?.mainAgentState?.messageHistory ?? [] |
| 883 | + |
| 884 | + // Create a huge system prompt (~10k tokens) |
| 885 | + const hugeSystemPrompt = 'x'.repeat(30000) // ~10k tokens |
| 886 | + |
| 887 | + // Create tool definitions (~10k tokens) |
| 888 | + const toolDefinitions = Array.from({ length: 20 }, (_, i) => ({ |
| 889 | + name: `tool_${i}`, |
| 890 | + description: 'A'.repeat(1000), // ~333 tokens each |
| 891 | + parameters: { type: 'object', properties: {} }, |
| 892 | + })) |
| 893 | + |
| 894 | + // Calculate initial token count |
| 895 | + const countTokens = (obj: any) => Math.ceil(JSON.stringify(obj).length / 3) |
| 896 | + const systemPromptTokens = countTokens(hugeSystemPrompt) |
| 897 | + const toolDefinitionTokens = countTokens(toolDefinitions) |
| 898 | + const initialMessageTokens = countTokens(initialMessages) |
| 899 | + const totalInitialTokens = |
| 900 | + systemPromptTokens + toolDefinitionTokens + initialMessageTokens |
| 901 | + |
| 902 | + console.log('System prompt tokens (approx):', systemPromptTokens) |
| 903 | + console.log('Tool definition tokens (approx):', toolDefinitionTokens) |
| 904 | + console.log('Initial message tokens (approx):', initialMessageTokens) |
| 905 | + console.log('Total initial tokens (approx):', totalInitialTokens) |
| 906 | + |
| 907 | + // Run context-pruner with default 200k limit |
| 908 | + // Both systemPrompt and toolDefinitions are read from agentState |
| 909 | + const mockAgentState: any = { |
| 910 | + messageHistory: initialMessages, |
| 911 | + systemPrompt: hugeSystemPrompt, |
| 912 | + toolDefinitions, |
| 913 | + } |
| 914 | + const mockLogger = { |
| 915 | + debug: () => {}, |
| 916 | + info: () => {}, |
| 917 | + warn: () => {}, |
| 918 | + error: () => {}, |
| 919 | + } |
| 920 | + |
| 921 | + // No maxContextLength param, defaults to 200k |
| 922 | + const generator = contextPruner.handleSteps!({ |
| 923 | + agentState: mockAgentState, |
| 924 | + logger: mockLogger, |
| 925 | + params: {}, |
| 926 | + }) |
| 927 | + |
| 928 | + const results: any[] = [] |
| 929 | + let result = generator.next() |
| 930 | + while (!result.done) { |
| 931 | + if (typeof result.value === 'object') { |
| 932 | + results.push(result.value) |
| 933 | + } |
| 934 | + result = generator.next() |
| 935 | + } |
| 936 | + |
| 937 | + expect(results).toHaveLength(1) |
| 938 | + const prunedMessages = results[0].input.messages |
| 939 | + const finalMessageTokens = countTokens(prunedMessages) |
| 940 | + const finalTotalTokens = |
| 941 | + systemPromptTokens + toolDefinitionTokens + finalMessageTokens |
| 942 | + |
| 943 | + console.log('Final message tokens (approx):', finalMessageTokens) |
| 944 | + console.log('Final total tokens (approx):', finalTotalTokens) |
| 945 | + |
| 946 | + // The context-pruner should prune so that system prompt + tools + messages < 200k |
| 947 | + // With ~10k system prompt + ~10k tools and default 200k limit, effective message budget is ~180k |
| 948 | + // Target is shortenedMessageTokenFactor (0.5) of effective budget = ~90k for messages |
| 949 | + // Total should be well under 200k |
| 950 | + const maxContextLength = 200_000 |
| 951 | + const prunedContextLength = maxContextLength * 0.6 |
| 952 | + expect(finalTotalTokens).toBeLessThan(prunedContextLength) |
| 953 | + |
| 954 | + // Also verify significant pruning occurred |
| 955 | + expect(finalMessageTokens).toBeLessThan(initialMessageTokens) |
| 956 | + }) |
| 957 | +}) |
| 958 | + |
793 | 959 | describe('context-pruner edge cases', () => { |
794 | 960 | let mockAgentState: any |
795 | 961 |
|
|
0 commit comments