@@ -399,71 +399,54 @@ def forward(self, x, input_pos):
399399
400400
401401# ---------------------------------------------------------------------------
402- # MoE: stacked expert weights + index by top-k
402+ # MoE: expert weights for fused MoE Triton kernel
403403
404- # 16 experts per group keeps each nn.Linear under ~32K output features,
405- # within tinygemm int4 packing limits while keeping the graph small
406- # (32 matmul nodes per layer instead of 768 with per-expert linears).
407- _EXPERTS_PER_GROUP = 16
408404
405+ class FusedMoEExperts (nn .Module ):
406+ """Expert weights stored as stacked tensors for the fused MoE Triton kernel.
409407
410- class ConditionalFeedForward ( nn . Module ):
411- """Grouped expert weights as nn.Linear for quantization compatibility .
408+ Before quantization: w1_weight [E, 2*inter, hidden] and w2_weight [E, hidden, inter]
409+ are nn.Parameter tensors loaded from the checkpoint .
412410
413- Experts are split into groups of _EXPERTS_PER_GROUP. Each group has:
414- gate_up_projs[g]: nn.Linear(hidden_size, G * intermediate_size * 2)
415- down_projs[g]: nn.Linear(intermediate_size, G * hidden_size)
416- This keeps each nn.Linear small enough for tinygemm int4 packing while
417- allowing quantize_model_() to handle them automatically.
411+ After quantization (in export.py): replaced with packed INT4 buffers
412+ w1 [E, 2*inter, hidden//2], w1_scale, w2 [E, hidden, inter//2], w2_scale.
418413 """
419414
420- def __init__ (self , hidden_size , intermediate_size , num_experts ):
415+ def __init__ (self , config ):
421416 super ().__init__ ()
422- self .num_experts = num_experts
423- self .intermediate_size = intermediate_size
424- self .hidden_size = hidden_size
425- G = _EXPERTS_PER_GROUP
426- assert num_experts % G == 0
427- num_groups = num_experts // G
428-
429- self .gate_up_projs = nn .ModuleList (
430- [
431- nn .Linear (hidden_size , G * intermediate_size * 2 , bias = False )
432- for _ in range (num_groups )
433- ]
417+ self .num_experts = config .num_experts
418+ self .intermediate_size = config .moe_intermediate_size
419+ self .hidden_size = config .hidden_size
420+ self .group_size = 32
421+
422+ self .w1_weight = nn .Parameter (
423+ torch .empty (
424+ config .num_experts ,
425+ 2 * config .moe_intermediate_size ,
426+ config .hidden_size ,
427+ )
434428 )
435- self .down_projs = nn .ModuleList (
436- [
437- nn .Linear (intermediate_size , G * hidden_size , bias = False )
438- for _ in range (num_groups )
439- ]
429+ self .w2_weight = nn .Parameter (
430+ torch .empty (
431+ config .num_experts ,
432+ config .hidden_size ,
433+ config .moe_intermediate_size ,
434+ )
440435 )
441436
442- def forward (self , x , expert_indices ):
443- # x: (T, D), expert_indices: (T, top_k)
444- T = x .size (0 )
445- top_k = expert_indices .size (1 )
446- G = _EXPERTS_PER_GROUP
447- H = self .intermediate_size
448- D = self .hidden_size
449-
450- # Gate + Up: compute per-group, cat, gather top-k
451- gate_up_parts = [proj (x ).view (T , G , 2 , H ) for proj in self .gate_up_projs ]
452- gate_up = torch .cat (gate_up_parts , dim = 1 ) # (T, E, 2, H)
453-
454- idx = expert_indices .unsqueeze (- 1 ).unsqueeze (- 1 ).expand (- 1 , - 1 , 2 , H )
455- gate_up_sel = gate_up .gather (1 , idx ) # (T, top_k, 2, H)
456- intermediate = F .silu (gate_up_sel [:, :, 0 , :]) * gate_up_sel [:, :, 1 , :]
457-
458- # Down: compute per-group, cat, gather correct expert per slot
459- intermediate_flat = intermediate .reshape (T * top_k , H )
460- down_parts = [
461- proj (intermediate_flat ).view (T , top_k , G , D ) for proj in self .down_projs
462- ]
463- all_down = torch .cat (down_parts , dim = 2 ) # (T, top_k, E, D)
464-
465- eidx = expert_indices .unsqueeze (- 1 ).unsqueeze (- 1 ).expand (- 1 , - 1 , 1 , D )
466- return all_down .gather (2 , eidx ).squeeze (2 ) # (T, top_k, D)
437+ def forward (self , x , expert_weights , expert_indices , top_k ):
438+ return torch .ops .triton .fused_moe (
439+ x ,
440+ self .w1 ,
441+ self .w1_scale ,
442+ self .w2 ,
443+ self .w2_scale ,
444+ expert_weights ,
445+ expert_indices ,
446+ top_k ,
447+ self .num_experts ,
448+ self .group_size ,
449+ )
467450
468451
469452class SwiGLU (nn .Module ):
@@ -484,12 +467,9 @@ class SparseMoE(nn.Module):
484467 def __init__ (self , config ):
485468 super ().__init__ ()
486469 self .top_k = config .num_experts_per_tok
470+ self .num_experts = config .num_experts
487471 self .gate = nn .Linear (config .hidden_size , config .num_experts , bias = False )
488- self .cond_ffn = ConditionalFeedForward (
489- config .hidden_size ,
490- config .moe_intermediate_size ,
491- config .num_experts ,
492- )
472+ self .experts = FusedMoEExperts (config )
493473 self .shared_expert = SwiGLU (
494474 config .hidden_size , config .shared_expert_intermediate_size
495475 )
@@ -503,8 +483,9 @@ def forward(self, x):
503483 expert_weights , expert_indices = torch .topk (scores , self .top_k , dim = - 1 )
504484 expert_weights = expert_weights .softmax (dim = - 1 )
505485
506- expert_outs = self .cond_ffn (x_flat , expert_indices )
507- routed_out = torch .einsum ("tai,ta->ti" , expert_outs , expert_weights )
486+ routed_out = self .experts (
487+ x_flat , expert_weights .float (), expert_indices , self .top_k
488+ )
508489
509490 shared_out = self .shared_expert (x_flat )
510491 shared_gate = torch .sigmoid (self .shared_expert_gate (x_flat ))
@@ -641,9 +622,8 @@ def _load_and_remap_checkpoint(model_dir, config):
641622 expert_weights ,
642623 )
643624
644- # Stack per-expert weights, split into groups, reshape for nn.Linear
625+ # Stack per-expert weights into [E, N, K] tensors for FusedMoEExperts
645626 if expert_weights :
646- G = _EXPERTS_PER_GROUP
647627 for layer_idx in range (config .num_hidden_layers ):
648628 gate_list = [
649629 expert_weights .get ((layer_idx , "gate" , e ))
@@ -661,21 +641,13 @@ def _load_and_remap_checkpoint(model_dir, config):
661641 if gate_list [0 ] is not None :
662642 w_gate = torch .stack (gate_list , dim = 0 ) # (E, H, D)
663643 w_up = torch .stack (up_list , dim = 0 )
664- fused = torch .cat ([w_gate , w_up ], dim = 1 ) # (E, 2*H, D)
665- num_groups = config .num_experts // G
666- for g in range (num_groups ):
667- chunk = fused [g * G : (g + 1 ) * G ]
668- state_dict [
669- f"layers.{ layer_idx } .mlp.cond_ffn.gate_up_projs.{ g } .weight"
670- ] = chunk .reshape (- 1 , chunk .size (- 1 ))
644+ state_dict [f"layers.{ layer_idx } .mlp.experts.w1_weight" ] = torch .cat (
645+ [w_gate , w_up ], dim = 1
646+ ) # (E, 2*H, D)
671647 if down_list [0 ] is not None :
672- w_down = torch .stack (down_list , dim = 0 ) # (E, D, H)
673- num_groups = config .num_experts // G
674- for g in range (num_groups ):
675- chunk = w_down [g * G : (g + 1 ) * G ]
676- state_dict [
677- f"layers.{ layer_idx } .mlp.cond_ffn.down_projs.{ g } .weight"
678- ] = chunk .reshape (- 1 , chunk .size (- 1 ))
648+ state_dict [f"layers.{ layer_idx } .mlp.experts.w2_weight" ] = torch .stack (
649+ down_list , dim = 0
650+ ) # (E, D, H)
679651 del expert_weights
680652
681653 # Handle tied embeddings
@@ -697,27 +669,15 @@ def _process_checkpoint_key(ckpt_key, tensor, state_dict, expert_weights):
697669 if norm_key .startswith (("model.visual." , "model.mtp_" )):
698670 return
699671
700- # Fused expert weights: split into groups of _EXPERTS_PER_GROUP
672+ # Fused expert weights: store directly as [E, N, K] for FusedMoEExperts
701673 m = _FUSED_EXPERT_RE .match (norm_key )
702674 if m :
703675 layer_idx = int (m .group (1 ))
704676 proj_name = m .group (2 )
705- G = _EXPERTS_PER_GROUP
706- num_groups = tensor .size (0 ) // G
707677 if proj_name == "gate_up_proj" :
708- # (E, 2*H, D) → groups of (G, 2*H, D) → each (G*2*H, D)
709- for g in range (num_groups ):
710- chunk = tensor [g * G : (g + 1 ) * G ]
711- state_dict [
712- f"layers.{ layer_idx } .mlp.cond_ffn.gate_up_projs.{ g } .weight"
713- ] = chunk .reshape (- 1 , chunk .size (- 1 )).contiguous ()
678+ state_dict [f"layers.{ layer_idx } .mlp.experts.w1_weight" ] = tensor
714679 else :
715- # down_proj: (E, D, H) → groups of (G, D, H) → each (G*D, H)
716- for g in range (num_groups ):
717- chunk = tensor [g * G : (g + 1 ) * G ]
718- state_dict [f"layers.{ layer_idx } .mlp.cond_ffn.down_projs.{ g } .weight" ] = (
719- chunk .reshape (- 1 , chunk .size (- 1 )).contiguous ()
720- )
680+ state_dict [f"layers.{ layer_idx } .mlp.experts.w2_weight" ] = tensor
721681 return
722682
723683 # Per-expert weights
0 commit comments