Skip to content

Commit 93df141

Browse files
add precision modifier
1 parent 6d4ce00 commit 93df141

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

src/anomalib/models/image/anomaly_dino/lightning_model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from torch import nn
5454
from torchvision.transforms.v2 import Compose, InterpolationMode, Normalize, Resize
5555

56-
from anomalib import LearningType
56+
from anomalib import LearningType, PrecisionType
5757
from anomalib.data import Batch
5858
from anomalib.metrics import Evaluator
5959
from anomalib.models.components import AnomalibModule, MemoryBankMixin
@@ -91,6 +91,9 @@ class AnomalyDINO(MemoryBankMixin, AnomalibModule):
9191
to reduce the size of the memory bank. Defaults to ``False``.
9292
sampling ratio(float, optional): If coreset subsampling, by what ratio
9393
should we subsample. Defaults to ``0.1``
94+
precision (str, optional): Precision type for model computations.
95+
Supported values are defined in :class:`PrecisionType`.
96+
Defaults to ``PrecisionType.FLOAT32``.
9497
pre_processor (PreProcessor | bool, optional): Pre-processor instance or
9598
bool flag to enable default preprocessing. Defaults to ``True``.
9699
post_processor (PostProcessor | bool, optional): Post-processor instance or
@@ -152,6 +155,7 @@ def __init__(
152155
masking: bool = False,
153156
coreset_subsampling: bool = False,
154157
sampling_ratio: float = 0.1,
158+
precision: str = PrecisionType.FLOAT32,
155159
pre_processor: nn.Module | bool = True,
156160
post_processor: nn.Module | bool = True,
157161
evaluator: Evaluator | bool = True,
@@ -171,6 +175,15 @@ def __init__(
171175
sampling_ratio=sampling_ratio,
172176
)
173177

178+
if precision == PrecisionType.FLOAT16:
179+
self.model = self.model.half()
180+
elif precision == PrecisionType.FLOAT32:
181+
self.model = self.model.float()
182+
else:
183+
msg = f"""Unsupported precision type: {precision}.
184+
Supported types are: {PrecisionType.FLOAT16}, {PrecisionType.FLOAT32}."""
185+
raise ValueError(msg)
186+
174187
@classmethod
175188
def configure_pre_processor(
176189
cls,

src/anomalib/models/image/anomaly_dino/torch_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch:
220220
* ``pred_score``: Image-level anomaly score ``(B, 1)``
221221
* ``anomaly_map``: Pixel-level anomaly heatmap ``(B, 1, H, W)``
222222
"""
223+
# set precicion
224+
input_tensor = input_tensor.type(self.memory_bank.dtype)
225+
226+
# work out sizing
223227
b, _, w, h = input_tensor.shape
224228
cropped_width = w - w % self.feature_encoder.patch_size
225229
cropped_height = h - h % self.feature_encoder.patch_size

0 commit comments

Comments
 (0)