Skip to content

Conversation

@alfieroddanintel
Copy link

@alfieroddanintel alfieroddanintel commented Nov 11, 2025

📝 Description

Implementation of DINOv2

The current Dinomaly implementation contains some custom layers like:

I have used the official implementation of DINOv2 and implemented it without xformers.

I have placed this in from anomalib.models.components.dinov2 import vision_transformer. It produces the exact same output as the original:

import torch
from anomalib.models.components.dinov2.dinov2_loader import DinoV2Loader
torch.manual_seed(0)

# Official model
official = torch.hub.load(
    "facebookresearch/dinov2",
    "dinov2_vits14",
    pretrained=True,
).eval()

mine = DinoV2Loader.from_name("dinov2_vit_small_14").eval()

# Input
x = torch.randn(1, 3, 224, 224)

# Forward (two outputs only)
with torch.no_grad():
    o1 = official(x)
    o2 = mine(x)

# Compare
diff = (o1 - o2).abs()
print("Max diff:", diff.max().item())
print("Mean diff:", diff.mean().item())
print(f"Same: {torch.allclose(o1, o1)}")

DINOv2 Loader

I refactored the DINOv2 loader to be a factory method to facilitate loading weights for DINOv2 for mulitple implementations. I planned to use it as following:

from anomalib.models.components.dinov2.dinov2_loader import DinoV2Loader

# load simply from string
dinov2_vit_small_14 = DinoV2Loader.from_name("dinov2_vit_small_14").eval()

# init without a factory method, defuault to official implementaiton
dinov2_vit_small_14 = DinoV2Loader().load(model_name="dinov2_vit_small_14").eval()

# use custom ViT implementation
from anomalib.models.image.dinomaly.components import vision_transformer as dinomaly_vision_transformer
dinomaly_vit_small_14 = DinoV2Loader(vit_factory=dinomaly_vision_transformer).load(model_name="dinov2_vit_small_14").eval()

Hopefully this is allowed with semgrep and bandit.

Implementation of AnomalyDINO.

AnomalyDINO uses DINOv2-Small to boost memory-bank models scores for few-shot models. The paper itself uses custom augmentations and masking depending on the category type. This is optional in my implementation.

To replicate Table 12. of the paper (full-shot with masking on MVTecAD) please run:

from anomalib.data import MVTecAD
from anomalib.engine import Engine
from anomalib.metrics import AUROC, Evaluator, F1Max
from anomalib.models import AnomalyDINO
from anomalib.post_processing import PostProcessor

MVTEC_CATEGORIES = [
    "hazelnut", "grid", "carpet", "bottle", "cable", "capsule", "leather",
    "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper"
]
MASKED_CATEGORIES = ["capsule", "hazelnut", "pill", "screw", "toothbrush"]

TEST_METRICS = [
    # image
    AUROC(fields=["pred_score", "gt_label"], prefix="image_"),
    F1Max(fields=["pred_score", "gt_label"], prefix="image_"),
    # pixel
    AUROC(fields=["anomaly_map", "gt_mask"], prefix="pixel_"),
    F1Max(fields=["anomaly_map", "gt_mask"], prefix="pixel_"),
]


for category in MVTEC_CATEGORIES:
    mask = category in MASKED_CATEGORIES
    print(f"\n--- Running category: {category} | masking={mask} ---")

    # Initialize data module
    # memory bank models do not do well with high eval_batch size
    datamodule = MVTecAD(category=category, eval_batch_size=1)

    # Initialize Evaluator
    evaluator = Evaluator(test_metrics=TEST_METRICS)

    # post processor
    post_processor = PostProcessor(enable_normalization=False, enable_thresholding=False, enable_threshold_matching=False)

    # Preprocessor
    preprocessor = AnomalyDINO.configure_pre_processor(image_size=(448))

    # Initialize model with or without masking
    model = AnomalyDINO(
        num_neighbours=1,
        encoder_name="dinov2_vit_small_14",
        masking=mask,
        coreset_subsampling=False,
        pre_processor=preprocessor,
        evaluator=evaluator,
    )
    engine = Engine()
    engine.fit(model=model, datamodule=datamodule)
    engine.test(model=model, datamodule=datamodule)

print("\n✅ All categories processed.")

📄 Reported Results from the Paper

Dataset Resolution Detection Segmentation
AUROC (%) F1-max (%) AP (%) AUROC (%) F1-max (%) PRO (%)
MVTec-AD 448 99.3 98.8 99.7 97.9 61.8 93.9
MVTec-AD 672 99.5 99.0 99.8 98.2 64.3 95.0
VisA 448 97.2 93.7 97.6 98.7 50.5 95.0

Our Implementation

We don't use Faiss.

📊 Category Metrics (%)

Category image_AUROC (%) image_F1Max (%) pixel_AUROC (%) pixel_F1Max (%)
hazelnut 99.9 99.3 99.6 77.6
grid 100.0 100.0 99.1 43.9
carpet 100.0 100.0 98.9 61.7
bottle 100.0 100.0 99.1 81.3
cable 98.4 96.7 97.5 64.1
capsule 96.9 98.2 99.1 55.4
leather 100.0 100.0 99.0 38.0
metal_nut 100.0 100.0 98.0 84.2
pill 99.5 98.9 98.1 71.4
screw 93.8 93.0 99.2 49.0
tile 100.0 100.0 94.9 61.7
toothbrush 100.0 100.0 99.3 64.3
transistor 99.5 97.5 95.2 55.6
wood 99.6 98.3 94.9 63.5
zipper 99.5 99.6 94.5 48.0

