Skip to content

Commit 8422125

Browse files
pytorchbotGithub Executorch
andauthored
Make multimethod generic in llm_config (pytorch#18537)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#18213 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/140/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/140/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/140/orig Differential Revision: [D96822523](https://our.internmc.facebook.com/intern/diff/D96822523/) @diff-train-skip-merge Co-authored-by: Github Executorch <github_executorch@arm.com>
1 parent 93960e5 commit 8422125

3 files changed

Lines changed: 55 additions & 41 deletions

File tree

examples/models/llama/export_llama_lib.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -897,15 +897,15 @@ def _validate_args(llm_config):
897897
"Shared embedding is only supported with torchao quantization."
898898
)
899899

900-
if llm_config.multimethod_lora.enabled:
900+
if llm_config.multimethod.enabled:
901901
if llm_config.base.lora_config is not None:
902902
raise ValueError(
903-
"Cannot use both base.lora_config and multimethod_lora.methods. "
904-
"Use multimethod_lora.methods for all LoRA variants."
903+
"Cannot use both base.lora_config and multimethod.methods. "
904+
"Use multimethod.methods for all LoRA variants."
905905
)
906906
if llm_config.quantization.pt2e_quantize is not None:
907907
raise ValueError(
908-
"PT2E quantization is not supported with multimethod_lora export."
908+
"PT2E quantization is not supported with multimethod export."
909909
)
910910
if (
911911
llm_config.backend.coreml.enabled
@@ -915,7 +915,7 @@ def _validate_args(llm_config):
915915
or llm_config.backend.openvino.enabled
916916
):
917917
raise ValueError(
918-
"multimethod_lora export only supports XNNPACK backend or portable ops"
918+
"multimethod export only supports XNNPACK backend or portable ops. "
919919
"Please disable other backends (coreml, vulkan, qnn, mps, openvino)."
920920
)
921921

