Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 35 additions & 48 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,35 +768,33 @@ def _create_device_sync_precompute_statements(used_frameworks: dict[str, str] |

precompute_statements: list[ast.stmt] = []

# PyTorch: pre-compute whether to sync CUDA or MPS
load_ctx = ast.Load()
store_ctx = ast.Store()

# PyTorch: pre-compute whether to sync CUDA or MPS
if "torch" in used_frameworks:
torch_alias = used_frameworks["torch"]
# _codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized()
torch_name = ast.Name(id=torch_alias, ctx=load_ctx)

# Helper to create torch.cuda attribute access
torch_cuda = ast.Attribute(value=torch_name, attr="cuda", ctx=load_ctx)

# _codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized()
precompute_statements.append(
ast.Assign(
targets=[ast.Name(id="_codeflash_should_sync_cuda", ctx=ast.Store())],
targets=[ast.Name(id="_codeflash_should_sync_cuda", ctx=store_ctx)],
value=ast.BoolOp(
op=ast.And(),
values=[
ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load()
),
attr="is_available",
ctx=ast.Load(),
),
func=ast.Attribute(value=torch_cuda, attr="is_available", ctx=load_ctx),
args=[],
keywords=[],
),
ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load()
),
attr="is_initialized",
ctx=ast.Load(),
),
func=ast.Attribute(value=torch_cuda, attr="is_initialized", ctx=load_ctx),
args=[],
keywords=[],
),
Expand All @@ -808,46 +806,37 @@ def _create_device_sync_precompute_statements(used_frameworks: dict[str, str] |
# _codeflash_should_sync_mps = (not _codeflash_should_sync_cuda and
# hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and
# hasattr(torch.mps, 'synchronize'))

# _codeflash_should_sync_mps = (not _codeflash_should_sync_cuda and
# hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and
# hasattr(torch.mps, 'synchronize'))
torch_backends = ast.Attribute(value=torch_name, attr="backends", ctx=load_ctx)
torch_mps_attr = ast.Attribute(value=torch_name, attr="mps", ctx=load_ctx)

precompute_statements.append(
ast.Assign(
targets=[ast.Name(id="_codeflash_should_sync_mps", ctx=ast.Store())],
targets=[ast.Name(id="_codeflash_should_sync_mps", ctx=store_ctx)],
value=ast.BoolOp(
op=ast.And(),
values=[
ast.UnaryOp(op=ast.Not(), operand=ast.Name(id="_codeflash_should_sync_cuda", ctx=ast.Load())),
ast.UnaryOp(op=ast.Not(), operand=ast.Name(id="_codeflash_should_sync_cuda", ctx=load_ctx)),
ast.Call(
func=ast.Name(id="hasattr", ctx=ast.Load()),
args=[
ast.Attribute(
value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="backends", ctx=ast.Load()
),
ast.Constant(value="mps"),
],
func=ast.Name(id="hasattr", ctx=load_ctx),
args=[torch_backends, ast.Constant(value="mps")],
keywords=[],
),
ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="backends", ctx=ast.Load()
),
attr="mps",
ctx=ast.Load(),
),
value=ast.Attribute(value=torch_backends, attr="mps", ctx=load_ctx),
attr="is_available",
ctx=ast.Load(),
ctx=load_ctx,
),
args=[],
keywords=[],
),
ast.Call(
func=ast.Name(id="hasattr", ctx=ast.Load()),
args=[
ast.Attribute(
value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="mps", ctx=ast.Load()
),
ast.Constant(value="synchronize"),
],
func=ast.Name(id="hasattr", ctx=load_ctx),
args=[torch_mps_attr, ast.Constant(value="synchronize")],
keywords=[],
),
],
Expand All @@ -862,10 +851,10 @@ def _create_device_sync_precompute_statements(used_frameworks: dict[str, str] |
# _codeflash_should_sync_jax = hasattr(jax, 'block_until_ready')
precompute_statements.append(
ast.Assign(
targets=[ast.Name(id="_codeflash_should_sync_jax", ctx=ast.Store())],
targets=[ast.Name(id="_codeflash_should_sync_jax", ctx=store_ctx)],
value=ast.Call(
func=ast.Name(id="hasattr", ctx=ast.Load()),
args=[ast.Name(id=jax_alias, ctx=ast.Load()), ast.Constant(value="block_until_ready")],
func=ast.Name(id="hasattr", ctx=load_ctx),
args=[ast.Name(id=jax_alias, ctx=load_ctx), ast.Constant(value="block_until_ready")],
keywords=[],
),
lineno=1,
Expand All @@ -878,16 +867,14 @@ def _create_device_sync_precompute_statements(used_frameworks: dict[str, str] |
# _codeflash_should_sync_tf = hasattr(tf.test.experimental, 'sync_devices')
precompute_statements.append(
ast.Assign(
targets=[ast.Name(id="_codeflash_should_sync_tf", ctx=ast.Store())],
targets=[ast.Name(id="_codeflash_should_sync_tf", ctx=store_ctx)],
value=ast.Call(
func=ast.Name(id="hasattr", ctx=ast.Load()),
func=ast.Name(id="hasattr", ctx=load_ctx),
args=[
ast.Attribute(
value=ast.Attribute(
value=ast.Name(id=tf_alias, ctx=ast.Load()), attr="test", ctx=ast.Load()
),
value=ast.Attribute(value=ast.Name(id=tf_alias, ctx=load_ctx), attr="test", ctx=load_ctx),
attr="experimental",
ctx=ast.Load(),
ctx=load_ctx,
),
ast.Constant(value="sync_devices"),
],
Expand Down
Loading