Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from tqdm import tqdm
from triton.testing import do_bench

from torch.nn.functional import scaled_mm, ScalingType


from torchao.prototype.blockwise_fp8_training.kernels import (
triton_fp8_blockwise_act_quant_lhs,
triton_fp8_blockwise_weight_quant_transposed_rhs,
Expand Down Expand Up @@ -112,25 +115,31 @@ def warmup(func, *args, **kwargs):
)

# Warm up then run torch bench
# scaled_mm requires A_s and B_t_s be in column-major format
A_s = A_s.t().contiguous().t()

scale_recipe_a = ScalingType.BlockWise1x128
scale_recipe_b = ScalingType.BlockWise128x128

warmup(
torch._scaled_mm,
scaled_mm,
A_q,
B_t_q,
1.0 / A_s,
scale_recipe_a,
1.0 / B_t_s,
out_dtype=config.out_dtype,
scale_recipe_b,
output_dtype=config.out_dtype,
)

fp8_scaled_mm_us = benchmark_cuda_function_in_microseconds(
torch._scaled_mm,
scaled_mm,
A_q,
B_t_q,
1.0 / A_s,
scale_recipe_a,
1.0 / B_t_s,
out_dtype=config.out_dtype,
scale_recipe_b,
output_dtype=config.out_dtype,
)

return ExperimentResult(
Expand All @@ -157,8 +166,10 @@ def print_results(experiments: List[Experiment]):
for experiment in experiments:
m, n, k = experiment.config.m, experiment.config.n, experiment.config.k
flops = 2 * m * n * k
bf16_mm_tflops_per_sec = (flops / 1e12) / (experiment.result.bf16_mm_us / 1e6)
triton_tflops_per_sec = (flops / 1e12) / (experiment.result.fp8_triton_us / 1e6)
bf16_mm_tflops_per_sec = (flops / 1e12) / \
(experiment.result.bf16_mm_us / 1e6)
triton_tflops_per_sec = (flops / 1e12) / \
(experiment.result.fp8_triton_us / 1e6)
scaled_mm_tflops_per_sec = (flops / 1e12) / (
experiment.result.fp8_scaled_mm_us / 1e6
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from tqdm import tqdm
from triton.testing import do_bench

from torch.nn.functional import scaled_mm, ScalingType


from torchao.prototype.blockwise_fp8_training.kernels import (
triton_fp8_blockwise_act_quant_rhs,
triton_fp8_blockwise_act_quant_transposed_lhs,
Expand Down Expand Up @@ -112,22 +115,30 @@ def warmup(func, *args, **kwargs):
)

# Warm up then run torch bench
scale_recipe_a = ScalingType.BlockWise1x128
scale_recipe_b = ScalingType.BlockWise1x128
B_s_t = B_s.t()

warmup(
torch._scaled_mm,
scaled_mm,
A_t_q,
B_q,
1.0 / A_t_s,
1.0 / B_s,
out_dtype=config.out_dtype,
scale_recipe_a,
1.0 / B_s_t,
scale_recipe_b,
output_dtype=config.out_dtype,
)

fp8_scaled_mm_us = benchmark_cuda_function_in_microseconds(
torch._scaled_mm,
scaled_mm,
A_t_q,
B_q,
1.0 / A_t_s,
1.0 / B_s,
out_dtype=config.out_dtype,
scale_recipe_a,
1.0 / B_s_t,
scale_recipe_b,
output_dtype=config.out_dtype,
)

return ExperimentResult(
Expand All @@ -154,8 +165,10 @@ def print_results(experiments: List[Experiment]):
for experiment in experiments:
m, n, k = experiment.config.m, experiment.config.n, experiment.config.k
flops = 2 * m * n * k
bf16_mm_tflops_per_sec = (flops / 1e12) / (experiment.result.bf16_mm_us / 1e6)
triton_tflops_per_sec = (flops / 1e12) / (experiment.result.fp8_triton_us / 1e6)
bf16_mm_tflops_per_sec = (flops / 1e12) / \
(experiment.result.bf16_mm_us / 1e6)
triton_tflops_per_sec = (flops / 1e12) / \
(experiment.result.fp8_triton_us / 1e6)
scaled_mm_tflops_per_sec = (flops / 1e12) / (
experiment.result.fp8_scaled_mm_us / 1e6
)
Expand Down