Skip to content

Commit 6745047

Browse files
authored
Add XNNPACK, MobileNetV2, MobileBERT, Llama, ResNet18 to RISC-V testing matrix (pytorch#19617)
### Summary This is the continuation of pytorch#19399 and pytorch#19521, to deliver on Phase 2 of pytorch#18991 ### Test plan This code is exclusively test code. Everything works out of the box, and CI will validate.
1 parent c0cbc74 commit 6745047

8 files changed

Lines changed: 330 additions & 22 deletions

File tree

.ci/scripts/test_riscv_qemu.sh

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# CI wrapper: install RISC-V cross-compile + qemu-user tooling, then run the
8-
# RISC-V Phase 1 smoke test (export, cross-compile, qemu-user execution) via
8+
# RISC-V smoke test (export, cross-compile, qemu-user execution) via
99
# examples/riscv/run.sh. The bundled-IO comparison and Test_result: PASS
1010
# check are done by run.sh.
1111

@@ -14,5 +14,43 @@ set -eu
1414
script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")")
1515
et_root_dir=$(realpath "${script_dir}/../..")
1616

17+
model="add"
18+
xnnpack=false
19+
quantize=false
20+
verbose=false
21+
22+
usage() {
23+
cat <<EOF
24+
Usage: $(basename "$0") [options]
25+
Options:
26+
--model=<NAME> Which model to export and run (default: add)
27+
--xnnpack Enable the XNNPACK backend (AOT partitioner + runtime)
28+
--quantize Produce an 8-bit quantized model
29+
-h, --help Show this help
30+
EOF
31+
}
32+
33+
for arg in "$@"; do
34+
case $arg in
35+
--model=*) model="${arg#*=}" ;;
36+
--xnnpack) xnnpack=true ;;
37+
--quantize) quantize=true ;;
38+
--verbose) verbose=true ;;
39+
-h|--help) usage; exit 0 ;;
40+
*) echo "Unknown option: $arg" >&2; usage; exit 1 ;;
41+
esac
42+
done
43+
44+
run_extra_args=()
45+
if ${xnnpack}; then
46+
run_extra_args+=(--xnnpack)
47+
fi
48+
if ${quantize}; then
49+
run_extra_args+=(--quantize)
50+
fi
51+
if ${verbose}; then
52+
run_extra_args+=(--verbose)
53+
fi
54+
1755
bash "${et_root_dir}/examples/riscv/setup.sh"
18-
bash "${et_root_dir}/examples/riscv/run.sh"
56+
bash "${et_root_dir}/examples/riscv/run.sh" --model="${model}" "${run_extra_args[@]}"

.github/workflows/_test_riscv.yml

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,36 @@ on:
1212
required: false
1313
type: number
1414
default: 30
15+
model:
16+
description: 'Which model to run. Possible values are: add, mv2 (mobilenetv2)'
17+
required: false
18+
type: string
19+
default: 'add'
20+
xnnpack:
21+
description: 'Whether to enable XNNPACK'
22+
required: false
23+
type: boolean
24+
default: false
25+
quantize:
26+
description: 'Produce an 8-bit quantized model'
27+
required: false
28+
type: boolean
29+
default: false
30+
gcc-version:
31+
description: 'The version of GCC to use'
32+
required: false
33+
type: number
34+
docker-image:
35+
description: 'The docker image to use for this job'
36+
required: false
37+
type: string
1538

1639
jobs:
1740
run:
1841
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
1942
with:
2043
runner: linux.2xlarge
21-
docker-image: ci-image:executorch-ubuntu-22.04-gcc11
44+
docker-image: ${{ inputs.docker-image || 'ci-image:executorch-ubuntu-22.04-gcc11' }}
2245
submodules: 'recursive'
2346
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
2447
timeout: ${{ inputs.timeout }}
@@ -29,4 +52,5 @@ jobs:
2952
source .ci/scripts/utils.sh
3053
install_executorch "--use-pt-pinned-commit"
3154
32-
bash .ci/scripts/test_riscv_qemu.sh
55+
export GCC_VERSION=${{ inputs.gcc-version }}
56+
bash .ci/scripts/test_riscv_qemu.sh --model="${{ inputs.model }}" ${{ inputs.xnnpack && '--xnnpack' || '' }} ${{ inputs.quantize && '--quantize' || '' }}

