Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions backends/arm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,17 @@ List of model specific and optional passes:
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py

- ToDevicePass
- This is a utility for moving an already-quantized or already-decomposed GraphModule to another device.
- it is intended to be used immediately before rerunning / retracing / torch.export.export(...)
- Functionalities:
- Calls `.to(device)` on the GraphModule and rewrites explicit `device=` kwargs on `call_function` nodes to a user-specified device.
- Useful when manually moving an already-quantized or already-decomposed graph module to another device for validation, since some constant-producing nodes may still carry an export-time device kwarg.
- Example usage:
- `from executorch.backends.arm._passes import ToDevicePass`
- `graph_module = ToDevicePass("cpu")(graph_module)`
- backends/arm/test/misc/test_post_quant_device_switch.py

## Help & Improvements

If you have problems or questions, or have suggestions for ways to improve the Arm backend, please reach out
Expand Down
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
from .rewrite_upsample import RewriteUpsamplePass # noqa
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
from .size_adjust_input_pass import SizeAdjustInputPass # noqa
from .to_device_pass import ToDevicePass # noqa
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
Expand Down
49 changes: 49 additions & 0 deletions backends/arm/_passes/to_device_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.exir.pass_base import ExportPass


class ToDevicePass(ArmPass):
"""Call .to(device) and rewrite explicit `device=` kwargs on call_function
nodes to given device.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def __init__(self, device: str | torch.device, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.device = torch.device(device)

def call( # type: ignore[override]
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
graph_module = graph_module.to(self.device)
modified = False

for node in graph_module.graph.nodes:
if node.op != "call_function" or "device" not in node.kwargs:
continue

current_device = node.kwargs["device"]
if current_device == self.device:
continue

node.update_kwarg("device", self.device)
modified = True

if modified:
graph_module.recompile()

return graph_module

def __call__( # type: ignore[override]
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
return self.call(graph_module)
232 changes: 232 additions & 0 deletions backends/arm/test/misc/test_post_quant_device_switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
from dataclasses import dataclass
from typing import Callable

import pytest
import torch
import torch.nn.functional as F
from executorch.backends.arm._passes import ToDevicePass
from executorch.backends.arm.quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.tosa import TosaSpecification
from torch._subclasses.fake_tensor import FakeTensor
from torchao.quantization.pt2e import move_exported_model_to_eval
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e


class AddAlpha(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y, alpha=2.0)


class SubAlpha(torch.nn.Module):
def forward(self, x, y):
return torch.sub(x, y, alpha=2.0)


class SliceScatter(torch.nn.Module):
def forward(self, x, src):
return torch.slice_scatter(x, src, dim=1, start=0, end=4, step=2)


class MeanDim(torch.nn.Module):
def forward(self, x):
return torch.mean(x, dim=(1,), keepdim=True)


class MeanDefault(torch.nn.Module):
def forward(self, x):
return torch.mean(x)


class VarCorrection(torch.nn.Module):
def forward(self, x):
return torch.var(x, dim=(2, 3), correction=1, keepdim=True)


class VarDim(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.var.dim(x, [2, 3], 1, True)


class DivTensorMode(torch.nn.Module):
def forward(self, x, y):
return torch.div(x, y, rounding_mode="trunc")


class LeakyRelu(torch.nn.Module):
def forward(self, x):
return F.leaky_relu(x, negative_slope=0.2)


class AvgPool2d(torch.nn.Module):
def forward(self, x):
return F.avg_pool2d(x, kernel_size=2, stride=1, padding=1)


class LayerNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer_norm = torch.nn.LayerNorm(4, elementwise_affine=False)

def forward(self, x):
return self.layer_norm(x)


class GroupNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.group_norm = torch.nn.GroupNorm(2, 4, affine=False)

def forward(self, x):
return self.group_norm(x)


@dataclass(frozen=True)
class MetaRetraceCase:
name: str
module_factory: Callable[[], torch.nn.Module]
inputs_factory: Callable[[], tuple[torch.Tensor, ...]]
aten_op: str


_TEST_CASES = [
MetaRetraceCase(
"add_alpha",
AddAlpha,
lambda: (torch.randn(2, 3), torch.randn(2, 3)),
"aten.add.Tensor",
),
MetaRetraceCase(
"sub_alpha",
SubAlpha,
lambda: (torch.randn(2, 3), torch.randn(2, 3)),
"aten.sub.Tensor",
),
MetaRetraceCase(
"slice_scatter",
SliceScatter,
lambda: (torch.randn(2, 4), torch.randn(2, 2)),
"aten.slice_scatter.default",
),
MetaRetraceCase(
"mean_dim",
MeanDim,
lambda: (torch.randn(2, 3, 4),),
"aten.mean.dim",
),
MetaRetraceCase(
"mean_default",
MeanDefault,
lambda: (torch.randn(2, 3, 4),),
"aten.mean.default",
),
MetaRetraceCase(
"var_correction",
VarCorrection,
lambda: (torch.randn(2, 3, 4, 4),),
"aten.var.correction",
),
MetaRetraceCase(
"var_dim",
VarDim,
lambda: (torch.randn(2, 3, 4, 4),),
"aten.var.dim",
),
MetaRetraceCase(
"div_tensor_mode",
DivTensorMode,
lambda: (torch.randn(2, 3), torch.randn(2, 3) + 1.0),
"aten.div.Tensor_mode",
),
MetaRetraceCase(
"leaky_relu",
LeakyRelu,
lambda: (torch.randn(2, 3),),
"aten.leaky_relu.default",
),
MetaRetraceCase(
"avg_pool2d",
AvgPool2d,
lambda: (torch.randn(1, 3, 4, 4),),
"aten.avg_pool2d.default",
),
MetaRetraceCase(
"layer_norm",
LayerNorm,
lambda: (torch.randn(2, 3, 4),),
"aten.layer_norm.default",
),
MetaRetraceCase(
"group_norm",
GroupNorm,
lambda: (torch.randn(2, 4, 3, 3),),
"aten.group_norm.default",
),
]


def _make_quantizer() -> TOSAQuantizer:
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
quantizer.set_global(get_symmetric_quantization_config(is_per_channel=False))
return quantizer


def _iter_fake_tensors(meta_val):
if isinstance(meta_val, FakeTensor):
yield meta_val
return

if isinstance(meta_val, (list, tuple)):
for item in meta_val:
yield from _iter_fake_tensors(item)


def _to_meta_inputs(
example_inputs: tuple[torch.Tensor, ...],
) -> tuple[torch.Tensor, ...]:
return tuple(inp.to(device="meta") for inp in example_inputs)


@pytest.mark.parametrize("case", _TEST_CASES, ids=[case.name for case in _TEST_CASES])
def test_post_quant_device_switch_no_target(case: MetaRetraceCase) -> None:
"""This test tests that moving a model to another device after quantiation
works.
"""
module = case.module_factory().train()
example_inputs = case.inputs_factory()

# Quantize module
exported = torch.export.export(module, example_inputs, strict=True)
prepared = prepare_qat_pt2e(copy.deepcopy(exported.graph_module), _make_quantizer())
prepared(*example_inputs)
prepared = move_exported_model_to_eval(prepared)
quantized_module = convert_pt2e(prepared)

# Move and test running the model with other device.
meta_inputs = _to_meta_inputs(example_inputs)
meta_module = ToDevicePass("meta")(quantized_module)
meta_module(*meta_inputs)

# Retrace module using meta device to check all fake tensors are moved.
meta_module = torch.export.export(meta_module, meta_inputs, strict=True)

# Validate transformation.
fake_tensor_devices = [
(str(fake_tensor.device), str(node))
for node in meta_module.graph.nodes
for fake_tensor in _iter_fake_tensors(node.meta.get("val"))
]

assert fake_tensor_devices, "Expected traced graph to contain FakeTensor metadata"
assert all(device == "meta" for device, _ in fake_tensor_devices), (
"Expected all traced FakeTensors to use the meta device, got "
f"{fake_tensor_devices}"
)
Loading