33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6- """AOT export for the RISC-V Phase 1.0 smoke test.
6+ """AOT export for the RISC-V smoke test.
77
8- Exports a trivial ``torch.add`` module to a BundledProgram (.bpte) that the
9- portable executor_runner can load on a riscv64 target and verify against the
10- embedded reference output, emitting ``Test_result: PASS`` on success.
8+ Exports a small model to a BundledProgram (.bpte) that the portable
9+ executor_runner can load on a riscv64 target and verify against the embedded
10+ reference output, emitting ``Test_result: PASS`` on success.
1111"""
1212
1313import argparse
14+ import logging
1415from pathlib import Path
1516
1617import torch
@@ -28,26 +29,186 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2829 return x + y
2930
3031
32+ def build_add ():
33+ model = AddModule ().eval ()
34+ example_inputs = (torch .ones (1 , 4 ), torch .full ((1 , 4 ), 2.0 ))
35+ test_inputs = [
36+ (torch .ones (1 , 4 ), torch .full ((1 , 4 ), 2.0 )),
37+ (torch .full ((1 , 4 ), 3.0 ), torch .full ((1 , 4 ), 4.0 )),
38+ ]
39+ return model , example_inputs , test_inputs , True
40+
41+
42+ def build_mv2 ():
43+ from torchvision .models import mobilenet_v2 , MobileNet_V2_Weights
44+
45+ model = mobilenet_v2 (weights = MobileNet_V2_Weights .DEFAULT ).eval ()
46+ torch .manual_seed (0 )
47+ example_inputs = (torch .randn (1 , 3 , 224 , 224 ),)
48+ test_inputs = [example_inputs ]
49+ return model , example_inputs , test_inputs , False
50+
51+
52+ def build_mobilebert ():
53+ from transformers import MobileBertConfig , MobileBertModel
54+
55+ config = MobileBertConfig (
56+ vocab_size = 1024 ,
57+ hidden_size = 128 ,
58+ embedding_size = 64 ,
59+ num_hidden_layers = 2 ,
60+ num_attention_heads = 2 ,
61+ intermediate_size = 128 ,
62+ intra_bottleneck_size = 32 ,
63+ )
64+
65+ class Wrapper (torch .nn .Module ):
66+ def __init__ (self ):
67+ super ().__init__ ()
68+ self .model = MobileBertModel (config ).eval ()
69+
70+ def forward (self , input_ids ):
71+ return self .model (input_ids ).last_hidden_state
72+
73+ model = Wrapper ().eval ()
74+ example_inputs = (torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ]]),)
75+ test_inputs = [example_inputs ]
76+ return model , example_inputs , test_inputs , False
77+
78+
79+ def build_llama2 ():
80+ # Use the executorch native Transformer (matches MODEL_NAME_TO_MODEL["llama2"]
81+ # in examples/models/__init__.py). Unlike HF LlamaModel, RoPE freqs are
82+ # precomputed buffers and just sliced at forward time, so no
83+ # torch.arange()/Long causal mask is built per forward — which is what
84+ # the PT2E XNNPACK quantizer trips over on HF Llama.
85+ from executorch .examples .models .llama .llama_transformer import construct_transformer
86+ from executorch .examples .models .llama .model_args import ModelArgs
87+
88+ seq_len = 8
89+ args = ModelArgs (
90+ dim = 128 ,
91+ n_layers = 2 ,
92+ n_heads = 4 ,
93+ n_kv_heads = 2 , # GQA: kv_heads < n_heads exercises the GQA path
94+ vocab_size = 1024 ,
95+ hidden_dim = 256 , # SwiGLU FFN: gate + up projections at this width
96+ max_seq_len = seq_len ,
97+ max_context_len = seq_len ,
98+ rope_theta = 10000.0 ,
99+ )
100+ torch .manual_seed (0 )
101+ model = construct_transformer (args ).eval ()
102+ example_inputs = (torch .tensor ([[1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ]], dtype = torch .long ),)
103+ test_inputs = [example_inputs ]
104+ return model , example_inputs , test_inputs , False
105+
106+
107+ def build_resnet18 ():
108+ from torchvision .models import resnet18 , ResNet18_Weights
109+
110+ model = resnet18 (weights = ResNet18_Weights .DEFAULT ).eval ()
111+ torch .manual_seed (0 )
112+ example_inputs = (torch .randn (1 , 3 , 224 , 224 ),)
113+ test_inputs = [example_inputs ]
114+ return model , example_inputs , test_inputs , False
115+
116+
117+ MODELS = {
118+ "add" : build_add ,
119+ "mv2" : build_mv2 ,
120+ "mobilebert" : build_mobilebert ,
121+ "llama2" : build_llama2 ,
122+ "resnet18" : build_resnet18 ,
123+ }
124+
125+
31126def main () -> None :
32127 parser = argparse .ArgumentParser (description = __doc__ )
128+ parser .add_argument (
129+ "--model" ,
130+ choices = sorted (MODELS ),
131+ default = "add" ,
132+ help = "Which model to export" ,
133+ )
33134 parser .add_argument (
34135 "--output" ,
35136 type = Path ,
36- default = Path ("add_riscv.bpte" ),
37- help = "Output .bpte path" ,
137+ default = None ,
138+ help = "Output .bpte path (default: <model>_riscv.bpte)" ,
139+ )
140+ parser .add_argument (
141+ "--xnnpack" ,
142+ action = "store_true" ,
143+ help = "Lower through the XNNPACK partitioner" ,
144+ )
145+ parser .add_argument (
146+ "--quantize" ,
147+ action = "store_true" ,
148+ help = "Produce an 8-bit quantized model" ,
149+ )
150+ parser .add_argument (
151+ "--verbose" ,
152+ action = "store_true" ,
153+ help = "Enable XNNPACK partitioner DEBUG logging and dump the lowered graph" ,
38154 )
39155 args = parser .parse_args ()
40156
41- model = AddModule (). eval ()
42- example_inputs = ( torch . ones ( 1 , 4 ), torch . full (( 1 , 4 ), 2.0 ) )
157+ if args . verbose :
158+ logging . basicConfig ( level = logging . DEBUG )
43159
44- exported = export (model , example_inputs )
45- et_program = to_edge_transform_and_lower (exported ).to_executorch ()
160+ if args .output is None :
161+ args .output = Path (f"{ args .model } _riscv.bpte" )
162+
163+ model , example_inputs , test_inputs , strict = MODELS [args .model ]()
164+
165+ if args .quantize :
166+ from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS , QuantType
167+ from executorch .examples .xnnpack .quantization .utils import quantize
168+
169+ if args .model not in MODEL_NAME_TO_OPTIONS :
170+ parser .error (f"No XNNPACK quantization recipe for model { args .model !r} " )
171+ quant_type = MODEL_NAME_TO_OPTIONS [args .model ].quantization
172+ if quant_type == QuantType .NONE :
173+ parser .error (f"Quantization recipe for { args .model !r} is NONE" )
174+ ep = export (model , example_inputs , strict = strict )
175+ model = quantize (ep .module (), example_inputs , quant_type )
176+
177+ exported = export (model , example_inputs , strict = strict )
178+ partitioners = []
179+ if args .xnnpack :
180+ from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
181+ XnnpackPartitioner ,
182+ )
183+
184+ partitioners .append (XnnpackPartitioner (verbose = args .verbose ))
185+
186+ compile_config = None
187+ if args .quantize :
188+ from executorch .exir import EdgeCompileConfig
189+
190+ compile_config = EdgeCompileConfig (_check_ir_validity = False )
191+
192+ edge = to_edge_transform_and_lower (
193+ exported , partitioner = partitioners , compile_config = compile_config
194+ )
195+ delegated = sum (
196+ 1
197+ for n in edge .exported_program ().graph .nodes
198+ if n .op == "call_function" and "call_delegate" in str (n .target )
199+ )
200+ print (
201+ f"[aot_riscv] model={ args .model } xnnpack={ args .xnnpack } "
202+ f"quantize={ args .quantize } delegated_nodes={ delegated } "
203+ )
204+
205+ if args .verbose :
206+ from executorch .exir .backend .utils import print_delegated_graph
207+
208+ print_delegated_graph (edge .exported_program ().graph_module )
209+
210+ et_program = edge .to_executorch ()
46211
47- test_inputs = [
48- (torch .ones (1 , 4 ), torch .full ((1 , 4 ), 2.0 )),
49- (torch .full ((1 , 4 ), 3.0 ), torch .full ((1 , 4 ), 4.0 )),
50- ]
51212 test_suite = MethodTestSuite (
52213 method_name = "forward" ,
53214 test_cases = [
0 commit comments