Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 61 additions & 29 deletions src/codeas/ui/components/refactoring_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,22 @@
generate_proposed_changes,
)

USE_PREVIOUS_OUTPUTS_LABEL = "Use previous outputs"


def is_safe_path(base_dir, target_path):
try:
base_dir_realpath = os.path.realpath(base_dir)
target_path_realpath = os.path.realpath(target_path)

return target_path_realpath.startswith(base_dir_realpath + os.sep) or target_path_realpath == base_dir_realpath
except OSError:
return False


def display():
use_previous_outputs_groups = st.toggle(
"Use previous outputs", value=True, key="use_previous_outputs_groups"
USE_PREVIOUS_OUTPUTS_LABEL, value=True, key="use_previous_outputs_groups"
)

if st.button(
Expand Down Expand Up @@ -55,13 +67,10 @@ def display():
),
"cost": previous_output["cost"],
"tokens": previous_output["tokens"],
"messages": previous_output["messages"], # Add this line
"messages": previous_output["messages"],
},
)
except FileNotFoundError:
# st.warning(
# "No previous output found for refactoring groups. Running generation..."
# )
st.session_state.outputs[
"refactoring_groups"
] = define_refactoring_files()
Expand All @@ -76,7 +85,7 @@ def display():
].tokens,
"messages": st.session_state.outputs[
"refactoring_groups"
].messages, # Add this line
].messages,
},
"refactoring_groups.json",
)
Expand All @@ -93,7 +102,7 @@ def display():
"tokens": st.session_state.outputs["refactoring_groups"].tokens,
"messages": st.session_state.outputs[
"refactoring_groups"
].messages, # Add this line
].messages,
},
"refactoring_groups.json",
)
Expand All @@ -119,7 +128,6 @@ def display():
)
groups = output.response.choices[0].message.parsed

# Create a DataFrame for the data editor
data = [
{
"selected": True,
Expand Down Expand Up @@ -154,7 +162,7 @@ def display():

def display_generate_proposed_changes():
use_previous_outputs_changes = st.toggle(
"Use previous outputs", value=True, key="use_previous_outputs_changes"
USE_PREVIOUS_OUTPUTS_LABEL, value=True, key="use_previous_outputs_changes"
)

groups = (
Expand Down Expand Up @@ -203,13 +211,10 @@ def display_generate_proposed_changes():
},
"cost": previous_output["cost"],
"tokens": previous_output["tokens"],
"messages": previous_output["messages"], # Add this line
"messages": previous_output["messages"],
},
)
except FileNotFoundError:
# st.warning(
# "No previous output found for proposed changes. Running generation..."
# )
st.session_state.outputs[
"proposed_changes"
] = generate_proposed_changes(groups)
Expand All @@ -229,7 +234,7 @@ def display_generate_proposed_changes():
].tokens,
"messages": st.session_state.outputs[
"proposed_changes"
].messages, # Add this line
].messages,
},
"proposed_changes.json",
)
Expand All @@ -249,7 +254,7 @@ def display_generate_proposed_changes():
"tokens": st.session_state.outputs["proposed_changes"].tokens,
"messages": st.session_state.outputs[
"proposed_changes"
].messages, # Add this line
].messages,
},
"proposed_changes.json",
)
Expand Down Expand Up @@ -285,7 +290,7 @@ def display_generate_proposed_changes():

def display_apply_changes():
use_previous_outputs_diffs = st.toggle(
"Use previous outputs", value=True, key="use_previous_outputs_diffs"
USE_PREVIOUS_OUTPUTS_LABEL, value=True, key="use_previous_outputs_diffs"
)

if st.button("Apply changes", type="primary", key="apply_changes"):
Expand All @@ -296,7 +301,6 @@ def display_apply_changes():
].response.values()
]
with st.spinner("Generating and applying changes..."):
# Generate diffs
if use_previous_outputs_diffs:
try:
previous_output = state.read_output("generated_diffs.json")
Expand All @@ -311,9 +315,6 @@ def display_apply_changes():
},
)
except FileNotFoundError:
# st.warning(
# "No previous output found for generated diffs. Running generation..."
# )
st.session_state.outputs["generated_diffs"] = generate_diffs(
groups_changes
)
Expand Down Expand Up @@ -348,16 +349,44 @@ def display_apply_changes():
"generated_diffs.json",
)

# Apply diffs
generated_diffs_output = st.session_state.outputs["generated_diffs"]

project_root = os.getcwd()

for file_path, response in generated_diffs_output.response.items():
if not is_safe_path(project_root, file_path):
st.error(f"Error: Original file path '{file_path}' from generated changes is outside project root. Skipping.")
continue

directory, filename = os.path.split(file_path)
name, ext = os.path.splitext(filename)
new_file_path = os.path.join(directory, f"{name}_refactored{ext}")

with open(file_path, "r") as f:
original_content = f.read()
if not is_safe_path(project_root, new_file_path):
st.error(f"Error: Generated refactored path '{new_file_path}' is outside project root. Skipping.")
continue

new_file_dir = os.path.dirname(new_file_path)
if not is_safe_path(project_root, new_file_dir):
st.error(f"Error: Directory for refactored file '{new_file_dir}' is outside project root. Skipping.")
continue
if not os.path.exists(new_file_dir):
try:
os.makedirs(new_file_dir, exist_ok=True)
except OSError as e:
st.error(f"Error creating directory '{new_file_dir}': {e}. Skipping.")
continue

original_content = None
try:
with open(file_path, "r") as f:
original_content = f.read()
except OSError as e:
st.error(f"Error reading original file '{file_path}': {e}. Skipping.")
continue

if original_content is None:
continue

diff = (
f"```diff\n{response['content']}\n```"
Expand All @@ -368,14 +397,17 @@ def display_apply_changes():
try:
patched_content = apply_diffs(original_content, diff)
except Exception:
# st.error(f"Error applying diff to {file_path}")
st.error(f"Error applying diff to {file_path}. Skipping.")
continue

if not os.path.exists(os.path.dirname(new_file_path)):
os.makedirs(os.path.dirname(new_file_path), exist_ok=True)
with open(new_file_path, "w") as f:
f.write(patched_content)
try:
with open(new_file_path, "w") as f:
f.write(patched_content)
except OSError as e:
st.error(f"Error writing refactored file '{new_file_path}': {e}. Skipping.")
continue


st.success(f"{new_file_path} successfully written!")
with st.expander(f"Generated changes [{file_path}]"):
st.code(diff)
st.code(diff)