-
Notifications
You must be signed in to change notification settings - Fork 53
Add agent cli #682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add agent cli #682
Changes from all commits
4ff544e
945a042
d26ab30
9ec566e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| def __getattr__(name): | ||
| if name == "Session": | ||
| from kwave.cli.session import Session | ||
|
|
||
| return Session | ||
| raise AttributeError(f"module {__name__!r} has no attribute {name!r}") | ||
|
|
||
|
|
||
| __all__ = ["Session"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| """Phantom generation and loading commands.""" | ||
|
|
||
| import click | ||
| import numpy as np | ||
|
|
||
| from kwave.cli.main import pass_session | ||
| from kwave.cli.schema import CLIError, CLIResponse, ValidationError, json_command | ||
|
|
||
|
|
||
| def _parse_int_tuple(s: str) -> tuple[int, ...]: | ||
| return tuple(int(x) for x in s.split(",")) | ||
|
|
||
|
|
||
| def _resolve_scalar_or_path(value: str, name: str, sess) -> dict: | ||
| """Parse a CLI value as scalar float or .npy path. Returns {name_scalar, name_path} dict.""" | ||
| if value.endswith(".npy"): | ||
| arr = np.load(value) | ||
| path = sess.save_array(name, arr) | ||
| return {f"{name}_scalar": None, f"{name}_path": path} | ||
| return {f"{name}_scalar": float(value), f"{name}_path": None} | ||
|
|
||
|
|
||
| @click.group("phantom") | ||
| def phantom(): | ||
| """Define the simulation phantom (medium + initial pressure).""" | ||
| pass | ||
|
|
||
|
|
||
| @phantom.command("load") | ||
| @click.option("--grid-size", required=True, help="Grid dimensions, e.g. 512 or 128,128") | ||
| @click.option("--spacing", required=True, type=float, help="Grid spacing in meters") | ||
| @click.option("--sound-speed", required=True, help="Scalar value (m/s) or path to .npy file") | ||
| @click.option("--density", default=None, help="Scalar value (kg/m^3) or path to .npy file") | ||
| @click.option("--cfl", type=float, default=None, help="CFL number for time step calculation") | ||
| @pass_session | ||
| @json_command("phantom.load") | ||
| def load(sess, grid_size, spacing, sound_speed, density, cfl): | ||
| """Load medium properties from scalar values or .npy files.""" | ||
| sess.load() | ||
|
|
||
| grid_n = _parse_int_tuple(grid_size) | ||
| ndim = len(grid_n) | ||
| grid_spacing = (spacing,) * ndim | ||
|
|
||
| medium_state = _resolve_scalar_or_path(sound_speed, "sound_speed", sess) | ||
| if density is not None: | ||
| medium_state.update(_resolve_scalar_or_path(density, "density", sess)) | ||
|
|
||
| grid_state = {"N": list(grid_n), "spacing": list(grid_spacing)} | ||
| if cfl is not None: | ||
| grid_state["cfl"] = cfl | ||
|
|
||
| sess.update_many({"grid": grid_state, "medium": medium_state}) | ||
|
|
||
| return CLIResponse( | ||
| result={"grid_size": list(grid_n), "spacing": list(grid_spacing), "medium": medium_state}, | ||
| derived={"ndim": ndim, "grid_points": int(np.prod(grid_n))}, | ||
| ) | ||
|
|
||
|
|
||
| @phantom.command() | ||
| @click.option("--type", "phantom_type", required=True, type=click.Choice(["disc", "spherical", "layered"])) | ||
| @click.option("--grid-size", required=True, help="Grid dimensions, e.g. 128,128") | ||
| @click.option("--spacing", required=True, type=float, help="Grid spacing in meters, e.g. 0.1e-3") | ||
| @click.option("--sound-speed", type=float, default=1500, help="Medium sound speed (m/s)") | ||
| @click.option("--density", type=float, default=1000, help="Medium density (kg/m^3)") | ||
| @click.option("--disc-center", default=None, help="Disc center, e.g. 64,64") | ||
| @click.option("--disc-radius", type=int, default=5, help="Disc radius in grid points") | ||
| @pass_session | ||
| @json_command("phantom.generate") | ||
| def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_center, disc_radius): | ||
| """Generate an analytical phantom.""" | ||
| sess.load() | ||
|
|
||
| grid_n = _parse_int_tuple(grid_size) | ||
| ndim = len(grid_n) | ||
| grid_spacing = (spacing,) * ndim | ||
|
|
||
| if phantom_type == "disc": | ||
| if ndim != 2: | ||
| raise ValidationError( | ||
| CLIError( | ||
| code="DISC_REQUIRES_2D", | ||
| field="grid_size", | ||
| value=grid_size, | ||
| constraint="disc phantom requires 2D grid", | ||
| suggestion="Use --grid-size Nx,Ny (two dimensions)", | ||
| ) | ||
| ) | ||
| from kwave.data import Vector | ||
| from kwave.utils.mapgen import make_disc | ||
|
|
||
| if disc_center is None: | ||
| center = Vector([n // 2 for n in grid_n]) | ||
| else: | ||
| center = Vector(_parse_int_tuple(disc_center)) | ||
|
|
||
| p0 = make_disc(Vector(list(grid_n)), center, disc_radius).astype(float) | ||
|
|
||
| elif phantom_type == "spherical": | ||
| center = np.array([n // 2 for n in grid_n]) | ||
| coords = np.mgrid[tuple(slice(0, n) for n in grid_n)] | ||
| dist = np.sqrt(sum((c - cn) ** 2 for c, cn in zip(coords, center))) | ||
| p0 = (dist <= disc_radius).astype(float) | ||
|
|
||
| elif phantom_type == "layered": | ||
| p0 = np.zeros(grid_n) | ||
| layer_pos = grid_n[0] // 4 | ||
| p0[layer_pos, ...] = 1.0 | ||
|
|
||
| p0_path = sess.save_array("p0", p0) | ||
|
|
||
| sess.update_many( | ||
| { | ||
| "grid": {"N": list(grid_n), "spacing": list(grid_spacing)}, | ||
| "medium": {"sound_speed_scalar": sound_speed, "sound_speed_path": None, "density_scalar": density, "density_path": None}, | ||
| "source": {"type": "initial-pressure", "p0_path": p0_path}, | ||
| } | ||
| ) | ||
|
|
||
| return CLIResponse( | ||
| result={ | ||
| "phantom_type": phantom_type, | ||
| "grid_size": list(grid_n), | ||
| "spacing": list(grid_spacing), | ||
| "p0_shape": list(p0.shape), | ||
| "p0_max": float(p0.max()), | ||
| "sound_speed": sound_speed, | ||
| "density": density, | ||
| }, | ||
| derived={"ndim": ndim, "grid_points": int(np.prod(grid_n))}, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| """Plan command: derive full simulation config, validate, estimate cost.""" | ||
|
|
||
| import click | ||
| import numpy as np | ||
|
|
||
| from kwave.cli.main import pass_session | ||
| from kwave.cli.schema import CLIResponse, json_command | ||
|
|
||
|
|
||
| @click.command("plan") | ||
| @pass_session | ||
| @json_command("plan") | ||
| def plan(sess): | ||
| """Derive full simulation config and validate before running.""" | ||
| sess.load() | ||
| sess.assert_ready("plan") | ||
|
|
||
| kgrid = sess.make_grid() | ||
| medium = sess.make_medium() | ||
|
|
||
| grid_n = tuple(int(n) for n in kgrid.N) | ||
| spacing = tuple(float(d) for d in kgrid.spacing) | ||
| ndim = len(grid_n) | ||
| grid_points = int(np.prod(grid_n)) | ||
| dt = float(kgrid.dt) | ||
| Nt = int(kgrid.Nt) | ||
|
|
||
| c_max = float(np.max(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed) | ||
| c_min = float(np.min(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed) | ||
| cfl = c_max * dt / min(spacing) | ||
|
|
||
| n_fields = 3 + 2 * ndim | ||
| memory_mb = grid_points * n_fields * 8 / (1024 * 1024) | ||
| estimated_runtime_s = grid_points * Nt * 50e-9 # ~50ns per grid point per step on CPU | ||
|
|
||
| pml_size = 20 | ||
|
|
||
| warnings = [] | ||
| if cfl > 0.5: | ||
| warnings.append( | ||
| { | ||
| "code": "HIGH_CFL", | ||
| "detail": f"CFL={cfl:.3f} exceeds 0.5, simulation may be unstable", | ||
| "suggestion": "Reduce time step or increase grid spacing", | ||
| } | ||
| ) | ||
|
|
||
| result = { | ||
| "grid": { | ||
| "N": list(grid_n), | ||
| "spacing": list(spacing), | ||
| "ndim": ndim, | ||
| "dt": dt, | ||
| "Nt": Nt, | ||
| }, | ||
| "pml": {"size": pml_size}, | ||
| "medium": { | ||
| "sound_speed": c_min if c_min == c_max else f"{c_min}-{c_max}", | ||
| }, | ||
| "source": sess.state["source"], | ||
| "sensor": sess.state["sensor"], | ||
| "backend": "python", | ||
| "device": "cpu", | ||
| } | ||
|
|
||
| derived = { | ||
| "cfl": round(cfl, 4), | ||
| "grid_points": grid_points, | ||
| "estimated_memory_mb": round(memory_mb, 1), | ||
| "estimated_runtime_s": round(estimated_runtime_s, 1), | ||
| } | ||
|
|
||
| return CLIResponse( | ||
| result=result, | ||
| derived=derived, | ||
| warnings=warnings, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| """Run command: execute simulation with structured JSON progress.""" | ||
|
|
||
| import json | ||
| import time | ||
|
|
||
| import click | ||
| import numpy as np | ||
|
|
||
| from kwave.cli.main import pass_session | ||
| from kwave.cli.schema import CLIResponse, json_command | ||
|
|
||
|
|
||
| def _emit_event(event: dict): | ||
| """Write a JSON event to stdout and flush.""" | ||
| click.echo(json.dumps(event, default=str)) | ||
|
|
||
|
|
||
| @click.command("run") | ||
| @click.option("--backend", default="python", type=click.Choice(["python", "cpp"])) | ||
| @click.option("--device", default="cpu", type=click.Choice(["cpu", "gpu"])) | ||
| @pass_session | ||
| @json_command("run") | ||
| def run(sess, backend, device): | ||
| """Execute the simulation.""" | ||
| sess.load() | ||
| sess.assert_ready("run") | ||
|
|
||
| kgrid = sess.make_grid() | ||
| medium = sess.make_medium() | ||
| source = sess.make_source() | ||
| sensor = sess.make_sensor() | ||
|
|
||
| Nt = int(kgrid.Nt) | ||
|
|
||
| _emit_event({"event": "started", "backend": backend, "device": device, "Nt": Nt}) | ||
|
|
||
| t_start = time.time() | ||
| last_pct = -5 # emit at most every 5% | ||
|
|
||
| def progress_callback(step, total): | ||
| nonlocal last_pct | ||
| pct = round(100 * step / total, 1) | ||
| if pct - last_pct >= 5 or step == total: | ||
| last_pct = pct | ||
| _emit_event( | ||
| { | ||
| "event": "progress", | ||
| "step": step, | ||
| "total": total, | ||
| "pct": pct, | ||
| "elapsed_s": round(time.time() - t_start, 2), | ||
| } | ||
| ) | ||
|
|
||
| from kwave.kspaceFirstOrder import kspaceFirstOrder | ||
|
|
||
| result = kspaceFirstOrder( | ||
| kgrid, | ||
| medium, | ||
| source, | ||
| sensor, | ||
| backend=backend, | ||
| device=device, | ||
| quiet=True, | ||
| progress_callback=progress_callback, | ||
| ) | ||
|
Comment on lines
+57
to
+66
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Consider emitting a warning event when if backend != "python":
_emit_event({"event": "warning", "detail": "progress_callback is only supported with the python backend"}) |
||
|
|
||
| elapsed = round(time.time() - t_start, 2) | ||
|
|
||
| # Save results | ||
| result_info = {} | ||
| for key, val in result.items(): | ||
| if isinstance(val, np.ndarray): | ||
| path = sess.save_array(f"result_{key}", val) | ||
| result_info[key] = {"shape": list(val.shape), "path": path} | ||
| else: | ||
| result_info[key] = val | ||
|
|
||
| sess.update("result_path", str(sess.data_dir)) | ||
|
|
||
| _emit_event({"event": "completed", "elapsed_s": elapsed, "output_keys": list(result.keys())}) | ||
|
|
||
| return CLIResponse( | ||
| result={ | ||
| "elapsed_s": elapsed, | ||
| "outputs": result_info, | ||
| }, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| """Sensor definition command.""" | ||
|
|
||
| import click | ||
|
|
||
| from kwave.cli.main import pass_session | ||
| from kwave.cli.schema import CLIResponse, json_command | ||
|
|
||
|
|
||
| @click.group("sensor") | ||
| def sensor(): | ||
| """Define sensor configuration.""" | ||
| pass | ||
|
|
||
|
|
||
| @sensor.command() | ||
| @click.option("--mask", required=True, help="Sensor mask: 'full-grid' or path to .npy file") | ||
| @click.option("--record", default="p,p_final", help="Comma-separated fields to record, e.g. p,p_final,ux") | ||
| @pass_session | ||
| @json_command("sensor.define") | ||
| def define(sess, mask, record): | ||
| """Define what and where to record.""" | ||
| sess.load() | ||
|
|
||
| record_fields = [r.strip() for r in record.split(",")] | ||
|
|
||
| sensor_config = {"record": record_fields} | ||
| if mask == "full-grid": | ||
| sensor_config["mask_type"] = "full-grid" | ||
| else: | ||
| sensor_config["mask_type"] = "file" | ||
| sensor_config["mask_path"] = mask | ||
|
|
||
| sess.update("sensor", sensor_config) | ||
|
|
||
| return CLIResponse( | ||
| result={ | ||
| "mask_type": sensor_config["mask_type"], | ||
| "record": record_fields, | ||
| } | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| """Session management commands.""" | ||
|
|
||
| import click | ||
|
|
||
| from kwave.cli.main import pass_session | ||
| from kwave.cli.schema import CLIResponse, json_command | ||
|
|
||
|
|
||
| @click.group("session") | ||
| def session(): | ||
| """Manage simulation session.""" | ||
| pass | ||
|
|
||
|
|
||
| @session.command() | ||
| @pass_session | ||
| @json_command("session.init") | ||
| def init(sess): | ||
| """Create a new session.""" | ||
| info = sess.init() | ||
| return CLIResponse(result=info) | ||
|
|
||
|
|
||
| @session.command() | ||
| @pass_session | ||
| @json_command("session.status") | ||
| def status(sess): | ||
| """Return full current session state.""" | ||
| sess.load() | ||
| return CLIResponse(result=sess.status()) | ||
|
|
||
|
|
||
| @session.command() | ||
| @pass_session | ||
| @json_command("session.reset") | ||
| def reset(sess): | ||
| """Clear session state.""" | ||
| info = sess.reset() | ||
| return CLIResponse(result=info) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int()andfloat()conversions produce unformatted errors_parse_int_tupleand_resolve_scalar_or_pathcallint()/float()on raw user input without catchingValueError. An invalid value like--grid-size 128,abcor--sound-speed notanumberraises a bare Python traceback rather than a structuredValidationError/CLIErrorJSON response, breaking the agent-parseable contract that the rest of the CLI upholds.Consider wrapping these with explicit error handling: