@@ -657,6 +657,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
657657 from flashinfer .fused_moe .core import (
658658 convert_to_block_layout ,
659659 get_w2_permute_indices_with_cache ,
660+ _maybe_get_cached_w3_w1_permute_indices ,
660661 )
661662
662663 # Swap halves to arrange as [w3; w1] (kernel expectation)
@@ -668,25 +669,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
668669 # Reorder rows of W1 for fused gated activation
669670 w13_weights_bf16_shuffled = []
670671 w2_weights_bf16_shuffled = []
671- for i in range (self . moe . num_experts ):
672- permute_indices = get_w2_permute_indices_with_cache (
672+ for i in range (layer . local_num_experts ):
673+ permute_indices = _maybe_get_cached_w3_w1_permute_indices (
673674 self ._cache_permute_indices ,
674- layer .w13_weight .data [i ].clone (). view (torch .uint8 ),
675+ layer .w13_weight .data [i ].view (torch .uint8 ),
675676 epilogue_tile_m ,
676677 )
677678 tmp_weights1 = (
678- layer .w13_weight .data [i ]
679+ layer .w13_weight .data [i ]. clone ()
679680 .view (torch .uint8 )[permute_indices .to (layer .w13_weight .data .device )]
680681 .contiguous ()
681682 )
682683
683684 permute_indices = get_w2_permute_indices_with_cache (
684685 self ._cache_permute_indices ,
685- layer .w2_weight .data [i ].clone (). view (torch .uint8 ),
686+ layer .w2_weight .data [i ].view (torch .uint8 ),
686687 epilogue_tile_m ,
687688 )
688689 tmp_weights2 = (
689- layer .w2_weight .data [i ]
690+ layer .w2_weight .data [i ]. clone ()
690691 .view (torch .uint8 )[permute_indices .to (layer .w2_weight .data .device )]
691692 .contiguous ()
692693 )
@@ -1508,7 +1509,6 @@ def __init__(
15081509 )
15091510 else :
15101511 self .routing_method_type = RoutingMethodType .TopK
1511-
15121512 self .moe_config : FusedMoEConfig = FusedMoEConfig (
15131513 num_experts = self .global_num_experts ,
15141514 experts_per_token = top_k ,
0 commit comments