Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions ajet/tuner_lib/experimental/swarm_overwatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@
from pydantic import BaseModel


class RewardHistoryEntry(BaseModel):
"""A single entry in the reward history."""
global_step: int
mean_reward: float
std_reward: float
timestamp: float # Unix timestamp when this entry was recorded


class RewardHistoryResponse(BaseModel):
"""Response containing the reward history for visualization."""
history: List[RewardHistoryEntry] = []


class CurrentBatchRolloutPoolInformation(BaseModel):
sample_collection_method: str = ""
completed_episodes: int = 0
Expand Down
81 changes: 80 additions & 1 deletion ajet/tuner_lib/experimental/swarm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from multiprocessing.managers import DictProxy
from typing import Coroutine, Optional, Tuple, List
from ajet.utils.process_killer import kill_process_tree
from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation
from ajet.tuner_lib.experimental.swarm_overwatch_utils import (
CurrentBatchRolloutPoolInformation,
RewardHistoryEntry,
RewardHistoryResponse,
)
from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE
from ajet.tuner_lib.experimental.interchange_utils import (
SyncTrainConfigRequest,
Expand Down Expand Up @@ -63,6 +67,14 @@ def register_enable_swarm_mode_routes(
if "current_batch_rollout_pool_information" not in shared_mem_dict:
shared_mem_dict["current_batch_rollout_pool_information"] = CurrentBatchRolloutPoolInformation()

# Initialize reward history storage for visualization
if "reward_history" not in shared_mem_dict:
shared_mem_dict["reward_history"] = [] # List of RewardHistoryEntry dicts

# Initialize reward accumulator for collecting rewards of current global step
if "current_rewards" not in shared_mem_dict:
shared_mem_dict["current_rewards"] = [] # [rewards...]

# ------------------------------------------------------------------------------------------------
# ------ Recycle claimed episodes that client failed to complete in (promised) time --------------
# --------------------------------- claimed -> unclaimed ----------------------------------------
Expand Down Expand Up @@ -166,6 +178,35 @@ def _delete_episode_record(episode_uuid: str, shared_mem_dict, shared_mem_dict_l
if episode_uuid in shared_mem_dict["unclaimed_episodes"]:
shared_mem_dict["unclaimed_episodes"].remove(episode_uuid)

# --------------------------------------------------------------------------------------
# -------------------------- reward history management ---------------------------------
# --------------------------------------------------------------------------------------

def _finalize_reward_history_for_step(global_step, shared_mem_dict, shared_mem_dict_lock):
"""Finalize reward statistics for a given global step and add to reward_history."""
import numpy as np

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import numpy as np statement is placed inside the _finalize_reward_history_for_step function. It is generally considered a best practice in Python to place all imports at the top of the file, outside of any functions or methods. This improves readability, makes dependencies clear at a glance, and avoids potential performance overhead from repeated imports if the function is called frequently. Please move this import to the top of the file with other imports.


rewards = shared_mem_dict.get("current_rewards", [])
if rewards:
rewards = list(rewards) # Convert proxy to list if needed
mean_reward = float(np.mean(rewards))
std_reward = float(np.std(rewards))

history = shared_mem_dict.get("reward_history", [])
history = list(history) # Convert proxy to list if needed

entry = RewardHistoryEntry(
global_step=global_step,
mean_reward=mean_reward,
std_reward=std_reward,
timestamp=time.time(),
)
history.append(entry.model_dump())
shared_mem_dict["reward_history"] = history

# Clear current rewards for next step
shared_mem_dict["current_rewards"] = []

# --------------------------------------------------------------------------------------
# -------------------------- return workflow output ------------------------------------
# --------------------------------------------------------------------------------------
Expand Down Expand Up @@ -272,6 +313,10 @@ def _clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict):
shared_mem_dict["unclaimed_episodes"] = []
logger.info(f"[_clean_up_engine_status] Cleared {num_unclaimed} unclaimed episodes")

# clear reward tracking
shared_mem_dict["current_rewards"] = []
shared_mem_dict["reward_history"] = []

# --------------------------------------------------------------------------------------
# -------------------------- fastapi routes --------------------------------------------
# --------------------------------------------------------------------------------------
Expand Down Expand Up @@ -446,7 +491,12 @@ async def update_engine_status(req: UpdateEngineStatusRequest):
engine_status_detail = req.engine_status_detail
global_step = req.global_step
if global_step is not None:
previous_global_step = shared_mem_dict.get("global_step", None)
shared_mem_dict["global_step"] = global_step
# When global_step changes, finalize reward statistics for the previous step
if previous_global_step is not None and previous_global_step != global_step:
_finalize_reward_history_for_step(previous_global_step, shared_mem_dict, shared_mem_dict_lock)

if engine_status_detail is not None:
shared_mem_dict["engine_status_detail"] = engine_status_detail
logger.info(f"[update_engine_status] Engine status set to {req.engine_status}")
Expand Down Expand Up @@ -636,6 +686,21 @@ async def end_episode(req: EndEpisodeRequest):
shared_mem_dict_lock,
)

# Record reward to current_rewards
if workflow_output.reward is not None:
reward_value = workflow_output.reward
# Handle both single reward and list of rewards
if isinstance(reward_value, list):
rewards_to_record = reward_value
else:
rewards_to_record = [reward_value]

with shared_mem_dict_lock:
current_rewards = shared_mem_dict.get("current_rewards", [])
current_rewards = list(current_rewards) # Convert proxy to list if needed
current_rewards.extend(rewards_to_record)
shared_mem_dict["current_rewards"] = current_rewards

elif episode_type == "eval":
if engine_status in ["ENGINE.ROLLING"]:
await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock)
Expand Down Expand Up @@ -779,6 +844,20 @@ async def get_current_batch_rollout_pool_information():
logger.error(f"Error getting current batch rollout pool information: {e}")
return CurrentBatchRolloutPoolInformation()

