Skip to content

Commit 5fc929f

Browse files
Qualcomm AI Engine Direct - Refactor llama runner for dynamic IO dtypes (pytorch#19146)
### Summary To enable GPU backend support in the Llama runner, refactoring is required because the dtypes of kv_cache, attention_mask, and logits are currently hardcoded, preventing floating‑point models from running. This PR focuses on removing the hardcode dtype for them. #### Key changes - Remove template parameter <typename T> from KVManager, LhdTokenGenerator, MultimodalPromptProcessor, and related runner classes - Detect kv_cache and attention_mask dtypes dynamically from MethodMeta at construction time instead of compile-time bitwidth detection - Switch to std::byte* pointer arithmetic with getDtypeSize() for all buffer offsets; add fill_mask() helper for multi-dtype attention mask filling - Update spec_prop pass for custom llama op for sharding case greater than 1 ### Test plan ``` python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_llama_stories_110m --model SM8650 --build_folder /local/mnt/workspace/chenweng/executorch/executorch/build-android --device acfa9311 --executorch_root . --artifact_dir ./stories_110m_pte_size --llama_artifacts . --use_fp16 ``` <img width="1977" height="468" alt="image" src="https://github.com/user-attachments/assets/8bf3bffa-9b9f-4655-9cbc-b20127c2468a" /> cc @cccclai @cbilgin @abhinaykukkadapu
1 parent b581615 commit 5fc929f

35 files changed

Lines changed: 820 additions & 706 deletions

backends/qualcomm/_passes/build_quant_io.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77
from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO
8-
from executorch.exir.delegate import executorch_call_delegate
98

10-
from executorch.exir.pass_base import ExportPass, ProxyValue
9+
from executorch.exir.delegate import executorch_call_delegate
10+
from executorch.exir.pass_base import ExportPass, PassResult
1111
from executorch.exir.tensor import TensorSpec
12-
from torch.utils import _pytree as pytree
1312

1413

1514
class BuildQuantIo(ExportPass):
@@ -28,22 +27,27 @@ def _make_spec(self, x):
2827
else:
2928
return None
3029

31-
def placeholder(self, name: str, arg, meta):
32-
if quantized_dtype := meta.data.get(QCOM_QUANTIZED_IO, None):
33-
arg = arg.to(dtype=quantized_dtype)
34-
meta["spec"] = self._make_spec(arg)
35-
return super().placeholder(name, arg, meta)
36-
37-
def call_getitem(self, value, key: int, meta):
38-
meta["spec"] = value.node.meta["spec"][key]
39-
return super().call_getitem(value, key, meta)
40-
41-
def call_delegate(self, lowered_module, args, kwargs, meta):
42-
args_data, _ = pytree.tree_map_only(
43-
ProxyValue, lambda x: x.data, (args, kwargs)
44-
)
45-
meta["spec"] = pytree.tree_map(
46-
self._make_spec,
47-
executorch_call_delegate(lowered_module, *args_data),
48-
)
49-
return super().call_delegate(lowered_module, args, kwargs, meta)
30+
def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
31+
# Forcedly update delegate node's meta['spec'] to get correct output
32+
# tensor size in runtime
33+
call_delegates = [
34+
node
35+
for node in graph_module.graph.nodes
36+
if node.op == "call_function" and node.target == executorch_call_delegate
37+
]
38+
for n in graph_module.graph.nodes:
39+
if QCOM_QUANTIZED_IO in n.meta:
40+
n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO])
41+
n.meta["spec"] = self._make_spec(n.meta["val"])
42+
43+
for call_delegate in call_delegates:
44+
spec = []
45+
for user in list(call_delegate.users):
46+
spec.append(self._make_spec(user.meta["val"]))
47+
call_delegate.meta["spec"] = tuple(spec)
48+
49+
def call(self, graph_module: torch.fx.GraphModule):
50+
self._build(graph_module)
51+
graph_module.graph.eliminate_dead_code()
52+
graph_module.recompile()
53+
return PassResult(graph_module, True)

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7730,8 +7730,11 @@ def test_llama_stories_110m(self):
77307730
"--max_context_len",
77317731
"128",
77327732
]
7733+
if self.use_fp16:
7734+
cmds.append("--use_fp16")
77337735
self.add_default_cmds(cmds)
7734-
7736+
print(" ".join(cmds))
7737+
exit(0)
77357738
golden_start_with = "Once upon a time,"
77367739
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
77377740
with Listener((self.ip, self.port)) as listener:
@@ -7750,7 +7753,10 @@ def test_llama_stories_110m(self):
77507753
# x86 does not allow weight sharing, so we don't check pte size
77517754
if not self.enable_x86_64:
77527755
pte_size = msg["pte_size"]
7753-
self.assertLessEqual(pte_size, 135_000_000) # 135MB
7756+
if self.use_fp16:
7757+
self.assertLessEqual(pte_size, 275_000_000) # 275MB
7758+
else:
7759+
self.assertLessEqual(pte_size, 135_000_000) # 135MB
77547760
if not self.compile_only and not self.enable_x86_64:
77557761
self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai
77567762

