Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions codeflash/code_utils/config_java_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

from pathlib import Path


def validate_java_module_resolution(source_file: Path, project_root: Path, module_root: Path) -> tuple[bool, str]:
"""Validate that a Java source file can be compiled and tested within the given module structure.

Checks:
- Source file exists
- Source file is within project root
- A build config (pom.xml, build.gradle, build.gradle.kts) exists in project root
- Source file is within module root
- Package declaration matches directory structure

Returns:
(True, "") if valid, (False, error_message) if invalid.

"""
source_file = source_file.resolve()
project_root = project_root.resolve()
module_root = module_root.resolve()

if not source_file.exists():
return False, f"Source file does not exist: {source_file}"

try:
source_file.relative_to(project_root)
except ValueError:
return False, f"Source file {source_file} is outside the project root {project_root}"

has_build_config = (
(project_root / "pom.xml").exists()
or (project_root / "build.gradle").exists()
or (project_root / "build.gradle.kts").exists()
)
if not has_build_config:
return False, f"No build configuration (pom.xml, build.gradle, build.gradle.kts) found in {project_root}"

try:
source_file.relative_to(module_root)
except ValueError:
return False, f"Source file {source_file} is outside the module root {module_root}"

# Validate package declaration matches directory structure
package_name = _parse_package_declaration(source_file)
if package_name is not None:
expected_dir = module_root / Path(*package_name.split("."))
actual_dir = source_file.parent
if actual_dir.resolve() != expected_dir.resolve():
return False, (
f"Package declaration '{package_name}' does not match directory structure. "
f"Expected file at {expected_dir}, but found at {actual_dir}"
)

return True, ""


def _parse_package_declaration(source_file: Path) -> str | None:
"""Extract the package name from a Java source file, or None if no package declaration."""
try:
content = source_file.read_text(encoding="utf-8")
except Exception:
return None

for line in content.split("\n"):
stripped = line.strip()
if stripped.startswith("package "):
return stripped[8:].rstrip(";").strip()
# Skip comments and blank lines at the top of the file
if (
stripped
and not stripped.startswith("//")
and not stripped.startswith("/*")
and not stripped.startswith("*")
):
break
return None


def infer_java_module_root(source_file: Path, project_root: Path | None = None) -> Path:
"""Infer the correct Java module root (source root) for a source file.

If project_root is None, walks up from source_file to find a build config.
Then uses find_source_root() from build_tools, falling back to
project_root/src/main/java, then project_root itself.
"""
source_file = source_file.resolve()

if project_root is None:
project_root = _find_project_root_from_file(source_file)
else:
project_root = project_root.resolve()

if project_root is None:
return source_file.parent

from codeflash.languages.java.build_tools import find_source_root

source_root = find_source_root(project_root)
if source_root is not None:
return source_root

# Fall back to standard Maven layout
standard_src = project_root / "src" / "main" / "java"
if standard_src.exists():
return standard_src

return project_root


def _find_project_root_from_file(source_file: Path) -> Path | None:
"""Walk up from source_file to find a directory with a build config."""
current = source_file.parent
while current != current.parent:
if (current / "pom.xml").exists():
return current
if (current / "build.gradle").exists() or (current / "build.gradle.kts").exists():
return current
current = current.parent
return None
151 changes: 151 additions & 0 deletions tests/code_utils/test_config_java_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from codeflash.code_utils.config_java_validation import infer_java_module_root, validate_java_module_resolution

if TYPE_CHECKING:
from pathlib import Path


