Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ def register_new_candidate(
def get_speedup_ratio(self, optimization_id: str) -> float | None:
return self.speedup_ratios.get(optimization_id)

def get_optimized_runtime(self, optimization_id: str) -> float | None:
return self.optimized_runtimes.get(optimization_id)


@dataclass(frozen=True)
class TestsInFile:
Expand Down
146 changes: 79 additions & 67 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,10 @@ def __init__(
self,
initial_candidates: list[OptimizedCandidate],
future_line_profile_results: concurrent.futures.Future,
all_refinements_data: list[AIServiceRefinerRequest],
ai_service_client: AiServiceClient,
executor: concurrent.futures.ThreadPoolExecutor,
eval_ctx: CandidateEvaluationContext,
original_markdown_code: str,
future_all_refinements: list[concurrent.futures.Future],
future_all_code_repair: list[concurrent.futures.Future],
future_adaptive_optimizations: list[concurrent.futures.Future],
) -> None:
Expand All @@ -198,16 +199,17 @@ def __init__(
self.refinement_done = False
self.candidate_len = len(initial_candidates)
self.ai_service_client = ai_service_client
self.executor = executor
self.refinement_calls_count = 0
self.original_markdown_code = original_markdown_code
self.eval_ctx = eval_ctx

# Initialize queue with initial candidates
for candidate in initial_candidates:
self.forest.add(candidate)
self.candidate_queue.put(candidate)

self.future_line_profile_results = future_line_profile_results
self.all_refinements_data = all_refinements_data
self.future_all_refinements = future_all_refinements
self.future_all_code_repair = future_all_code_repair
self.future_adaptive_optimizations = future_adaptive_optimizations

Expand Down Expand Up @@ -238,7 +240,13 @@ def _handle_empty_queue(self) -> CandidateNode | None:
lambda: self.future_all_code_repair.clear(),
)
if self.line_profiler_done and not self.refinement_done:
return self._process_refinement_results()
return self._process_candidates(
self.future_all_refinements,
"Refining generated code for improved quality and performance...",
"Added {0} candidates from refinement, total candidates now: {1}",
lambda: setattr(self, "refinement_done", True),
filter_candidates_func=self._filter_refined_candidates,
)
if len(self.future_adaptive_optimizations) > 0:
return self._process_candidates(
self.future_adaptive_optimizations,
Expand All @@ -254,6 +262,7 @@ def _process_candidates(
loading_msg: str,
success_msg: str,
callback: Callable[[], None],
filter_candidates_func: Callable[[list[OptimizedCandidate]], list[OptimizedCandidate]] | None = None,
) -> CandidateNode | None:
if len(future_candidates) == 0:
return None
Expand All @@ -272,6 +281,7 @@ def _process_candidates(
else:
candidates.append(candidate_result)

candidates = filter_candidates_func(candidates) if filter_candidates_func else candidates
for candidate in candidates:
self.forest.add(candidate)
self.candidate_queue.put(candidate)
Expand All @@ -283,49 +293,42 @@ def _process_candidates(
callback()
return self.get_next_candidate()

def refine_optimizations(self, request: list[AIServiceRefinerRequest]) -> concurrent.futures.Future:
return self.executor.submit(self.ai_service_client.optimize_python_code_refinement, request=request)
def _filter_refined_candidates(self, candidates: list[OptimizedCandidate]) -> list[OptimizedCandidate]:
"""We generate a weighted ranking based on the runtime and diff lines and select the best of valid optimizations to be tested."""
self.refinement_calls_count += len(candidates)

def _process_refinement_results(self) -> CandidateNode | None:
"""Process refinement results and add to queue. We generate a weighted ranking based on the runtime and diff lines and select the best (round of 45%) of valid optimizations to be refined."""
future_refinements: list[concurrent.futures.Future] = []
refinement_call_index = 0
if len(candidates) <= REFINE_ALL_THRESHOLD:
return candidates
diff_lens_list = []
runtimes_list = []
for c in candidates:
# current refined candidates is not benchmarked yet, a close values we would expect to be the parent candidate
parent_id = c.parent_id
parent_candidate_node = self.forest.get_node(parent_id)
parent_optimized_runtime = self.eval_ctx.get_optimized_runtime(parent_id)
if not parent_optimized_runtime or not parent_candidate_node:
continue
diff_lens_list.append(
diff_length(self.original_markdown_code, parent_candidate_node.candidate.source_code.markdown)
)
runtimes_list.append(parent_optimized_runtime)

if len(self.all_refinements_data) <= REFINE_ALL_THRESHOLD:
for data in self.all_refinements_data:
refinement_call_index += 1
future_refinements.append(self.refine_optimizations([data]))
else:
diff_lens_list = []
runtimes_list = []
for c in self.all_refinements_data:
diff_lens_list.append(diff_length(c.original_source_code, c.optimized_source_code))
runtimes_list.append(c.optimized_code_runtime)

runtime_w, diff_w = REFINED_CANDIDATE_RANKING_WEIGHTS
weights = choose_weights(runtime=runtime_w, diff=diff_w)

runtime_norm = normalize_by_max(runtimes_list)
diffs_norm = normalize_by_max(diff_lens_list)
# the lower the better
score_dict = create_score_dictionary_from_metrics(weights, runtime_norm, diffs_norm)
top_n_candidates = int((TOP_N_REFINEMENTS * len(runtimes_list)) + 0.5)
top_indecies = sorted(score_dict, key=score_dict.get)[:top_n_candidates]

for idx in top_indecies:
refinement_call_index += 1
data = self.all_refinements_data[idx]
future_refinements.append(self.refine_optimizations([data]))

# Track total refinement calls made
self.refinement_calls_count = refinement_call_index

return self._process_candidates(
future_refinements,
"Refining generated code for improved quality and performance...",
"Added {0} candidates from refinement, total candidates now: {1}",
lambda: setattr(self, "refinement_done", True),
)
if not runtimes_list or not diff_lens_list:
# should not happen
logger.warning("No valid candidates for refinement while filtering")
return candidates

runtime_w, diff_w = REFINED_CANDIDATE_RANKING_WEIGHTS
weights = choose_weights(runtime=runtime_w, diff=diff_w)

runtime_norm = normalize_by_max(runtimes_list)
diffs_norm = normalize_by_max(diff_lens_list)
# the lower the better
score_dict = create_score_dictionary_from_metrics(weights, runtime_norm, diffs_norm)
top_n_candidates = int((TOP_N_REFINEMENTS * len(runtimes_list)) + 0.5)
top_indecies = sorted(score_dict, key=score_dict.get)[:top_n_candidates]

return [candidates[idx] for idx in top_indecies]

def is_done(self) -> bool:
"""Check if processing is complete."""
Expand Down Expand Up @@ -386,6 +389,7 @@ def __init__(
)
self.optimization_review = ""
self.future_all_code_repair: list[concurrent.futures.Future] = []
self.future_all_refinements: list[concurrent.futures.Future] = []
self.future_adaptive_optimizations: list[concurrent.futures.Future] = []
self.repair_counter = 0 # track how many repairs we did for each function
self.adaptive_optimization_counter = 0 # track how many adaptive optimizations we did for each function
Expand Down Expand Up @@ -832,7 +836,6 @@ def process_single_candidate(
original_helper_code: dict[Path, str],
file_path_to_helper_classes: dict[Path, set[str]],
eval_ctx: CandidateEvaluationContext,
all_refinements_data: list[AIServiceRefinerRequest],
exp_type: str,
function_references: str,
) -> BestOptimization | None:
Expand Down Expand Up @@ -942,33 +945,40 @@ def process_single_candidate(
c.source == OptimizedCandidateSource.REFINE for c in current_tree_candidates
)

aiservice_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client

if is_candidate_refined_before:
future_adaptive_optimization = self.call_adaptive_optimize(
trace_id=self.get_trace_id(exp_type),
original_source_code=code_context.read_writable_code.markdown,
prev_candidates=current_tree_candidates,
eval_ctx=eval_ctx,
ai_service_client=self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client,
ai_service_client=aiservice_client,
)
if future_adaptive_optimization:
self.future_adaptive_optimizations.append(future_adaptive_optimization)
else:
all_refinements_data.append(
AIServiceRefinerRequest(
optimization_id=best_optimization.candidate.optimization_id,
original_source_code=code_context.read_writable_code.markdown,
read_only_dependency_code=code_context.read_only_context_code,
original_code_runtime=original_code_baseline.runtime,
optimized_source_code=best_optimization.candidate.source_code.markdown,
optimized_explanation=best_optimization.candidate.explanation,
optimized_code_runtime=best_optimization.runtime,
speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%",
trace_id=self.get_trace_id(exp_type),
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"],
function_references=function_references,
)
future_refinement = self.executor.submit(
aiservice_client.optimize_python_code_refinement,
request=[
AIServiceRefinerRequest(
optimization_id=best_optimization.candidate.optimization_id,
original_source_code=code_context.read_writable_code.markdown,
read_only_dependency_code=code_context.read_only_context_code,
original_code_runtime=original_code_baseline.runtime,
optimized_source_code=best_optimization.candidate.source_code.markdown,
optimized_explanation=best_optimization.candidate.explanation,
optimized_code_runtime=best_optimization.runtime,
speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%",
trace_id=self.get_trace_id(exp_type),
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"],
function_references=function_references,
)
],
)
self.future_all_refinements.append(future_refinement)

# Display runtime information
if is_LSP_enabled():
lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree)))
Expand Down Expand Up @@ -1000,9 +1010,11 @@ def determine_best_candidate(

# Initialize evaluation context and async tasks
eval_ctx = CandidateEvaluationContext()
all_refinements_data: list[AIServiceRefinerRequest] = []

self.future_all_refinements.clear()
self.future_all_code_repair.clear()
self.future_adaptive_optimizations.clear()

self.repair_counter = 0
self.adaptive_optimization_counter = 0

Expand All @@ -1025,9 +1037,10 @@ def determine_best_candidate(
processor = CandidateProcessor(
candidates,
future_line_profile_results,
all_refinements_data,
self.aiservice_client,
self.executor,
eval_ctx,
code_context.read_writable_code.markdown,
self.future_all_refinements,
self.future_all_code_repair,
self.future_adaptive_optimizations,
)
Expand All @@ -1051,7 +1064,6 @@ def determine_best_candidate(
original_helper_code=original_helper_code,
file_path_to_helper_classes=file_path_to_helper_classes,
eval_ctx=eval_ctx,
all_refinements_data=all_refinements_data,
exp_type=exp_type,
function_references=function_references,
)
Expand Down
Loading