-
Notifications
You must be signed in to change notification settings - Fork 21
Non-CPU device synchronization for accurate profiling #1015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
⚡️ Codeflash found optimizations for this PR📄 25% (0.25x) speedup for
|
|
@HeshamHM28 have a look, the staging link is not working for me |
⚡️ Codeflash found optimizations for this PR📄 55% (0.55x) speedup for
|
⚡️ Codeflash found optimizations for this PR📄 14% (0.14x) speedup for
|
|
@claude review this |
|
Claude finished @aseembits93's task in 1m 45s —— View job Code Review CompleteI've reviewed the PR that adds GPU/non-CPU device synchronization for accurate profiling. This is a significant enhancement that addresses an important gap in performance measurement for GPU-accelerated code. OverviewThis PR adds device synchronization support for accurate profiling of GPU/accelerator operations across PyTorch (CUDA/MPS), TensorFlow, and JAX. The implementation is well-designed with a pre-computation optimization to minimize timing overhead. Strengths
Issues and Concerns1. Missing Error Handling (Medium Priority)Location: The device sync code has no error handling. Framework APIs can fail in various scenarios:
Recommendation: Wrap sync calls in try-except blocks to prevent test instrumentation from breaking tests: # In _create_device_sync_statements(), wrap sync calls:
try:
torch.cuda.synchronize()
except Exception:
pass # Silently fail to avoid breaking tests2. JAX return_value Assumption (High Priority)Location: The JAX sync code calls
Recommendation: Add type checking or error handling: # Before calling block_until_ready:
if _codeflash_should_sync_jax and hasattr(return_value, '__jax_array__'):
jax.block_until_ready(return_value)3. Import Duplication Risk (Medium Priority)Location: The code adds framework imports unconditionally, which can create duplicate imports: # If test already has: import torch
# This adds another: import torch
new_imports.append(ast.Import(names=[ast.alias(name=framework_name)]))Issue: While Recommendation: Check existing imports before adding new ones. 4. Incomplete Framework Detection (Low Priority)Location:
Example: import torch as th
from torch import nn # Won't update alias to 'torch'This is minor since the first alias is usually sufficient, but could cause issues in edge cases. 5. Missing Tests (High Priority)Location: PR description explicitly mentions "TODO: write more tests as it is crucial" The test files added ( Missing test coverage:
Recommendation: Add unit tests for: def test_detect_frameworks_with_aliases()
def test_device_sync_with_cuda_unavailable()
def test_jax_sync_with_non_array_return()
def test_framework_import_deduplication()6. MPS Sync Availability Check Incomplete (Medium Priority)Location: The MPS sync check is very thorough: hasattr(torch.backends, 'mps') and
torch.backends.mps.is_available() and
hasattr(torch.mps, 'synchronize')However, it doesn't check if there's actually an MPS device being used. The sync will be called even if the code only uses CPU tensors, which is wasteful. Recommendation: Consider adding a check like: torch.mps.current_device() is not None7. TensorFlow Sync API Instability (Low Priority)Location: Using
Note: This is documented in the support matrix as "Future" for some devices, which is good. Code Quality Observations
Test File ReviewThe added test files (
Security ConsiderationsNo major security issues, but note:
Performance Considerations
Recommendations SummaryMust Fix (before merge):
Should Fix (before or soon after merge): Nice to Have: VerdictThis is a valuable feature that addresses a real gap in GPU profiling accuracy. The core design is sound with good optimization patterns. However, there are some edge cases and error handling gaps that should be addressed before merging, particularly around JAX return value handling and general error resilience. The PR is close to being merge-ready but needs the critical fixes listed above, especially adding proper error handling to prevent instrumentation from breaking tests unexpectedly. |
TODO: write more tests as it is crucial
Support Matrix