Skip to content

Conversation

@raulchen
Copy link
Contributor

Summary

  • Add GPU memory benchmark script for measuring peak GRAM usage across different batch sizes and sequence lengths
  • Fix API server to exit when the background engine subprocess crashes (prevents hanging benchmarks)

Changes

New: benchmarks/benchmark_memory.py

A comprehensive benchmark tool for profiling GPU memory consumption:

  • Test modes: sampling (inference), training (forward-backward), or both
  • Sweep parameters: configurable batch sizes and sequence lengths
  • Early termination: skips remaining batch sizes if OOM occurs
  • GPU monitoring: polls nvidia-smi for peak memory usage
  • Server lifecycle: automatic start/stop per test with log capture
  • Output: CSV results, per-test server logs, optional XLA HLO dumps
  • Server-only mode: launch server for manual testing/debugging

Fix: tx/tinker/api.py

The API server spawns a background engine subprocess. Previously, if the engine crashed, the API server stayed alive but coul
dn't process requests, causing benchmarks to hang indefinitely.

Fixed by adding a monitor task that:

  1. Waits on the engine subprocess in a background asyncio task
  2. Exits the API server immediately if the engine crashes unexpectedly
  3. Distinguishes crash vs graceful shutdown via a flag

Test Plan

Will be used for benchmarking various optimizations for this issue #891

raulchen and others added 14 commits January 16, 2026 13:16
Add comprehensive benchmark tool for measuring GPU memory usage during
sampling and training operations. Features include:

- Configurable batch sizes and sequence lengths (comma-separated)
- Support for sample, train, or both test modes
- Early termination on failure (skips remaining batch sizes per seq_len)
- GPU memory monitoring via nvidia-smi polling
- Per-test server logs and optional XLA HLO graph dumps
- Configurable JAX/XLA environment (preallocate, allocator, etc.)
- JSON config and CSV results output to timestamped directory
- JIT compilation log capture in final report

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Move batch_size and mode to ServerManager constructor
- Add start_and_wait_ready() method for unified server startup
- Add kill_existing_servers() static method
- Remove train_micro_batch_size/sample_max_num_sequences from BenchmarkConfig
- Remove _make_server_config() and _kill_existing_processes() from BenchmarkRunner
- Server-only mode now uses same code path as regular mode

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove unused running() context manager
- Remove dead signal handler code (globals never assigned)
- Remove unused contextmanager import
- Fix potential file handle leak in start() with try/except
- Default gpu_allocator to empty string with comment

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Monitor the engine subprocess in a background task. If the engine exits
unexpectedly, exit the API server immediately to prevent hanging benchmarks.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive GPU memory benchmark script and an important fix to prevent the API server from hanging when its background engine crashes. The new benchmark script is well-designed and feature-rich. The fix in the API server correctly addresses the hanging issue. My review focuses on making the server shutdown process more graceful, improving the robustness and error reporting of the new benchmark script, and suggesting a minor refactoring for better maintainability.

Comment on lines +166 to +171
def start(self) -> None:
"""Start background GPU monitoring thread."""
self._stop_event.clear()
self._peak_memory = 0
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
self._thread.start()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If nvidia-smi is not found, the script silently fails to monitor GPU memory and reports 0 MiB, which can be misleading for a memory benchmark. It would be more robust to check for nvidia-smi's availability at the start of monitoring. If it's not present or fails to run, a warning should be printed to stderr and the monitoring thread should not be started. This makes the script's behavior more transparent to the user.

Suggested change
def start(self) -> None:
"""Start background GPU monitoring thread."""
self._stop_event.clear()
self._peak_memory = 0
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
self._thread.start()
def start(self) -> None:
"""Start background GPU monitoring thread."""
# Check for nvidia-smi once at the beginning
try:
subprocess.run(["nvidia-smi"], capture_output=True, check=True, timeout=5.0)
except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired):
print("WARNING: `nvidia-smi` not found or failed to run. GPU memory usage will be reported as 0.", file=sys.stderr)
# Don't start the thread if nvidia-smi is not available.
# The stop() method will then return the initial peak_memory of 0.
return
self._stop_event.clear()
self._peak_memory = 0
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
self._thread.start()


