Skip to content

Commit dc908f1

Browse files
add tests for vit and dinov2loader
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
1 parent 59f9d0f commit dc908f1

File tree

3 files changed

+293
-0
lines changed

3 files changed

+293
-0
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Tests for dinov2 implementation and loader."""
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Tests for DinoV2Loader."""
5+
6+
from __future__ import annotations
7+
8+
import re
9+
from unittest.mock import MagicMock, patch
10+
11+
import pytest
12+
import torch
13+
from torch import nn
14+
15+
from anomalib.models.components.dinov2.dinov2_loader import DinoV2Loader
16+
17+
18+
@pytest.fixture()
19+
def dummy_model() -> nn.Module:
20+
"""Return a simple dummy model used by fake constructors."""
21+
22+
class Dummy(nn.Module):
23+
def __init__(self) -> None:
24+
super().__init__()
25+
self.linear = nn.Linear(4, 4)
26+
27+
return Dummy()
28+
29+
30+
@pytest.fixture()
31+
def loader() -> DinoV2Loader:
32+
"""Return a loader instance with a non-functional cache path."""
33+
return DinoV2Loader(cache_dir="not_used_in_unit_tests")
34+
35+
36+
@pytest.mark.parametrize(
37+
("name", "expected"),
38+
[
39+
("dinov2_vit_base_14", ("dinov2", "base", 14)),
40+
("dinov2reg_vit_small_16", ("dinov2_reg", "small", 16)),
41+
("dinomaly_vit_large_14", ("dinomaly", "large", 14)),
42+
],
43+
)
44+
def test_parse_name_valid(
45+
loader: DinoV2Loader,
46+
name: str,
47+
expected: tuple[str, str, int],
48+
) -> None:
49+
"""Validate that supported model names parse correctly."""
50+
assert loader._parse_name(name) == expected # noqa: SLF001
51+
52+
53+
@pytest.mark.parametrize(
54+
("name", "expected"),
55+
[
56+
("foo_vit_base_14", "foo"),
57+
("x_vit_small_16", "x"),
58+
("wrongprefix_vit_large_14", "wrongprefix"),
59+
],
60+
)
61+
def test_parse_name_invalid_prefix(loader: DinoV2Loader, name: str, expected: str) -> None:
62+
"""Ensure invalid model prefixes raise ValueError."""
63+
msg = f"Unknown model type prefix '{expected}'."
64+
with pytest.raises(ValueError, match=msg):
65+
loader._parse_name(name) # noqa: SLF001
66+
67+
68+
def test_parse_name_invalid_architecture(loader: DinoV2Loader) -> None:
69+
"""Ensure unknown architecture names raise ValueError."""
70+
expected_msg = f"Invalid architecture 'tiny'. Expected one of: {list(loader.MODEL_CONFIGS)}"
71+
with pytest.raises(ValueError, match=re.escape(expected_msg)):
72+
loader._parse_name("dinov2_vit_tiny_14") # noqa: SLF001
73+
74+
75+
def test_create_model_success(loader: DinoV2Loader, dummy_model: nn.Module) -> None:
76+
"""Verify model creation succeeds when constructor exists."""
77+
fake_module = MagicMock()
78+
fake_module.vit_small = MagicMock(return_value=dummy_model)
79+
80+
loader.vit_factory = fake_module
81+
model = loader.create_model("dinov2", "small", 14)
82+
83+
fake_module.vit_small.assert_called_once()
84+
assert model is dummy_model
85+
86+
87+
def test_create_model_missing_constructor(loader: DinoV2Loader) -> None:
88+
"""Verify missing constructors cause ValueError."""
89+
loader.vit_factory = object()
90+
expected_msg = f"No constructor vit_base in module {loader.vit_factory}"
91+
with pytest.raises(ValueError, match=expected_msg):
92+
loader.create_model("dinov2", "base", 14)
93+
94+
95+
def test_get_weight_path_dinov2(loader: DinoV2Loader) -> None:
96+
"""Check generated weight filename for default dinov2 models."""
97+
path = loader._get_weight_path("dinov2", "base", 14) # noqa: SLF001
98+
assert path.name == "dinov2_vitb14_pretrain.pth"
99+
100+
101+
def test_get_weight_path_reg(loader: DinoV2Loader) -> None:
102+
"""Check generated weight filename for register-token models."""
103+
path = loader._get_weight_path("dinov2_reg", "large", 16) # noqa: SLF001
104+
assert path.name == "dinov2_vitl16_reg4_pretrain.pth"
105+
106+
107+
@patch("anomalib.models.components.dinov2.dinov2_loader.torch.load")
108+
@patch("anomalib.models.components.dinov2.dinov2_loader.DinoV2Loader._download_weights")
109+
def test_load_calls_weight_loading(
110+
mock_download: MagicMock,
111+
mock_torch_load: MagicMock,
112+
loader: DinoV2Loader,
113+
dummy_model: nn.Module,
114+
) -> None:
115+
"""Confirm load() uses existing weights without downloading."""
116+
fake_module = MagicMock()
117+
fake_module.vit_base = MagicMock(return_value=dummy_model)
118+
loader.vit_factory = fake_module
119+
120+
fake_path = MagicMock()
121+
fake_path.exists.return_value = True
122+
loader._get_weight_path = MagicMock(return_value=fake_path) # noqa: SLF001
123+
124+
mock_torch_load.return_value = {"layer": torch.zeros(1)}
125+
126+
loaded = loader.load("dinov2_vit_base_14")
127+
128+
fake_module.vit_base.assert_called_once()
129+
mock_download.assert_not_called()
130+
mock_torch_load.assert_called_once()
131+
assert loaded is dummy_model
132+
133+
134+
@patch("anomalib.models.components.dinov2.dinov2_loader.torch.load")
135+
@patch("anomalib.models.components.dinov2.dinov2_loader.DinoV2Loader._download_weights")
136+
def test_load_triggers_download_when_missing(
137+
mock_download: MagicMock,
138+
mock_torch_load: MagicMock,
139+
loader: DinoV2Loader,
140+
dummy_model: nn.Module,
141+
) -> None:
142+
"""Confirm load() triggers weight download when file is missing."""
143+
fake_module = MagicMock()
144+
fake_module.vit_small = MagicMock(return_value=dummy_model)
145+
loader.vit_factory = fake_module
146+
147+
fake_path = MagicMock()
148+
fake_path.exists.return_value = False
149+
loader._get_weight_path = MagicMock(return_value=fake_path) # noqa: SLF001
150+
151+
mock_torch_load.return_value = {"test": torch.zeros(1)}
152+
153+
loader.load("dinov2_vit_small_14")
154+
155+
mock_download.assert_called_once()
156+
mock_torch_load.assert_called_once()
157+
fake_module.vit_small.assert_called_once()
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Unit tests for DINOv2 ViT / Loader."""
5+
6+
from __future__ import annotations
7+
8+
import pytest
9+
import torch
10+
from torch import Tensor
11+
12+
from anomalib.models.components.dinov2.vision_transformer import (
13+
DinoVisionTransformer,
14+
vit_base,
15+
vit_large,
16+
vit_small,
17+
)
18+
19+
20+
@pytest.fixture()
21+
def tiny_vit() -> DinoVisionTransformer:
22+
"""Return a very small ViT model for unit testing."""
23+
return DinoVisionTransformer(
24+
img_size=32,
25+
patch_size=8,
26+
embed_dim=64,
27+
depth=2,
28+
num_heads=4,
29+
)
30+
31+
32+
@pytest.fixture()
33+
def tiny_input() -> Tensor:
34+
"""Return a small dummy input tensor."""
35+
return torch.randn(2, 3, 32, 32) # (B=2, C=3, H=W=32)
36+
37+
38+
def test_model_initializes(tiny_vit: DinoVisionTransformer) -> None:
39+
"""Model constructs and exposes expected attributes."""
40+
m: DinoVisionTransformer = tiny_vit
41+
42+
assert m.embed_dim == 64
43+
assert m.patch_size == 8
44+
assert m.n_blocks == 2
45+
assert hasattr(m, "patch_embed")
46+
assert hasattr(m, "cls_token")
47+
assert hasattr(m, "pos_embed")
48+
assert hasattr(m, "blocks")
49+
50+
51+
def test_patch_embedding_shape(
52+
tiny_vit: DinoVisionTransformer,
53+
tiny_input: Tensor,
54+
) -> None:
55+
"""Patch embedding output has correct (B, N, C) shape."""
56+
patches: Tensor = tiny_vit.patch_embed(tiny_input)
57+
b, n, c = patches.shape
58+
59+
assert b == 2
60+
assert n == 16 # 32x32 with patch_size=8 → 4x4 → 16 patches
61+
assert tiny_vit.embed_dim == c
62+
63+
64+
def test_prepare_tokens_output_shape(
65+
tiny_vit: DinoVisionTransformer,
66+
tiny_input: Tensor,
67+
) -> None:
68+
"""prepare_tokens_with_masks adds CLS and keeps correct embedding dims."""
69+
tokens: Tensor = tiny_vit.prepare_tokens_with_masks(tiny_input)
70+
71+
expected_tokens: int = 1 + tiny_vit.patch_embed.num_patches
72+
assert tokens.shape == (2, expected_tokens, tiny_vit.embed_dim)
73+
74+
75+
def test_forward_features_training_output_shapes(
76+
tiny_vit: DinoVisionTransformer,
77+
tiny_input: Tensor,
78+
) -> None:
79+
"""forward(is_training=True) returns a dict with expected shapes."""
80+
out: dict[str, Tensor | None] = tiny_vit(tiny_input, is_training=True) # type: ignore[assignment]
81+
82+
assert isinstance(out, dict)
83+
assert out["x_norm_clstoken"] is not None
84+
assert out["x_norm_patchtokens"] is not None
85+
86+
cls: Tensor = out["x_norm_clstoken"] # type: ignore[assignment]
87+
patches: Tensor = out["x_norm_patchtokens"] # type: ignore[assignment]
88+
89+
assert cls.shape == (2, tiny_vit.embed_dim)
90+
assert patches.shape[1] == tiny_vit.patch_embed.num_patches
91+
92+
93+
def test_forward_inference_output_shape(
94+
tiny_vit: DinoVisionTransformer,
95+
tiny_input: Tensor,
96+
) -> None:
97+
"""Inference mode returns class-token output only."""
98+
out: Tensor = tiny_vit(tiny_input) # default is is_training=False
99+
100+
assert isinstance(out, Tensor)
101+
assert out.shape == (2, tiny_vit.embed_dim)
102+
103+
104+
def test_get_intermediate_layers_shapes(
105+
tiny_vit: DinoVisionTransformer,
106+
tiny_input: Tensor,
107+
) -> None:
108+
"""Intermediate layer extraction returns tensors shaped (B, tokens, C)."""
109+
feats: tuple[Tensor, ...] = tiny_vit.get_intermediate_layers(
110+
tiny_input,
111+
n=1,
112+
)
113+
114+
assert len(feats) == 1
115+
116+
f: Tensor = feats[0]
117+
assert f.shape[0] == 2 # batch
118+
assert f.shape[2] == tiny_vit.embed_dim
119+
120+
121+
@pytest.mark.parametrize(
122+
"factory",
123+
[vit_small, vit_base, vit_large],
124+
)
125+
def test_vit_factories_create_models(factory) -> None: # noqa: ANN001
126+
"""vit_small/base/large should return valid models."""
127+
model: DinoVisionTransformer = factory()
128+
129+
assert isinstance(model, DinoVisionTransformer)
130+
assert model.embed_dim > 0
131+
assert model.n_blocks > 0
132+
assert model.num_heads > 0

0 commit comments

Comments
 (0)