.github/workflows/riscv64.yml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,31 @@ jobs:
2525
test-riscv:
2626
name: test-riscv
2727
uses: ./.github/workflows/_test_riscv.yml
28+
strategy:
29+
fail-fast: false
30+
matrix:
31+
include:
32+
- { model: add, xnnpack: false, quantize: false }
33+
- { model: add, xnnpack: true, quantize: false }
34+
- { model: mv2, xnnpack: false, quantize: false }
35+
- { model: mv2, xnnpack: true, quantize: false }
36+
- { model: mv2, xnnpack: true, quantize: true }
37+
- { model: mobilebert, xnnpack: false, quantize: false }
38+
- { model: mobilebert, xnnpack: true, quantize: false }
39+
- { model: mobilebert, xnnpack: true, quantize: true }
40+
- { model: llama2, xnnpack: false, quantize: false }
41+
- { model: llama2, xnnpack: true, quantize: false }
42+
- { model: llama2, xnnpack: true, quantize: true }
43+
- { model: resnet18, xnnpack: false, quantize: false }
44+
- { model: resnet18, xnnpack: true, quantize: false }
45+
- { model: resnet18, xnnpack: true, quantize: true }
2846
permissions:
2947
id-token: write
3048
contents: read
49+
with:
50+
model: ${{ matrix.model }}
51+
xnnpack: ${{ matrix.xnnpack }}
52+
quantize: ${{ matrix.quantize }}
53+
# XNNPACK requires GCC 14+
54+
gcc-version: ${{ matrix.xnnpack && 14 || 11 }}
55+
docker-image: ${{ matrix.xnnpack && 'ci-image:executorch-ubuntu-24.04-gcc14' || 'ci-image:executorch-ubuntu-22.04-gcc11' }}

examples/riscv/aot_riscv.py

