Skip to content

Conversation

@zewenli98
Copy link
Collaborator

@zewenli98 zewenli98 commented Oct 28, 2025

Description

Weak typing behavior in TensorRT is deprecated. However it is a good way to maximize performance. Therefore, we want to create similar PyTorch native system to use with Torch-TensorRT that recovers some of this behavior.

Fixes #3869

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 self-assigned this Oct 28, 2025
@meta-cla meta-cla bot added the cla signed label Oct 28, 2025
@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Oct 28, 2025
@github-actions github-actions bot requested a review from apbose October 28, 2025 05:16
@zewenli98 zewenli98 removed the request for review from apbose October 28, 2025 05:16
@github-actions github-actions bot removed the component: conversion Issues re: Conversion stage label Oct 29, 2025
Comment on lines 437 to 444
enable_autocast: bool = _defaults.ENABLE_AUTOCAST,
low_precision_type: Optional[
Union[torch.dtype, dtype]
] = _defaults.LOW_PRECISION_TYPE,
nodes_to_exclude: Collection[str] = _defaults.NODES_TO_EXCLUDE,
targets_to_exclude: Collection[Target] = _defaults.TARGETS_TO_EXCLUDE,
data_max: float = _defaults.DATA_MAX,
max_depth_of_reduction: Optional[int] = _defaults.MAX_DEPTH_OF_REDUCTION,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before merging, these args should be added to other compile functions in this file.

]:
# GEMM: A (M, K) @ B (K, N) = C (M, N)
self.reduction_depth = input_0_dims[-1]
# TODO: Add more reduction ops here
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should any more reduction targets be added?

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan
Copy link
Collaborator

narendasan commented Nov 6, 2025

For Tests

  1. Should external autocast in pytorch with strong typing
  2. Whole graph autocast pass
  3. a test case that exercises max_output_threshold fallback

L1 or L2 tests

@github-actions github-actions bot added component: tests Issues re: Tests and removed component: core Issues re: The core compiler labels Nov 8, 2025
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Nov 14, 2025
If we compile the above model using Torch-TensorRT, layer profiling logs indicate that all the layers are
run in FP32. This is because TensorRT picks the kernels for layers which result in the best performance.
If we compile the above model using Torch-TensorRT with the following settings, layer profiling logs indicate that all the layers are
run in FP32. This is because TensorRT picks the kernels for layers which result in the best performance (i.e., weak typing in TensorRT).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to reorient around strong typing first and then weak typing as an optimization. Right now this is a bit confusing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So like in the tutorial

  1. Demonstrate strong typing and explain that its going to be the default behavior
  2. Show the weak typing behavior and talk about how the trt graph changed (and maybe why)
  3. Show how you can recover the weak typing behavior using auto cast for trt 11 and beyond

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since TRT has deprecated weak typing, should we mention weak typing is deprecated so need to use autocast instead? Thus, we have only two modes:

User defineds precision:          use_explicit_typing=True + enable_autocast=False
Autocast chooses precision:          use_explicit_typing=True + enable_autocast=True

Autocast
---------------

Weak typing behavior in TensorRT is deprecated. However it is a good way to maximize performance. Therefore, in Torch-TensorRT,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However mixed precision is a good way to maximize performance

reduced precision on the rest of the nodes. Torch-TensorRT Autocast also supports users to specify which nodes to exclude from Autocast,
considering some nodes might be more sensitive to affecting accuracy. In addition, Torch-TensorRT Autocast can cooperate with PyTorch
native Autocast, allowing users to use both PyTorch and Torch-TensorRT Autocast in the same model. Torch-TensorRT respects the precision
of the nodes within PyTorch Autocast.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the difference between PyTorch and Torch-TensorRT autocast?

@@ -0,0 +1,70 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add comments to this doc? Here is an example of what im looking for https://docs.pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/converter_overloading.html

return out


if __name__ == "__main__":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know its not best practice but lets just make them pure scripts so they render better

pre_lowering_pass_list = [
remove_detach,
remove_assert_nodes,
rule_based_autocast,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this pass be conditionally added to the pre_lowering_pass_list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a condition inside of rule_based_autocast

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice its looking good, some final polishing details then I think its good to go

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: runtime component: tests Issues re: Tests documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants