Skip to content

Commit a06bf07

Browse files
林旻佑林旻佑
authored andcommitted
Refactor: English docstring and safer channel heuristic in ensure_channel_first (refs #8366)
Signed-off-by: 林旻佑 <linminyou@linminyoudeMacBook-Air.local>
1 parent a81fce1 commit a06bf07

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

monai/inferers/utils.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,27 +59,37 @@ def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) ->
5959
Uses a small-channel heuristic (<=32) typical for segmentation/classification. When ambiguous,
6060
prefers preserving the input layout or raises ValueError to avoid silent errors.
6161
"""
62-
63-
62+
if not isinstance(x, torch.Tensor):
63+
raise TypeError(f"Expected torch.Tensor, got {type(x)}")
64+
if x.ndim < 3:
65+
raise ValueError(f"Expected >=3 dims (N,C,spatial...), got shape={tuple(x.shape)}")
66+
67+
# Infer spatial dims if not provided (handles 1D/2D/3D uniformly).
6468
if spatial_ndim is None:
65-
spatial_ndim = x.ndim - 2
69+
spatial_ndim = x.ndim - 2 # not directly used for logic; informative only
6670

67-
threshold = 32
68-
s1, sl = int(x.shape[1]), int(x.shape[-1])
71+
# Heuristic: channels are usually small (e.g., <=32) in segmentation/classification.
72+
threshold = 32
73+
s1 = int(x.shape[1]) # candidate channel at dim=1 (N, C, ...)
74+
sl = int(x.shape[-1]) # candidate channel at last dim (..., C)
6975

76+
# Unambiguous cases first.
7077
if s1 <= threshold and sl > threshold:
78+
# Looks like NCHW/D already.
7179
return x, 1
7280
if sl <= threshold and s1 > threshold:
81+
# Looks like NHWC/D: move last dim to channel dim.
7382
return x.movedim(-1, 1), -1
7483

84+
# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
7585
if s1 <= threshold and sl <= threshold:
7686
return x, 1
7787

7888
raise ValueError(
7989
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; "
8090
f"both dim1={s1} and dim-1={sl} look like spatial dims"
81-
)
82-
91+
)
92+
8393

8494
def sliding_window_inference(
8595
inputs: torch.Tensor | MetaTensor,

0 commit comments

Comments
 (0)