diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 752d08a526d8c..53ceb1f92ce2d 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -156,6 +156,8 @@ All warnings for upcoming changes in pandas will have the base class :class:`pan Other enhancements ^^^^^^^^^^^^^^^^^^ +- :class:`pandas.NamedAgg` now supports passing ``*args`` and ``**kwargs`` + to calls of ``aggfunc`` (:issue:`58283`) - :func:`pandas.merge` propagates the ``attrs`` attribute to the result if all inputs have identical ``attrs``, as has so far already been the case for :func:`pandas.concat`. diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 468f24a07cb4a..1098ceb4c3929 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -1745,7 +1745,13 @@ def reconstruct_func( >>> reconstruct_func("min") (False, 'min', None, None) """ - relabeling = func is None and is_multi_agg_with_relabel(**kwargs) + from pandas.core.groupby.generic import NamedAgg + + relabeling = func is None and ( + is_multi_agg_with_relabel(**kwargs) + or any(isinstance(v, NamedAgg) for v in kwargs.values()) + ) + columns: tuple[str, ...] | None = None order: npt.NDArray[np.intp] | None = None @@ -1766,9 +1772,22 @@ def reconstruct_func( # "Callable[..., Any] | str | list[Callable[..., Any] | str] | # MutableMapping[Hashable, Callable[..., Any] | str | list[Callable[..., Any] | # str]] | None") + converted_kwargs = {} + for key, val in kwargs.items(): + if isinstance(val, NamedAgg): + aggfunc = val.aggfunc + if val.args or val.kwargs: + aggfunc = lambda x, func=aggfunc, a=val.args, kw=val.kwargs: func( + x, *a, **kw + ) + converted_kwargs[key] = (val.column, aggfunc) + else: + converted_kwargs[key] = val + func, columns, order = normalize_keyword_aggregation( # type: ignore[assignment] - kwargs + converted_kwargs ) + assert func is not None return relabeling, func, columns, order diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index d279594617235..512f82c495007 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -10,13 +10,13 @@ from collections import abc from collections.abc import Callable +import dataclasses from functools import partial from textwrap import dedent from typing import ( TYPE_CHECKING, Any, Literal, - NamedTuple, TypeAlias, TypeVar, cast, @@ -113,12 +113,11 @@ @set_module("pandas") -class NamedAgg(NamedTuple): +@dataclasses.dataclass +class NamedAgg: """ Helper for column specific aggregation with control over output column names. - Subclass of typing.NamedTuple. - Parameters ---------- column : Hashable @@ -126,6 +125,8 @@ class NamedAgg(NamedTuple): aggfunc : function or str Function to apply to the provided column. If string, the name of a built-in pandas function. + *args, **kwargs : Any + Optional positional and keyword arguments passed to ``aggfunc``. See Also -------- @@ -137,14 +138,57 @@ class NamedAgg(NamedTuple): >>> agg_a = pd.NamedAgg(column="a", aggfunc="min") >>> agg_1 = pd.NamedAgg(column=1, aggfunc=lambda x: np.mean(x)) >>> df.groupby("key").agg(result_a=agg_a, result_1=agg_1) - result_a result_1 + result_a result_1 key 1 -1 10.5 2 1 12.0 + + >>> def n_between(ser, low, high, **kwargs): + ... return ser.between(low, high, **kwargs).sum() + + >>> agg_between = pd.NamedAgg("a", n_between, 0, 1) + >>> df.groupby("key").agg(count_between=agg_between) + count_between + key + 1 1 + 2 1 + + >>> agg_between_kw = pd.NamedAgg("a", n_between, 0, 1, inclusive="both") + >>> df.groupby("key").agg(count_between_kw=agg_between_kw) + count_between_kw + key + 1 1 + 2 1 """ column: Hashable aggfunc: AggScalar + args: tuple[Any, ...] = () + kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) + + def __init__( + self, + column: Hashable, + aggfunc: Callable[..., Any] | str, + *args: Any, + **kwargs: Any, + ) -> None: + self.column = column + self.aggfunc = aggfunc + self.args = args + self.kwargs = kwargs + + def __getitem__(self, key: int) -> Any: + """Provide backward-compatible tuple-style access.""" + if key == 0: + return self.column + elif key == 1: + return self.aggfunc + elif key == 2: + return self.args + elif key == 3: + return self.kwargs + raise IndexError("index out of range") @set_module("pandas.api.typing") diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index c968587c469d1..2dc4911459989 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -866,6 +866,64 @@ def test_agg_namedtuple(self): expected = df.groupby("A").agg(b=("B", "sum"), c=("B", "count")) tm.assert_frame_equal(result, expected) + def n_between(self, ser, low, high, **kwargs): + return ser.between(low, high, **kwargs).sum() + + def test_namedagg_args(self): + # https://github.com/pandas-dev/pandas/issues/58283 + df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]}) + + result = df.groupby("A").agg( + count_between=pd.NamedAgg("B", self.n_between, 0, 1) + ) + expected = DataFrame({"count_between": [1, 1]}, index=Index([0, 1], name="A")) + tm.assert_frame_equal(result, expected) + + def test_namedagg_kwargs(self): + # https://github.com/pandas-dev/pandas/issues/58283 + df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]}) + + result = df.groupby("A").agg( + count_between_kw=pd.NamedAgg("B", self.n_between, 0, 1, inclusive="both") + ) + expected = DataFrame( + {"count_between_kw": [1, 1]}, index=Index([0, 1], name="A") + ) + tm.assert_frame_equal(result, expected) + + def test_namedagg_args_and_kwargs(self): + # https://github.com/pandas-dev/pandas/issues/58283 + df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]}) + + result = df.groupby("A").agg( + count_between_mix=pd.NamedAgg( + "B", self.n_between, 0, 1, inclusive="neither" + ) + ) + expected = DataFrame( + {"count_between_mix": [0, 0]}, index=Index([0, 1], name="A") + ) + tm.assert_frame_equal(result, expected) + + def test_multiple_named_agg_with_args_and_kwargs(self): + # https://github.com/pandas-dev/pandas/issues/58283 + df = DataFrame({"A": [0, 1, 2, 3], "B": [1, 2, 3, 4]}) + + result = df.groupby("A").agg( + n_between01=pd.NamedAgg("B", self.n_between, 0, 1), + n_between13=pd.NamedAgg("B", self.n_between, 1, 3), + n_between02=pd.NamedAgg("B", self.n_between, 0, 2), + ) + expected = DataFrame( + { + "n_between01": [1, 0, 0, 0], + "n_between13": [1, 1, 1, 0], + "n_between02": [1, 1, 0, 0], + }, + index=Index([0, 1, 2, 3], name="A"), + ) + tm.assert_frame_equal(result, expected) + def test_mangled(self): df = DataFrame({"A": [0, 1], "B": [1, 2], "C": [3, 4]}) result = df.groupby("A").agg(b=("B", lambda x: 0), c=("C", lambda x: 1))