@@ -309,6 +309,69 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype)
309309 }
310310
311311
312+ def _transform_eager_model (model , args , float_dtype ):
313+ """Apply splitting, quantization, and graph breaks to a model."""
314+ model = model .to (float_dtype ).eval ()
315+
316+ if args .target_split_size is not None :
317+ print (f"\n Splitting linear layers with target size { args .target_split_size } ..." )
318+ replace_linear_with_split_linear (
319+ model ,
320+ out_target_split_size = args .target_split_size ,
321+ out_max_splits = args .max_splits ,
322+ in_target_split_size = 1 ,
323+ in_max_splits = 1 ,
324+ )
325+
326+ if args .embedding_quantize :
327+ bitwidth , group_size = args .embedding_quantize .split ("," )
328+ bitwidth = int (bitwidth )
329+ group_size = int (group_size )
330+ assert bitwidth in [4 , 8 ], "CoreML only supports 4-bit and 8-bit quantization"
331+
332+ print (f"\n Quantizing embeddings: { bitwidth } -bit, group_size={ group_size } ..." )
333+ if group_size == 0 :
334+ granularity = PerAxis (0 )
335+ else :
336+ granularity = PerGroup (group_size )
337+ weight_dtype = getattr (torch , f"int{ bitwidth } " )
338+
339+ quantize_ (
340+ model ,
341+ IntxWeightOnlyConfig (weight_dtype = weight_dtype , granularity = granularity ),
342+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
343+ )
344+
345+ if args .linear_quantize == "b4w" :
346+ print ("\n Quantizing linear layers: 4-bit blockwise (group_size=32)..." )
347+ quantize_ (
348+ model ,
349+ IntxWeightOnlyConfig (
350+ weight_dtype = torch .int4 ,
351+ granularity = PerGroup (32 ),
352+ ),
353+ )
354+ elif args .linear_quantize == "c4w" :
355+ print ("\n Quantizing linear layers: 4-bit channelwise..." )
356+ quantize_ (
357+ model ,
358+ IntxWeightOnlyConfig (
359+ weight_dtype = torch .int4 ,
360+ granularity = PerAxis (0 ),
361+ ),
362+ )
363+
364+ if not args .no_graph_breaks :
365+ print ("\n Adding graph breaks between before/after the transformer blocks..." )
366+ n_layers = len (model .layers )
367+ model .layers [0 ] = BlockWithGraphBreak (model .layers [0 ], break_before = True )
368+ model .layers [n_layers - 1 ] = BlockWithGraphBreak (
369+ model .layers [n_layers - 1 ], break_before = False
370+ )
371+
372+ return model
373+
374+
312375def main ():
313376 parser = argparse .ArgumentParser (
314377 description = "Export static attention Llama model to CoreML"
@@ -441,70 +504,8 @@ def main():
441504 )
442505 print (f"Model loaded: { model_args .n_layers } layers, { model_args .dim } dim" )
443506
444- # Set dtype
445507 float_dtype = {"fp16" : torch .float16 , "fp32" : torch .float32 }[args .dtype ]
446- model = model .to (float_dtype ).eval ()
447-
448- # Apply linear splitting (before quantization)
449- if args .target_split_size is not None :
450- print (f"\n Splitting linear layers with target size { args .target_split_size } ..." )
451- replace_linear_with_split_linear (
452- model ,
453- out_target_split_size = args .target_split_size ,
454- out_max_splits = args .max_splits ,
455- in_target_split_size = 1 ,
456- in_max_splits = 1 ,
457- )
458-
459- # Apply embedding quantization
460- if args .embedding_quantize :
461- bitwidth , group_size = args .embedding_quantize .split ("," )
462- bitwidth = int (bitwidth )
463- group_size = int (group_size )
464- assert bitwidth in [4 , 8 ], "CoreML only supports 4-bit and 8-bit quantization"
465-
466- print (f"\n Quantizing embeddings: { bitwidth } -bit, group_size={ group_size } ..." )
467- if group_size == 0 :
468- granularity = PerAxis (0 )
469- else :
470- granularity = PerGroup (group_size )
471- weight_dtype = getattr (torch , f"int{ bitwidth } " )
472-
473- quantize_ (
474- model ,
475- IntxWeightOnlyConfig (weight_dtype = weight_dtype , granularity = granularity ),
476- lambda m , fqn : isinstance (m , torch .nn .Embedding ),
477- )
478-
479- # Apply linear quantization
480- if args .linear_quantize == "b4w" :
481- print ("\n Quantizing linear layers: 4-bit blockwise (group_size=32)..." )
482- quantize_ (
483- model ,
484- IntxWeightOnlyConfig (
485- weight_dtype = torch .int4 ,
486- granularity = PerGroup (32 ),
487- ),
488- )
489- elif args .linear_quantize == "c4w" :
490- print ("\n Quantizing linear layers: 4-bit channelwise..." )
491- quantize_ (
492- model ,
493- IntxWeightOnlyConfig (
494- weight_dtype = torch .int4 ,
495- granularity = PerAxis (0 ),
496- ),
497- )
498-
499- # Add graph breaks between transformer blocks
500- # Keeping model pieces smaller helps with ANE performance
501- if not args .no_graph_breaks :
502- print ("\n Adding graph breaks between before/after the transformer blocks..." )
503- n_layers = len (model .layers )
504- model .layers [0 ] = BlockWithGraphBreak (model .layers [0 ], break_before = True )
505- model .layers [n_layers - 1 ] = BlockWithGraphBreak (
506- model .layers [n_layers - 1 ], break_before = False
507- )
508+ model = _transform_eager_model (model , args , float_dtype )
508509
509510 if args .multifunction :
510511 # Multifunction mode: separate prefill and decode graphs with weight sharing
0 commit comments