-
Notifications
You must be signed in to change notification settings - Fork 222
[tx] Add GPU memory benchmark script #892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Add comprehensive benchmark tool for measuring GPU memory usage during sampling and training operations. Features include: - Configurable batch sizes and sequence lengths (comma-separated) - Support for sample, train, or both test modes - Early termination on failure (skips remaining batch sizes per seq_len) - GPU memory monitoring via nvidia-smi polling - Per-test server logs and optional XLA HLO graph dumps - Configurable JAX/XLA environment (preallocate, allocator, etc.) - JSON config and CSV results output to timestamped directory - JIT compilation log capture in final report Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Move batch_size and mode to ServerManager constructor - Add start_and_wait_ready() method for unified server startup - Add kill_existing_servers() static method - Remove train_micro_batch_size/sample_max_num_sequences from BenchmarkConfig - Remove _make_server_config() and _kill_existing_processes() from BenchmarkRunner - Server-only mode now uses same code path as regular mode Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove unused running() context manager - Remove dead signal handler code (globals never assigned) - Remove unused contextmanager import - Fix potential file handle leak in start() with try/except - Default gpu_allocator to empty string with comment Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Monitor the engine subprocess in a background task. If the engine exits unexpectedly, exit the API server immediately to prevent hanging benchmarks. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive GPU memory benchmark script and an important fix to prevent the API server from hanging when its background engine crashes. The new benchmark script is well-designed and feature-rich. The fix in the API server correctly addresses the hanging issue. My review focuses on making the server shutdown process more graceful, improving the robustness and error reporting of the new benchmark script, and suggesting a minor refactoring for better maintainability.
| def start(self) -> None: | ||
| """Start background GPU monitoring thread.""" | ||
| self._stop_event.clear() | ||
| self._peak_memory = 0 | ||
| self._thread = threading.Thread(target=self._monitor_loop, daemon=True) | ||
| self._thread.start() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If nvidia-smi is not found, the script silently fails to monitor GPU memory and reports 0 MiB, which can be misleading for a memory benchmark. It would be more robust to check for nvidia-smi's availability at the start of monitoring. If it's not present or fails to run, a warning should be printed to stderr and the monitoring thread should not be started. This makes the script's behavior more transparent to the user.
| def start(self) -> None: | |
| """Start background GPU monitoring thread.""" | |
| self._stop_event.clear() | |
| self._peak_memory = 0 | |
| self._thread = threading.Thread(target=self._monitor_loop, daemon=True) | |
| self._thread.start() | |
| def start(self) -> None: | |
| """Start background GPU monitoring thread.""" | |
| # Check for nvidia-smi once at the beginning | |
| try: | |
| subprocess.run(["nvidia-smi"], capture_output=True, check=True, timeout=5.0) | |
| except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired): | |
| print("WARNING: `nvidia-smi` not found or failed to run. GPU memory usage will be reported as 0.", file=sys.stderr) | |
| # Don't start the thread if nvidia-smi is not available. | |
| # The stop() method will then return the initial peak_memory of 0. | |
| return | |
| self._stop_event.clear() | |
| self._peak_memory = 0 | |
| self._thread = threading.Thread(target=self._monitor_loop, daemon=True) | |
| self._thread.start() |
|
|
||
| except Exception as e: | ||
| result.error_message = str(e) | ||
| result.status = "ERROR" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When an exception occurs during a test (e.g., an out-of-memory error), the peak GPU memory is not recorded. The gpu_monitor has been running and likely captured the peak memory usage right before the failure, which is valuable information for debugging. You should record this peak memory value in the except block.
| result.status = "ERROR" | |
| result.peak_gpu_mem_mib = gpu_monitor.stop() |
skyrl-tx/tx/tinker/api.py
Outdated
| exit_code = await loop.run_in_executor(None, background_engine.wait) | ||
| if not shutting_down: | ||
| logger.error(f"Background engine crashed with exit code {exit_code}, exiting API server") | ||
| os._exit(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using os._exit(1) causes a hard exit, bypassing cleanup handlers (like finally blocks) and potentially leading to resource leaks or data corruption. A more graceful shutdown can be achieved by sending a SIGTERM signal to the process. uvicorn will catch this signal and shut down cleanly. Note that this will require importing the signal module.
| os._exit(1) | |
| os.kill(os.getpid(), signal.SIGTERM) |
| def kill_existing_servers() -> None: | ||
| """Kill any existing server processes.""" | ||
| subprocess.run(["pkill", "-f", "tx.tinker.api"], capture_output=True) | ||
| time.sleep(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using pkill -f is a powerful command that can be risky on a shared machine. It could unintentionally terminate processes belonging to other users if their command line also contains tx.tinker.api. For a more targeted and safer approach, consider managing the server's process ID (PID) using a PID file. The server could write its PID to a file on startup, and this script could then read that file to kill the specific process. While the current implementation is likely fine for a controlled benchmark environment, this change would make it more robust for shared systems.
| # Save config to output directory | ||
| config_dict = { | ||
| "base_model": config.base_model, | ||
| "tp_size": config.tp_size, | ||
| "max_lora_adapters": config.max_lora_adapters, | ||
| "gradient_checkpointing": config.gradient_checkpointing, | ||
| "extra_backend_config": config.extra_backend_config, | ||
| "test_mode": config.test_mode, | ||
| "batch_sizes": config.batch_sizes, | ||
| "seq_lens": config.seq_lens, | ||
| "server_only": config.server_only, | ||
| "output_root": str(config.output_root), | ||
| "xla_preallocate": config.xla_preallocate, | ||
| "gpu_allocator": config.gpu_allocator, | ||
| "jax_log_compiles": config.jax_log_compiles, | ||
| "dump_xla": config.dump_xla, | ||
| "timestamp": datetime.now().isoformat(), | ||
| } | ||
| with open(config.output_dir / "config.json", "w") as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The manual creation of config_dict for JSON serialization is verbose and duplicates the structure of the BenchmarkConfig dataclass. This can lead to inconsistencies if BenchmarkConfig is updated but this dictionary is not. A more maintainable approach is to use dataclasses.asdict to convert the config object to a dictionary, then filter and format it for saving. This ensures the saved configuration always reflects the BenchmarkConfig definition.
| # Save config to output directory | |
| config_dict = { | |
| "base_model": config.base_model, | |
| "tp_size": config.tp_size, | |
| "max_lora_adapters": config.max_lora_adapters, | |
| "gradient_checkpointing": config.gradient_checkpointing, | |
| "extra_backend_config": config.extra_backend_config, | |
| "test_mode": config.test_mode, | |
| "batch_sizes": config.batch_sizes, | |
| "seq_lens": config.seq_lens, | |
| "server_only": config.server_only, | |
| "output_root": str(config.output_root), | |
| "xla_preallocate": config.xla_preallocate, | |
| "gpu_allocator": config.gpu_allocator, | |
| "jax_log_compiles": config.jax_log_compiles, | |
| "dump_xla": config.dump_xla, | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| with open(config.output_dir / "config.json", "w") as f: | |
| # Save config to output directory | |
| import dataclasses | |
| config_dict = dataclasses.asdict(config) | |
| # Filter and format for JSON serialization | |
| keys_to_save = [ | |
| "base_model", "tp_size", "max_lora_adapters", "gradient_checkpointing", | |
| "extra_backend_config", "test_mode", "batch_sizes", "seq_lens", | |
| "server_only", "output_root", "xla_preallocate", "gpu_allocator", | |
| "jax_log_compiles", "dump_xla" | |
| ] | |
| json_config = {k: config_dict[k] for k in keys_to_save if k in config_dict} | |
| json_config["output_root"] = str(json_config["output_root"]) | |
| json_config["timestamp"] = datetime.now().isoformat() | |
| with open(config.output_dir / "config.json", "w") as f: | |
| json.dump(json_config, f, indent=2) |
Add --train-micro-batch-size and --sample-max-num-sequences CLI args to allow independent control of train and sample batch sizes. If not set, both default to the current --batch-sizes value being tested. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add is_alive() method to check if server process is running - Check server aliveness in wait_ready() and start_and_wait_ready() - Distinguish "crashed" vs "timed out" errors during startup - Print last 30 lines of server log when server crashes Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive GPU memory benchmark script and a crucial fix to the API server to prevent it from hanging when the background engine crashes. The benchmark script is well-structured and feature-rich, covering various test modes and parameters. The fix in the API server correctly uses an asyncio task to monitor the engine subprocess and gracefully shut down the server on a crash.
My review includes a few suggestions to improve maintainability and adhere to modern Python best practices. Specifically, I've recommended moving a local import to the top level in the benchmark script, simplifying configuration object creation for better maintainability, and using the more recent asyncio.to_thread function in the API server for cleaner asynchronous code.
|
|
||
| def wait_ready(self, timeout: float = 120.0) -> bool: | ||
| """Wait for server to respond to health check.""" | ||
| import httpx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| config = BenchmarkConfig( | ||
| base_model=args.base_model, | ||
| tp_size=args.tp_size, | ||
| max_lora_adapters=args.max_lora_adapters, | ||
| gradient_checkpointing=args.gradient_checkpointing, | ||
| train_micro_batch_size=args.train_micro_batch_size, | ||
| sample_max_num_sequences=args.sample_max_num_sequences, | ||
| extra_backend_config=args.backend_config, | ||
| test_mode=args.mode, | ||
| batch_sizes=args.batch_sizes, | ||
| seq_lens=args.seq_lens, | ||
| server_only=args.server_only, | ||
| host=args.host, | ||
| port=args.port, | ||
| experiment_name=args.experiment_name, | ||
| output_root=args.output_root, | ||
| gpu_poll_interval=args.gpu_poll_interval, | ||
| xla_preallocate=args.xla_preallocate, | ||
| gpu_allocator=args.gpu_allocator, | ||
| jax_log_compiles=args.jax_log_compiles, | ||
| dump_xla=args.dump_xla, | ||
| output_dir=output_dir, | ||
| db_path=output_dir / "tinker.db", | ||
| csv_path=output_dir / "results.csv", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The instantiation of BenchmarkConfig is verbose as it manually lists all arguments from args. This can be simplified and made more maintainable by creating a dictionary from vars(args) and unpacking it. This way, if you add new arguments to parse_args and BenchmarkConfig, they will be handled automatically without needing to update this section.
# Build configuration with derived paths
config_data = {k: v for k, v in vars(args).items() if hasattr(BenchmarkConfig, k)}
config = BenchmarkConfig(
**config_data,
output_dir=output_dir,
db_path=output_dir / "tinker.db",
csv_path=output_dir / "results.csv",
)
skyrl-tx/tx/tinker/api.py
Outdated
| loop = asyncio.get_event_loop() | ||
| exit_code = await loop.run_in_executor(None, background_engine.wait) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using loop.run_in_executor is correct for running blocking code in an async application. However, since Python 3.9, asyncio.to_thread provides a higher-level and more convenient API for this purpose. It's generally recommended to use asyncio.to_thread for better readability and conciseness.
| loop = asyncio.get_event_loop() | |
| exit_code = await loop.run_in_executor(None, background_engine.wait) | |
| exit_code = await asyncio.to_thread(background_engine.wait) |
- Move httpx import to top-level - Use vars(args) with arg_renames for config creation - Use asyncio.to_thread() instead of loop.run_in_executor() - Print warning when killing existing servers (shows PIDs) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Instead of running tests in a thread pool, poll result(timeout=5) directly and check server.is_alive() on each timeout. This avoids orphaned threads and properly detects server crashes since the SDK retries connection errors indefinitely. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When the engine crashes, SIGTERM requests graceful shutdown but may hang if active requests are blocked. Add a background thread that force-exits after 10s timeout as a fallback. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Results are now written to CSV as soon as each test finishes, rather than waiting until all tests complete. This allows monitoring progress in real-time. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Summary
Changes
New:
benchmarks/benchmark_memory.pyA comprehensive benchmark tool for profiling GPU memory consumption:
Fix:
tx/tinker/api.pyThe API server spawns a background engine subprocess. Previously, if the engine crashed, the API server stayed alive but coul
dn't process requests, causing benchmarks to hang indefinitely.
Fixed by adding a monitor task that:
Test Plan
Will be used for benchmarking various optimizations for this issue #891