diff --git a/tensornetwork/backends/numpy/numpy_backend.py b/tensornetwork/backends/numpy/numpy_backend.py index 2d3753911..9f1a9220a 100644 --- a/tensornetwork/backends/numpy/numpy_backend.py +++ b/tensornetwork/backends/numpy/numpy_backend.py @@ -604,7 +604,9 @@ def sum(self, tensor: Tensor, axis: Optional[Sequence[int]] = None, keepdims: bool = False) -> Tensor: - return np.sum(tensor, axis=tuple(axis), keepdims=keepdims) + if axis is not None and type(axis) == list: + return np.sum(tensor, axis=tuple(axis), keepdims=keepdims) + return np.sum(tensor, axis=axis, keepdims=keepdims) def matmul(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: if (tensor1.ndim <= 1) or (tensor2.ndim <= 1):