except Exception as e:
result.error_message = str(e)
result.status = "ERROR"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When an exception occurs during a test (e.g., an out-of-memory error), the peak GPU memory is not recorded. The gpu_monitor has been running and likely captured the peak memory usage right before the failure, which is valuable information for debugging. You should record this peak memory value in the except block.

Suggested change
result.status = "ERROR"
result.peak_gpu_mem_mib = gpu_monitor.stop()

exit_code = await loop.run_in_executor(None, background_engine.wait)
if not shutting_down:
logger.error(f"Background engine crashed with exit code {exit_code}, exiting API server")
os._exit(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using os._exit(1) causes a hard exit, bypassing cleanup handlers (like finally blocks) and potentially leading to resource leaks or data corruption. A more graceful shutdown can be achieved by sending a SIGTERM signal to the process. uvicorn will catch this signal and shut down cleanly. Note that this will require importing the signal module.

Suggested change
os._exit(1)
os.kill(os.getpid(), signal.SIGTERM)

Comment on lines 305 to 308
def kill_existing_servers() -> None:
"""Kill any existing server processes."""
subprocess.run(["pkill", "-f", "tx.tinker.api"], capture_output=True)
time.sleep(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using pkill -f is a powerful command that can be risky on a shared machine. It could unintentionally terminate processes belonging to other users if their command line also contains tx.tinker.api. For a more targeted and safer approach, consider managing the server's process ID (PID) using a PID file. The server could write its PID to a file on startup, and this script could then read that file to kill the specific process. While the current implementation is likely fine for a controlled benchmark environment, this change would make it more robust for shared systems.

Comment on lines 776 to 794
# Save config to output directory
config_dict = {
"base_model": config.base_model,
"tp_size": config.tp_size,
"max_lora_adapters": config.max_lora_adapters,
"gradient_checkpointing": config.gradient_checkpointing,
"extra_backend_config": config.extra_backend_config,
"test_mode": config.test_mode,
"batch_sizes": config.batch_sizes,
"seq_lens": config.seq_lens,
"server_only": config.server_only,
"output_root": str(config.output_root),
"xla_preallocate": config.xla_preallocate,
"gpu_allocator": config.gpu_allocator,
"jax_log_compiles": config.jax_log_compiles,
"dump_xla": config.dump_xla,
"timestamp": datetime.now().isoformat(),
}
with open(config.output_dir / "config.json", "w") as f:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The manual creation of config_dict for JSON serialization is verbose and duplicates the structure of the BenchmarkConfig dataclass. This can lead to inconsistencies if BenchmarkConfig is updated but this dictionary is not. A more maintainable approach is to use dataclasses.asdict to convert the config object to a dictionary, then filter and format it for saving. This ensures the saved configuration always reflects the BenchmarkConfig definition.

Suggested change
# Save config to output directory
config_dict = {
"base_model": config.base_model,
"tp_size": config.tp_size,
"max_lora_adapters": config.max_lora_adapters,
"gradient_checkpointing": config.gradient_checkpointing,
"extra_backend_config": config.extra_backend_config,
"test_mode": config.test_mode,
"batch_sizes": config.batch_sizes,
"seq_lens": config.seq_lens,
"server_only": config.server_only,
"output_root": str(config.output_root),
"xla_preallocate": config.xla_preallocate,
"gpu_allocator": config.gpu_allocator,
"jax_log_compiles": config.jax_log_compiles,
"dump_xla": config.dump_xla,
"timestamp": datetime.now().isoformat(),
}
with open(config.output_dir / "config.json", "w") as f:
# Save config to output directory
import dataclasses
config_dict = dataclasses.asdict(config)
# Filter and format for JSON serialization
keys_to_save = [
"base_model", "tp_size", "max_lora_adapters", "gradient_checkpointing",
"extra_backend_config", "test_mode", "batch_sizes", "seq_lens",
"server_only", "output_root", "xla_preallocate", "gpu_allocator",
"jax_log_compiles", "dump_xla"
]
json_config = {k: config_dict[k] for k in keys_to_save if k in config_dict}
json_config["output_root"] = str(json_config["output_root"])
json_config["timestamp"] = datetime.now().isoformat()
with open(config.output_dir / "config.json", "w") as f:
json.dump(json_config, f, indent=2)

raulchen and others added 4 commits January 16, 2026 13:35
Add --train-micro-batch-size and --sample-max-num-sequences CLI args
to allow independent control of train and sample batch sizes. If not
set, both default to the current --batch-sizes value being tested.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add is_alive() method to check if server process is running
- Check server aliveness in wait_ready() and start_and_wait_ready()
- Distinguish "crashed" vs "timed out" errors during startup
- Print last 30 lines of server log when server crashes

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@raulchen
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive GPU memory benchmark script and a crucial fix to the API server to prevent it from hanging when the background engine crashes. The benchmark script is well-structured and feature-rich, covering various test modes and parameters. The fix in the API server correctly uses an asyncio task to monitor the engine subprocess and gracefully shut down the server on a crash.

My review includes a few suggestions to improve maintainability and adhere to modern Python best practices. Specifically, I've recommended moving a local import to the top level in the benchmark script, simplifying configuration object creation for better maintainability, and using the more recent asyncio.to_thread function in the API server for cleaner asynchronous code.


def wait_ready(self, timeout: float = 120.0) -> bool:
"""Wait for server to respond to health check."""
import httpx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better code organization and readability, it's recommended to place all imports at the top of the file. Please move import httpx to the top-level imports section of this file (e.g., around line 71). This makes dependencies clearer and avoids potential issues with delayed import errors.

Comment on lines 816 to 840
config = BenchmarkConfig(
base_model=args.base_model,
tp_size=args.tp_size,
max_lora_adapters=args.max_lora_adapters,
gradient_checkpointing=args.gradient_checkpointing,
train_micro_batch_size=args.train_micro_batch_size,
sample_max_num_sequences=args.sample_max_num_sequences,
extra_backend_config=args.backend_config,
test_mode=args.mode,
batch_sizes=args.batch_sizes,
seq_lens=args.seq_lens,
server_only=args.server_only,
host=args.host,
port=args.port,
experiment_name=args.experiment_name,
output_root=args.output_root,
gpu_poll_interval=args.gpu_poll_interval,
xla_preallocate=args.xla_preallocate,
gpu_allocator=args.gpu_allocator,
jax_log_compiles=args.jax_log_compiles,
dump_xla=args.dump_xla,
output_dir=output_dir,
db_path=output_dir / "tinker.db",
csv_path=output_dir / "results.csv",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The instantiation of BenchmarkConfig is verbose as it manually lists all arguments from args. This can be simplified and made more maintainable by creating a dictionary from vars(args) and unpacking it. This way, if you add new arguments to parse_args and BenchmarkConfig, they will be handled automatically without needing to update this section.

    # Build configuration with derived paths
    config_data = {k: v for k, v in vars(args).items() if hasattr(BenchmarkConfig, k)}
    config = BenchmarkConfig(
        **config_data,
        output_dir=output_dir,
        db_path=output_dir / "tinker.db",
        csv_path=output_dir / "results.csv",
    )

Comment on lines 71 to 72
loop = asyncio.get_event_loop()
exit_code = await loop.run_in_executor(None, background_engine.wait)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using loop.run_in_executor is correct for running blocking code in an async application. However, since Python 3.9, asyncio.to_thread provides a higher-level and more convenient API for this purpose. It's generally recommended to use asyncio.to_thread for better readability and conciseness.

Suggested change
loop = asyncio.get_event_loop()
exit_code = await loop.run_in_executor(None, background_engine.wait)
exit_code = await asyncio.to_thread(background_engine.wait)

raulchen and others added 6 commits January 16, 2026 14:36
- Move httpx import to top-level
- Use vars(args) with arg_renames for config creation
- Use asyncio.to_thread() instead of loop.run_in_executor()
- Print warning when killing existing servers (shows PIDs)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Instead of running tests in a thread pool, poll result(timeout=5)
directly and check server.is_alive() on each timeout. This avoids
orphaned threads and properly detects server crashes since the SDK
retries connection errors indefinitely.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When the engine crashes, SIGTERM requests graceful shutdown but may
hang if active requests are blocked. Add a background thread that
force-exits after 10s timeout as a fallback.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Results are now written to CSV as soon as each test finishes, rather
than waiting until all tests complete. This allows monitoring progress
in real-time.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@pcmoritz pcmoritz added the tx label Jan 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants