Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from codemodder.codetf import CodeTF
from codemodder.context import CodemodExecutionContext
from codemodder.dependency import Dependency
from codemodder.llm import MisconfiguredAIClient, TokenUsage, log_token_usage
from codemodder.llm import TokenUsage, log_token_usage
from codemodder.logging import configure_logger, log_list, log_section, logger
from codemodder.project_analysis.file_parsers.package_store import PackageStore
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
Expand Down Expand Up @@ -134,7 +134,6 @@ def run(
original_cli_args: list[str] | None = None,
codemod_registry: registry.CodemodRegistry | None = None,
sast_only: bool = False,
ai_client: bool = True,
log_matched_files: bool = False,
remediation: bool = False,
) -> tuple[CodeTF | None, int, TokenUsage]:
Expand Down Expand Up @@ -162,24 +161,18 @@ def run(

repo_manager = PythonRepoManager(Path(directory))

try:
context = CodemodExecutionContext(
Path(directory),
dry_run,
verbose,
codemod_registry,
provider_registry,
repo_manager,
path_include,
path_exclude,
tool_result_files_map,
max_workers,
ai_client,
)
except MisconfiguredAIClient as e:
logger.error(e)
# Codemodder instructions conflicted (according to spec)
return None, 3, token_usage
context = CodemodExecutionContext(
Path(directory),
dry_run,
verbose,
codemod_registry,
provider_registry,
repo_manager,
path_include,
path_exclude,
tool_result_files_map,
max_workers,
)

context.repo_manager.parse_project()

Expand Down
11 changes: 0 additions & 11 deletions src/codemodder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
build_failed_dependency_notification,
)
from codemodder.file_context import FileContext
from codemodder.llm import setup_azure_llama_llm_client, setup_openai_llm_client
from codemodder.logging import log_list, logger
from codemodder.project_analysis.file_parsers.package_store import PackageStore
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
Expand All @@ -28,9 +27,6 @@
from codemodder.utils.update_finding_metadata import update_finding_metadata

if TYPE_CHECKING:
from azure.ai.inference import ChatCompletionsClient
from openai import OpenAI

from codemodder.codemods.base_codemod import BaseCodemod


Expand All @@ -51,8 +47,6 @@ class CodemodExecutionContext:
max_workers: int = 1
tool_result_files_map: dict[str, list[Path]]
semgrep_prefilter_results: ResultSet | None = None
openai_llm_client: OpenAI | None = None
azure_llama_llm_client: ChatCompletionsClient | None = None

def __init__(
self,
Expand All @@ -66,7 +60,6 @@ def __init__(
path_exclude: list[str] | None = None,
tool_result_files_map: dict[str, list[Path]] | None = None,
max_workers: int = 1,
ai_client: bool = True,
):
self.directory = directory
self.dry_run = dry_run
Expand All @@ -85,10 +78,6 @@ def __init__(
self.max_workers = max_workers
self.tool_result_files_map = tool_result_files_map or {}
self.semgrep_prefilter_results = None
self.openai_llm_client = setup_openai_llm_client() if ai_client else None
self.azure_llama_llm_client = (
setup_azure_llama_llm_client() if ai_client else None
)

def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]):
self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets)
Expand Down
85 changes: 0 additions & 85 deletions src/codemodder/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,13 @@

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING

from typing_extensions import Self

try:
from openai import AzureOpenAI, OpenAI
except ImportError:
OpenAI = None
AzureOpenAI = None

try:
from azure.ai.inference import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential
except ImportError:
ChatCompletionsClient = None
AzureKeyCredential = None

if TYPE_CHECKING:
from openai import OpenAI
from azure.ai.inference import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential

from codemodder.logging import logger

__all__ = [
"MODELS",
"setup_openai_llm_client",
"setup_azure_llama_llm_client",
"MisconfiguredAIClient",
"TokenUsage",
"log_token_usage",
]
Expand All @@ -42,7 +20,6 @@
"o1-mini",
"o1",
]
DEFAULT_AZURE_OPENAI_API_VERSION = "2024-02-01"


class ModelRegistry(dict):
Expand All @@ -66,68 +43,6 @@ def __getattr__(self, name):
MODELS = ModelRegistry(models)


def setup_openai_llm_client() -> OpenAI | None:
"""Configure either the Azure OpenAI LLM client or the OpenAI client, in that order."""
if not AzureOpenAI:
logger.info("Azure OpenAI API client not available")
return None

azure_openapi_key = os.getenv("CODEMODDER_AZURE_OPENAI_API_KEY")
azure_openapi_endpoint = os.getenv("CODEMODDER_AZURE_OPENAI_ENDPOINT")
if bool(azure_openapi_key) ^ bool(azure_openapi_endpoint):
raise MisconfiguredAIClient(
"Azure OpenAI API key and endpoint must both be set or unset"
)

if azure_openapi_key and azure_openapi_endpoint:
logger.info("Using Azure OpenAI API client")
return AzureOpenAI(
api_key=azure_openapi_key,
api_version=os.getenv(
"CODEMODDER_AZURE_OPENAI_API_VERSION",
DEFAULT_AZURE_OPENAI_API_VERSION,
),
azure_endpoint=azure_openapi_endpoint,
)

if not OpenAI:
logger.info("OpenAI API client not available")
return None

if not (api_key := os.getenv("CODEMODDER_OPENAI_API_KEY")):
logger.info("OpenAI API key not found")
return None

logger.info("Using OpenAI API client")
return OpenAI(api_key=api_key)


def setup_azure_llama_llm_client() -> ChatCompletionsClient | None:
"""Configure the Azure Llama LLM client."""
if not ChatCompletionsClient:
logger.info("Azure Llama client not available")
return None

azure_llama_key = os.getenv("CODEMODDER_AZURE_LLAMA_API_KEY")
azure_llama_endpoint = os.getenv("CODEMODDER_AZURE_LLAMA_ENDPOINT")
if bool(azure_llama_key) ^ bool(azure_llama_endpoint):
raise MisconfiguredAIClient(
"Azure Llama API key and endpoint must both be set or unset"
)

if azure_llama_key and azure_llama_endpoint:
logger.info("Using Azure Llama API client")
return ChatCompletionsClient(
credential=AzureKeyCredential(azure_llama_key),
endpoint=azure_llama_endpoint,
)
return None


class MisconfiguredAIClient(ValueError):
pass


@dataclass
class TokenUsage:
completion_tokens: int = 0
Expand Down
Loading
Loading