Skip to content

Commit d66a37c

Browse files
1 parent 10c8958 commit d66a37c

17 files changed

Lines changed: 159 additions & 51 deletions

File tree

.ci/docker/build.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ esac
9797
TORCH_VERSION=$(cat ci_commit_pins/pytorch.txt)
9898
BUILD_DOCS=1
9999

100+
if [[ "${GCC_VERSION:-}" == "11" && -z "${SKIP_PYTORCH:-}" ]]; then
101+
PYTORCH_BUILD_MAX_JOBS=6
102+
fi
103+
100104
# Copy requirements-lintrunner.txt from root to here
101105
cp ../../requirements-lintrunner.txt ./
102106

@@ -109,6 +113,7 @@ docker build \
109113
--build-arg "PYTHON_VERSION=${PYTHON_VERSION}" \
110114
--build-arg "MINICONDA_VERSION=${MINICONDA_VERSION}" \
111115
--build-arg "TORCH_VERSION=${TORCH_VERSION}" \
116+
--build-arg "PYTORCH_BUILD_MAX_JOBS=${PYTORCH_BUILD_MAX_JOBS:-}" \
112117
--build-arg "BUCK2_VERSION=${BUCK2_VERSION}" \
113118
--build-arg "LINTRUNNER=${LINTRUNNER:-}" \
114119
--build-arg "BUILD_DOCS=${BUILD_DOCS}" \
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
release/2.11
1+
release/2.12

