Skip to content

Commit a360570

Browse files
林旻佑林旻佑
authored andcommitted
Fix: update ensure_channel_first and DiceHelper channel-last handling (refs #8366)
Signed-off-by: 林旻佑 <linminyou@linminyoudeMacBook-Air.local>
1 parent 1080987 commit a360570

File tree

2 files changed

+40
-20
lines changed

2 files changed

+40
-20
lines changed

monai/inferers/utils.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,24 @@
3838

3939
__all__ = ["sliding_window_inference"]
4040

41-
def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) -> Tuple[torch.Tensor, int]:
41+
dfrom typing import Optional, Tuple
42+
import torch
43+
44+
def ensure_channel_first(
45+
x: torch.Tensor,
46+
spatial_ndim: Optional[int] = None,
47+
channel_hint: Optional[int] = None,
48+
threshold: int = 32,
49+
) -> Tuple[torch.Tensor, int]:
4250
"""
4351
Normalize a tensor to channel-first layout (N, C, spatial...).
4452
4553
Args:
4654
x: Tensor with shape (N, C, spatial...) or (N, spatial..., C).
4755
spatial_ndim: Number of spatial dimensions. If None, inferred as x.ndim - 2.
56+
channel_hint: If provided, the expected channel size (e.g., num_classes). When present,
57+
we prioritize matching this size at either dim=1 (channel-first) or dim=-1 (channel-last).
58+
threshold: Heuristic upper bound for typical channel counts to disambiguate layouts.
4859
4960
Returns:
5061
A tuple (x_cf, orig_channel_dim):
@@ -56,41 +67,45 @@ def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) ->
5667
ValueError: if x.ndim < 3 or the channel dimension cannot be inferred unambiguously.
5768
5869
Notes:
59-
Uses a small-channel heuristic (<=32) typical for segmentation/classification. When ambiguous,
60-
prefers preserving the input layout or raises ValueError to avoid silent errors.
70+
1. When channel_hint is provided, it is used as a strong signal to decide layout.
71+
2. Otherwise, uses a heuristic where channels are usually small (<= threshold).
72+
3. In ambiguous cases (both candidate dims small or both large), the input layout
73+
is preserved (assumed channel-first) to avoid silent mis-reordering.
6174
"""
6275
if not isinstance(x, torch.Tensor):
6376
raise TypeError(f"Expected torch.Tensor, got {type(x)}")
6477
if x.ndim < 3:
6578
raise ValueError(f"Expected >=3 dims (N,C,spatial...), got shape={tuple(x.shape)}")
6679

67-
# Infer spatial dims if not provided (handles 1D/2D/3D uniformly).
6880
if spatial_ndim is None:
69-
spatial_ndim = x.ndim - 2 # not directly used for logic; informative only
81+
spatial_ndim = x.ndim - 2 # informative only
7082

71-
# Heuristic: channels are usually small (e.g., <=32) in segmentation/classification.
72-
threshold = 32
73-
s1 = int(x.shape[1]) # candidate channel at dim=1 (N, C, ...)
74-
sl = int(x.shape[-1]) # candidate channel at last dim (..., C)
83+
s1 = int(x.shape[1]) # candidate channel at dim=1
84+
sl = int(x.shape[-1]) # candidate channel at dim=-1
7585

76-
# Unambiguous cases first.
86+
# 1) Strong signal: use channel_hint if provided
87+
if channel_hint is not None:
88+
if s1 == channel_hint and sl != channel_hint:
89+
return x, 1
90+
if sl == channel_hint and s1 != channel_hint:
91+
return x.movedim(-1, 1), -1
92+
# if both match or both mismatch, fall back to heuristic
93+
94+
# 2) Heuristic: channels are usually small
7795
if s1 <= threshold and sl > threshold:
78-
# Looks like NCHW/D already.
7996
return x, 1
8097
if sl <= threshold and s1 > threshold:
81-
# Looks like NHWC/D: move last dim to channel dim.
8298
return x.movedim(-1, 1), -1
8399

84-
# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
85-
if s1 <= threshold and sl <= threshold:
100+
# 3) Ambiguous: both sides small OR both sides large → preserve as channel-first
101+
if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold):
86102
return x, 1
87103

104+
# 4) Should not reach here under normal cases
88105
raise ValueError(
89-
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; "
90-
f"both dim1={s1} and dim-1={sl} look like spatial dims"
106+
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]"
91107
)
92108

93-
94109
def sliding_window_inference(
95110
inputs: torch.Tensor | MetaTensor,
96111
roi_size: Sequence[int] | int,

monai/metrics/meandice.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,16 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
309309
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
310310
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
311311
"""
312-
y_pred, _ = ensure_channel_first(y_pred)
313-
312+
# --- Normalize layout to channel-first (N, C, spatial...) ---
314313
n_ch = self.num_classes or y_pred.shape[1]
314+
315+
# Always normalize y_pred with hint
316+
y_pred, _ = ensure_channel_first(y_pred, channel_hint=n_ch)
317+
318+
# Normalize y if it looks like channel-last (last dim = 1 or n_ch)
315319
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
316-
y, _ = ensure_channel_first(y)
320+
y, _ = ensure_channel_first(y, channel_hint=n_ch)
321+
317322

318323

319324
_apply_argmax, _threshold = self.apply_argmax, self.threshold

0 commit comments

Comments
 (0)