5353from torch import nn
5454from torchvision .transforms .v2 import Compose , InterpolationMode , Normalize , Resize
5555
56- from anomalib import LearningType
56+ from anomalib import LearningType , PrecisionType
5757from anomalib .data import Batch
5858from anomalib .metrics import Evaluator
5959from anomalib .models .components import AnomalibModule , MemoryBankMixin
@@ -91,6 +91,9 @@ class AnomalyDINO(MemoryBankMixin, AnomalibModule):
9191 to reduce the size of the memory bank. Defaults to ``False``.
9292 sampling ratio(float, optional): If coreset subsampling, by what ratio
9393 should we subsample. Defaults to ``0.1``
94+ precision (str, optional): Precision type for model computations.
95+ Supported values are defined in :class:`PrecisionType`.
96+ Defaults to ``PrecisionType.FLOAT32``.
9497 pre_processor (PreProcessor | bool, optional): Pre-processor instance or
9598 bool flag to enable default preprocessing. Defaults to ``True``.
9699 post_processor (PostProcessor | bool, optional): Post-processor instance or
@@ -152,6 +155,7 @@ def __init__(
152155 masking : bool = False ,
153156 coreset_subsampling : bool = False ,
154157 sampling_ratio : float = 0.1 ,
158+ precision : str = PrecisionType .FLOAT32 ,
155159 pre_processor : nn .Module | bool = True ,
156160 post_processor : nn .Module | bool = True ,
157161 evaluator : Evaluator | bool = True ,
@@ -171,6 +175,15 @@ def __init__(
171175 sampling_ratio = sampling_ratio ,
172176 )
173177
178+ if precision == PrecisionType .FLOAT16 :
179+ self .model = self .model .half ()
180+ elif precision == PrecisionType .FLOAT32 :
181+ self .model = self .model .float ()
182+ else :
183+ msg = f"""Unsupported precision type: { precision } .
184+ Supported types are: { PrecisionType .FLOAT16 } , { PrecisionType .FLOAT32 } ."""
185+ raise ValueError (msg )
186+
174187 @classmethod
175188 def configure_pre_processor (
176189 cls ,
0 commit comments