Skip to content

Conversation

@yiming0416
Copy link
Contributor

@yiming0416 yiming0416 commented Nov 8, 2025

We should be able to control what passes to run in the compiler. This PR uses the config compile.passes to indicate in a list of graph passes to apply on the captured gm.

By default, no pass is applied. Users can specify what passes to apply.

Currently there are autobucketing_reordering_pass and regional_inductor_pass.

NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor

Also updated CI to include this new config

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 8, 2025
@yiming0416 yiming0416 force-pushed the yiming/control_pass_with_config branch 2 times, most recently from 7a9bd50 to 1fdba5c Compare November 10, 2025 18:24
@yiming0416 yiming0416 marked this pull request as ready for review November 10, 2025 18:25
@yiming0416 yiming0416 force-pushed the yiming/control_pass_with_config branch from 1fdba5c to be49e8c Compare November 10, 2025 18:26

**SimpleFSDP + TP + auto-bucketing**
```shell
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think they need to be exposed via command line args.

The key feature here should be passes can be specified via an API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SherlockNoMad If we don't expose via command line args, how do users/developers trigger runs with different config passes?
I would imagine the use case to be the following:

  1. Run CLI command with config 1 (no passes)
  2. Run CLI command with config 2 (some passes turned on)
  3. compare the loss curves / profiling traces with run1 and run2 to verify the effect of passes.

Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approve to unblock.

@yiming0416 yiming0416 merged commit f4514ef into main Nov 10, 2025
5 checks passed
ahoffman-aws pushed a commit to drcanchi-aws/torchtitan that referenced this pull request Nov 11, 2025
We should be able to control what passes to run in the compiler. This PR
uses the config compile.passes to indicate in a list of graph passes to
apply on the captured gm.

By default, no pass is applied. Users can specify what passes to apply.

Currently there are `autobucketing_reordering_pass` and
`regional_inductor_pass`.

```
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor
```

Also updated CI to include this new config
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants