Skip to content

Commit af63a1c

Browse files
ENH: pd.NamedAgg forwards *args and **kwargs to aggfunc (#62729)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 607e489 commit af63a1c

File tree

4 files changed

+130
-7
lines changed

4 files changed

+130
-7
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ All warnings for upcoming changes in pandas will have the base class :class:`pan
156156

157157
Other enhancements
158158
^^^^^^^^^^^^^^^^^^
159+
- :class:`pandas.NamedAgg` now supports passing ``*args`` and ``**kwargs``
160+
to calls of ``aggfunc`` (:issue:`58283`)
159161
- :func:`pandas.merge` propagates the ``attrs`` attribute to the result if all
160162
inputs have identical ``attrs``, as has so far already been the case for
161163
:func:`pandas.concat`.

pandas/core/apply.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,7 +1745,13 @@ def reconstruct_func(
17451745
>>> reconstruct_func("min")
17461746
(False, 'min', None, None)
17471747
"""
1748-
relabeling = func is None and is_multi_agg_with_relabel(**kwargs)
1748+
from pandas.core.groupby.generic import NamedAgg
1749+
1750+
relabeling = func is None and (
1751+
is_multi_agg_with_relabel(**kwargs)
1752+
or any(isinstance(v, NamedAgg) for v in kwargs.values())
1753+
)
1754+
17491755
columns: tuple[str, ...] | None = None
17501756
order: npt.NDArray[np.intp] | None = None
17511757

@@ -1766,9 +1772,22 @@ def reconstruct_func(
17661772
# "Callable[..., Any] | str | list[Callable[..., Any] | str] |
17671773
# MutableMapping[Hashable, Callable[..., Any] | str | list[Callable[..., Any] |
17681774
# str]] | None")
1775+
converted_kwargs = {}
1776+
for key, val in kwargs.items():
1777+
if isinstance(val, NamedAgg):
1778+
aggfunc = val.aggfunc
1779+
if val.args or val.kwargs:
1780+
aggfunc = lambda x, func=aggfunc, a=val.args, kw=val.kwargs: func(
1781+
x, *a, **kw
1782+
)
1783+
converted_kwargs[key] = (val.column, aggfunc)
1784+
else:
1785+
converted_kwargs[key] = val
1786+
17691787
func, columns, order = normalize_keyword_aggregation( # type: ignore[assignment]
1770-
kwargs
1788+
converted_kwargs
17711789
)
1790+
17721791
assert func is not None
17731792

17741793
return relabeling, func, columns, order

pandas/core/groupby/generic.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
from collections import abc
1212
from collections.abc import Callable
13+
import dataclasses
1314
from functools import partial
1415
from textwrap import dedent
1516
from typing import (
1617
TYPE_CHECKING,
1718
Any,
1819
Literal,
19-
NamedTuple,
2020
TypeAlias,
2121
TypeVar,
2222
cast,
@@ -113,19 +113,20 @@
113113

114114

115115
@set_module("pandas")
116-
class NamedAgg(NamedTuple):
116+
@dataclasses.dataclass
117+
class NamedAgg:
117118
"""
118119
Helper for column specific aggregation with control over output column names.
119120
120-
Subclass of typing.NamedTuple.
121-
122121
Parameters
123122
----------
124123
column : Hashable
125124
Column label in the DataFrame to apply aggfunc.
126125
aggfunc : function or str
127126
Function to apply to the provided column. If string, the name of a built-in
128127
pandas function.
128+
*args, **kwargs : Any
129+
Optional positional and keyword arguments passed to ``aggfunc``.
129130
130131
See Also
131132
--------
@@ -137,14 +138,57 @@ class NamedAgg(NamedTuple):
137138
>>> agg_a = pd.NamedAgg(column="a", aggfunc="min")
138139
>>> agg_1 = pd.NamedAgg(column=1, aggfunc=lambda x: np.mean(x))
139140
>>> df.groupby("key").agg(result_a=agg_a, result_1=agg_1)
140-
result_a result_1
141+
result_a result_1
141142
key
142143
1 -1 10.5
143144
2 1 12.0
145+
146+
>>> def n_between(ser, low, high, **kwargs):
147+
... return ser.between(low, high, **kwargs).sum()
148+
149+
>>> agg_between = pd.NamedAgg("a", n_between, 0, 1)
150+
>>> df.groupby("key").agg(count_between=agg_between)
151+
count_between
152+
key
153+
1 1
154+
2 1
155+
156+
>>> agg_between_kw = pd.NamedAgg("a", n_between, 0, 1, inclusive="both")
157+
>>> df.groupby("key").agg(count_between_kw=agg_between_kw)
158+
count_between_kw
159+
key
160+
1 1
161+
2 1
144162
"""
145163

146164
column: Hashable
147165
aggfunc: AggScalar
166+
args: tuple[Any, ...] = ()
167+
kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
168+
169+
def __init__(
170+
self,
171+
column: Hashable,
172+
aggfunc: Callable[..., Any] | str,
173+
*args: Any,
174+
**kwargs: Any,
175+
) -> None:
176+
self.column = column
177+
self.aggfunc = aggfunc
178+
self.args = args
179+
self.kwargs = kwargs
180+
181+
def __getitem__(self, key: int) -> Any:
182+
"""Provide backward-compatible tuple-style access."""
183+
if key == 0:
184+
return self.column
185+
elif key == 1:
186+
return self.aggfunc
187+
elif key == 2:
188+
return self.args
189+
elif key == 3:
190+
return self.kwargs
191+
raise IndexError("index out of range")
148192

149193

150194
@set_module("pandas.api.typing")

pandas/tests/groupby/aggregate/test_aggregate.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,64 @@ def test_agg_namedtuple(self):
866866
expected = df.groupby("A").agg(b=("B", "sum"), c=("B", "count"))
867867
tm.assert_frame_equal(result, expected)
868868

869+
def n_between(self, ser, low, high, **kwargs):
870+
return ser.between(low, high, **kwargs).sum()
871+
872+
def test_namedagg_args(self):
873+
# https://github.com/pandas-dev/pandas/issues/58283
874+
df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]})
875+
876+
result = df.groupby("A").agg(
877+
count_between=pd.NamedAgg("B", self.n_between, 0, 1)
878+
)
879+
expected = DataFrame({"count_between": [1, 1]}, index=Index([0, 1], name="A"))
880+
tm.assert_frame_equal(result, expected)
881+
882+
def test_namedagg_kwargs(self):
883+
# https://github.com/pandas-dev/pandas/issues/58283
884+
df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]})
885+
886+
result = df.groupby("A").agg(
887+
count_between_kw=pd.NamedAgg("B", self.n_between, 0, 1, inclusive="both")
888+
)
889+
expected = DataFrame(
890+
{"count_between_kw": [1, 1]}, index=Index([0, 1], name="A")
891+
)
892+
tm.assert_frame_equal(result, expected)
893+
894+
def test_namedagg_args_and_kwargs(self):
895+
# https://github.com/pandas-dev/pandas/issues/58283
896+
df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]})
897+
898+
result = df.groupby("A").agg(
899+
count_between_mix=pd.NamedAgg(
900+
"B", self.n_between, 0, 1, inclusive="neither"
901+
)
902+
)
903+
expected = DataFrame(
904+
{"count_between_mix": [0, 0]}, index=Index([0, 1], name="A")
905+
)
906+
tm.assert_frame_equal(result, expected)
907+
908+
def test_multiple_named_agg_with_args_and_kwargs(self):
909+
# https://github.com/pandas-dev/pandas/issues/58283
910+
df = DataFrame({"A": [0, 1, 2, 3], "B": [1, 2, 3, 4]})
911+
912+
result = df.groupby("A").agg(
913+
n_between01=pd.NamedAgg("B", self.n_between, 0, 1),
914+
n_between13=pd.NamedAgg("B", self.n_between, 1, 3),
915+
n_between02=pd.NamedAgg("B", self.n_between, 0, 2),
916+
)
917+
expected = DataFrame(
918+
{
919+
"n_between01": [1, 0, 0, 0],
920+
"n_between13": [1, 1, 1, 0],
921+
"n_between02": [1, 1, 0, 0],
922+
},
923+
index=Index([0, 1, 2, 3], name="A"),
924+
)
925+
tm.assert_frame_equal(result, expected)
926+
869927
def test_mangled(self):
870928
df = DataFrame({"A": [0, 1], "B": [1, 2], "C": [3, 4]})
871929
result = df.groupby("A").agg(b=("B", lambda x: 0), c=("C", lambda x: 1))

0 commit comments

Comments
 (0)