📈 Mean Metrics (%)

Metric Mean (%)
image_AUROC 98.9
image_F1Max 98.5
pixel_AUROC 97.5
pixel_F1Max 61.5

✨ Changes

Select what type of change your PR is:

  • 🚀 New feature (non-breaking change which adds functionality)
  • 🐞 Bug fix (non-breaking change which fixes an issue)
  • 🔄 Refactor (non-breaking change which refactors the code base)
  • ⚡ Performance improvements
  • 🎨 Style changes (code style/formatting)
  • 🧪 Tests (adding/modifying tests)
  • 📚 Documentation update
  • 📦 Build system changes
  • 🚧 CI/CD configuration
  • 🔧 Chore (general maintenance)
  • 🔒 Security update
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)

✅ Checklist

Before you submit your pull request, please make sure you have completed the following steps:

  • 📚 I have made the necessary updates to the documentation (if applicable).
  • 🧪 I have written tests that support my changes and prove that my fix is effective or my feature works (if applicable).
  • 🏷️ My PR title follows conventional commit format.

For more information about code review checklists, see the Code Review Checklist.

@alfieroddanintel alfieroddanintel changed the title feat/model/AnomalyDINO feat(model): add AnomalyDINO Nov 11, 2025
Copy link
Contributor

@samet-akcay samet-akcay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean! Thanks a lot for the contribution!

Only have super minor comments..

@alfieroddanintel alfieroddanintel force-pushed the feat/model/AnomalyDINO branch 2 times, most recently from f720362 to a608e5e Compare November 11, 2025 16:42
@alfieroddanintel alfieroddanintel marked this pull request as ready for review November 11, 2025 17:49
self._download_weights(model_type, architecture, patch_size)

# Using weights_only=True for safety mitigation (see Anomalib PR #2729)
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True) # nosec B614

Check failure

Code scanning / Semgrep OSS

Semgrep Finding: trailofbits.python.pickles-in-pytorch.pickles-in-pytorch Error

Functions reliant on pickle can result in arbitrary code execution. Consider loading from state_dict, using fickling, or switching to a safer serialization method like ONNX
Comment on lines +235 to +243
urlretrieve( # noqa: S310 # nosec B310
url=url,
filename=weight_path,
reporthook=progress_bar.update_to,
)

Check warning

Code scanning / Semgrep OSS

Semgrep Finding: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected Warning

Detected a dynamic value being used with urllib. urllib supports 'file://' schemes, so a dynamic value controlled by a malicious actor may allow them to read arbitrary files. Audit uses of urllib calls to ensure user data cannot control the URLs, or consider using the 'requests' library instead.
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
)

dpr = [drop_path_rate] * depth if drop_path_uniform else np.linspace(0, drop_path_rate, depth).tolist()

Check warning

Code scanning / Semgrep OSS

Semgrep Finding: trailofbits.python.numpy-in-pytorch-modules.numpy-in-pytorch-modules Warning

Usage of NumPy library inside PyTorch DinoVisionTransformer module was found. Avoid mixing these libraries for efficiency and proper ONNX loading
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
)

dpr = [drop_path_rate] * depth if drop_path_uniform else np.linspace(0, drop_path_rate, depth).tolist()

Check warning

Code scanning / Semgrep OSS

Semgrep Finding: trailofbits.python.numpy-in-pytorch-modules.numpy-in-pytorch-modules Warning

Usage of NumPy library inside PyTorch DinoVisionTransformer module was found. Avoid mixing these libraries for efficiency and proper ONNX loading
@alfieroddanintel alfieroddanintel marked this pull request as draft November 13, 2025 12:21
@alfieroddanintel alfieroddanintel changed the title feat(model): add AnomalyDINO feat(model): add DINOv2 official implementation and AnomalyDINO Nov 13, 2025
@alfieroddanintel alfieroddanintel force-pushed the feat/model/AnomalyDINO branch 2 times, most recently from 7936d17 to 3419086 Compare November 13, 2025 13:56
@alfieroddanintel alfieroddanintel marked this pull request as ready for review November 13, 2025 13:58
Copy link
Contributor

@samet-akcay samet-akcay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, some super minor comments

Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
…mprove comments

Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
…t to matmul, work with half tensors

Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
…r generating dinov2. update anomaly_dino to use factory method

Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
…rom dinomaly

Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
alfieroddanintel and others added 4 commits November 14, 2025 03:18
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
…tform#3108)

* Update xpu.py regarind PR open-edge-platform#3092

Added the name method to fix an issue related to a newly added feature in lightning 2.5.6

Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com>

* Update xpu.py

Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com>

* Update xpu.py

Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com>

* Update xpu.py with docstring

Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com>

* Update xpu.py with correct docstring

Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com>

* added name method for XPUAccelerator

Signed-off-by: waschsalz <niclas.zschach@icloud.com>

---------

Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com>
Signed-off-by: waschsalz <niclas.zschach@icloud.com>
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants