Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
01eb74b
feat: Enhance Java testing framework with stdout capture and sorting …
HeshamHM28 Feb 25, 2026
036caf4
fix falling tests
HeshamHM28 Feb 25, 2026
463a064
feat: Add tests for void function comparison and instrumentation
HeshamHM28 Feb 25, 2026
c8d4fd3
chore: merge origin/omni-java and resolve test_instrumentation.py con…
HeshamHM28 Feb 25, 2026
8028174
feat: Enhance wrap_target_calls_with_treesitter and _add_behavior_ins…
HeshamHM28 Feb 26, 2026
188fb69
fix: sort test class names for ConsoleLauncher and rebuild runtime JAR
Feb 26, 2026
0f9321c
fix: use globally unique iteration_id for Java behavioral test compar…
Feb 26, 2026
ebe1ace
style: apply ruff formatting to instrumentation files
Feb 26, 2026
7bea0c6
fix: resolve merge conflicts between feat/add/void/func and omni-java
Mar 4, 2026
8330afa
fix: use Object type for assertTrue/assertFalse target call capture
Mar 4, 2026
e5d384d
Optimize JavaAssertTransformer._infer_return_type
codeflash-ai[bot] Mar 4, 2026
5f32ec1
revert: restore original Fibonacci.java in code_to_optimize
Mar 4, 2026
1cc0336
fix: resolve merge conflicts with latest omni-java
Mar 4, 2026
9b4ba30
fix: add SuppressWarnings annotation to void function test expectations
Mar 4, 2026
a57236a
chore: apply ruff formatting to PR-changed files
Mar 4, 2026
39546f2
fix: update test_java_assertion_removal.py for Object type in assertT…
Mar 4, 2026
953ef50
Merge pull request #1757 from codeflash-ai/codeflash/optimize-pr1655-…
claude[bot] Mar 4, 2026
f3916f0
fix: resolve merge conflict with omni-java, unify target_return_type …
Mar 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions code_to_optimize/java/src/main/java/com/example/Fibonacci.java
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,77 @@ public static boolean areConsecutiveFibonacci(long a, long b) {

return Math.abs(indexA - indexB) == 1;
}

/**
* Sort an array in-place using bubble sort.
* Intentionally naive O(n^2) implementation for optimization testing.
*
* @param arr Array to sort (modified in-place)
*/
public static void sortArray(long[] arr) {
if (arr == null) {
throw new IllegalArgumentException("Array must not be null");
}
for (int i = 0; i < arr.length; i++) {
for (int j = 0; j < arr.length - 1 - i; j++) {
if (arr[j] > arr[j + 1]) {
long temp = arr[j];
arr[j] = arr[j + 1];
arr[j + 1] = temp;
}
}
}
}

/**
* Append Fibonacci numbers up to a limit into the provided list.
* Clears the list first, then fills it with Fibonacci numbers less than limit.
* Uses repeated naive recursion — intentionally slow for optimization testing.
*
* @param output List to populate (cleared first)
* @param limit Upper bound (exclusive)
*/
public static void collectFibonacciInto(List<Long> output, long limit) {
if (output == null) {
throw new IllegalArgumentException("Output list must not be null");
}
output.clear();

if (limit <= 0) {
return;
}

int index = 0;
while (true) {
long fib = fibonacci(index);
if (fib >= limit) {
break;
}
output.add(fib);
index++;
if (index > 50) {
break;
}
}
}

/**
* Compute running Fibonacci sums in-place.
* result[i] = sum of fibonacci(0) through fibonacci(i).
* Uses repeated naive recursion — intentionally O(n * 2^n).
*
* @param result Array to fill with running sums (must be pre-allocated)
*/
public static void fillFibonacciRunningSums(long[] result) {
if (result == null) {
throw new IllegalArgumentException("Array must not be null");
}
for (int i = 0; i < result.length; i++) {
long sum = 0;
for (int j = 0; j <= i; j++) {
sum += fibonacci(j);
}
result[i] = sum;
}
}
}
102 changes: 102 additions & 0 deletions code_to_optimize/java/src/test/java/com/example/FibonacciTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.example;

import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;

Expand Down Expand Up @@ -136,4 +137,105 @@ void testAreConsecutiveFibonacci() {
// Non-Fibonacci number
assertFalse(Fibonacci.areConsecutiveFibonacci(4, 5)); // 4 is not Fibonacci
}

@Test
void testSortArray() {
long[] arr = {5, 3, 8, 1, 2, 7, 4, 6};
Fibonacci.sortArray(arr);
assertArrayEquals(new long[]{1, 2, 3, 4, 5, 6, 7, 8}, arr);
}

@Test
void testSortArrayAlreadySorted() {
long[] arr = {1, 2, 3, 4, 5};
Fibonacci.sortArray(arr);
assertArrayEquals(new long[]{1, 2, 3, 4, 5}, arr);
}

@Test
void testSortArrayReversed() {
long[] arr = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
Fibonacci.sortArray(arr);
assertArrayEquals(new long[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, arr);
}

@Test
void testSortArrayDuplicates() {
long[] arr = {3, 1, 4, 1, 5, 9, 2, 6, 5, 3};
Fibonacci.sortArray(arr);
assertArrayEquals(new long[]{1, 1, 2, 3, 3, 4, 5, 5, 6, 9}, arr);
}

@Test
void testSortArrayEmpty() {
long[] arr = {};
Fibonacci.sortArray(arr);
assertArrayEquals(new long[]{}, arr);
}

@Test
void testSortArraySingle() {
long[] arr = {42};
Fibonacci.sortArray(arr);
assertArrayEquals(new long[]{42}, arr);
}

@Test
void testSortArrayNegatives() {
long[] arr = {-3, -1, -4, -1, -5};
Fibonacci.sortArray(arr);
assertArrayEquals(new long[]{-5, -4, -3, -1, -1}, arr);
}

@Test
void testSortArrayNull() {
assertThrows(IllegalArgumentException.class, () -> Fibonacci.sortArray(null));
}

@Test
void testCollectFibonacciInto() {
List<Long> output = new ArrayList<>();
Fibonacci.collectFibonacciInto(output, 10);
assertEquals(7, output.size());
assertEquals(List.of(0L, 1L, 1L, 2L, 3L, 5L, 8L), output);
}

@Test
void testCollectFibonacciIntoZeroLimit() {
List<Long> output = new ArrayList<>();
Fibonacci.collectFibonacciInto(output, 0);
assertTrue(output.isEmpty());
}

@Test
void testCollectFibonacciIntoClearsExisting() {
List<Long> output = new ArrayList<>(List.of(99L, 100L));
Fibonacci.collectFibonacciInto(output, 5);
assertEquals(List.of(0L, 1L, 1L, 2L, 3L), output);
}

@Test
void testCollectFibonacciIntoNull() {
assertThrows(IllegalArgumentException.class, () -> Fibonacci.collectFibonacciInto(null, 10));
}

@Test
void testFillFibonacciRunningSums() {
long[] result = new long[6];
Fibonacci.fillFibonacciRunningSums(result);
// sums: fib(0)=0, 0+1=1, 0+1+1=2, 0+1+1+2=4, 0+1+1+2+3=7, 0+1+1+2+3+5=12
assertArrayEquals(new long[]{0, 1, 2, 4, 7, 12}, result);
}

@Test
void testFillFibonacciRunningSumsEmpty() {
long[] result = new long[0];
Fibonacci.fillFibonacciRunningSums(result);
assertArrayEquals(new long[]{}, result);
}

@Test
void testFillFibonacciRunningSumsNull() {
assertThrows(IllegalArgumentException.class, () -> Fibonacci.fillFibonacciRunningSums(null));
}
}
127 changes: 79 additions & 48 deletions codeflash-java-runtime/src/main/java/com/codeflash/Comparator.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ public static void main(String[] args) {
}