.ci/docker/common/install_cache.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ init_sccache() {
7676
# This is the remote cache bucket
7777
export SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2
7878
export SCCACHE_S3_KEY_PREFIX=executorch
79+
export SCCACHE_REGION=us-east-1
80+
export AWS_REGION=us-east-1
81+
export AWS_DEFAULT_REGION=us-east-1
7982
export SCCACHE_IDLE_TIMEOUT=0
8083
export SCCACHE_ERROR_LOG=/tmp/sccache_error.log
8184
export RUST_LOG=sccache::server=error

.ci/docker/common/install_pytorch.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,20 @@ install_pytorch_and_domains() {
2727
chown -R ci-user .
2828

2929
export _GLIBCXX_USE_CXX11_ABI=1
30+
if [[ "$(uname -m)" == "aarch64" ]]; then
31+
export BUILD_IGNORE_SVE_UNAVAILABLE=1
32+
fi
33+
if [[ -n "${PYTORCH_BUILD_MAX_JOBS:-}" ]]; then
34+
export MAX_JOBS="${PYTORCH_BUILD_MAX_JOBS}"
35+
fi
3036
# Then build and install PyTorch
3137
conda_run python setup.py bdist_wheel
3238
pip_install "$(echo dist/*.whl)"
3339

3440
# Grab the pinned audio and vision commits from PyTorch
3541
TORCHAUDIO_VERSION=release/2.11
3642
export TORCHAUDIO_VERSION
37-
TORCHVISION_VERSION=release/0.26
43+
TORCHVISION_VERSION=release/0.27
3844
export TORCHVISION_VERSION
3945

4046
install_domains

.ci/docker/ubuntu/Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,12 @@ RUN bash ./install_cache.sh && rm install_cache.sh utils.sh
6262
ENV SCCACHE_BUCKET ossci-compiler-cache-circleci-v2
6363
ENV SCCACHE_S3_KEY_PREFIX executorch
6464
ENV SCCACHE_REGION us-east-1
65+
ENV AWS_REGION us-east-1
66+
ENV AWS_DEFAULT_REGION us-east-1
6567

6668
ARG TORCH_VERSION
6769
ARG SKIP_PYTORCH
70+
ARG PYTORCH_BUILD_MAX_JOBS
6871
COPY ./common/install_pytorch.sh install_pytorch.sh
6972
COPY ./common/utils.sh utils.sh
7073
RUN if [ -z "${SKIP_PYTORCH}" ]; then bash ./install_pytorch.sh; fi && rm install_pytorch.sh utils.sh

.ci/scripts/utils.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ install_pytorch_and_domains() {
107107
local torch_release=$(cat version.txt)
108108
# Download key must match the upload key below (basename of dist/*.whl,
109109
# which always carries setup.py's resolved +gitHASH). Branch-ref pins
110-
# like `release/2.11` would otherwise produce `+gitrelease` here and
110+
# like `release/2.12` would otherwise produce `+gitrelease` here and
111111
# never hit the cache.
112112
local torch_short_hash=$(git rev-parse --short=7 HEAD)
113113
local torch_wheel_path="cached_artifacts/pytorch/executorch/pytorch_wheels/${system_name}/${python_version}"
@@ -132,6 +132,9 @@ install_pytorch_and_domains() {
132132
# (e.g. executorch's requirements-ci.txt).
133133
pip install -r requirements-build.txt
134134
git submodule update --init --recursive
135+
if [[ "$(uname -m)" == "aarch64" ]]; then
136+
export BUILD_IGNORE_SVE_UNAVAILABLE=1
137+
fi
135138
USE_DISTRIBUTED=1 python setup.py bdist_wheel
136139
pip install "$(echo dist/*.whl)"
137140

@@ -175,7 +178,7 @@ install_pytorch_and_domains() {
175178
# Grab the pinned audio and vision commits from PyTorch
176179
TORCHAUDIO_VERSION=release/2.11
177180
export TORCHAUDIO_VERSION
178-
TORCHVISION_VERSION=release/0.26
181+
TORCHVISION_VERSION=release/0.27
179182
export TORCHVISION_VERSION
180183

181184
install_domains

.github/workflows/mlx.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ jobs:
120120
--prompt-len 4 \
121121
--max-new-tokens 5 2>&1)
122122
echo "$OUTPUT"
123-
if echo "$OUTPUT" | grep -q "Generated token ids: \[167, 167, 81, 167, 81\]"; then
123+
if echo "$OUTPUT" | grep -q "Generated token ids: \[167, 94, 253, 88, 227\]"; then
124124
echo "Success: Qwen 3.5 MoE MLX export + inference completed with expected output"
125125
else
126-
echo "Failed: unexpected output (expected [167, 167, 81, 167, 81])"
126+
echo "Failed: unexpected output (expected [167, 94, 253, 88, 227])"
127127
exit 1
128128
fi
129129
echo "::endgroup::"

backends/arm/_passes/arm_pass.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
from abc import abstractmethod
1010
from typing import Any, List, Optional, Set, Type
1111

12+
import torch
1213
from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY
1314
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1415
from executorch.exir.dialects._ops import ops as exir_ops
1516
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
1617
from torch.fx import GraphModule
1718
from torch.fx.passes.infra.pass_base import PassResult
19+
from torch.utils import _pytree as pytree
1820

1921

2022
class ArmPass(ExportPass):
@@ -79,6 +81,13 @@ def get_name(pass_) -> str:
7981
)
8082

8183
def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
84+
if (
85+
op == exir_ops.edge.aten.bmm.default
86+
and isinstance(meta, NodeMetadata)
87+
and len(meta.data.get("input_qparams", {})) > 0
88+
):
89+
return self._call_quantized_bmm_without_fake_kernel(op, args, kwargs, meta)
90+
8291
if not updated:
8392
return super().call_operator(op, args, kwargs, meta)
8493

@@ -91,6 +100,35 @@ def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False)
91100
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
92101
return super().call_operator(op, args, kwargs, NodeMetadata(new_meta))
93102

103+
def _call_quantized_bmm_without_fake_kernel(
104+
self,
105+
op,
106+
args: tuple[ProxyValue, ...],
107+
kwargs: dict[str, Any],
108+
meta: NodeMetadata,
109+
) -> ProxyValue:
110+
old_val = meta.data["val"]
111+
output_qparams = meta.data.get("output_qparams", {})
112+
dtype = (
113+
next(iter(output_qparams.values())).dtype
114+
if len(output_qparams) > 0
115+
else old_val.dtype
116+
)
117+
res_data = torch.empty_like(old_val, dtype=dtype)
118+
119+
args_proxy, kwargs_proxy = pytree.tree_map_only(
120+
ProxyValue, lambda x: x.proxy, (args, kwargs)
121+
)
122+
res_proxy = self.tracer.create_proxy(
123+
"call_function",
124+
op,
125+
args_proxy,
126+
kwargs_proxy,
127+
)
128+
res_proxy.node.meta.update(meta.data)
129+
self.tracer.set_metadata(res_proxy.node, res_data)
130+
return ProxyValue(res_data, res_proxy)
131+
94132
def call_submodule(
95133
self, graph_module: GraphModule, inputs: tuple[Any, ...]
96134
) -> PassResult:

backends/nxp/tests/generic_tests/test_per_channel_conversion.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,19 @@ def test_per_channel_convolution(self, _, use_qat: bool):
169169
atol=1.0,
170170
)
171171

172-
nodes = list(exported_program.graph.nodes)
173-
172+
conv_nodes = [
173+
node
174+
for node in exported_program.graph.nodes
175+
if node.target == exir_ops.edge.aten.convolution.default
176+
]
177+
assert len(conv_nodes) == 1
178+
179+
conv_node = conv_nodes[0]
174180
assert (
175-
nodes[8].target
181+
conv_node.args[1].target
176182
== exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
177183
)
178184
assert (
179-
nodes[9].target
185+
conv_node.args[2].target
180186
== exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
181187
)
182-
assert nodes[10].target == exir_ops.edge.aten.convolution.default

examples/models/llama3_2_vision/text_decoder/model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,19 +181,19 @@ def get_example_kwarg_inputs(self):
181181
return None
182182

183183
def get_dynamic_shapes(self):
184-
batch_size = 1
184+
static = torch.export.Dim.STATIC
185185
dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
186186
# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
187187
if self.use_kv_cache:
188188
dynamic_shapes = {
189-
"tokens": {0: batch_size, 1: dim_seq_len},
190-
"encoder_input": None,
191-
"encoder_mask": {0: 1, 1: dim_seq_len, 2: None},
192-
"mask": {0: batch_size, 1: dim_seq_len, 2: None},
193-
"input_pos": {0: batch_size, 1: dim_seq_len},
189+
"tokens": {0: static, 1: dim_seq_len},
190+
"encoder_input": {0: static, 1: static, 2: static},
191+
"encoder_mask": {0: static, 1: dim_seq_len, 2: static},
192+
"mask": {0: static, 1: dim_seq_len, 2: static},
193+
"input_pos": {0: static, 1: dim_seq_len},
194194
}
195195
else:
196196
dynamic_shapes = {
197-
"tokens": {0: batch_size, 1: dim_seq_len},
197+
"tokens": {0: static, 1: dim_seq_len},
198198
}
199199
return dynamic_shapes

0 commit comments

Comments
 (0)