@@ -220,7 +220,6 @@ def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool
220220 return torch .clone (x )
221221 return torch .amin (x , axis , keepdims = keepdims )
222222
223- clip = get_xp (torch )(_aliases .clip )
224223unstack = get_xp (torch )(_aliases .unstack )
225224cumulative_sum = get_xp (torch )(_aliases .cumulative_sum )
226225cumulative_prod = get_xp (torch )(_aliases .cumulative_prod )
@@ -808,6 +807,38 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
808807 return torch .take_along_dim (x , indices , dim = axis )
809808
810809
810+ def clip (
811+ x : Array ,
812+ / ,
813+ min : int | float | Array | None = None ,
814+ max : int | float | Array | None = None ,
815+ ** kwargs
816+ ) -> Array :
817+ def _isscalar (a : object ):
818+ return isinstance (a , int | float ) or a is None
819+
820+ # cf clip in common/_aliases.py
821+ if not x .is_floating_point ():
822+ if type (min ) is int and min <= torch .iinfo (x .dtype ).min :
823+ min = None
824+ if type (max ) is int and max >= torch .iinfo (x .dtype ).max :
825+ max = None
826+
827+ if min is None and max is None :
828+ return torch .clone (x )
829+
830+ min_is_scalar = _isscalar (min )
831+ max_is_scalar = _isscalar (max )
832+
833+ if min is not None and max is not None :
834+ if min_is_scalar and not max_is_scalar :
835+ min = torch .as_tensor (min , dtype = x .dtype , device = x .device )
836+ if max_is_scalar and not min_is_scalar :
837+ max = torch .as_tensor (max , dtype = x .dtype , device = x .device )
838+
839+ return torch .clamp (x , min , max , ** kwargs )
840+
841+
811842def sign (x : Array , / ) -> Array :
812843 # torch sign() does not support complex numbers and does not propagate
813844 # nans. See https://github.com/data-apis/array-api-compat/issues/136
0 commit comments