Skip to content

Commit 951f1a2

Browse files
committed
use cache
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
1 parent 66eb14f commit 951f1a2

File tree

1 file changed

+7
-7
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+7
-7
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
657657
from flashinfer.fused_moe.core import (
658658
convert_to_block_layout,
659659
get_w2_permute_indices_with_cache,
660+
_maybe_get_cached_w3_w1_permute_indices,
660661
)
661662

662663
# Swap halves to arrange as [w3; w1] (kernel expectation)
@@ -668,25 +669,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
668669
# Reorder rows of W1 for fused gated activation
669670
w13_weights_bf16_shuffled = []
670671
w2_weights_bf16_shuffled = []
671-
for i in range(self.moe.num_experts):
672-
permute_indices = get_w2_permute_indices_with_cache(
672+
for i in range(layer.local_num_experts):
673+
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
673674
self._cache_permute_indices,
674-
layer.w13_weight.data[i].clone().view(torch.uint8),
675+
layer.w13_weight.data[i].view(torch.uint8),
675676
epilogue_tile_m,
676677
)
677678
tmp_weights1 = (
678-
layer.w13_weight.data[i]
679+
layer.w13_weight.data[i].clone()
679680
.view(torch.uint8)[permute_indices.to(layer.w13_weight.data.device)]
680681
.contiguous()
681682
)
682683

683684
permute_indices = get_w2_permute_indices_with_cache(
684685
self._cache_permute_indices,
685-
layer.w2_weight.data[i].clone().view(torch.uint8),
686+
layer.w2_weight.data[i].view(torch.uint8),
686687
epilogue_tile_m,
687688
)
688689
tmp_weights2 = (
689-
layer.w2_weight.data[i]
690+
layer.w2_weight.data[i].clone()
690691
.view(torch.uint8)[permute_indices.to(layer.w2_weight.data.device)]
691692
.contiguous()
692693
)
@@ -1508,7 +1509,6 @@ def __init__(
15081509
)
15091510
else:
15101511
self.routing_method_type = RoutingMethodType.TopK
1511-
15121512
self.moe_config: FusedMoEConfig = FusedMoEConfig(
15131513
num_experts=self.global_num_experts,
15141514
experts_per_token=top_k,

0 commit comments

Comments
 (0)