Skip to content

Conversation

@aseembits93
Copy link
Contributor

@aseembits93 aseembits93 commented Jan 7, 2026

TODO: write more tests as it is crucial

Support Matrix

Framework/Device CUDA(NVIDIA) TPU(Google) MPS(Apple) ROCm(AMD)
Pytorch Yes Yes (extension) Yes Yes
Tensorflow Yes Yes Future Future
Jax Yes Yes Future Yes
mlx Future Future Future Future
Cupy Future No No No

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Jan 7, 2026

⚡️ Codeflash found optimizations for this PR

📄 25% (0.25x) speedup for _create_device_sync_statements in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 2.51 milliseconds 2.00 milliseconds (best of 5 runs)

A new Optimization Review has been created.

🔗 Review here

Static Badge

@aseembits93
Copy link
Contributor Author

@HeshamHM28 have a look, the staging link is not working for me

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Jan 7, 2026

⚡️ Codeflash found optimizations for this PR

📄 55% (0.55x) speedup for detect_frameworks_from_code in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 16.8 milliseconds 10.8 milliseconds (best of 5 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch gpu-sync-instrumentation).

Static Badge

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Jan 9, 2026

⚡️ Codeflash found optimizations for this PR

📄 14% (0.14x) speedup for _create_device_sync_precompute_statements in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 691 microseconds 605 microseconds (best of 173 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch gpu-sync-instrumentation).

Static Badge

@codeflash-ai codeflash-ai deleted a comment from claude bot Jan 10, 2026
@aseembits93
Copy link
Contributor Author

@claude review this

@claude
Copy link

claude bot commented Jan 10, 2026

Claude finished @aseembits93's task in 1m 45s —— View job


Code Review Complete

I've reviewed the PR that adds GPU/non-CPU device synchronization for accurate profiling. This is a significant enhancement that addresses an important gap in performance measurement for GPU-accelerated code.


Overview

This PR adds device synchronization support for accurate profiling of GPU/accelerator operations across PyTorch (CUDA/MPS), TensorFlow, and JAX. The implementation is well-designed with a pre-computation optimization to minimize timing overhead.


Strengths

  1. Smart Pre-computation Pattern: The split between _create_device_sync_precompute_statements() and _create_device_sync_statements() is excellent. Pre-computing sync conditions outside the timing block avoids measurement bias from hasattr/is_available checks.

  2. Comprehensive Framework Coverage: Supports the major ML frameworks (PyTorch, TensorFlow, JAX) with proper device-specific handling (CUDA, MPS, TPU, ROCm).

  3. Proper Import Detection: The detect_frameworks_from_code() function handles both direct imports and aliased imports correctly.

  4. JAX Special Handling: Correctly implements JAX's block_until_ready() pattern only after function execution (not before), which is the right approach for JAX.

  5. PyTorch Device Priority: Properly checks CUDA before MPS, avoiding unnecessary MPS checks on CUDA-enabled systems.


Issues and Concerns

1. Missing Error Handling (Medium Priority)

Location: codeflash/code_utils/instrument_existing_tests.py:762-909

The device sync code has no error handling. Framework APIs can fail in various scenarios:

  • torch.cuda.synchronize() can raise if CUDA context is corrupted
  • tf.test.experimental.sync_devices() might not exist in all TF versions
  • jax.block_until_ready() can fail on non-array return values

Recommendation: Wrap sync calls in try-except blocks to prevent test instrumentation from breaking tests:

# In _create_device_sync_statements(), wrap sync calls:
try:
    torch.cuda.synchronize()
except Exception:
    pass  # Silently fail to avoid breaking tests

2. JAX return_value Assumption (High Priority)

Location: codeflash/code_utils/instrument_existing_tests.py:985-1000

The JAX sync code calls jax.block_until_ready(return_value) but:

  • This only works if return_value is a JAX array/pytree
  • Fails for functions returning primitives, None, or non-JAX objects
  • Could crash the test wrapper unexpectedly

Recommendation: Add type checking or error handling:

# Before calling block_until_ready:
if _codeflash_should_sync_jax and hasattr(return_value, '__jax_array__'):
    jax.block_until_ready(return_value)

Fix this →

3. Import Duplication Risk (Medium Priority)

Location: codeflash/code_utils/instrument_existing_tests.py:747-755

The code adds framework imports unconditionally, which can create duplicate imports:

# If test already has: import torch
# This adds another: import torch
new_imports.append(ast.Import(names=[ast.alias(name=framework_name)]))

Issue: While sort_imports() might deduplicate, this creates noise and could cause issues with import side effects.

