|
| 1 | +# Copyright (C) 2025 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Tests for DinoV2Loader.""" |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import re |
| 9 | +from unittest.mock import MagicMock, patch |
| 10 | + |
| 11 | +import pytest |
| 12 | +import torch |
| 13 | +from torch import nn |
| 14 | + |
| 15 | +from anomalib.models.components.dinov2.dinov2_loader import DinoV2Loader |
| 16 | + |
| 17 | + |
| 18 | +@pytest.fixture() |
| 19 | +def dummy_model() -> nn.Module: |
| 20 | + """Return a simple dummy model used by fake constructors.""" |
| 21 | + |
| 22 | + class Dummy(nn.Module): |
| 23 | + def __init__(self) -> None: |
| 24 | + super().__init__() |
| 25 | + self.linear = nn.Linear(4, 4) |
| 26 | + |
| 27 | + return Dummy() |
| 28 | + |
| 29 | + |
| 30 | +@pytest.fixture() |
| 31 | +def loader() -> DinoV2Loader: |
| 32 | + """Return a loader instance with a non-functional cache path.""" |
| 33 | + return DinoV2Loader(cache_dir="not_used_in_unit_tests") |
| 34 | + |
| 35 | + |
| 36 | +@pytest.mark.parametrize( |
| 37 | + ("name", "expected"), |
| 38 | + [ |
| 39 | + ("dinov2_vit_base_14", ("dinov2", "base", 14)), |
| 40 | + ("dinov2reg_vit_small_16", ("dinov2_reg", "small", 16)), |
| 41 | + ("dinomaly_vit_large_14", ("dinomaly", "large", 14)), |
| 42 | + ], |
| 43 | +) |
| 44 | +def test_parse_name_valid( |
| 45 | + loader: DinoV2Loader, |
| 46 | + name: str, |
| 47 | + expected: tuple[str, str, int], |
| 48 | +) -> None: |
| 49 | + """Validate that supported model names parse correctly.""" |
| 50 | + assert loader._parse_name(name) == expected # noqa: SLF001 |
| 51 | + |
| 52 | + |
| 53 | +@pytest.mark.parametrize( |
| 54 | + ("name", "expected"), |
| 55 | + [ |
| 56 | + ("foo_vit_base_14", "foo"), |
| 57 | + ("x_vit_small_16", "x"), |
| 58 | + ("wrongprefix_vit_large_14", "wrongprefix"), |
| 59 | + ], |
| 60 | +) |
| 61 | +def test_parse_name_invalid_prefix(loader: DinoV2Loader, name: str, expected: str) -> None: |
| 62 | + """Ensure invalid model prefixes raise ValueError.""" |
| 63 | + msg = f"Unknown model type prefix '{expected}'." |
| 64 | + with pytest.raises(ValueError, match=msg): |
| 65 | + loader._parse_name(name) # noqa: SLF001 |
| 66 | + |
| 67 | + |
| 68 | +def test_parse_name_invalid_architecture(loader: DinoV2Loader) -> None: |
| 69 | + """Ensure unknown architecture names raise ValueError.""" |
| 70 | + expected_msg = f"Invalid architecture 'tiny'. Expected one of: {list(loader.MODEL_CONFIGS)}" |
| 71 | + with pytest.raises(ValueError, match=re.escape(expected_msg)): |
| 72 | + loader._parse_name("dinov2_vit_tiny_14") # noqa: SLF001 |
| 73 | + |
| 74 | + |
| 75 | +def test_create_model_success(loader: DinoV2Loader, dummy_model: nn.Module) -> None: |
| 76 | + """Verify model creation succeeds when constructor exists.""" |
| 77 | + fake_module = MagicMock() |
| 78 | + fake_module.vit_small = MagicMock(return_value=dummy_model) |
| 79 | + |
| 80 | + loader.vit_factory = fake_module |
| 81 | + model = loader.create_model("dinov2", "small", 14) |
| 82 | + |
| 83 | + fake_module.vit_small.assert_called_once() |
| 84 | + assert model is dummy_model |
| 85 | + |
| 86 | + |
| 87 | +def test_create_model_missing_constructor(loader: DinoV2Loader) -> None: |
| 88 | + """Verify missing constructors cause ValueError.""" |
| 89 | + loader.vit_factory = object() |
| 90 | + expected_msg = f"No constructor vit_base in module {loader.vit_factory}" |
| 91 | + with pytest.raises(ValueError, match=expected_msg): |
| 92 | + loader.create_model("dinov2", "base", 14) |
| 93 | + |
| 94 | + |
| 95 | +def test_get_weight_path_dinov2(loader: DinoV2Loader) -> None: |
| 96 | + """Check generated weight filename for default dinov2 models.""" |
| 97 | + path = loader._get_weight_path("dinov2", "base", 14) # noqa: SLF001 |
| 98 | + assert path.name == "dinov2_vitb14_pretrain.pth" |
| 99 | + |
| 100 | + |
| 101 | +def test_get_weight_path_reg(loader: DinoV2Loader) -> None: |
| 102 | + """Check generated weight filename for register-token models.""" |
| 103 | + path = loader._get_weight_path("dinov2_reg", "large", 16) # noqa: SLF001 |
| 104 | + assert path.name == "dinov2_vitl16_reg4_pretrain.pth" |
| 105 | + |
| 106 | + |
| 107 | +@patch("anomalib.models.components.dinov2.dinov2_loader.torch.load") |
| 108 | +@patch("anomalib.models.components.dinov2.dinov2_loader.DinoV2Loader._download_weights") |
| 109 | +def test_load_calls_weight_loading( |
| 110 | + mock_download: MagicMock, |
| 111 | + mock_torch_load: MagicMock, |
| 112 | + loader: DinoV2Loader, |
| 113 | + dummy_model: nn.Module, |
| 114 | +) -> None: |
| 115 | + """Confirm load() uses existing weights without downloading.""" |
| 116 | + fake_module = MagicMock() |
| 117 | + fake_module.vit_base = MagicMock(return_value=dummy_model) |
| 118 | + loader.vit_factory = fake_module |
| 119 | + |
| 120 | + fake_path = MagicMock() |
| 121 | + fake_path.exists.return_value = True |
| 122 | + loader._get_weight_path = MagicMock(return_value=fake_path) # noqa: SLF001 |
| 123 | + |
| 124 | + mock_torch_load.return_value = {"layer": torch.zeros(1)} |
| 125 | + |
| 126 | + loaded = loader.load("dinov2_vit_base_14") |
| 127 | + |
| 128 | + fake_module.vit_base.assert_called_once() |
| 129 | + mock_download.assert_not_called() |
| 130 | + mock_torch_load.assert_called_once() |
| 131 | + assert loaded is dummy_model |
| 132 | + |
| 133 | + |
| 134 | +@patch("anomalib.models.components.dinov2.dinov2_loader.torch.load") |
| 135 | +@patch("anomalib.models.components.dinov2.dinov2_loader.DinoV2Loader._download_weights") |
| 136 | +def test_load_triggers_download_when_missing( |
| 137 | + mock_download: MagicMock, |
| 138 | + mock_torch_load: MagicMock, |
| 139 | + loader: DinoV2Loader, |
| 140 | + dummy_model: nn.Module, |
| 141 | +) -> None: |
| 142 | + """Confirm load() triggers weight download when file is missing.""" |
| 143 | + fake_module = MagicMock() |
| 144 | + fake_module.vit_small = MagicMock(return_value=dummy_model) |
| 145 | + loader.vit_factory = fake_module |
| 146 | + |
| 147 | + fake_path = MagicMock() |
| 148 | + fake_path.exists.return_value = False |
| 149 | + loader._get_weight_path = MagicMock(return_value=fake_path) # noqa: SLF001 |
| 150 | + |
| 151 | + mock_torch_load.return_value = {"test": torch.zeros(1)} |
| 152 | + |
| 153 | + loader.load("dinov2_vit_small_14") |
| 154 | + |
| 155 | + mock_download.assert_called_once() |
| 156 | + mock_torch_load.assert_called_once() |
| 157 | + fake_module.vit_small.assert_called_once() |
0 commit comments