@@ -1230,7 +1230,7 @@ def _to_edge_and_lower_llama( # noqa: C901
12301230

12311231

12321232
def _get_xnnpack_partitioners(llm_config: LlmConfig) -> Optional[List[Partitioner]]:
1233-
"""Get XNNPACK partitioners for multimethod_lora export."""
1233+
"""Get XNNPACK partitioners for multimethod export."""
12341234
partitioners = []
12351235

12361236
# Order matters here, dynamic quantization should be applied first when
@@ -1268,20 +1268,20 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
12681268
"""
12691269
Export multiple methods (base + LoRA variants) to a single .pte file.
12701270
1271-
For each method in llm_config.multimethod_lora.methods:
1271+
For each method in llm_config.multimethod.methods:
12721272
- If LoraConfig is None: use base model
12731273
- If LoraConfig is provided: create model with LoRA weights
12741274
12751275
Limitations:
1276-
- Only XNNPACK backend is supported for multimethod_lora export.
1276+
- Only XNNPACK backend is supported for multimethod export.
12771277
- PT2E quantization is not supported.
12781278
- Each method is exported separately; export time scales linearly
12791279
with the number of methods.
12801280
- The final .pte file deduplicates shared weights automatically.
12811281
"""
1282-
num_methods = len(llm_config.multimethod_lora.methods)
1282+
num_methods = len(llm_config.multimethod.methods)
12831283
logging.info(
1284-
f"multimethod_lora export: exporting {num_methods} method(s). "
1284+
f"multimethod export: exporting {num_methods} method(s). "
12851285
"Each method requires separate model instantiation and export."
12861286
)
12871287

@@ -1293,14 +1293,14 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
12931293
method_to_program: Dict[str, ExportedProgram] = {}
12941294
first_builder = None
12951295

1296-
for method_name, lora_config in llm_config.multimethod_lora.methods.items():
1297-
logging.info(f"Exporting method: {method_name}")
1296+
for method in llm_config.multimethod.methods:
1297+
logging.info(f"Exporting method: {method.method_name}")
12981298

12991299
# Create a copy of config with this method's LoRA setting
13001300
method_config = copy.deepcopy(llm_config)
1301-
method_config.base.lora_config = lora_config
1302-
# Disable multimethod_lora to avoid infinite recursion
1303-
method_config.multimethod_lora.methods = {}
1301+
method_config.base.lora_config = method.lora_config
1302+
# Disable multimethod to avoid infinite recursion
1303+
method_config.multimethod.methods = []
13041304

13051305
# Load and prepare model for this method
13061306
builder = _prepare_for_llama_export(method_config)
@@ -1309,7 +1309,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
13091309

13101310
# Get the exported program
13111311
exported_program = builder._export(builder.pre_autograd_graph_module)
1312-
method_to_program[method_name] = exported_program
1312+
method_to_program[method.method_name] = exported_program
13131313

13141314
if first_builder is None:
13151315
first_builder = builder
@@ -1319,7 +1319,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
13191319
# Get partitioners based on backend config
13201320
partitioners = _get_xnnpack_partitioners(llm_config)
13211321

1322-
# Lower all methods together using multimethod_lora API
1322+
# Lower all methods together using multimethod API
13231323
edge_config = first_builder._get_edge_config()
13241324
edge_manager = to_edge_transform_and_lower(
13251325
method_to_program,
@@ -1333,7 +1333,7 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
13331333
first_builder.edge_manager = edge_manager
13341334
first_builder = first_builder.to_executorch(
13351335
passes=additional_passes,
1336-
share_mutable_buffers=llm_config.multimethod_lora.share_mutable_buffers,
1336+
share_mutable_buffers=llm_config.multimethod.share_mutable_buffers,
13371337
)
13381338

13391339
output_file = _get_output_filename(
@@ -1350,8 +1350,8 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager:
13501350
def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
13511351
_validate_args(llm_config)
13521352

1353-
# Check for multimethod_lora export
1354-
if llm_config.multimethod_lora.enabled:
1353+
# Check for multimethod export
1354+
if llm_config.multimethod.enabled:
13551355
return _export_llama_multimethod(llm_config)
13561356

13571357
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(

examples/models/qwen3/config/qwen3_multimethod.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ quantization:
2222
qmode: "8da4w"
2323
group_size: 32
2424

25-
multimethod_lora:
25+
multimethod:
2626
methods:
2727
# LoRA method - adapter paths from environment variables
28-
lora_forward:
29-
adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT}
30-
adapter_config: ${oc.env:LORA_ADAPTER_CONFIG}
28+
- method_name: lora_forward
29+
lora_config:
30+
adapter_checkpoint: ${oc.env:LORA_ADAPTER_CHECKPOINT}
31+
adapter_config: ${oc.env:LORA_ADAPTER_CONFIG}
3132
# Base method - no LoRA
32-
base_forward: null
33+
- method_name: base_forward
3334
share_mutable_buffers: True

extension/llm/export/config/llm_config.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import re
2323
from dataclasses import dataclass, field
2424
from enum import Enum
25-
from typing import ClassVar, Dict, List, Optional
25+
from typing import ClassVar, List, Optional
2626

2727

2828
################################################################################
@@ -293,37 +293,52 @@ class DebugConfig:
293293

294294

295295
################################################################################
296-
############################## MultimethodLoraConfig ###########################
296+
############################## MultimethodConfig ###########################
297297
################################################################################
298298

299299

300300
@dataclass
301-
class MultimethodLoraConfig:
301+
class MethodConfig:
302+
"""Configuration for exporting a single method to a .pte file.
303+
By default, all other fields fall back to the default configs in
304+
the yaml file.
305+
306+
Attributes:
307+
method_name: Name of the method to export.
308+
lora_config: Optional LoRA configuration.
309+
"""
310+
311+
method_name: str
312+
lora_config: Optional[LoraConfig] = None
313+
314+
315+
@dataclass
316+
class MultimethodConfig:
302317
"""Configuration for exporting multiple methods to a single .pte file.
303318
304-
Maps method names to optional LoRA configurations. A None value means
305-
the method uses base model weights.
319+
Holds a list of method configs, as well as global options that apply
320+
across all methods.
306321
307322
Attributes:
308-
methods: Dict mapping method names to optional LoRA configs.
309-
Empty dict disables multimethod_lora export.
323+
methods: List of MethodConfig objects with method name and config
324+
for each method.
310325
share_mutable_buffers: Whether to share mutable buffers across methods.
311326
If True, sets all mutable buffers to mem_id=2. Mutable buffers with
312327
the same FQN (fully qualified name) will have the same offset.
313328
314329
Example:
315-
MultimethodLoraConfig(methods={
316-
"forward": None, # base model
317-
"lora_forward": lora_config, # LoRA variant
318-
})
330+
MultimethodConfig(methods=[
331+
MethodConfig("forward", lora_config=None), # base model
332+
MethodConfig("lora_forward", lora_config=lora_config), # LoRA variant
333+
])
319334
"""
320335

321-
methods: Dict[str, Optional[LoraConfig]] = field(default_factory=dict)
336+
methods: List[MethodConfig] = field(default_factory=list)
322337
share_mutable_buffers: bool = False
323338

324339
@property
325340
def enabled(self) -> bool:
326-
"""Returns True if multimethod_lora export is configured."""
341+
"""Returns True if multimethod export is configured."""
327342
return len(self.methods) > 0
328343

329344

@@ -611,9 +626,7 @@ class LlmConfig:
611626
model: ModelConfig = field(default_factory=ModelConfig)
612627
export: ExportConfig = field(default_factory=ExportConfig)
613628
debug: DebugConfig = field(default_factory=DebugConfig)
614-
multimethod_lora: MultimethodLoraConfig = field(
615-
default_factory=MultimethodLoraConfig
616-
)
629+
multimethod: MultimethodConfig = field(default_factory=MultimethodConfig)
617630
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
618631
backend: BackendConfig = field(default_factory=BackendConfig)
619632

0 commit comments

Comments
 (0)