static String compareDatabases(String originalDbPath, String candidateDbPath) throws Exception {
Map<String, byte[]> originalResults = readTestResults(originalDbPath);
Map<String, byte[]> candidateResults = readTestResults(candidateDbPath);
Map<String, TestResult> originalResults = readTestResults(originalDbPath);
Map<String, TestResult> candidateResults = readTestResults(candidateDbPath);

Set<String> allKeys = new LinkedHashSet<>();
allKeys.addAll(originalResults.keySet());
Expand All @@ -87,46 +87,50 @@ static String compareDatabases(String originalDbPath, String candidateDbPath) th
int skippedDeserializationErrors = 0;

for (String key : allKeys) {
byte[] origBytes = originalResults.get(key);
byte[] candBytes = candidateResults.get(key);
TestResult origResult = originalResults.get(key);
TestResult candResult = candidateResults.get(key);

if (origBytes == null && candBytes == null) {
// Both null (void methods) — a real comparison (void-to-void match)
actualComparisons++;
continue;
}
byte[] origBytes = origResult != null ? origResult.returnValue : null;
byte[] candBytes = candResult != null ? candResult.returnValue : null;

if (origBytes == null) {
if (origBytes == null && candBytes == null) {
// Both null (void methods) — check stdout still
} else if (origBytes == null) {
Object candObj = safeDeserialize(candBytes);
diffs.add(formatDiff("missing", key, 0, null, safeToString(candObj)));
actualComparisons++;
continue;
}

if (candBytes == null) {
} else if (candBytes == null) {
Object origObj = safeDeserialize(origBytes);
diffs.add(formatDiff("missing", key, 0, safeToString(origObj), null));
actualComparisons++;
continue;
}
} else {
Object origObj = safeDeserialize(origBytes);
Object candObj = safeDeserialize(candBytes);

Object origObj = safeDeserialize(origBytes);
Object candObj = safeDeserialize(candBytes);
if (isDeserializationError(origObj) || isDeserializationError(candObj)) {
skippedDeserializationErrors++;
continue;
}

if (isDeserializationError(origObj) || isDeserializationError(candObj)) {
skippedDeserializationErrors++;
continue;
try {
if (!compare(origObj, candObj)) {
diffs.add(formatDiff("return_value", key, 0, safeToString(origObj), safeToString(candObj)));
}
} catch (KryoPlaceholderAccessException e) {
skippedPlaceholders++;
continue;
}
}

try {
if (!compare(origObj, candObj)) {
diffs.add(formatDiff("return_value", key, 0, safeToString(origObj), safeToString(candObj)));
}
actualComparisons++;
} catch (KryoPlaceholderAccessException e) {
skippedPlaceholders++;
continue;
// Compare stdout (for void methods and side-effect verification)
String origStdout = origResult != null ? origResult.stdout : null;
String candStdout = candResult != null ? candResult.stdout : null;
if (origStdout != null && candStdout != null && !origStdout.equals(candStdout)) {
diffs.add(formatDiff("stdout", key, 0, truncate(origStdout, 200), truncate(candStdout, 200)));
}
actualComparisons++;
}

boolean equivalent = diffs.isEmpty() && actualComparisons > 0;
Expand Down Expand Up @@ -154,31 +158,53 @@ static String compareDatabases(String originalDbPath, String candidateDbPath) th
return json.toString();
}

private static Map<String, byte[]> readTestResults(String dbPath) throws Exception {
Map<String, byte[]> results = new LinkedHashMap<>();
private static class TestResult {
final byte[] returnValue;
final String stdout;

TestResult(byte[] returnValue, String stdout) {
this.returnValue = returnValue;
this.stdout = stdout;
}
}

private static Map<String, TestResult> readTestResults(String dbPath) throws Exception {
Map<String, TestResult> results = new LinkedHashMap<>();
String url = "jdbc:sqlite:" + dbPath;

try (Connection conn = DriverManager.getConnection(url);
Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery(
"SELECT test_module_path, test_class_name, test_function_name, iteration_id, return_value FROM test_results WHERE loop_index = 1")) {
while (rs.next()) {
String testModulePath = rs.getString("test_module_path");
String testClassName = rs.getString("test_class_name");
String testFunctionName = rs.getString("test_function_name");
String iterationId = rs.getString("iteration_id");
byte[] returnValue = rs.getBytes("return_value");
// Strip the CODEFLASH_TEST_ITERATION suffix (e.g. "7_0" -> "7")
// Original runs with _0, candidate with _1, but the test iteration
// counter before the underscore is what identifies the invocation.
int lastUnderscore = iterationId.lastIndexOf('_');
if (lastUnderscore > 0) {
iterationId = iterationId.substring(0, lastUnderscore);
Statement stmt = conn.createStatement()) {

// Check if stdout column exists (backward compatibility)
boolean hasStdout = false;
try (ResultSet columns = conn.getMetaData().getColumns(null, null, "test_results", "stdout")) {
hasStdout = columns.next();
}

String query = hasStdout
? "SELECT test_module_path, test_class_name, test_function_name, iteration_id, return_value, stdout FROM test_results WHERE loop_index = 1"
: "SELECT test_module_path, test_class_name, test_function_name, iteration_id, return_value FROM test_results WHERE loop_index = 1";

try (ResultSet rs = stmt.executeQuery(query)) {
while (rs.next()) {
String testModulePath = rs.getString("test_module_path");
String testClassName = rs.getString("test_class_name");
String testFunctionName = rs.getString("test_function_name");
String iterationId = rs.getString("iteration_id");
byte[] returnValue = rs.getBytes("return_value");
String stdout = hasStdout ? rs.getString("stdout") : null;
// Strip the CODEFLASH_TEST_ITERATION suffix (e.g. "7_0" -> "7")
// Original runs with _0, candidate with _1, but the test iteration
// counter before the underscore is what identifies the invocation.
int lastUnderscore = iterationId.lastIndexOf('_');
if (lastUnderscore > 0) {
iterationId = iterationId.substring(0, lastUnderscore);
}
// Use module:class:function:iteration as key to uniquely identify
// each invocation across different test files, classes, and methods
String key = testModulePath + ":" + testClassName + ":" + testFunctionName + "::" + iterationId;
results.put(key, new TestResult(returnValue, stdout));
}
// Use module:class:function:iteration as key to uniquely identify
// each invocation across different test files, classes, and methods
String key = testModulePath + ":" + testClassName + ":" + testFunctionName + "::" + iterationId;
results.put(key, returnValue);
}
}
return results;
Expand Down Expand Up @@ -214,6 +240,11 @@ private static String safeToString(Object obj) {
}
}

private static String truncate(String s, int maxLen) {
if (s == null || s.length() <= maxLen) return s;
return s.substring(0, maxLen) + "...";
}

private static String formatDiff(String scope, String methodId, int callId,
String originalValue, String candidateValue) {
StringBuilder sb = new StringBuilder();
Expand Down
Loading
Loading