@@ -59,27 +59,37 @@ def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) ->
5959 Uses a small-channel heuristic (<=32) typical for segmentation/classification. When ambiguous,
6060 prefers preserving the input layout or raises ValueError to avoid silent errors.
6161 """
62-
63-
62+ if not isinstance (x , torch .Tensor ):
63+ raise TypeError (f"Expected torch.Tensor, got { type (x )} " )
64+ if x .ndim < 3 :
65+ raise ValueError (f"Expected >=3 dims (N,C,spatial...), got shape={ tuple (x .shape )} " )
66+
67+ # Infer spatial dims if not provided (handles 1D/2D/3D uniformly).
6468 if spatial_ndim is None :
65- spatial_ndim = x .ndim - 2
69+ spatial_ndim = x .ndim - 2 # not directly used for logic; informative only
6670
67- threshold = 32
68- s1 , sl = int (x .shape [1 ]), int (x .shape [- 1 ])
71+ # Heuristic: channels are usually small (e.g., <=32) in segmentation/classification.
72+ threshold = 32
73+ s1 = int (x .shape [1 ]) # candidate channel at dim=1 (N, C, ...)
74+ sl = int (x .shape [- 1 ]) # candidate channel at last dim (..., C)
6975
76+ # Unambiguous cases first.
7077 if s1 <= threshold and sl > threshold :
78+ # Looks like NCHW/D already.
7179 return x , 1
7280 if sl <= threshold and s1 > threshold :
81+ # Looks like NHWC/D: move last dim to channel dim.
7382 return x .movedim (- 1 , 1 ), - 1
7483
84+ # Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
7585 if s1 <= threshold and sl <= threshold :
7686 return x , 1
7787
7888 raise ValueError (
7989 f"cannot infer channel dim for shape={ tuple (x .shape )} ; expected [N,C,spatial...] or [N,spatial...,C]; "
8090 f"both dim1={ s1 } and dim-1={ sl } look like spatial dims"
81- )
82-
91+ )
92+
8393
8494def sliding_window_inference (
8595 inputs : torch .Tensor | MetaTensor ,
0 commit comments