Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c4f0d0d
Moved `litellm` dependency to a separate package.
blkt Jan 23, 2025
0ce355b
Refactoring anthropic client.
blkt Jan 28, 2025
de09f23
Better iterator abstraction around Anthropic streaming response.
blkt Feb 3, 2025
ccc3b05
Added methods to ChatCompletionRequest.
blkt Feb 5, 2025
fda8577
Refactored response parser.
blkt Feb 6, 2025
8170730
Refactored pipeline to use internal structs.
blkt Feb 6, 2025
82a5bf5
Added native support for openapi and ollama.
blkt Feb 11, 2025
ebd0da8
Fixing base url usage plus other bugs.
blkt Feb 11, 2025
1be90bd
Relax OpenAI models to align with upstream spec (#1023)
jhrozek Feb 12, 2025
bea40d1
Add set_text to Anthropic UserMessage to enable raw string messages (…
jhrozek Feb 12, 2025
ac2f5b4
Start using OpenAI native types in copilot (#1024)
jhrozek Feb 12, 2025
172c4a5
Fixed FIM cache and db access.
blkt Feb 12, 2025
ddf9978
Methods for openai reply types (#1031)
jhrozek Feb 12, 2025
23d193e
Fixed openai-style endpoints from ollama.
blkt Feb 12, 2025
3b061c0
Fixed PII.
blkt Feb 12, 2025
49c15e9
Fixed `get_text`/`set_text` interface in ollama and anthropic.
blkt Feb 13, 2025
aa67d8f
Fix PII output pipeline returning redacted chunks.
blkt Feb 13, 2025
368ecc4
Fixed PII.
blkt Feb 13, 2025
3796f92
A couple of small fixes related to copilot and openai request/reply t…
jhrozek Feb 13, 2025
d229657
Added tests of types module.
blkt Feb 13, 2025
5ab479f
Fixed tool calls for Ollama.
blkt Feb 14, 2025
654db51
Ported OpenAI and OpenRouter providers.
blkt Feb 14, 2025
2d4833c
Changed default values for base urls to just hostnames.
blkt Feb 14, 2025
b1ab764
Fix codegate version (#1062)
jhrozek Feb 15, 2025
6ce1820
Only append user messages with text to the user_messages array, not N…
jhrozek Feb 15, 2025
d52f045
Removed some leftovers.
blkt Feb 17, 2025
87c20e2
Secrets step simplification (#1074)
jhrozek Feb 17, 2025
c0e37ae
Improvements to get_last_user_message_block
jhrozek Feb 17, 2025
7bca26a
Fix test_messages_block
jhrozek Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions prompts/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pii_redacted: |
The context files contain redacted personally identifiable information (PII) that is represented by a UUID encased within <>. For example:
- <123e4567-e89b-12d3-a456-426614174000>
- <2d040296-98e9-4350-84be-fda4336057eb>
If you encounter any PII redacted with a UUID, DO NOT WARN the user about it. Simplt respond to the user request and keep the PII redacted and intact, using the same UUID.
If you encounter any PII redacted with a UUID, DO NOT WARN the user about it. Simply respond to the user request and keep the PII redacted and intact, using the same UUID.
# Security-focused prompts
security_audit: "You are a security expert conducting a thorough code review. Identify potential security vulnerabilities, suggest improvements, and explain security best practices."

Expand All @@ -56,6 +56,6 @@ red_team: "You are a red team member conducting a security assessment. Identify
# BlueTeam prompts
blue_team: "You are a blue team member conducting a security assessment. Identify security controls, misconfigurations, and potential vulnerabilities."

# Per client prompts
# Per client prompts
client_prompts:
kodu: "If malicious packages or leaked secrets are found, please end the task, sending the problems found embedded in <attempt_completion><result> tags"
6 changes: 3 additions & 3 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

# Default provider URLs
DEFAULT_PROVIDER_URLS = {
"openai": "https://api.openai.com/v1",
"openrouter": "https://openrouter.ai/api/v1",
"anthropic": "https://api.anthropic.com/v1",
"openai": "https://api.openai.com",
"openrouter": "https://openrouter.ai/api",
"anthropic": "https://api.anthropic.com",
"vllm": "http://localhost:8000", # Base URL without /v1 path
"ollama": "http://localhost:11434", # Default Ollama server URL
"lm_studio": "http://localhost:1234",
Expand Down
26 changes: 21 additions & 5 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ def does_db_exist(self):
return self._db_path.is_file()


def row_from_model(model: BaseModel) -> dict:
return dict(
id=model.id,
timestamp=model.timestamp,
provider=model.provider,
request=model.request.json(exclude_defaults=True, exclude_unset=True),
type=model.type,
workspace_id=model.workspace_id,
)


class DbRecorder(DbCodeGate):
def __init__(self, sqlite_path: Optional[str] = None):
super().__init__(sqlite_path)
Expand All @@ -99,7 +110,10 @@ async def _execute_update_pydantic_model(
"""Execute an update or insert command for a Pydantic model."""
try:
async with self._async_db_engine.begin() as conn:
result = await conn.execute(sql_command, model.model_dump())
row = model
if isinstance(model, BaseModel):
row = model.model_dump()
result = await conn.execute(sql_command, row)
row = result.first()
if row is None:
return None
Expand Down Expand Up @@ -140,7 +154,8 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
RETURNING *
"""
)
recorded_request = await self._execute_update_pydantic_model(prompt_params, sql)
row = row_from_model(prompt_params)
recorded_request = await self._execute_update_pydantic_model(row, sql)
# Uncomment to debug the recorded request
# logger.debug(f"Recorded request: {recorded_request}")
return recorded_request # type: ignore
Expand All @@ -159,7 +174,8 @@ async def update_request(
RETURNING *
"""
)
updated_request = await self._execute_update_pydantic_model(prompt_params, sql)
row = row_from_model(prompt_params)
updated_request = await self._execute_update_pydantic_model(row, sql)
# Uncomment to debug the recorded request
# logger.debug(f"Recorded request: {recorded_request}")
return updated_request # type: ignore
Expand All @@ -182,7 +198,7 @@ async def record_outputs(
output=first_output.output,
)
full_outputs = []
# Just store the model respnses in the list of JSON objects.
# Just store the model responses in the list of JSON objects.
for output in outputs:
full_outputs.append(output.output)

Expand Down Expand Up @@ -306,7 +322,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
f"Alerts: {len(context.alerts_raised)}."
)
except Exception as e:
logger.error(f"Failed to record context: {context}.", error=str(e))
logger.error(f"Failed to record context: {context}.", error=str(e), exc_info=e)

async def add_workspace(self, workspace_name: str) -> WorkspaceRow:
"""Add a new workspace to the DB.
Expand Down
12 changes: 12 additions & 0 deletions src/codegate/db/fim_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ def __init__(self):

def _extract_message_from_fim_request(self, request: str) -> Optional[str]:
"""Extract the user message from the FIM request"""
### NEW CODE PATH ###
if not isinstance(request, str):
content_message = None
for message in request.get_messages():
for content in message.get_content():
if content_message is None:
content_message = content.get_text()
else:
logger.warning("Expected one user message, found multiple.")
return None
return content_message

try:
parsed_request = json.loads(request)
except Exception as e:
Expand Down
8 changes: 7 additions & 1 deletion src/codegate/extract_snippets/message_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,16 @@ def extract_snippets(self, message: str, require_filepath: bool = False) -> List
"""
regexes = self._choose_regex(require_filepath)
# Find all code block matches
if isinstance(message, str):
return [
self._get_snippet_for_match(match)
for regex in regexes
for match in regex.finditer(message)
]
return [
self._get_snippet_for_match(match)
for regex in regexes
for match in regex.finditer(message)
for match in regex.finditer(message.get_text())
]

def extract_unique_snippets(self, message: str) -> Dict[str, CodeSnippet]:
Expand Down
4 changes: 2 additions & 2 deletions src/codegate/llm_utils/llmclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import litellm
import structlog
from litellm import acompletion
from ollama import Client as OllamaClient

from codegate.config import Config
from codegate.inference import LlamaCppInferenceEngine
from codegate.types.generators import legacy_acompletion

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -133,7 +133,7 @@ async def _complete_litellm(
)
content = response.message.content
else:
response = await acompletion(
response = await legacy_acompletion(
model=model,
messages=request["messages"],
api_key=api_key,
Expand Down
65 changes: 19 additions & 46 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from typing import Any, Dict, List, Optional

import structlog
from litellm import ChatCompletionRequest, ModelResponse
from pydantic import BaseModel

from codegate.clients.clients import ClientType
from codegate.db.models import Alert, AlertSeverity, Output, Prompt
from codegate.extract_snippets.message_extractor import CodeSnippet
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.types.common import ChatCompletionRequest, ModelResponse

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -96,14 +96,12 @@ def add_input_request(
if self.prompt_id is None:
self.prompt_id = str(uuid.uuid4())

request_str = json.dumps(normalized_request)

self.input_request = Prompt(
id=self.prompt_id,
timestamp=datetime.datetime.now(datetime.timezone.utc),
provider=provider,
type="fim" if is_fim_request else "chat",
request=request_str,
request=normalized_request,
workspace_id=None,
)
# Uncomment the below to debug the input
Expand Down Expand Up @@ -197,70 +195,45 @@ def get_last_user_message(
Optional[tuple[str, int]]: A tuple containing the message content and
its index, or None if no user message is found
"""
if request.get("messages") is None:
msg = request.last_user_message()

if msg is None:
return None
for i in reversed(range(len(request["messages"]))):
if request["messages"][i]["role"] == "user":
content = request["messages"][i]["content"] # type: ignore
return str(content), i

return None
# unpack the tuple
msg, idx = msg
return "".join([content.get_text() for content in msg.get_content()]), idx


@staticmethod
def get_last_user_message_block(
request: ChatCompletionRequest,
client: ClientType = ClientType.GENERIC,
) -> Optional[tuple[str, int]]:
"""
Get the last block of consecutive 'user' messages from the request.

Args:
request (ChatCompletionRequest): The chat completion request to process
client (ClientType): The client type to consider when processing the request

Returns:
Optional[str, int]: A string containing all consecutive user messages in the
last user message block, separated by newlines, or None if
no user message block is found.
Index of the first message detected in the block.
"""
if request.get("messages") is None:
return None

user_messages = []
messages = request["messages"]
block_start_index = None

accepted_roles = ["user", "assistant"]
if client == ClientType.OPEN_INTERPRETER:
# open interpreter also uses the role "tool"
accepted_roles.append("tool")

# Iterate in reverse to find the last block of consecutive 'user' messages
for i in reversed(range(len(messages))):
if messages[i]["role"] in accepted_roles:
content_str = messages[i].get("content")
if content_str is None:
last_idx = -1
for msg, idx in request.last_user_block():
for content in msg.get_content():
txt = content.get_text()
if not txt:
continue
user_messages.append(txt)
last_idx = idx

if messages[i]["role"] in ["user", "tool"]:
user_messages.append(content_str)
block_start_index = i

# Specifically for Aider, when "Ok." block is found, stop
if content_str == "Ok." and messages[i]["role"] == "assistant":
break
else:
# Stop when a message with a different role is encountered
if user_messages:
break

# Reverse the collected user messages to preserve the original order
if user_messages and block_start_index is not None:
content = "\n".join(reversed(user_messages))
return content, block_start_index

return None
if user_messages == []:
return None
return "\n".join(reversed(user_messages)), last_idx

@abstractmethod
async def process(
Expand Down
4 changes: 2 additions & 2 deletions src/codegate/pipeline/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import shlex
from typing import Optional

from litellm import ChatCompletionRequest

from codegate.clients.clients import ClientType
from codegate.pipeline.base import (
Expand All @@ -12,6 +11,7 @@
PipelineStep,
)
from codegate.pipeline.cli.commands import CustomInstructions, Version, Workspace
from codegate.types.common import ChatCompletionRequest

codegate_regex = re.compile(r"^codegate(?:\s+(.*))?", re.IGNORECASE)

Expand Down Expand Up @@ -158,7 +158,7 @@ async def process(
response=PipelineResponse(
step_name=self.name,
content=cmd_out,
model=request["model"],
model=request.get_model()
),
context=context,
)
Expand Down
24 changes: 14 additions & 10 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import itertools
import json
import re

import structlog
from litellm import ChatCompletionRequest

from codegate.clients.clients import ClientType
from codegate.db.models import AlertSeverity
Expand All @@ -13,6 +13,7 @@
PipelineStep,
)
from codegate.storage.storage_engine import StorageEngine
from codegate.types.common import ChatCompletionRequest
from codegate.utils.package_extractor import PackageExtractor
from codegate.utils.utils import generate_vector_string

Expand Down Expand Up @@ -61,7 +62,7 @@ async def process( # noqa: C901
Use RAG DB to add context to the user request
"""
# Get the latest user message
last_message = self.get_last_user_message_block(request, context.client)
last_message = self.get_last_user_message_block(request)
if not last_message:
return PipelineResult(request=request)
user_message, last_user_idx = last_message
Expand Down Expand Up @@ -126,14 +127,16 @@ async def process( # noqa: C901
context_str = self.generate_context_str(all_bad_packages, context)
context.bad_packages_found = True

# Make a copy of the request
new_request = request.copy()

# perform replacement in all the messages starting from this index
messages = request.get_messages()
filtered = itertools.dropwhile(lambda x: x[0] < last_user_idx, enumerate(messages))
if context.client != ClientType.OPEN_INTERPRETER:
for i in range(last_user_idx, len(new_request["messages"])):
message = new_request["messages"][i]
message_str = str(message["content"]) # type: ignore
for i, message in filtered:
message_str = "".join([
txt
for content in message.get_content()
for txt in content.get_text()
])
context_msg = message_str
# Add the context to the last user message
if context.client in [ClientType.CLINE, ClientType.KODU]:
Expand All @@ -154,7 +157,8 @@ async def process( # noqa: C901
context_msg = updated_task_content + rest_of_message
else:
context_msg = f"Context: {context_str} \n\n Query: {message_str}"
new_request["messages"][i]["content"] = context_msg
content = next(message.get_content())
content.set_text(context_msg)
logger.debug("Final context message", context_message=context_msg)
else:
#  just add a message in the end
Expand All @@ -164,4 +168,4 @@ async def process( # noqa: C901
"role": "assistant",
}
)
return PipelineResult(request=new_request, context=context)
return PipelineResult(request=request, context=context)
Loading
Loading