Skip to content

Commit 3c5d17b

Browse files
committed
Mark input tokens to routed experts as dynamic to avoid a recompile
1 parent 61c25f8 commit 3c5d17b

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def parallelize_deepseekv3(
118118
)
119119

120120
if model_compile_enabled:
121-
apply_compile(model, job_config.compile)
121+
apply_compile(model, job_config.compile, parallel_dims.ep_enabled)
122122

123123
dp_mesh: DeviceMesh | None = None
124124
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def parallelize_llama(
129129

130130
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
131131
if model_compile_enabled:
132-
apply_compile(model, job_config.compile)
132+
apply_compile(model, job_config.compile, parallel_dims.ep_enabled)
133133

134134
dp_mesh: DeviceMesh | None = None
135135
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
@@ -506,7 +506,7 @@ def apply_moe_ep_tp(
506506
)
507507

508508

509-
def apply_compile(model: nn.Module, compile_config: CompileConfig):
509+
def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: bool):
510510
"""
511511
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
512512
repeated structure. Alternatively one can compile the whole model (after applying DP).
@@ -577,6 +577,22 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig):
577577
fullgraph=True,
578578
)
579579

580+
if ep_enabled:
581+
compiled_fn = moe_module._run_experts_grouped_mm
582+
583+
def _run_experts_grouped_mm_dynamic(
584+
w1: torch.Tensor,
585+
w2: torch.Tensor,
586+
w3: torch.Tensor,
587+
x: torch.Tensor,
588+
num_tokens_per_expert: torch.Tensor,
589+
) -> torch.Tensor:
590+
# dynamic number of tokens in expert parallel
591+
torch._dynamo.mark_dynamic(x, 0)
592+
return compiled_fn(w1, w2, w3, x, num_tokens_per_expert)
593+
594+
moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic
595+
580596
# NOTE: We don't compile for loop code path due to an issue with unbacked symints:
581597
# https://github.com/pytorch/pytorch/issues/166460
582598

torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def parallelize_qwen3(
119119

120120
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
121121
if model_compile_enabled:
122-
apply_compile(model, job_config.compile)
122+
apply_compile(model, job_config.compile, parallel_dims.ep_enabled)
123123

124124
if parallel_dims.fsdp_enabled:
125125
# apply FSDP or HSDP, potentially with Context Parallel

0 commit comments

Comments
 (0)