-
Notifications
You must be signed in to change notification settings - Fork 369
feat: Autocast #3878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Autocast #3878
Changes from all commits
eac8809
f6c7c7c
f7d8068
e15ce94
94757d2
4bf12e7
0a62149
a990653
3e008c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,18 +32,15 @@ Consider the following PyTorch model which explicitly casts intermediate layer t | |
| return x | ||
|
|
||
|
|
||
| If we compile the above model using Torch-TensorRT, layer profiling logs indicate that all the layers are | ||
| run in FP32. This is because TensorRT picks the kernels for layers which result in the best performance. | ||
| If we compile the above model using Torch-TensorRT with the following settings, layer profiling logs indicate that all the layers are | ||
| run in FP32. This is because TensorRT picks the kernels for layers which result in the best performance (i.e., weak typing in TensorRT). | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()] | ||
| mod = MyModule().eval().cuda() | ||
| ep = torch.export.export(mod, tuple(inputs)) | ||
| with torch_tensorrt.logging.debug(): | ||
| trt_gm = torch_tensorrt.dynamo.compile(ep, | ||
| inputs=inputs, | ||
| debug=True) | ||
| trt_gm = torch_tensorrt.dynamo.compile(ep, inputs=inputs) | ||
|
|
||
| # Debug log info | ||
| # Layers: | ||
|
|
@@ -53,31 +50,50 @@ run in FP32. This is because TensorRT picks the kernels for layers which result | |
|
|
||
|
|
||
| In order to respect the types specified by the user in the model (eg: in this case, ``linear2`` layer to run in FP16), users can enable | ||
| the compilation setting ``use_explicit_typing=True``. Compiling with this option results in the following TensorRT logs | ||
|
|
||
| .. note:: If you enable ``use_explicit_typing=True``, only torch.float32 is supported in the enabled_precisions. | ||
|
|
||
| the compilation setting ``use_explicit_typing=True``. Compiling with this option results in the following TensorRT logs: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()] | ||
| mod = MyModule().eval().cuda() | ||
| ep = torch.export.export(mod, tuple(inputs)) | ||
| with torch_tensorrt.logging.debug(): | ||
| trt_gm = torch_tensorrt.dynamo.compile(ep, | ||
| inputs=inputs, | ||
| use_explicit_typing=True, | ||
| debug=True) | ||
| trt_gm = torch_tensorrt.dynamo.compile(ep, inputs=inputs, use_explicit_typing=True) | ||
|
|
||
| # Debug log info | ||
| # Layers: | ||
| # Name: __myl_MulSumAddCas_myl0_0, LayerType: kgen, Inputs: [ { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }, { Name: __mye112_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], TacticName: __myl_MulSumAddCas_0xacf8f5dd9be2f3e7bb09cdddeac6c936, StreamId: 0, Metadata: | ||
| # Name: __myl_ResMulSumAddCas_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye127_dconst, Dimensions: [10,30], Format/Datatype: Half }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantHalf, Dimensions: [1,30], Format/Datatype: Half }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_ResMulSumAddCas_0x5a3b318b5a1c97b7d5110c0291481337, StreamId: 0, Metadata: | ||
| # Name: __myl_ResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye142_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_ResMulSumAdd_0x3fad91127c640fd6db771aa9cde67db0, StreamId: 0, Metadata: | ||
|
|
||
| Now the ``linear2`` layer runs in FP16 as shown in the above logs. | ||
| Autocast | ||
| --------------- | ||
|
|
||
| Weak typing behavior in TensorRT is deprecated. However it is a good way to maximize performance. Therefore, in Torch-TensorRT, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| we want to provide a way to enable weak typing behavior in Torch-TensorRT, which is called `Autocast`. | ||
|
|
||
| Torch-TensorRT Autocast intelligently selects nodes to keep in FP32 precision to maintain model accuracy while benefiting from | ||
| reduced precision on the rest of the nodes. Torch-TensorRT Autocast also supports users to specify which nodes to exclude from Autocast, | ||
| considering some nodes might be more sensitive to affecting accuracy. In addition, Torch-TensorRT Autocast can cooperate with PyTorch | ||
| native Autocast, allowing users to use both PyTorch and Torch-TensorRT Autocast in the same model. Torch-TensorRT respects the precision | ||
| of the nodes within PyTorch Autocast. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain the difference between PyTorch and Torch-TensorRT autocast? |
||
|
|
||
| To enable Torch-TensorRT Autocast, users need to set both ``enable_autocast=True`` and ``use_explicit_typing=True``. For example, | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()] | ||
| mod = MyModule().eval().cuda() | ||
| ep = torch.export.export(mod, tuple(inputs)) | ||
| trt_gm = torch_tensorrt.dynamo.compile(ep, inputs=inputs, enable_autocast=True, use_explicit_typing=True) | ||
|
|
||
|
|
||
| Users can also specify the precision of the nodes by ``autocast_low_precision_type``, or ``autocast_excluded_nodes`` / ``autocast_excluded_ops`` | ||
| to exclude certain nodes/ops from Autocast. | ||
|
|
||
| In summary, there are three ways in Torch-TensorRT to enable mixed precision: | ||
| 1. TRT chooses precision (weak typing): ``use_explicit_typing=False + enable_autocast=False`` | ||
| 2. User specifies precision (strong typing): ``use_explicit_typing=True + enable_autocast=False`` | ||
| 3. Autocast chooses precision (autocast + strong typing): ``use_explicit_typing=True + enable_autocast=True`` | ||
|
|
||
| FP32 Accumulation | ||
| ----------------- | ||
|
|
@@ -93,14 +109,12 @@ When ``use_fp32_acc=True`` is set, Torch-TensorRT will attempt to use FP32 accum | |
| inputs = [torch.randn((1, 10), dtype=torch.float16).cuda()] | ||
| mod = MyModule().eval().cuda() | ||
| ep = torch.export.export(mod, tuple(inputs)) | ||
| with torch_tensorrt.logging.debug(): | ||
| trt_gm = torch_tensorrt.dynamo.compile( | ||
| ep, | ||
| inputs=inputs, | ||
| use_fp32_acc=True, | ||
| use_explicit_typing=True, # Explicit typing must be enabled | ||
| debug=True | ||
| ) | ||
| trt_gm = torch_tensorrt.dynamo.compile( | ||
| ep, | ||
| inputs=inputs, | ||
| use_fp32_acc=True, | ||
| use_explicit_typing=True, # Explicit typing must be enabled | ||
| ) | ||
|
|
||
| # Debug log info | ||
| # Layers: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| import torch | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add comments to this doc? Here is an example of what im looking for https://docs.pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/converter_overloading.html |
||
| import torch.nn as nn | ||
| import torch_tensorrt | ||
|
|
||
|
|
||
| class MixedPytorchAutocastModel(nn.Module): | ||
| def __init__(self): | ||
| super(MixedPytorchAutocastModel, self).__init__() | ||
| self.conv1 = nn.Conv2d( | ||
| in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 | ||
| ) | ||
| self.relu1 = nn.ReLU() | ||
| self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
| self.conv2 = nn.Conv2d( | ||
| in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 | ||
| ) | ||
| self.relu2 = nn.ReLU() | ||
| self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
| self.flatten = nn.Flatten() | ||
| self.fc1 = nn.Linear(16 * 8 * 8, 10) | ||
|
|
||
| def forward(self, x): | ||
| x = self.conv1(x) | ||
| x = self.relu1(x) | ||
| x = self.pool1(x) | ||
| x = self.conv2(x) | ||
| x = self.relu2(x) | ||
| x = self.pool2(x) | ||
| x = self.flatten(x) | ||
| with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): | ||
| x = self.fc1(x) | ||
| out = torch.log( | ||
| torch.abs(x) + 1 | ||
| ) # log is fp32 due to Pytorch Autocast requirements | ||
| return out | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know its not best practice but lets just make them pure scripts so they render better |
||
| model = MixedPytorchAutocastModel().cuda().eval() | ||
| inputs = (torch.randn((8, 3, 32, 32), dtype=torch.float32, device="cuda"),) | ||
| ep = torch.export.export(model, inputs) | ||
| calibration_dataloader = torch.utils.data.DataLoader( | ||
| torch.utils.data.TensorDataset(*inputs), batch_size=2, shuffle=False | ||
| ) | ||
|
|
||
| with torch_tensorrt.dynamo.Debugger( | ||
| "graphs", | ||
| logging_dir=".", | ||
| engine_builder_monitor=False, | ||
| ): | ||
| trt_autocast_mod = torch_tensorrt.compile( | ||
| ep.module(), | ||
| arg_inputs=inputs, | ||
| min_block_size=1, | ||
| use_python_runtime=True, | ||
| ##### weak typing ##### | ||
| # use_explicit_typing=False, | ||
| # enabled_precisions={torch.float16}, | ||
| ##### strong typing + autocast ##### | ||
| use_explicit_typing=True, | ||
| enable_autocast=True, | ||
| autocast_low_precision_type=torch.float16, | ||
| autocast_excluded_nodes={"^conv1$", "relu"}, | ||
| autocast_excluded_ops={"torch.ops.aten.flatten.using_ints"}, | ||
| autocast_max_output_threshold=512, | ||
| autocast_max_depth_of_reduction=None, | ||
| autocast_calibration_dataloader=calibration_dataloader, | ||
| ) | ||
|
|
||
| autocast_outs = trt_autocast_mod(*inputs) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to reorient around strong typing first and then weak typing as an optimization. Right now this is a bit confusing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So like in the tutorial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since TRT has deprecated weak typing, should we mention weak typing is deprecated so need to use autocast instead? Thus, we have only two modes: