-
Notifications
You must be signed in to change notification settings - Fork 7
construct the framework of agent-v1 #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,246 @@ | ||
| import argparse | ||
| import copy | ||
| import os | ||
| import os.path as osp | ||
| from concurrent.futures import ThreadPoolExecutor, as_completed | ||
| from typing import Any, Dict, List, Tuple | ||
|
|
||
| from tqdm import tqdm | ||
|
|
||
| from scieval.agents.records import EvalRecord, TrajectoryStore | ||
| from scieval.agents.smolagents import SmolAgentsAgent | ||
| from scieval.dataset import build_dataset | ||
| from scieval.smp import dump, get_logger, load, timestr, githash, ls | ||
|
|
||
|
|
||
| def _build_dataset_from_config(cfg: Dict[str, Any], dataset_name: str): | ||
| import inspect | ||
| import scieval.dataset as dataset_mod | ||
|
|
||
| config = copy.deepcopy(cfg[dataset_name]) | ||
| if config == {}: | ||
| return build_dataset(dataset_name) | ||
| if "class" not in config: | ||
| return build_dataset(dataset_name, **config) | ||
| cls_name = config.pop("class") | ||
| if hasattr(dataset_mod, cls_name): | ||
| cls = getattr(dataset_mod, cls_name) | ||
| sig = inspect.signature(cls.__init__) | ||
| valid_params = {k: v for k, v in config.items() if k in sig.parameters} | ||
| return cls(**valid_params) | ||
| raise ValueError(f"Dataset class {cls_name} is not supported in scieval.dataset") | ||
|
|
||
|
|
||
| def _build_agent_from_config(cfg: Dict[str, Any], agent_name: str): | ||
| config = copy.deepcopy(cfg[agent_name]) | ||
| cls_name = config.pop("class", "SmolAgentsAgent") | ||
| if cls_name not in ["SmolAgentsAgent", "smolagents"]: | ||
| raise ValueError(f"Unsupported agent class: {cls_name}") | ||
| return SmolAgentsAgent(**config) | ||
|
|
||
|
|
||
| def _run_one_sample( | ||
| idx: int, | ||
| agent, | ||
| dataset, | ||
| store: TrajectoryStore, | ||
| judge_kwargs: Dict[str, Any], | ||
| reuse: bool, | ||
| do_infer: bool, | ||
| do_eval: bool, | ||
| ) -> Tuple[int, Dict[str, Any], str]: | ||
| final_answer = "" | ||
| traj = store.load_traj(idx) if reuse else None | ||
| if do_infer: | ||
| if traj and traj.get("success"): | ||
| final_answer = traj.get("final_answer", "") | ||
| else: | ||
| sample = dataset.build_agent_sample(idx) | ||
| result = agent.run(sample) | ||
| store.save_traj(idx, result) | ||
| final_answer = result.final_answer | ||
| elif traj: | ||
| final_answer = traj.get("final_answer", "") | ||
|
|
||
| if not do_eval: | ||
| return idx, {}, final_answer | ||
|
|
||
| eval_cached = store.load_eval(idx) if reuse else None | ||
| if eval_cached is not None: | ||
| cached_score = eval_cached.get("score", eval_cached) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fallback value in |
||
| cached_final = eval_cached.get("final_answer", final_answer) | ||
| return idx, cached_score, cached_final | ||
|
|
||
| score = dataset.score_agent_sample(idx, final_answer, **judge_kwargs) | ||
| metadata = {} | ||
| if "question" in score: | ||
| metadata["question"] = score["question"] | ||
| if "answer" in score: | ||
| metadata["answer"] = score["answer"] | ||
| record = EvalRecord(index=idx, final_answer=final_answer, score=score, metadata=metadata) | ||
| store.save_eval(idx, record) | ||
| return idx, score, final_answer | ||
|
|
||
|
|
||
| def _is_number(value: Any) -> bool: | ||
| return isinstance(value, (int, float)) and not isinstance(value, bool) | ||
|
|
||
|
|
||
| def run_agent_eval( | ||
| agent, | ||
| dataset, | ||
| work_dir: str, | ||
| nproc: int = 1, | ||
| reuse: bool = False, | ||
| mode: str = "all", | ||
| judge_kwargs: Dict[str, Any] = None, | ||
| ): | ||
| logger = get_logger("AGENT_EVAL") | ||
| judge_kwargs = judge_kwargs or {} | ||
| dataset_name = getattr(dataset, "dataset_name", dataset.__class__.__name__) | ||
| root_dir = osp.join(work_dir, "agent_eval", dataset_name, agent.name, agent.model_version) | ||
| eval_id = f"T{timestr('day')}_G{githash(digits=8)}" | ||
| log_dir = osp.join(root_dir, eval_id) | ||
| if reuse and osp.exists(root_dir): | ||
| prev_runs = ls(root_dir, mode="dir") | ||
| if prev_runs: | ||
| prev_runs.sort() | ||
| log_dir = prev_runs[-1] | ||
| store = TrajectoryStore(log_dir) | ||
| logger.info(f"Logging directory: {log_dir}") | ||
|
|
||
| do_infer = mode in ["all", "infer"] | ||
| do_eval = mode in ["all", "eval"] | ||
|
|
||
| results: List[Tuple[int, Dict[str, Any], str]] = [] | ||
| tasks = list(range(len(dataset))) | ||
| tasks_to_run = tasks | ||
| if reuse: | ||
| tasks_to_run = [] | ||
| for idx in tasks: | ||
| if do_eval: | ||
| eval_cached = store.load_eval(idx) | ||
| if eval_cached is not None: | ||
| cached_score = eval_cached.get("score", eval_cached) | ||
| cached_final = eval_cached.get("final_answer", "") | ||
| if not cached_final: | ||
| traj = store.load_traj(idx) | ||
| if traj is not None: | ||
| cached_final = traj.get("final_answer", "") | ||
| results.append((idx, cached_score, cached_final)) | ||
| continue | ||
| tasks_to_run.append(idx) | ||
| continue | ||
|
|
||
| if do_infer: | ||
| traj = store.load_traj(idx) | ||
| if traj and traj.get("success"): | ||
| results.append((idx, {}, traj.get("final_answer", ""))) | ||
| else: | ||
| tasks_to_run.append(idx) | ||
| else: | ||
| tasks_to_run.append(idx) | ||
|
|
||
| if nproc > 1: | ||
| with ThreadPoolExecutor(max_workers=nproc) as executor: | ||
| futures = [ | ||
| executor.submit( | ||
| _run_one_sample, | ||
| idx, | ||
| agent, | ||
| dataset, | ||
| store, | ||
| judge_kwargs, | ||
| reuse, | ||
| do_infer, | ||
| do_eval, | ||
| ) | ||
| for idx in tasks_to_run | ||
| ] | ||
| with tqdm(total=len(tasks_to_run), desc="Agent Eval", unit="sample") as pbar: | ||
| for fut in as_completed(futures): | ||
| results.append(fut.result()) | ||
| pbar.update(1) | ||
| else: | ||
| with tqdm(total=len(tasks_to_run), desc="Agent Eval", unit="sample") as pbar: | ||
| for idx in tasks_to_run: | ||
| results.append( | ||
| _run_one_sample( | ||
| idx, agent, dataset, store, judge_kwargs, reuse, do_infer, do_eval | ||
| ) | ||
| ) | ||
| pbar.update(1) | ||
|
|
||
| results.sort(key=lambda x: x[0]) | ||
| predictions = [{"index": idx, "prediction": final_answer} for idx, _, final_answer in results] | ||
| pred_file = osp.join(log_dir, f"{agent.name}_{dataset_name}.json") | ||
| dump(predictions, pred_file) | ||
|
|
||
| agg: Dict[str, List[float]] = {} | ||
| for _, score, _ in results: | ||
| for k, v in score.items(): | ||
| if _is_number(v): | ||
| agg.setdefault(k, []).append(float(v)) | ||
|
|
||
| summary = {k: (sum(v) / len(v) if v else 0.0) for k, v in agg.items()} | ||
| summary_file = osp.join(log_dir, "summary.json") | ||
| dump(summary, summary_file) | ||
| return summary | ||
|
|
||
|
|
||
| def run_agent_eval_from_config(cfg: Dict[str, Any], args) -> Dict[str, Any]: | ||
| logger = get_logger("AGENT_RUN") | ||
| agent_cfg = cfg.get("agent") or cfg.get("agents") | ||
| data_cfg = cfg.get("data") | ||
| if not agent_cfg or not data_cfg: | ||
| raise ValueError("Config must include 'agent' and 'data' sections for agent evaluation.") | ||
|
|
||
| if isinstance(agent_cfg, dict) and "class" in agent_cfg: | ||
| agents_cfg = {"agent": agent_cfg} | ||
| else: | ||
| agents_cfg = agent_cfg | ||
|
|
||
| results = {} | ||
| for agent_name in agents_cfg: | ||
| agent = _build_agent_from_config(agents_cfg, agent_name) | ||
| for dataset_name in data_cfg: | ||
| dataset = _build_dataset_from_config(data_cfg, dataset_name) | ||
| if dataset is None: | ||
| logger.error(f"Dataset {dataset_name} is not valid, skipping.") | ||
| continue | ||
| summary = run_agent_eval( | ||
| agent, | ||
| dataset, | ||
| work_dir=args.work_dir, | ||
| nproc=args.api_nproc, | ||
| reuse=args.reuse, | ||
| mode=args.mode, | ||
| judge_kwargs={ | ||
| "model": getattr(args, "judge", None), | ||
| "api_key": os.environ.get("OPENAI_API_KEY", ""), | ||
| "api_base": os.environ.get("OPENAI_API_BASE", ""), | ||
| }, | ||
|
Comment on lines
+218
to
+222
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| ) | ||
| results[f"{agent_name}:{dataset_name}"] = summary | ||
| return results | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser(description="Agent evaluation runner") | ||
| parser.add_argument("--config", type=str, required=True, help="Path to agent eval config JSON") | ||
| parser.add_argument("--work-dir", type=str, default="./outputs", help="Output directory") | ||
| parser.add_argument("--mode", type=str, default="all", choices=["all", "infer", "eval"]) | ||
| parser.add_argument("--api-nproc", type=int, default=1, help="Parallel agent calls") | ||
| parser.add_argument("--reuse", action="store_true") | ||
| parser.add_argument("--judge", type=str, default=None) | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| def main(): | ||
| args = parse_args() | ||
| cfg = load(args.config) | ||
| run_agent_eval_from_config(cfg, args) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| from .base import AgentBase, EvalSample | ||
| from .records import EvalResult, StepResult, ToolCalling, TrajectoryStore | ||
| from .smolagents import SmolAgentsAgent | ||
|
|
||
| __all__ = [ | ||
| "AgentBase", | ||
| "EvalSample", | ||
| "EvalResult", | ||
| "StepResult", | ||
| "ToolCalling", | ||
| "TrajectoryStore", | ||
| "SmolAgentsAgent", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Dict, List, Optional | ||
|
|
||
|
|
||
| class EvalSample: | ||
| def __init__( | ||
| self, | ||
| prompt: str, | ||
| images: Optional[List[str]] = None, | ||
| files: Optional[Dict[str, str]] = None, | ||
| metadata: Optional[Dict[str, str]] = None, | ||
| ): | ||
| self.prompt = prompt | ||
| self.images = images or [] | ||
| self.files = files or {} | ||
| self.metadata = metadata or {} | ||
|
|
||
|
|
||
| class AgentBase(ABC): | ||
| name = "agent" | ||
|
|
||
| def __init__(self, name: Optional[str] = None, model_version: Optional[str] = None, **kwargs): | ||
| self.name = name or getattr(self, "name", self.__class__.__name__.lower()) | ||
| self.model_version = model_version or "default" | ||
|
|
||
| @abstractmethod | ||
| def run(self, sample: EvalSample): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imports should generally be at the top of the file as per PEP 8 guidelines. While local imports can be used to avoid circular dependencies (which might be the case for
scieval.dataset), theinspectmodule is a standard library and can be safely moved to the top of the file to improve code organization.