# --------------------------------------------------------------------
# ------------ get reward history for visualization ------------------
# --------------------------------------------------------------------
@app.get("/get_reward_history", response_model=RewardHistoryResponse)
async def get_reward_history():
"""Get the reward history for visualization (reward curves)."""
try:
history = shared_mem_dict.get("reward_history", [])
entries = [RewardHistoryEntry(**entry) for entry in history]
return RewardHistoryResponse(history=entries)
except Exception as e:
logger.error(f"Error getting reward history: {e}")
return RewardHistoryResponse(history=[])

# --------------------------------------------------------------------
# ------------ bring engine back to ENGINE.OFFLINE -------------------
# --------------------------------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions ajet/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _dive_to_set_value(config, dotted_key, value):
sub_config[keys[-1]] = value


def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone):
def align_parameters(from_config_fp, to_config_fp, convertion_json_fp, backbone):
"""Align configuration values based on a conversion map.

Parameters
Expand All @@ -107,7 +107,7 @@ def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone)
Source YAML path to read values from.
to_config_fp : str
Destination YAML path that is updated in place.
convertion_json_fg : str
convertion_json_fp : str
JSON path mapping dotted keys between configs.
backbone : str
Backbone identifier used for framework-specific alignment.
Expand All @@ -121,7 +121,7 @@ def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone)
# read convertion json
import json

with open(convertion_json_fg, "r", encoding="utf-8") as file:
with open(convertion_json_fp, "r", encoding="utf-8") as file:
convertion_json = json.load(file)

logger.success("----------------------------------------------------")
Expand Down
161 changes: 159 additions & 2 deletions ajet/utils/swarm_overwatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from rich.text import Text
from loguru import logger

from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation
from ajet.tuner_lib.experimental.swarm_overwatch_utils import (
CurrentBatchRolloutPoolInformation,
RewardHistoryResponse,
)


class SwarmOverwatch:
Expand Down Expand Up @@ -56,6 +59,20 @@ def fetch_pool_info(self) -> Optional[CurrentBatchRolloutPoolInformation]:
# logger.error(f"Failed to fetch pool info: {e}")
return None

def fetch_reward_history(self) -> Optional[RewardHistoryResponse]:
"""Fetch reward history from server for visualization"""
try:
response = self._httpx_client.get(
f"{self.server_url}/get_reward_history",
timeout=5.0,
)
response.raise_for_status()
data = RewardHistoryResponse.model_validate(response.json())
return data
except Exception as e:
logger.error(f"Failed to fetch reward history: {e}")
return None

def create_header(
self, info: Optional[CurrentBatchRolloutPoolInformation] = None
) -> Panel:
Expand Down Expand Up @@ -450,6 +467,141 @@ def create_dashboard(

return layout

def display_reward_curve(self):
"""Display ASCII reward curve in terminal"""
self.console.clear()

# Fetch reward history
history = self.fetch_reward_history()
if history is None or not history.history:
self.console.print("[bold yellow]No reward history available yet.[/bold yellow]")
self.console.print("[dim]Reward history is recorded when training completes batches with rewards.[/dim]")
self.console.print("\n[dim]Press Enter to return to menu...[/dim]")
input()
return

# Get terminal size
terminal_width = self.console.width or 80
terminal_height = self.console.height or 24

# Reserve space for header, labels, and footer
chart_width = min(terminal_width - 15, 120) # Reserve space for y-axis labels
chart_height = min(terminal_height - 10, 30) # Reserve space for header and x-axis

# Extract data
global_steps = [entry.global_step for entry in history.history]
mean_rewards = [entry.mean_reward for entry in history.history]

# Calculate y-axis range with padding
y_min = min(mean_rewards)
y_max = max(mean_rewards)
y_range = y_max - y_min
if y_range == 0:
y_range = 1.0 # Avoid division by zero
y_min -= 0.5
y_max += 0.5
else:
# Add 10% padding
y_min -= y_range * 0.1
y_max += y_range * 0.1
y_range = y_max - y_min

# Calculate x-axis range
x_min = min(global_steps)
x_max = max(global_steps)
x_range = x_max - x_min
if x_range == 0:
x_range = 1

# Create the chart grid
chart = [[' ' for _ in range(chart_width)] for _ in range(chart_height)]

# Plot the data points
for i, (step, reward) in enumerate(zip(global_steps, mean_rewards)):
# Map to chart coordinates
x = int((step - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
y = int((reward - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0

# Invert y because terminal coordinates go top-down
y = chart_height - 1 - y

# Clamp to valid range
x = max(0, min(chart_width - 1, x))
y = max(0, min(chart_height - 1, y))

# Draw point
chart[y][x] = '*'

# Connect points with lines if there are multiple points
if len(global_steps) > 1:
for i in range(len(global_steps) - 1):
step1, reward1 = global_steps[i], mean_rewards[i]
step2, reward2 = global_steps[i + 1], mean_rewards[i + 1]

x1 = int((step1 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
y1 = int((reward1 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0
x2 = int((step2 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
y2 = int((reward2 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0

y1 = chart_height - 1 - y1
y2 = chart_height - 1 - y2

# Simple line drawing between points
steps_between = max(abs(x2 - x1), abs(y2 - y1))
if steps_between > 0:
for s in range(1, steps_between):
t = s / steps_between
x = int(x1 + t * (x2 - x1))
y = int(y1 + t * (y2 - y1))
x = max(0, min(chart_width - 1, x))
y = max(0, min(chart_height - 1, y))
if chart[y][x] == ' ':
chart[y][x] = '.'

# Build the output
output = Text()
output.append("\n Reward Curve (Mean Reward vs Global Step)\n", style="bold cyan")
output.append(f" Server: {self.server_url}\n", style="dim")
output.append(f" Data points: {len(global_steps)}\n\n", style="dim")

# Draw y-axis labels and chart
y_labels = []
for i in range(chart_height):
y_val = y_max - (i / (chart_height - 1)) * y_range if chart_height > 1 else y_max
y_labels.append(y_val)

for i, row in enumerate(chart):
# Y-axis label (only show a few)
if i == 0 or i == chart_height - 1 or i == chart_height // 2:
label = f"{y_labels[i]:8.3f} |"
else:
label = " |"
output.append(label, style="dim")
output.append(''.join(row), style="green")
output.append("\n")

# X-axis
output.append(" +" + "-" * chart_width + "\n", style="dim")

# X-axis labels
x_label_line = " "
x_label_line += f"{x_min:<{chart_width // 3}}"
mid_step = x_min + x_range // 2
x_label_line += f"{mid_step:^{chart_width // 3}}"
x_label_line += f"{x_max:>{chart_width // 3}}"
output.append(x_label_line[:chart_width + 10] + "\n", style="dim")
output.append(" " + " " * (chart_width // 2 - 5) + "Global Step\n", style="dim cyan")

# Statistics
output.append("\n Statistics:\n", style="bold yellow")
output.append(f" Latest Global Step: {global_steps[-1]}\n", style="green")
output.append(f" Latest Mean Reward: {mean_rewards[-1]:.4f}\n", style="green")
output.append(f" Min Mean Reward: {min(mean_rewards):.4f} (step {global_steps[mean_rewards.index(min(mean_rewards))]})\n", style="cyan")
output.append(f" Max Mean Reward: {max(mean_rewards):.4f} (step {global_steps[mean_rewards.index(max(mean_rewards))]})\n", style="cyan")

self.console.print(output)
self.console.print("\n[dim]Press Enter to return to menu...[/dim]")
input()

def display_latest_llm_call(self):
while True:
Expand Down Expand Up @@ -515,6 +667,7 @@ def choose_run(self) -> str:
self.console.print("\n[bold]Choose action:[/bold]")
self.console.print(" [bold cyan]o[/bold cyan] - Return to overwatch")
self.console.print(" [bold cyan]t[/bold cyan] - Show replay_latest_llm_call")
self.console.print(" [bold cyan]c[/bold cyan] - Show reward curve")
self.console.print(" [bold cyan]ctrl+c[/bold cyan] - Exit")
choice = input("\n> ").strip().lower()

Expand All @@ -526,8 +679,12 @@ def choose_run(self) -> str:
mode = "replay_latest_llm_call"
self.console.clear()
continue
elif choice == "c":
self.display_reward_curve()
self.console.clear()
continue
else:
self.console.print("[yellow]Invalid choice. Please enter 'o' or 't'.[/yellow]")
self.console.print("[yellow]Invalid choice. Please enter 'o', 't', or 'c'.[/yellow]")

def run(self):
"""Start the monitoring interface"""
Expand Down
Loading