Recommendation: Check existing imports before adding new ones.

Fix this →

4. Incomplete Framework Detection (Low Priority)

Location: codeflash/code_utils/instrument_existing_tests.py:668-703

detect_frameworks_from_code() only detects the first occurrence of each framework and doesn't handle:

  • Multiple different aliases for the same framework
  • Submodule imports that use different patterns (e.g., from torch.cuda import synchronize)

Example:

import torch as th
from torch import nn  # Won't update alias to 'torch'

This is minor since the first alias is usually sufficient, but could cause issues in edge cases.

5. Missing Tests (High Priority)

Location: PR description explicitly mentions "TODO: write more tests as it is crucial"

The test files added (test_gridmake2.py, test_gridmake2_torch.py) are for test data, not for testing the instrumentation logic itself.

Missing test coverage:

  • Device sync statement generation
  • Framework detection with various import patterns
  • Error handling in sync operations
  • Pre-computation correctness
  • Behavior with missing frameworks

Recommendation: Add unit tests for:

def test_detect_frameworks_with_aliases()
def test_device_sync_with_cuda_unavailable()
def test_jax_sync_with_non_array_return()
def test_framework_import_deduplication()

6. MPS Sync Availability Check Incomplete (Medium Priority)

Location: codeflash/code_utils/instrument_existing_tests.py:820-866

The MPS sync check is very thorough:

hasattr(torch.backends, 'mps') and 
torch.backends.mps.is_available() and 
hasattr(torch.mps, 'synchronize')

However, it doesn't check if there's actually an MPS device being used. The sync will be called even if the code only uses CPU tensors, which is wasteful.

Recommendation: Consider adding a check like:

torch.mps.current_device() is not None

7. TensorFlow Sync API Instability (Low Priority)

Location: codeflash/code_utils/instrument_existing_tests.py:888-907

Using tf.test.experimental.sync_devices() from the "experimental" namespace is risky:

  • Experimental APIs can change or be removed
  • Not guaranteed to be stable across TF versions
  • May not work with TF 1.x (though probably not an issue)

Note: This is documented in the support matrix as "Future" for some devices, which is good.


Code Quality Observations

  1. Good Documentation: Functions have clear docstrings explaining their purpose and return values.

  2. Consistent Naming: The _codeflash_should_sync_* variable naming is consistent and clear.

  3. Complex AST Generation: The AST construction code (lines 762-1032) is quite dense and hard to read. Consider extracting helper functions for common patterns like creating ast.If nodes or ast.Call nodes.

  4. Magic Numbers: Line number arguments (lineno=1, lineno=lineno + 10, etc.) are scattered throughout. These seem arbitrary and could be fragile.


Test File Review

The added test files (test_gridmake2.py, test_gridmake2_torch.py, discrete_riccati.py) appear to be example/test data for GPU profiling scenarios:

  • Good: Comprehensive test coverage of the _gridmake2 functions
  • Good: Tests both CPU and CUDA execution paths
  • Note: These files seem to belong in a test_data or fixtures directory rather than code_to_optimize/

Security Considerations

No major security issues, but note:

  • Dynamic import detection could potentially be bypassed by obfuscated imports
  • No validation of framework version compatibility

Performance Considerations

  1. Pre-computation Optimization: Excellent! This avoids repeated checks during profiling.
  2. AST Walking: detect_frameworks_from_code() uses ast.walk() which is O(n) on tree size. This is fine for typical test files.

Recommendations Summary

Must Fix (before merge):

  1. Add error handling to device sync calls
  2. Fix JAX block_until_ready() to handle non-array return values
  3. Add unit tests for instrumentation logic

Should Fix (before or soon after merge):
4. Prevent duplicate imports
5. Improve MPS availability checking
6. Add more comprehensive tests (as noted in PR description)

Nice to Have:
7. Refactor dense AST generation code for readability
8. Improve framework detection for complex import patterns


Verdict

This is a valuable feature that addresses a real gap in GPU profiling accuracy. The core design is sound with good optimization patterns. However, there are some edge cases and error handling gaps that should be addressed before merging, particularly around JAX return value handling and general error resilience.

The PR is close to being merge-ready but needs the critical fixes listed above, especially adding proper error handling to prevent instrumentation from breaking tests unexpectedly.


| Branch: gpu-sync-instrumentation

@aseembits93 aseembits93 merged commit 8a09ee9 into main Jan 13, 2026
23 of 25 checks passed
@aseembits93 aseembits93 deleted the gpu-sync-instrumentation branch January 13, 2026 00:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants