Skip to content

Commit 86ab64e

Browse files
committed
added preserve_root param for Extension array
1 parent 7569baa commit 86ab64e

File tree

4 files changed

+41
-99
lines changed

4 files changed

+41
-99
lines changed

doc/source/reference/extensions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ objects.
5858
api.extensions.ExtensionArray.isin
5959
api.extensions.ExtensionArray.isna
6060
api.extensions.ExtensionArray.ravel
61+
api.extensions.ExtensionArray.map
6162
api.extensions.ExtensionArray.repeat
6263
api.extensions.ExtensionArray.searchsorted
6364
api.extensions.ExtensionArray.shift

pandas/core/arrays/arrow/array.py

Lines changed: 21 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import functools
4-
import math
54
import operator
65
import re
76
import textwrap
@@ -17,15 +16,13 @@
1716
import warnings
1817

1918
import numpy as np
20-
import pandas as pd
2119

2220
from pandas._libs import lib
2321
from pandas._libs.tslibs import (
2422
Timedelta,
2523
Timestamp,
2624
timezones,
2725
)
28-
from sqlalchemy import values
2926
from pandas.compat import (
3027
HAS_PYARROW,
3128
pa_version_under12p1,
@@ -404,26 +401,16 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
404401
if len(values) == 0:
405402
# Retain our dtype
406403
return self[:0].copy()
407-
if not isinstance(self.dtype, ArrowDtype):
408-
return super()._cast_pointwise_result(values)
409-
404+
410405
try:
411-
if (
412-
isinstance(values, (np.ndarray, list, tuple))
413-
and np.asarray(values).dtype == np.bool_
414-
):
415-
arr = pa.array(values, from_pandas=True)
416-
else:
417-
arr = pa.array(values, type=self._pa_array.type, from_pandas=True)
406+
arr = pa.array(values, from_pandas=True)
418407
except (ValueError, TypeError):
419408
# e.g. test_by_column_values_with_same_starting_value with nested
420409
# values, one entry of which is an ArrowStringArray
421410
# or test_agg_lambda_complex128_dtype_conversion for complex values
422411
return super()._cast_pointwise_result(values)
423-
if pa.types.is_duration(self._pa_array.type) and pa.types.is_floating(arr.type):
424-
# just return a numpy array / float Series
425-
return np.asarray(values, dtype="float64")
426-
elif pa.types.is_duration(arr.type):
412+
413+
if pa.types.is_duration(arr.type):
427414
# workaround for https://github.com/apache/arrow/issues/40620
428415
result = ArrowExtensionArray._from_sequence(values)
429416
if pa.types.is_duration(self._pa_array.type):
@@ -444,79 +431,28 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
444431
dtype = ArrowDtype(pa.duration("s"))
445432
result = result.astype(dtype) # type: ignore[assignment]
446433
return result
447-
448-
elif pa.types.is_timestamp(arr.type) and pa.types.is_timestamp(self.dtype.pyarrow_dtype):
449-
# Preserve the original array's timestamp unit (i.e. us/ns/...)
450-
original_unit = self.dtype.pyarrow_dtype.unit
451-
tz = arr.type.tz
452-
453-
# Only convert if units don't match
454-
if arr.type.unit != original_unit:
455-
target_pa_dtype = pa.timestamp(original_unit, tz=tz)
456-
arr = arr.cast(target_pa_dtype)
457-
458-
# Create ArrowExtensionArray with the processed array
459-
return self._from_pyarrow_array(arr)
460-
elif pa.types.is_floating(self._pa_array.type):
461-
try:
462-
if self._pa_array.type == pa.float32():
463-
coerced = [
464-
None if (v is None or v is pd.NA or (isinstance(v, float) and np.isnan(v)))
465-
else np.float32(v)
466-
for v in values
467-
]
468-
arr = pa.array(coerced, type=pa.float32(), from_pandas=True)
469-
else:
470-
arr = pa.array(values, type=self._pa_array.type, from_pandas=True)
471-
except pa.ArrowInvalid:
472-
arr = pa.array(values, from_pandas=True)
434+
473435
elif pa.types.is_date(arr.type) and pa.types.is_date(self._pa_array.type):
474436
arr = arr.cast(self._pa_array.type)
475437
elif pa.types.is_time(arr.type) and pa.types.is_time(self._pa_array.type):
476438
arr = arr.cast(self._pa_array.type)
477439
elif pa.types.is_decimal(arr.type) and pa.types.is_decimal(self._pa_array.type):
478440
arr = arr.cast(self._pa_array.type)
479-
elif is_numeric_dtype(self.dtype):
480-
if pa.types.is_integer(self._pa_array.type):
481-
try:
482-
# Handle the case where Python map gives floats (e.g., 1 → 1.0)
483-
floats = [v for v in values if isinstance(v, float) and v is not None]
484-
if floats and all(math.isfinite(v) and v.is_integer() for v in floats):
485-
arr = pa.array([int(v) if isinstance(v, float) else v for v in values],
486-
type=self._pa_array.type,
487-
from_pandas=True)
488-
return self._from_pyarrow_array(arr)
489-
490-
# Special handling for unsigned integers
491-
if pa.types.is_unsigned_integer(self._pa_array.type):
492-
if any((isinstance(v, (int, np.integer, float)) and v is not None and v < 0)
493-
for v in values):
494-
# Promote to signed int (wider type to hold negatives)
495-
signed_type = pa.int16() if self._pa_array.type == pa.uint8() else pa.int64()
496-
arr = pa.array(values, type=signed_type, from_pandas=True)
497-
else:
498-
arr = pa.array(values, type=self._pa_array.type, from_pandas=True)
499-
else:
500-
arr = pa.array(values, type=self._pa_array.type, from_pandas=True)
501-
except (pa.ArrowInvalid, OverflowError):
502-
arr = pa.array(values, from_pandas=True)
503-
504-
elif (
505-
(pa.types.is_integer(arr.type) and pa.types.is_integer(self._pa_array.type))
506-
or (pa.types.is_floating(arr.type) and pa.types.is_integer(self._pa_array.type))
507-
):
441+
elif pa.types.is_integer(arr.type) and pa.types.is_integer(self._pa_array.type):
508442
try:
509443
arr = arr.cast(self._pa_array.type)
510444
except pa.lib.ArrowInvalid:
511445
# e.g. test_combine_add if we can't cast
512446
pass
513-
elif pa.types.is_floating(arr.type) and pa.types.is_floating(self._pa_array.type):
447+
elif pa.types.is_floating(arr.type) and pa.types.is_floating(
448+
self._pa_array.type
449+
):
514450
try:
515451
arr = arr.cast(self._pa_array.type)
516452
except pa.lib.ArrowInvalid:
517453
# e.g. test_combine_add if we can't cast
518454
pass
519-
455+
520456
if isinstance(self.dtype, StringDtype):
521457
if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
522458
# ArrowStringArray preserves dtype.na_value
@@ -526,7 +462,6 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
526462
# result instead
527463
return super()._cast_pointwise_result(values)
528464
return ArrowExtensionArray(arr)
529-
530465
return self._from_pyarrow_array(arr)
531466

532467
@classmethod
@@ -794,28 +729,15 @@ def __getitem__(self, item: PositionalIndexer):
794729
return self._from_pyarrow_array(value)
795730
else:
796731
pa_type = self._pa_array.type
797-
# Special case: timestamp : avoid overflow
798-
if pa.types.is_timestamp(pa_type):
799-
return pd.Timestamp(value.as_py())
800-
# Special case: duration
801-
if pa.types.is_duration(pa_type):
802-
return pd.Timedelta(value.as_py())
803732
scalar = value.as_py()
804-
805733
if scalar is None:
806734
return self._dtype.na_value
807-
elif pa.types.is_timestamp(pa_type):
735+
elif pa.types.is_timestamp(pa_type) and pa_type.unit != "ns":
808736
# GH 53326
809-
ts = pd.Timestamp(scalar)
810-
if pa_type.unit != "ns":
811-
return ts.as_unit(pa_type.unit)
812-
return ts
813-
elif pa.types.is_duration(pa_type):
737+
return Timestamp(scalar).as_unit(pa_type.unit)
738+
elif pa.types.is_duration(pa_type) and pa_type.unit != "ns":
814739
# GH 53326
815-
td = pd.Timedelta(scalar)
816-
if pa_type.unit != "ns":
817-
return td.as_unit(pa_type.unit)
818-
return td
740+
return Timedelta(scalar).as_unit(pa_type.unit)
819741
else:
820742
return scalar
821743

@@ -1670,12 +1592,16 @@ def to_numpy(
16701592
result[~mask] = data[~mask]._pa_array.to_numpy()
16711593
return result
16721594

1673-
def map(self, mapper, na_action: Literal["ignore"] | None = None):
1595+
def map(self, mapper,
1596+
na_action: Literal["ignore"] | None = None,
1597+
preserve_dtype: bool = False):
16741598
if is_numeric_dtype(self.dtype):
16751599
result = map_array(self.to_numpy(), mapper, na_action=na_action)
1676-
return self._cast_pointwise_result(result)
1600+
if preserve_dtype:
1601+
result = self._cast_pointwise_result(result)
1602+
return result
16771603
else:
1678-
return super().map(mapper, na_action)
1604+
return super().map(mapper, na_action, preserve_dtype=preserve_dtype)
16791605

16801606
@doc(ExtensionArray.duplicated)
16811607
def duplicated(

pandas/core/arrays/base.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,7 +2510,9 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
25102510

25112511
return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs)
25122512

2513-
def map(self, mapper, na_action: Literal["ignore"] | None = None):
2513+
def map(self, mapper,
2514+
na_action: Literal["ignore"] | None = None,
2515+
preserve_dtype: bool = False):
25142516
"""
25152517
Map values using an input mapping or function.
25162518
@@ -2522,6 +2524,12 @@ def map(self, mapper, na_action: Literal["ignore"] | None = None):
25222524
If 'ignore', propagate NA values, without passing them to the
25232525
mapping correspondence. If 'ignore' is not supported, a
25242526
``NotImplementedError`` should be raised.
2527+
preserve_dtype : bool, default False
2528+
If True, attempt to cast the elementwise result back to the
2529+
original ExtensionArray type (and dtype) when possible. This is
2530+
primarily intended for identity or dtype-preserving mappings.
2531+
If False, the result of the mapping is returned as produced by
2532+
the underlying implementation (typically a NumPy ndarray).
25252533
25262534
Returns
25272535
-------
@@ -2531,7 +2539,9 @@ def map(self, mapper, na_action: Literal["ignore"] | None = None):
25312539
a MultiIndex will be returned.
25322540
"""
25332541
results = map_array(self, mapper, na_action=na_action)
2534-
return self._cast_pointwise_result(results)
2542+
if preserve_dtype:
2543+
results = self._cast_pointwise_result(results)
2544+
return results
25352545

25362546
# ------------------------------------------------------------------------
25372547
# GroupBy Methods

pandas/core/arrays/masked.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,9 +1394,14 @@ def max(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs):
13941394
)
13951395
return self._wrap_reduction_result("max", result, skipna=skipna, axis=axis)
13961396

1397-
def map(self, mapper, na_action: Literal["ignore"] | None = None):
1397+
def map(self, mapper,
1398+
na_action: Literal["ignore"] | None = None,
1399+
preserve_dtype: bool = False):
1400+
"""See ExtensionArray.map."""
13981401
result = map_array(self.to_numpy(), mapper, na_action=na_action)
1399-
return self._cast_pointwise_result(result)
1402+
if preserve_dtype:
1403+
result = self._cast_pointwise_result(result)
1404+
return result
14001405

14011406
@overload
14021407
def any(

0 commit comments

Comments
 (0)