Skip to content

Commit 1fdba5c

Browse files
committed
add configs to specify passes in compiler
1 parent 4caa379 commit 1fdba5c

File tree

8 files changed

+248
-66
lines changed

8 files changed

+248
-66
lines changed

torchtitan/experiments/compiler_toolkit/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,18 @@ NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.tom
2929
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4
3030
```
3131

32+
**SimpleFSDP + TP + auto-bucketing**
33+
```shell
34+
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering
35+
```
36+
3237
**SimpleFSDP + TP + FlexAttention**
3338
```shell
3439
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn
3540
```
41+
42+
**SimpleFSDP + TP + FlexAttention + auto-bucketing + regional-inductor**
43+
44+
```shell
45+
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor
46+
```

torchtitan/experiments/compiler_toolkit/common_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,13 @@ def register_blockmask_pytree_node():
5353
flatten_with_keys_fn=BlockMask._flatten_with_keys,
5454
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
5555
)
56+
57+
58+
def validate_flex_attention_annotation(joint_with_descriptors):
59+
"""Verify user annotations show up in the graph."""
60+
for node in joint_with_descriptors.graph_module.graph.nodes:
61+
if node.target in {
62+
torch.ops.higher_order.flex_attention,
63+
torch.ops.higher_order.flex_attention_backward,
64+
}:
65+
assert "compile_with_inductor" in node.meta.get("custom", {})

torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,19 @@
1717
disable_compile,
1818
parallelize_inputs,
1919
register_blockmask_pytree_node,
20+
validate_flex_attention_annotation,
2021
)
2122

2223
from torchtitan.experiments.compiler_toolkit.graph_utils import (
2324
CompiledModule,
25+
get_compiler_passes_from_config,
2426
joint_graph_builder,
27+
make_compiler_with_passes,
2528
)
2629

2730
from torchtitan.experiments.simple_fsdp.deepseek_v3.parallelize import (
2831
parallelize_deepseekv3 as simple_fsdp_parallelize_deepseekv3,
2932
)
30-
from torchtitan.tools.logging import logger
31-
32-
33-
def compiler(name: str, gm: torch.fx.GraphModule, example_inputs):
34-
logger.info(f"{name} before compiler:")
35-
logger.info(gm.print_readable(print_output=False))
36-
37-
# TODO: regional_inductor should work with deepseek_v3
38-
# gm = regional_inductor(gm, example_inputs)
39-
40-
logger.info(f"{name} after compiler:")
41-
logger.info(gm.print_readable(print_output=False))
42-
return gm
43-
44-
45-
def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
46-
return compiler("fwd_gm", gm, example_inputs)
47-
48-
49-
def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
50-
return compiler("bwd_gm", gm, example_inputs)
5133

5234

5335
def annotate_deepseekv3() -> None:
@@ -75,7 +57,17 @@ def parallelize_deepseekv3(
7557
parallel_dims: ParallelDims,
7658
job_config: JobConfig,
7759
) -> CompiledModule:
60+
"""
61+
Parallelize and compile a DeepSeek v3 model with optional custom compiler passes.
62+
63+
Args:
64+
model: The model to parallelize
65+
parallel_dims: Parallel dimensions configuration
66+
job_config: Job configuration
7867
68+
Returns:
69+
CompiledModule wrapping the parallelized and compiled model
70+
"""
7971
annotate_deepseekv3()
8072

8173
register_blockmask_pytree_node()
@@ -84,11 +76,18 @@ def parallelize_deepseekv3(
8476
with disable_compile(job_config):
8577
model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config)
8678

79+
# Get compiler passes from config
80+
compiler_passes = get_compiler_passes_from_config(job_config)
81+
82+
# Create compilers with specified passes (defaults to no passes)
83+
fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes)
84+
85+
# Create custom joint_graph_builder with deepseekv3-specific compilers
8786
deepseekv3_joint_graph_builder = functools.partial(
8887
joint_graph_builder,
8988
fw_compiler=fw_compiler,
9089
bw_compiler=bw_compiler,
91-
joint_custom_pass=None,
90+
joint_custom_pass=validate_flex_attention_annotation,
9291
)
9392

9493
# TODO: CompiledModule should take sample input as well, so that we can

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import contextlib
8-
from typing import Callable, Optional
8+
from typing import Callable, List, Optional
99

1010
import torch
1111
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
@@ -16,6 +16,7 @@
1616
)
1717
from torch._guards import tracing, TracingContext
1818
from torch.distributed.tensor import DTensor
19+
from torchtitan.config import JobConfig
1920
from torchtitan.distributed import ParallelDims
2021
from torchtitan.tools.logging import logger
2122

@@ -180,3 +181,88 @@ def forward(self, *args, **kwargs):
180181
# calling the line below returns control to torchtitan's runner
181182
# letting it call the backward, and optimizer.
182183
return self.joint_graph_module(args, kwargs)
184+
185+
186+
# Default compiler pass configuration - no passes by default
187+
DEFAULT_COMPILER_PASSES = []
188+
189+
190+
def compiler(
191+
name: str,
192+
gm: torch.fx.GraphModule,
193+
example_inputs,
194+
passes: List[Callable] = None,
195+
):
196+
"""
197+
Compile a graph module by applying a sequence of compiler passes.
198+
199+
Args:
200+
name: Name for logging purposes
201+
gm: The graph module to compile
202+
example_inputs: Example inputs for the graph module
203+
passes: List of compiler pass functions to apply. Each function should take
204+
(gm, example_inputs) and return a transformed gm. If None, uses
205+
DEFAULT_COMPILER_PASSES.
206+
"""
207+
if passes is None:
208+
passes = DEFAULT_COMPILER_PASSES
209+
210+
logger.info(f"{name} before compiler:")
211+
logger.info(gm.print_readable(print_output=False))
212+
213+
for pass_fn in passes:
214+
logger.info(f"Applying pass: {pass_fn.__name__}")
215+
gm = pass_fn(gm, example_inputs)
216+
217+
logger.info(f"{name} after compiler:")
218+
logger.info(gm.print_readable(print_output=False))
219+
return gm
220+
221+
222+
def make_compiler_with_passes(passes: List[Callable] = None):
223+
"""
224+
Create forward and backward compilers with specified passes.
225+
226+
Args:
227+
passes: List of compiler pass functions to apply. If None, uses DEFAULT_COMPILER_PASSES.
228+
229+
Returns:
230+
Tuple of (fw_compiler, bw_compiler) functions
231+
"""
232+
233+
def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
234+
return compiler("fwd_gm", gm, example_inputs, passes=passes)
235+
236+
def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
237+
return compiler("bwd_gm", gm, example_inputs, passes=passes)
238+
239+
return fw_compiler, bw_compiler
240+
241+
242+
def get_compiler_passes_from_config(job_config: JobConfig):
243+
"""
244+
Extract and validate compiler passes from job config.
245+
246+
Args:
247+
job_config: Job configuration containing compile.passes
248+
249+
Returns:
250+
List of compiler pass functions
251+
"""
252+
from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_PASSES
253+
254+
pass_names = getattr(job_config.compile, "passes", [])
255+
compiler_passes = []
256+
257+
for pass_name in pass_names:
258+
if pass_name not in AVAILABLE_PASSES:
259+
raise ValueError(
260+
f"Unknown compiler pass: {pass_name}. "
261+
f"Available passes: {list(AVAILABLE_PASSES.keys())}"
262+
)
263+
compiler_passes.append(AVAILABLE_PASSES[pass_name])
264+
265+
if pass_names:
266+
logger.info(f"Using compiler passes from config: {pass_names}")
267+
268+
return compiler_passes
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
9+
10+
@dataclass
11+
class Compile:
12+
"""
13+
List of compiler pass names to apply in the compiler toolkit workflow.
14+
By default, no passes are applied.
15+
Example: --compile.passes autobucketing_reordering,regional_inductor
16+
"""
17+
18+
passes: list[str] = field(default_factory=list)
19+
20+
21+
@dataclass
22+
class JobConfig:
23+
compile: Compile = field(default_factory=Compile)

torchtitan/experiments/compiler_toolkit/llama3/parallelize.py

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
import functools
99

1010
import torch
11-
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
12-
13-
from torch.fx.passes.regional_inductor import regional_inductor
1411
from torch.fx.traceback import annotate_fn
1512

1613
from torchtitan.config import JobConfig
@@ -19,56 +16,19 @@
1916
disable_compile,
2017
parallelize_inputs,
2118
register_blockmask_pytree_node,
19+
validate_flex_attention_annotation,
2220
)
2321

2422
from torchtitan.experiments.compiler_toolkit.graph_utils import (
2523
CompiledModule,
24+
get_compiler_passes_from_config,
2625
joint_graph_builder,
26+
make_compiler_with_passes,
2727
)
2828
from torchtitan.experiments.simple_fsdp.llama3.parallelize import (
2929
parallelize_llama as simple_fsdp_parallelize_llama,
3030
)
3131

32-
from torchtitan.tools.logging import logger
33-
34-
35-
# TODO: support passing configs into schedule_overlap_bucketing
36-
def autobucketing_reordering_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
37-
schedule_overlap_bucketing(gm, collective_bucketing=True)
38-
gm.recompile()
39-
return gm
40-
41-
42-
def compiler(name: str, gm: torch.fx.GraphModule, example_inputs):
43-
logger.info(f"{name} before compiler:")
44-
logger.info(gm.print_readable(print_output=False))
45-
46-
gm = autobucketing_reordering_pass(gm)
47-
48-
gm = regional_inductor(gm, example_inputs)
49-
50-
logger.info(f"{name} after compiler:")
51-
logger.info(gm.print_readable(print_output=False))
52-
return gm
53-
54-
55-
def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
56-
return compiler("fwd_gm", gm, example_inputs)
57-
58-
59-
def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
60-
return compiler("bwd_gm", gm, example_inputs)
61-
62-
63-
def validate_flex_attention_annotation(joint_with_descriptors):
64-
"""Verify user annotations show up in the graph."""
65-
for node in joint_with_descriptors.graph_module.graph.nodes:
66-
if node.target in {
67-
torch.ops.higher_order.flex_attention,
68-
torch.ops.higher_order.flex_attention_backward,
69-
}:
70-
assert "compile_with_inductor" in node.meta.get("custom", {})
71-
7232

7333
def annotate_llama() -> None:
7434
from torchtitan.models.attention import FlexAttentionWrapper
@@ -84,7 +44,17 @@ def parallelize_llama(
8444
parallel_dims: ParallelDims,
8545
job_config: JobConfig,
8646
) -> CompiledModule:
47+
"""
48+
Parallelize and compile a Llama model with optional custom compiler passes.
49+
50+
Args:
51+
model: The model to parallelize
52+
parallel_dims: Parallel dimensions configuration
53+
job_config: Job configuration
8754
55+
Returns:
56+
CompiledModule wrapping the parallelized and compiled model
57+
"""
8858
annotate_llama()
8959

9060
register_blockmask_pytree_node()
@@ -93,6 +63,12 @@ def parallelize_llama(
9363
with disable_compile(job_config):
9464
model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config)
9565

66+
# Get compiler passes from config
67+
compiler_passes = get_compiler_passes_from_config(job_config)
68+
69+
# Create compilers with specified passes (defaults to no passes)
70+
fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes)
71+
9672
# Create custom joint_graph_builder with llama-specific compilers and validation
9773
llama_joint_graph_builder = functools.partial(
9874
joint_graph_builder,
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Compiler passes for the compiler toolkit.
9+
10+
This module provides various compiler passes that can be applied to graph modules
11+
during compilation. Passes can be selected and configured via job config.
12+
"""
13+
14+
import torch
15+
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
16+
from torch.fx.passes.regional_inductor import regional_inductor
17+
18+
19+
def autobucketing_reordering_pass(
20+
gm: torch.fx.GraphModule, example_inputs=None
21+
) -> torch.fx.GraphModule:
22+
"""
23+
Apply autobucketing and reordering optimization.
24+
25+
This pass applies schedule_overlap_bucketing with collective_bucketing enabled
26+
to optimize communication patterns in distributed training.
27+
"""
28+
schedule_overlap_bucketing(gm, collective_bucketing=True)
29+
gm.recompile()
30+
return gm
31+
32+
33+
def regional_inductor_pass(
34+
gm: torch.fx.GraphModule, example_inputs
35+
) -> torch.fx.GraphModule:
36+
"""
37+
Apply regional inductor compilation based on user annotation.
38+
"""
39+
return regional_inductor(gm, example_inputs)
40+
41+
42+
# Registry mapping pass names to pass functions
43+
AVAILABLE_PASSES = {
44+
"autobucketing_reordering": autobucketing_reordering_pass,
45+
"regional_inductor": regional_inductor_pass,
46+
}

0 commit comments

Comments
 (0)