Skip to content

Commit b24b6ab

Browse files
authored
Refactor: extract _transform_eager_model() from CoreML export main() (pytorch#18343)
Extract the model preparation logic (dtype conversion, linear splitting, quantization, graph breaks) into a reusable _transform_eager_model() helper. This enables applying the same transformations consistently to multiple models in a follow-up change. No functional change — pure refactor. Generated with Claude.
1 parent 36a1952 commit b24b6ab

1 file changed

Lines changed: 64 additions & 63 deletions

File tree

examples/apple/coreml/llama/export_static_llm_coreml.py

Lines changed: 64 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,69 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype)
309309
}
310310

311311

312+
def _transform_eager_model(model, args, float_dtype):
313+
"""Apply splitting, quantization, and graph breaks to a model."""
314+
model = model.to(float_dtype).eval()
315+
316+
if args.target_split_size is not None:
317+
print(f"\nSplitting linear layers with target size {args.target_split_size}...")
318+
replace_linear_with_split_linear(
319+
model,
320+
out_target_split_size=args.target_split_size,
321+
out_max_splits=args.max_splits,
322+
in_target_split_size=1,
323+
in_max_splits=1,
324+
)
325+
326+
if args.embedding_quantize:
327+
bitwidth, group_size = args.embedding_quantize.split(",")
328+
bitwidth = int(bitwidth)
329+
group_size = int(group_size)
330+
assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization"
331+
332+
print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...")
333+
if group_size == 0:
334+
granularity = PerAxis(0)
335+
else:
336+
granularity = PerGroup(group_size)
337+
weight_dtype = getattr(torch, f"int{bitwidth}")
338+
339+
quantize_(
340+
model,
341+
IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity),
342+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
343+
)
344+
345+
if args.linear_quantize == "b4w":
346+
print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...")
347+
quantize_(
348+
model,
349+
IntxWeightOnlyConfig(
350+
weight_dtype=torch.int4,
351+
granularity=PerGroup(32),
352+
),
353+
)
354+
elif args.linear_quantize == "c4w":
355+
print("\nQuantizing linear layers: 4-bit channelwise...")
356+
quantize_(
357+
model,
358+
IntxWeightOnlyConfig(
359+
weight_dtype=torch.int4,
360+
granularity=PerAxis(0),
361+
),
362+
)
363+
364+
if not args.no_graph_breaks:
365+
print("\nAdding graph breaks between before/after the transformer blocks...")
366+
n_layers = len(model.layers)
367+
model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True)
368+
model.layers[n_layers - 1] = BlockWithGraphBreak(
369+
model.layers[n_layers - 1], break_before=False
370+
)
371+
372+
return model
373+
374+
312375
def main():
313376
parser = argparse.ArgumentParser(
314377
description="Export static attention Llama model to CoreML"
@@ -441,70 +504,8 @@ def main():
441504
)
442505
print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim")
443506

444-
# Set dtype
445507
float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype]
446-
model = model.to(float_dtype).eval()
447-
448-
# Apply linear splitting (before quantization)
449-
if args.target_split_size is not None:
450-
print(f"\nSplitting linear layers with target size {args.target_split_size}...")
451-
replace_linear_with_split_linear(
452-
model,
453-
out_target_split_size=args.target_split_size,
454-
out_max_splits=args.max_splits,
455-
in_target_split_size=1,
456-
in_max_splits=1,
457-
)
458-
459-
# Apply embedding quantization
460-
if args.embedding_quantize:
461-
bitwidth, group_size = args.embedding_quantize.split(",")
462-
bitwidth = int(bitwidth)
463-
group_size = int(group_size)
464-
assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization"
465-
466-
print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...")
467-
if group_size == 0:
468-
granularity = PerAxis(0)
469-
else:
470-
granularity = PerGroup(group_size)
471-
weight_dtype = getattr(torch, f"int{bitwidth}")
472-
473-
quantize_(
474-
model,
475-
IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity),
476-
lambda m, fqn: isinstance(m, torch.nn.Embedding),
477-
)
478-
479-
# Apply linear quantization
480-
if args.linear_quantize == "b4w":
481-
print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...")
482-
quantize_(
483-
model,
484-
IntxWeightOnlyConfig(
485-
weight_dtype=torch.int4,
486-
granularity=PerGroup(32),
487-
),
488-
)
489-
elif args.linear_quantize == "c4w":
490-
print("\nQuantizing linear layers: 4-bit channelwise...")
491-
quantize_(
492-
model,
493-
IntxWeightOnlyConfig(
494-
weight_dtype=torch.int4,
495-
granularity=PerAxis(0),
496-
),
497-
)
498-
499-
# Add graph breaks between transformer blocks
500-
# Keeping model pieces smaller helps with ANE performance
501-
if not args.no_graph_breaks:
502-
print("\nAdding graph breaks between before/after the transformer blocks...")
503-
n_layers = len(model.layers)
504-
model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True)
505-
model.layers[n_layers - 1] = BlockWithGraphBreak(
506-
model.layers[n_layers - 1], break_before=False
507-
)
508+
model = _transform_eager_model(model, args, float_dtype)
508509

509510
if args.multifunction:
510511
# Multifunction mode: separate prefill and decode graphs with weight sharing

0 commit comments

Comments
 (0)