@@ -129,7 +129,7 @@ def parallelize_llama(
129129
130130 # turn on per-TransformerBlock compile after AC wrapping and before FSDP
131131 if model_compile_enabled :
132- apply_compile (model , job_config .compile )
132+ apply_compile (model , job_config .compile , parallel_dims . ep_enabled )
133133
134134 dp_mesh : DeviceMesh | None = None
135135 if parallel_dims .fsdp_enabled or parallel_dims .ep_enabled :
@@ -506,7 +506,7 @@ def apply_moe_ep_tp(
506506 )
507507
508508
509- def apply_compile (model : nn .Module , compile_config : CompileConfig ):
509+ def apply_compile (model : nn .Module , compile_config : CompileConfig , ep_enabled : bool ):
510510 """
511511 Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
512512 repeated structure. Alternatively one can compile the whole model (after applying DP).
@@ -577,6 +577,22 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig):
577577 fullgraph = True ,
578578 )
579579
580+ if ep_enabled :
581+ compiled_fn = moe_module ._run_experts_grouped_mm
582+
583+ def _run_experts_grouped_mm_dynamic (
584+ w1 : torch .Tensor ,
585+ w2 : torch .Tensor ,
586+ w3 : torch .Tensor ,
587+ x : torch .Tensor ,
588+ num_tokens_per_expert : torch .Tensor ,
589+ ) -> torch .Tensor :
590+ # dynamic number of tokens in expert parallel
591+ torch ._dynamo .mark_dynamic (x , 0 )
592+ return compiled_fn (w1 , w2 , w3 , x , num_tokens_per_expert )
593+
594+ moe_module ._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic
595+
580596 # NOTE: We don't compile for loop code path due to an issue with unbacked symints:
581597 # https://github.com/pytorch/pytorch/issues/166460
582598
0 commit comments