Lines changed: 175 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
"""AOT export for the RISC-V Phase 1.0 smoke test.
6+
"""AOT export for the RISC-V smoke test.
77
8-
Exports a trivial ``torch.add`` module to a BundledProgram (.bpte) that the
9-
portable executor_runner can load on a riscv64 target and verify against the
10-
embedded reference output, emitting ``Test_result: PASS`` on success.
8+
Exports a small model to a BundledProgram (.bpte) that the portable
9+
executor_runner can load on a riscv64 target and verify against the embedded
10+
reference output, emitting ``Test_result: PASS`` on success.
1111
"""
1212

1313
import argparse
14+
import logging
1415
from pathlib import Path
1516

1617
import torch
@@ -28,26 +29,186 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2829
return x + y
2930

3031

32+
def build_add():
33+
model = AddModule().eval()
34+
example_inputs = (torch.ones(1, 4), torch.full((1, 4), 2.0))
35+
test_inputs = [
36+
(torch.ones(1, 4), torch.full((1, 4), 2.0)),
37+
(torch.full((1, 4), 3.0), torch.full((1, 4), 4.0)),
38+
]
39+
return model, example_inputs, test_inputs, True
40+
41+
42+
def build_mv2():
43+
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
44+
45+
model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
46+
torch.manual_seed(0)
47+
example_inputs = (torch.randn(1, 3, 224, 224),)
48+
test_inputs = [example_inputs]
49+
return model, example_inputs, test_inputs, False
50+
51+
52+
def build_mobilebert():
53+
from transformers import MobileBertConfig, MobileBertModel
54+
55+
config = MobileBertConfig(
56+
vocab_size=1024,
57+
hidden_size=128,
58+
embedding_size=64,
59+
num_hidden_layers=2,
60+
num_attention_heads=2,
61+
intermediate_size=128,
62+
intra_bottleneck_size=32,
63+
)
64+
65+
class Wrapper(torch.nn.Module):
66+
def __init__(self):
67+
super().__init__()
68+
self.model = MobileBertModel(config).eval()
69+
70+
def forward(self, input_ids):
71+
return self.model(input_ids).last_hidden_state
72+
73+
model = Wrapper().eval()
74+
example_inputs = (torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]),)
75+
test_inputs = [example_inputs]
76+
return model, example_inputs, test_inputs, False
77+
78+
79+
def build_llama2():
80+
# Use the executorch native Transformer (matches MODEL_NAME_TO_MODEL["llama2"]
81+
# in examples/models/__init__.py). Unlike HF LlamaModel, RoPE freqs are
82+
# precomputed buffers and just sliced at forward time, so no
83+
# torch.arange()/Long causal mask is built per forward — which is what
84+
# the PT2E XNNPACK quantizer trips over on HF Llama.
85+
from executorch.examples.models.llama.llama_transformer import construct_transformer
86+
from executorch.examples.models.llama.model_args import ModelArgs
87+
88+
seq_len = 8
89+
args = ModelArgs(
90+
dim=128,
91+
n_layers=2,
92+
n_heads=4,
93+
n_kv_heads=2, # GQA: kv_heads < n_heads exercises the GQA path
94+
vocab_size=1024,
95+
hidden_dim=256, # SwiGLU FFN: gate + up projections at this width
96+
max_seq_len=seq_len,
97+
max_context_len=seq_len,
98+
rope_theta=10000.0,
99+
)
100+
torch.manual_seed(0)
101+
model = construct_transformer(args).eval()
102+
example_inputs = (torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long),)
103+
test_inputs = [example_inputs]
104+
return model, example_inputs, test_inputs, False
105+
106+
107+
def build_resnet18():
108+
from torchvision.models import resnet18, ResNet18_Weights
109+
110+
model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()
111+
torch.manual_seed(0)
112+
example_inputs = (torch.randn(1, 3, 224, 224),)
113+
test_inputs = [example_inputs]
114+
return model, example_inputs, test_inputs, False
115+
116+
117+
MODELS = {
118+
"add": build_add,
119+
"mv2": build_mv2,
120+
"mobilebert": build_mobilebert,
121+
"llama2": build_llama2,
122+
"resnet18": build_resnet18,
123+
}
124+
125+
31126
def main() -> None:
32127
parser = argparse.ArgumentParser(description=__doc__)
128+
parser.add_argument(
129+
"--model",
130+
choices=sorted(MODELS),
131+
default="add",
132+
help="Which model to export",
133+
)
33134
parser.add_argument(
34135
"--output",
35136
type=Path,
36-
default=Path("add_riscv.bpte"),
37-
help="Output .bpte path",
137+
default=None,
138+
help="Output .bpte path (default: <model>_riscv.bpte)",
139+
)
140+
parser.add_argument(
141+
"--xnnpack",
142+
action="store_true",
143+
help="Lower through the XNNPACK partitioner",
144+
)
145+
parser.add_argument(
146+
"--quantize",
147+
action="store_true",
148+
help="Produce an 8-bit quantized model",
149+
)
150+
parser.add_argument(
151+
"--verbose",
152+
action="store_true",
153+
help="Enable XNNPACK partitioner DEBUG logging and dump the lowered graph",
38154
)
39155
args = parser.parse_args()
40156

41-
model = AddModule().eval()
42-
example_inputs = (torch.ones(1, 4), torch.full((1, 4), 2.0))
157+
if args.verbose:
158+
logging.basicConfig(level=logging.DEBUG)
43159

44-
exported = export(model, example_inputs)
45-
et_program = to_edge_transform_and_lower(exported).to_executorch()
160+
if args.output is None:
161+
args.output = Path(f"{args.model}_riscv.bpte")
162+
163+
model, example_inputs, test_inputs, strict = MODELS[args.model]()
164+
165+
if args.quantize:
166+
from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType
167+
from executorch.examples.xnnpack.quantization.utils import quantize
168+
169+
if args.model not in MODEL_NAME_TO_OPTIONS:
170+
parser.error(f"No XNNPACK quantization recipe for model {args.model!r}")
171+
quant_type = MODEL_NAME_TO_OPTIONS[args.model].quantization
172+
if quant_type == QuantType.NONE:
173+
parser.error(f"Quantization recipe for {args.model!r} is NONE")
174+
ep = export(model, example_inputs, strict=strict)
175+
model = quantize(ep.module(), example_inputs, quant_type)
176+
177+
exported = export(model, example_inputs, strict=strict)
178+
partitioners = []
179+
if args.xnnpack:
180+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
181+
XnnpackPartitioner,
182+
)
183+
184+
partitioners.append(XnnpackPartitioner(verbose=args.verbose))
185+
186+
compile_config = None
187+
if args.quantize:
188+
from executorch.exir import EdgeCompileConfig
189+
190+
compile_config = EdgeCompileConfig(_check_ir_validity=False)
191+
192+
edge = to_edge_transform_and_lower(
193+
exported, partitioner=partitioners, compile_config=compile_config
194+
)
195+
delegated = sum(
196+
1
197+
for n in edge.exported_program().graph.nodes
198+
if n.op == "call_function" and "call_delegate" in str(n.target)
199+
)
200+
print(
201+
f"[aot_riscv] model={args.model} xnnpack={args.xnnpack} "
202+
f"quantize={args.quantize} delegated_nodes={delegated}"
203+
)
204+
205+
if args.verbose:
206+
from executorch.exir.backend.utils import print_delegated_graph
207+
208+
print_delegated_graph(edge.exported_program().graph_module)
209+
210+
et_program = edge.to_executorch()
46211

47-
test_inputs = [
48-
(torch.ones(1, 4), torch.full((1, 4), 2.0)),
49-
(torch.full((1, 4), 3.0), torch.full((1, 4), 4.0)),
50-
]
51212
test_suite = MethodTestSuite(
52213
method_name="forward",
53214
test_cases=[

examples/riscv/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torchvision
2+
transformers

0 commit comments

Comments
 (0)