Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 59 additions & 69 deletions src/codeas/ui/pages/5_💬_Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from codeas.ui.components import metadata_ui, repo_ui
from codeas.ui.utils import read_prompts

# Define constants for duplicated strings
ALL_FILES = "All files"
FULL_CONTENT = "Full content"


def chat():
st.subheader("💬 Chat")
Expand Down Expand Up @@ -51,8 +55,8 @@ def display_config_section():

retriever = ContextRetriever(**get_retriever_args())
if (
st.session_state.get("file_types") != "All files"
or st.session_state.get("content_types") != "Full content"
st.session_state.get("file_types") != ALL_FILES
or st.session_state.get("content_types") != FULL_CONTENT
):
files_missing_metadata = metadata_ui.display()
if not any(files_missing_metadata):
Expand All @@ -76,14 +80,13 @@ def display_config_section():
st.caption(f"{num_selected_files:,} files | {selected_tokens:,} tokens")
repo_ui.display_files_editor()

if not any(files_missing_metadata):
if st.button("Show context"):
context = retriever.retrieve(
files_paths=state.repo.included_files_paths,
files_tokens=state.repo.included_files_tokens,
metadata=state.repo_metadata,
)
st.text_area("Context", context, height=300)
if not any(files_missing_metadata) and st.button("Show context"):
context = retriever.retrieve(
files_paths=state.repo.included_files_paths,
files_tokens=state.repo.included_files_tokens,
metadata=state.repo_metadata,
)
st.text_area("Context", context, height=300)

if not any(files_missing_metadata):
st.caption(f"{num_selected_files:,} files | {selected_tokens:,} tokens")
Expand All @@ -95,7 +98,7 @@ def display_file_options():
st.selectbox(
"File types",
options=[
"All files",
ALL_FILES,
"Code files",
"Testing files",
"Config files",
Expand All @@ -109,7 +112,7 @@ def display_file_options():
with col2:
st.selectbox(
"Content types",
options=["Full content", "Descriptions", "Details"],
options=[FULL_CONTENT, "Descriptions", "Details"],
key="content_types",
)

Expand Down Expand Up @@ -159,21 +162,26 @@ def display_chat_history():
if entry["role"] == "user":
with st.expander(f"USER {template_label}", icon="👤", expanded=False):
st.write(entry["content"])
else: # assistant
_display_assistant_message(entry, i, template_label)


def _display_assistant_message(entry, index, template_label):
"""Helper function to display an assistant message."""
with st.expander(
f"ASSISTANT [{entry['model']}] {template_label}",
expanded=True,
icon="🤖",
):
if entry.get("content") is None:
with st.spinner("Running agent..."):
content, cost = run_agent(entry["model"])
st.write(f"💸 ${cost['total_cost']:.4f}")
st.session_state.chat_history[index]["content"] = content
st.session_state.chat_history[index]["cost"] = cost
else:
with st.expander(
f"ASSISTANT [{entry['model']}] {template_label}",
expanded=True,
icon="🤖",
):
if entry.get("content") is None:
with st.spinner("Running agent..."):
content, cost = run_agent(entry["model"])
st.write(f"💰 ${cost['total_cost']:.4f}")
st.session_state.chat_history[i]["content"] = content
st.session_state.chat_history[i]["cost"] = cost
else:
st.write(entry["content"])
st.write(f"💰 ${entry['cost']['total_cost']:.4f}")
st.write(entry["content"])
st.write(f"💸 ${entry['cost']['total_cost']:.4f}")


def display_user_input():
Expand All @@ -200,30 +208,6 @@ def display_template_options():
index=0 if st.session_state.input_reset else None,
)

# remaining_options = [
# opt for opt in prompt_options if opt != st.session_state.template1
# ]
# with col2:
# st.selectbox(
# "Template 2",
# options=remaining_options,
# key="template2",
# index=0 if st.session_state.input_reset else None,
# disabled=not st.session_state.template1,
# )

# final_options = [
# opt for opt in remaining_options if opt != st.session_state.template2
# ]
# with col3:
# st.selectbox(
# "Template 3",
# options=final_options,
# key="template3",
# index=0 if st.session_state.input_reset else None,
# disabled=not st.session_state.template2,
# )


def display_input_areas():
prompts = read_prompts()
Expand Down Expand Up @@ -283,22 +267,27 @@ def handle_send_button():
if st.session_state.get(f"template{i}")
]

user_inputs_with_templates = []
if len(selected_templates) > 1:
user_inputs = [
st.session_state.get(f"instructions{i}").strip()
for i in range(1, len(selected_templates) + 1)
]
for i in range(len(selected_templates)):
user_input = st.session_state.get(f"instructions{i+1}").strip()
template = selected_templates[i]
if user_input:
user_inputs_with_templates.append((user_input, template))
else:
user_inputs = [st.session_state.instructions.strip()]
user_input = st.session_state.instructions.strip()
template = selected_templates[0] if selected_templates else ""
if user_input:
user_inputs_with_templates.append((user_input, template))

if any(user_inputs):
for i, user_input in enumerate(user_inputs):
if user_input:
template = selected_templates[i] if len(selected_templates) > 1 else ""
st.session_state.chat_history.append(
{"role": "user", "content": user_input, "template": template}
)
for i, user_input in enumerate(user_inputs):

if user_inputs_with_templates:
for user_input, template in user_inputs_with_templates:
# Add user message
st.session_state.chat_history.append(
{"role": "user", "content": user_input, "template": template}
)
# Add assistant message(s) for this user input
for model in get_selected_models():
st.session_state.chat_history.append(
{
Expand All @@ -309,7 +298,7 @@ def handle_send_button():
}
)
st.session_state.input_reset = True
st.rerun()
st.rerun()


def handle_preview_button():
Expand Down Expand Up @@ -343,7 +332,7 @@ def handle_preview_button():
llm_client = LLMClients(model=model)
cost = llm_client.calculate_cost(messages)
st.write(
f"💰 ${cost['input_cost']:.4f} [input] ({cost['input_tokens']:,} tokens) "
f"💸 ${cost['input_cost']:.4f} [input] ({cost['input_tokens']:,} tokens) "
)


Expand Down Expand Up @@ -387,10 +376,10 @@ def get_history_messages(model):


def get_retriever_args():
file_types = st.session_state.get("file_types", "All files")
content_types = st.session_state.get("content_types", "Full content")
file_types = st.session_state.get("file_types", ALL_FILES)
content_types = st.session_state.get("content_types", FULL_CONTENT)
return {
"include_all_files": file_types == "All files",
"include_all_files": file_types == ALL_FILES,
"include_code_files": file_types == "Code files",
"include_testing_files": file_types == "Testing files",
"include_config_files": file_types == "Config files",
Expand All @@ -408,6 +397,7 @@ def log_agent_execution(model, messages, cost):
if "conversation_id" not in st.session_state:
st.session_state.conversation_id = str(uuid.uuid4())
# Get the content of the last message
# Check if messages list is not empty before accessing the last element
prompt = messages[-1]["content"] if messages else ""
# Get template information
selected_templates = [
Expand All @@ -431,4 +421,4 @@ def log_agent_execution(model, messages, cost):
)


chat()
chat()