class TestValidateJavaModuleResolution:
def test_valid_maven_project(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
src_root = project_root / "src" / "main" / "java"
pkg_dir = src_root / "com" / "example"
pkg_dir.mkdir(parents=True)
source_file = pkg_dir / "Foo.java"
source_file.write_text("package com.example;\npublic class Foo {}", encoding="utf-8")

valid, error = validate_java_module_resolution(source_file, project_root, src_root)
assert valid is True
assert error == ""

def test_source_does_not_exist(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
source_file = project_root / "src" / "main" / "java" / "Missing.java"

valid, error = validate_java_module_resolution(source_file, project_root, project_root)
assert valid is False
assert "does not exist" in error

def test_source_outside_project_root(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
outside_file = tmp_path / "outside" / "Foo.java"
outside_file.parent.mkdir(parents=True)
outside_file.write_text("public class Foo {}", encoding="utf-8")

valid, error = validate_java_module_resolution(outside_file, project_root, project_root)
assert valid is False
assert "outside the project root" in error

def test_no_build_config(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
source_file = project_root / "Foo.java"
source_file.write_text("public class Foo {}", encoding="utf-8")

valid, error = validate_java_module_resolution(source_file, project_root, project_root)
assert valid is False
assert "No build configuration" in error

def test_source_outside_module_root(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
module_root = project_root / "src" / "main" / "java"
module_root.mkdir(parents=True)
# Source file is in the project root, not in module root
source_file = project_root / "Foo.java"
source_file.write_text("public class Foo {}", encoding="utf-8")

valid, error = validate_java_module_resolution(source_file, project_root, module_root)
assert valid is False
assert "outside the module root" in error

def test_package_declaration_mismatch(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
src_root = project_root / "src" / "main" / "java"
wrong_dir = src_root / "com" / "bar"
wrong_dir.mkdir(parents=True)
source_file = wrong_dir / "Foo.java"
source_file.write_text("package com.foo;\npublic class Foo {}", encoding="utf-8")

valid, error = validate_java_module_resolution(source_file, project_root, src_root)
assert valid is False
assert "does not match directory structure" in error

def test_no_package_declaration(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
src_root = project_root / "src" / "main" / "java"
src_root.mkdir(parents=True)
source_file = src_root / "Foo.java"
source_file.write_text("public class Foo {}", encoding="utf-8")

valid, error = validate_java_module_resolution(source_file, project_root, src_root)
assert valid is True
assert error == ""

def test_gradle_project(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "build.gradle").write_text("apply plugin: 'java'", encoding="utf-8")
src_root = project_root / "src" / "main" / "java"
pkg_dir = src_root / "com" / "example"
pkg_dir.mkdir(parents=True)
source_file = pkg_dir / "Foo.java"
source_file.write_text("package com.example;\npublic class Foo {}", encoding="utf-8")

valid, error = validate_java_module_resolution(source_file, project_root, src_root)
assert valid is True
assert error == ""


class TestInferJavaModuleRoot:
def test_infers_standard_maven_layout(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
src_root = project_root / "src" / "main" / "java"
src_root.mkdir(parents=True)
source_file = src_root / "com" / "example" / "Foo.java"
source_file.parent.mkdir(parents=True)
source_file.write_text("package com.example;\npublic class Foo {}", encoding="utf-8")

result = infer_java_module_root(source_file, project_root)
assert result.resolve() == src_root.resolve()

def test_falls_back_to_project_root(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
# No src/main/java, no alternative source dirs with .java files
source_file = project_root / "Foo.java"
source_file.write_text("public class Foo {}", encoding="utf-8")

result = infer_java_module_root(source_file, project_root)
assert result.resolve() == project_root.resolve()

def test_detects_project_root_from_pom(self, tmp_path: Path) -> None:
project_root = tmp_path / "project"
project_root.mkdir()
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
src_root = project_root / "src" / "main" / "java"
src_root.mkdir(parents=True)
source_file = src_root / "com" / "example" / "Foo.java"
source_file.parent.mkdir(parents=True)
source_file.write_text("package com.example;\npublic class Foo {}", encoding="utf-8")

# Don't pass project_root — let it detect from pom.xml
result = infer_java_module_root(source_file, project_root=None)
assert result.resolve() == src_root.resolve()
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

from argparse import Namespace
from pathlib import Path
from unittest.mock import MagicMock, patch

from codeflash.either import is_successful
from codeflash.languages.java.function_optimizer import JavaFunctionOptimizer


def _make_optimizer(tmp_path: Path, source_file: Path, module_root: Path | None = None) -> JavaFunctionOptimizer:
"""Create a JavaFunctionOptimizer with minimal mocked dependencies."""
optimizer = object.__new__(JavaFunctionOptimizer)
optimizer.project_root = tmp_path.resolve()

fto = MagicMock()
fto.file_path = source_file.resolve()
fto.language = "java"
fto.function_name = "doSomething"
optimizer.function_to_optimize = fto

args = Namespace(module_root=module_root or tmp_path, project_root=tmp_path, no_gen_tests=False)
optimizer.args = args
return optimizer


class TestTryCorrectModuleRoot:
def test_noop_when_config_is_correct(self, tmp_path: Path) -> None:
project_root = tmp_path
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
src_root = project_root / "src" / "main" / "java"
pkg_dir = src_root / "com" / "example"
pkg_dir.mkdir(parents=True)
source_file = pkg_dir / "Foo.java"
source_file.write_text("package com.example;\npublic class Foo {}", encoding="utf-8")

optimizer = _make_optimizer(project_root, source_file, module_root=src_root)

result = optimizer.try_correct_module_root()
assert result is True
# Module root should remain unchanged
assert Path(optimizer.args.module_root).resolve() == src_root.resolve()

def test_corrects_module_root_and_updates_config(self, tmp_path: Path) -> None:
project_root = tmp_path
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
src_root = project_root / "src" / "main" / "java"
pkg_dir = src_root / "com" / "example"
pkg_dir.mkdir(parents=True)
source_file = pkg_dir / "Foo.java"
source_file.write_text("package com.example;\npublic class Foo {}", encoding="utf-8")

# Start with wrong module root (project root instead of src/main/java)
optimizer = _make_optimizer(project_root, source_file, module_root=project_root)

# Mock the config file update since we don't want to actually write config
with patch.object(optimizer, "_update_config_module_root"):
result = optimizer.try_correct_module_root()

assert result is True
assert Path(optimizer.args.module_root).resolve() == src_root.resolve()

def test_returns_false_when_inferred_root_doesnt_contain_source(self, tmp_path: Path) -> None:
project_root = tmp_path
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
# Source file is outside any reasonable source root
outside_dir = tmp_path / "outside"
outside_dir.mkdir()
source_file = outside_dir / "Foo.java"
source_file.write_text("public class Foo {}", encoding="utf-8")

# Module root is wrong and doesn't contain the source file
wrong_root = project_root / "src" / "main" / "java"
wrong_root.mkdir(parents=True)
optimizer = _make_optimizer(project_root, source_file, module_root=wrong_root)
# Also set the fto.file_path to a file outside the project root
# This ensures the validation fails and infer can't fix it
optimizer.function_to_optimize.file_path = source_file.resolve()

result = optimizer.try_correct_module_root()
assert result is False


class TestCanBeOptimized:
def test_returns_failure_when_source_outside_module_root(self, tmp_path: Path) -> None:
project_root = tmp_path
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
module_root = project_root / "src" / "main" / "java"
module_root.mkdir(parents=True)

# Source file is outside the module root
source_file = project_root / "Foo.java"
source_file.write_text("public class Foo {}", encoding="utf-8")

optimizer = _make_optimizer(project_root, source_file, module_root=module_root)

result = optimizer.can_be_optimized()
assert not is_successful(result)
assert "Java module validation failed" in result.failure()

def test_delegates_to_super_when_valid(self, tmp_path: Path) -> None:
project_root = tmp_path
(project_root / "pom.xml").write_text("<project/>", encoding="utf-8")
src_root = project_root / "src" / "main" / "java"
pkg_dir = src_root / "com" / "example"
pkg_dir.mkdir(parents=True)
source_file = pkg_dir / "Foo.java"
source_file.write_text("package com.example;\npublic class Foo {}", encoding="utf-8")

optimizer = _make_optimizer(project_root, source_file, module_root=src_root)

# Mock super().can_be_optimized() since it has many dependencies
with patch.object(
JavaFunctionOptimizer.__bases__[0], "can_be_optimized", return_value=MagicMock()
) as mock_super:
result = optimizer.can_be_optimized()
mock_super.assert_called_once()
Loading