@@ -10087,6 +10093,13 @@ def setup_environment():
1008710093
choices=["wikitext_ppl", "hellaswag_acc_norm", "sqnr"],
1008810094
type=str,
1008910095
)
10096+
parser.add_argument(
10097+
"-F",
10098+
"--use_fp16",
10099+
help="If specified, will run in fp16 precision and discard ptq setting",
10100+
action="store_true",
10101+
default=False,
10102+
)
1009010103

1009110104
args, ns_args = parser.parse_known_args(namespace=unittest)
1009210105
TestQNN.host = args.host
@@ -10114,6 +10127,7 @@ def setup_environment():
1011410127
TestQNN.backend = args.backend
1011510128
TestQNN.static_llm_eval_method = args.static_llm_eval_method
1011610129
TestQNN.direct_build_folder = args.direct_build_folder
10130+
TestQNN.use_fp16 = args.use_fp16
1011710131

1011810132
return sys.argv[:1] + ns_args
1011910133

backends/qualcomm/tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ class TestQNN(unittest.TestCase):
221221
static_llm_eval_method = ""
222222
direct_build_folder: str = ""
223223
dsp_heap_profile_filename = "htp_heap_usage.txt"
224+
use_fp16 = False
224225

225226
@classmethod
226227
def setUpClass(cls):
Binary file not shown.

examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _init_runner_base_cmd(self):
133133
base_cmd = " ".join(
134134
[
135135
f"export LD_LIBRARY_PATH={self.qnn_sdk}/lib/x86_64-linux-clang/:{args.build_folder}/lib &&",
136-
f"./{args.build_folder}/examples/qualcomm/oss_scripts/llama/{self.runner}",
136+
f"{args.build_folder}/examples/qualcomm/oss_scripts/llama/{self.runner}",
137137
f"--decoder_model_version {DECODER_MODEL_VERSION[args.decoder_model]}",
138138
f"--tokenizer_path {self.runtime_tokenizer_path}",
139139
f"--output_path {self.device_output_response_path}",

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,9 @@ def retrieve_info_from_pte(pte_path: str) -> dict:
317317
pte_max_context_len = pte_max_seq_len
318318

319319
# FP has no scale/zero_point, use following values, which is equivalent to not performing dequantize.
320-
if kv_io_bit_width == 32:
320+
if kv_io_bit_width == 32 or (logits_scale is None or logits_zero_point is None):
321321
logits_scale = 1
322322
logits_zero_point = 0
323-
elif logits_scale is None or logits_zero_point is None:
324-
raise RuntimeError(
325-
"Unable to find scale/offset. The .pte file might be deprecated. Please generate a new .pte file"
326-
)
327323
assert output_vocab_size is not None, "Couldn't find the vocab size"
328324
assert pte_max_seq_len is not None, "Couldn't find the max_seq_len from pte"
329325
meta_info = {

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222

2323
from executorch.backends.qualcomm.utils.utils import (
24+
generate_gpu_compiler_spec,
2425
generate_htp_compiler_spec,
2526
generate_qnn_executorch_compiler_spec,
2627
get_soc_to_chipset_map,
@@ -119,9 +120,15 @@ def compile(
119120
# because the encoder is quite sensitive and quantization can make it harder for the model to distinguish
120121
# between images within the same conversation.
121122
to_skip = len(args.image_path) > 1
122-
backend_options = generate_htp_compiler_spec(
123-
use_fp16=to_skip,
124-
)
123+
if args.backend == "htp":
124+
backend_options = generate_htp_compiler_spec(
125+
use_fp16=to_skip,
126+
)
127+
elif args.backend == "gpu":
128+
backend_options = generate_gpu_compiler_spec()
129+
else:
130+
raise ValueError(f"Unsupported backend {args.backend}")
131+
125132
encoder_compile_specs = generate_qnn_executorch_compiler_spec(
126133
soc_model=get_soc_to_chipset_map()[args.soc_model],
127134
backend_options=backend_options,
@@ -131,34 +138,48 @@ def compile(
131138
skip_quantize[modality] = to_skip
132139
compile_specs[modality] = encoder_compile_specs
133140
elif is_multimodal and modality == TOK_EMBEDDING:
134-
backend_options = generate_htp_compiler_spec(
135-
use_fp16=False,
136-
# x86 emulator does not support weight sharing
137-
use_weight_sharing=not args.enable_x86_64,
138-
)
141+
if args.backend == "htp":
142+
backend_options = generate_htp_compiler_spec(
143+
use_fp16=False,
144+
# x86 emulator does not support weight sharing
145+
use_weight_sharing=not args.enable_x86_64,
146+
)
147+
elif args.backend == "gpu":
148+
backend_options = generate_gpu_compiler_spec()
149+
else:
150+
raise ValueError(f"Unsupported backend {args.backend}")
151+
139152
compile_specs[modality] = [
140153
generate_qnn_executorch_compiler_spec(
141154
soc_model=get_soc_to_chipset_map()[args.soc_model],
142155
backend_options=backend_options,
143156
# x86 emulator does not support shared buffer
144157
shared_buffer=not args.enable_x86_64,
158+
online_prepare=args.online_prepare,
145159
)
146160
] * len(TOK_EMBEDDING_GRAPH_NAMES)
147161
elif modality == TEXT_DECODER:
148162
# compile spec for text decoder
149-
backend_options = generate_htp_compiler_spec(
150-
use_fp16=False,
151-
use_multi_contexts=decoder_model_config.num_sharding > 1,
152-
# x86 emulator does not support weight sharing
153-
use_weight_sharing=not args.enable_x86_64,
154-
)
163+
if args.backend == "htp":
164+
backend_options = generate_htp_compiler_spec(
165+
use_fp16=args.use_fp16,
166+
use_multi_contexts=decoder_model_config.num_sharding > 1,
167+
# x86 emulator does not support weight sharing
168+
use_weight_sharing=not args.enable_x86_64,
169+
)
170+
elif args.backend == "gpu":
171+
backend_options = generate_gpu_compiler_spec()
172+
else:
173+
raise ValueError(f"Unsupported backend {args.backend}")
174+
skip_quantize[modality] = args.use_fp16
155175
compile_specs[modality] = [
156176
generate_qnn_executorch_compiler_spec(
157177
soc_model=get_soc_to_chipset_map()[args.soc_model],
158178
backend_options=backend_options,
159179
# x86 emulator does not support shared buffer
160180
shared_buffer=not args.enable_x86_64,
161181
use_mha2sha=True,
182+
online_prepare=args.online_prepare,
162183
)
163184
] * len(DECODER_GRAPH_NAMES)
164185

@@ -172,7 +193,11 @@ def compile(
172193
)
173194

174195
# perform compilation
175-
multi_modal_mgr.compile(compile_specs=compile_specs, pte_filenames=pte_filenames)
196+
multi_modal_mgr.compile(
197+
compile_specs=compile_specs,
198+
pte_filenames=pte_filenames,
199+
skip_quantize=skip_quantize,
200+
)
176201

177202

178203
def inference(
@@ -529,6 +554,14 @@ def _build_parser():
529554
help="Number of examples in few-shot context",
530555
)
531556

557+
parser.add_argument(
558+
"-F",
559+
"--use_fp16",
560+
help="If specified, will run in fp16 precision and discard ptq setting",
561+
action="store_true",
562+
default=False,
563+
)
564+
532565
parser.add_argument("-v", "--verbose", action="store_true")
533566

534567
parser.add_argument(
@@ -592,6 +625,12 @@ def export_llama(args) -> None:
592625
pte_filename = "lookahead_llama_qnn"
593626
else:
594627
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
628+
629+
if args.model_mode == "hybrid" and args.online_prepare:
630+
raise RuntimeError(
631+
"Currently hybrid mode is not compatible with online_prepare."
632+
)
633+
595634
if args.decoder_model == "stories260k":
596635
pte_filename = f"{args.decoder_model}_" + pte_filename
597636
pte_filenames = {
@@ -740,6 +779,7 @@ def export_llama(args) -> None:
740779
def main():
741780
parser = _build_parser()
742781
args = parser.parse_args()
782+
args.build_folder = os.path.realpath(args.build_folder)
743783
try:
744784
export_llama(args)
745785
except Exception as e:

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ std::string get_formatted_prompt(
210210
return formatted_prompt;
211211
}
212212

213-
template <typename T>
214213
void start_runner(
215214
std::unique_ptr<executorch::extension::Module> module,
216215
std::vector<std::string>& prompts,
@@ -219,7 +218,7 @@ void start_runner(
219218
gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default ? false
220219
: true;
221220
// create llama runner
222-
example::Runner<T> runner(
221+
example::Runner runner(
223222
std::move(module),
224223
FLAGS_decoder_model_version.c_str(),
225224
FLAGS_model_path.c_str(),
@@ -298,26 +297,8 @@ int main(int argc, char** argv) {
298297
FLAGS_attention_sink_rope_path.c_str(),
299298
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
300299
}
301-
// Using 8bit as default since this meta is introduced with 16bit kv io
302-
// support and older models only have 8bit kv io.
303-
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
304-
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
305-
kv_bitwidth = static_cast<example::KvBitWidth>(
306-
module->get("get_kv_io_bit_width").get().toScalar().to<int64_t>());
307-
}
308-
309-
if (kv_bitwidth == example::KvBitWidth::kWidth8) {
310-
start_runner<uint8_t>(
311-
std::move(module), prompts, std::move(attention_sink_rope_module));
312-
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
313-
start_runner<uint16_t>(
314-
std::move(module), prompts, std::move(attention_sink_rope_module));
315-
} else {
316-
ET_CHECK_MSG(
317-
false,
318-
"Unsupported kv bitwidth: %ld",
319-
static_cast<int64_t>(kv_bitwidth));
320-
}
300+
start_runner(
301+
std::move(module), prompts, std::move(attention_sink_rope_module));
321302

322303
return 0;
323304
}

examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ std::vector<std::string> CollectPrompts(int argc, char** argv) {
137137
return prompts;
138138
}
139139

140-
template <typename T>
141140
void start_multimodal_runner(
142141
std::unique_ptr<executorch::extension::Module> encoder,
143142
std::unique_ptr<executorch::extension::Module> tok_embedding,
@@ -150,7 +149,7 @@ void start_multimodal_runner(
150149
: true;
151150

152151
// Create multimodal runner
153-
example::QNNMultimodalRunner<T> runner(
152+
example::QNNMultimodalRunner runner(
154153
std::move(encoder),
155154
std::move(tok_embedding),
156155
std::move(text_decoder),
@@ -289,35 +288,12 @@ int main(int argc, char** argv) {
289288
FLAGS_decoder_path.c_str(),
290289
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
291290

292-
// Using 8bit as default since this meta is introduced with 16bit kv io
293-
// support and older models only have 8bit kv io.
294-
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
295-
if (text_decoder->method_names()->count("get_kv_io_bit_width") > 0) {
296-
kv_bitwidth = static_cast<example::KvBitWidth>(
297-
text_decoder->get("get_kv_io_bit_width")
298-
.get()
299-
.toScalar()
300-
.to<int64_t>());
301-
}
302-
// Start runner with appropriate KV bitwidth
303-
if (kv_bitwidth == example::KvBitWidth::kWidth8) {
304-
start_multimodal_runner<uint8_t>(
305-
std::move(encoder),
306-
std::move(tok_embedding),
307-
std::move(text_decoder),
308-
prompts);
309-
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
310-
start_multimodal_runner<uint16_t>(
311-
std::move(encoder),
312-
std::move(tok_embedding),
313-
std::move(text_decoder),
314-
prompts);
315-
} else {
316-
ET_CHECK_MSG(
317-
false,
318-
"Unsupported kv bitwidth: %ld",
319-
static_cast<int64_t>(kv_bitwidth));
320-
}
291+
// Start runner
292+
start_multimodal_runner(
293+
std::move(encoder),
294+
std::move(tok_embedding),
295+
std::move(text_decoder),
296+
prompts);
321297

322298
return 0;
323299
}

0 commit comments

Comments
 (0)