3838
3939__all__ = ["sliding_window_inference" ]
4040
41- def ensure_channel_first (x : torch .Tensor , spatial_ndim : Optional [int ] = None ) -> Tuple [torch .Tensor , int ]:
41+ dfrom typing import Optional , Tuple
42+ import torch
43+
44+ def ensure_channel_first (
45+ x : torch .Tensor ,
46+ spatial_ndim : Optional [int ] = None ,
47+ channel_hint : Optional [int ] = None ,
48+ threshold : int = 32 ,
49+ ) -> Tuple [torch .Tensor , int ]:
4250 """
4351 Normalize a tensor to channel-first layout (N, C, spatial...).
4452
4553 Args:
4654 x: Tensor with shape (N, C, spatial...) or (N, spatial..., C).
4755 spatial_ndim: Number of spatial dimensions. If None, inferred as x.ndim - 2.
56+ channel_hint: If provided, the expected channel size (e.g., num_classes). When present,
57+ we prioritize matching this size at either dim=1 (channel-first) or dim=-1 (channel-last).
58+ threshold: Heuristic upper bound for typical channel counts to disambiguate layouts.
4859
4960 Returns:
5061 A tuple (x_cf, orig_channel_dim):
@@ -56,41 +67,45 @@ def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) ->
5667 ValueError: if x.ndim < 3 or the channel dimension cannot be inferred unambiguously.
5768
5869 Notes:
59- Uses a small-channel heuristic (<=32) typical for segmentation/classification. When ambiguous,
60- prefers preserving the input layout or raises ValueError to avoid silent errors.
70+ 1. When channel_hint is provided, it is used as a strong signal to decide layout.
71+ 2. Otherwise, uses a heuristic where channels are usually small (<= threshold).
72+ 3. In ambiguous cases (both candidate dims small or both large), the input layout
73+ is preserved (assumed channel-first) to avoid silent mis-reordering.
6174 """
6275 if not isinstance (x , torch .Tensor ):
6376 raise TypeError (f"Expected torch.Tensor, got { type (x )} " )
6477 if x .ndim < 3 :
6578 raise ValueError (f"Expected >=3 dims (N,C,spatial...), got shape={ tuple (x .shape )} " )
6679
67- # Infer spatial dims if not provided (handles 1D/2D/3D uniformly).
6880 if spatial_ndim is None :
69- spatial_ndim = x .ndim - 2 # not directly used for logic; informative only
81+ spatial_ndim = x .ndim - 2 # informative only
7082
71- # Heuristic: channels are usually small (e.g., <=32) in segmentation/classification.
72- threshold = 32
73- s1 = int (x .shape [1 ]) # candidate channel at dim=1 (N, C, ...)
74- sl = int (x .shape [- 1 ]) # candidate channel at last dim (..., C)
83+ s1 = int (x .shape [1 ]) # candidate channel at dim=1
84+ sl = int (x .shape [- 1 ]) # candidate channel at dim=-1
7585
76- # Unambiguous cases first.
86+ # 1) Strong signal: use channel_hint if provided
87+ if channel_hint is not None :
88+ if s1 == channel_hint and sl != channel_hint :
89+ return x , 1
90+ if sl == channel_hint and s1 != channel_hint :
91+ return x .movedim (- 1 , 1 ), - 1
92+ # if both match or both mismatch, fall back to heuristic
93+
94+ # 2) Heuristic: channels are usually small
7795 if s1 <= threshold and sl > threshold :
78- # Looks like NCHW/D already.
7996 return x , 1
8097 if sl <= threshold and s1 > threshold :
81- # Looks like NHWC/D: move last dim to channel dim.
8298 return x .movedim (- 1 , 1 ), - 1
8399
84- # Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
85- if s1 <= threshold and sl <= threshold :
100+ # 3) Ambiguous: both sides small OR both sides large → preserve as channel-first
101+ if ( s1 <= threshold and sl <= threshold ) or ( s1 > threshold and sl > threshold ) :
86102 return x , 1
87103
104+ # 4) Should not reach here under normal cases
88105 raise ValueError (
89- f"cannot infer channel dim for shape={ tuple (x .shape )} ; expected [N,C,spatial...] or [N,spatial...,C]; "
90- f"both dim1={ s1 } and dim-1={ sl } look like spatial dims"
106+ f"cannot infer channel dim for shape={ tuple (x .shape )} ; expected [N,C,spatial...] or [N,spatial...,C]"
91107 )
92108
93-
94109def sliding_window_inference (
95110 inputs : torch .Tensor | MetaTensor ,
96111 roi_size : Sequence [int ] | int ,
0 commit comments