Skip to content

Commit 8cbdf74

Browse files
Fix QuerysetAggregateWrapper implementation to handle aggregate queries properly
Co-Authored-By: Nishant Singh <saysnishant@gmail.com>
1 parent 5412ff2 commit 8cbdf74

File tree

2 files changed

+115
-68
lines changed

2 files changed

+115
-68
lines changed

django_querysets_single_query_fetch/service.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,23 +56,14 @@ class QuerysetAggregateWrapper:
5656
"""
5757
Wrapper around queryset to indicate that we want to fetch the result of .aggregate()
5858
This is useful for executing aggregate queries in a single database query along with other querysets.
59+
60+
Since aggregates don't support lazy evaluation, we need to store the queryset and
61+
the aggregate expressions separately.
5962
"""
6063

61-
def __init__(self, queryset: QuerySet, **aggregates) -> None:
64+
def __init__(self, queryset: QuerySet, **aggregate_expressions) -> None:
6265
self.queryset = queryset
63-
self.aggregates = {}
64-
for key in aggregates:
65-
if key == 'total_price':
66-
self.aggregates[key] = Sum('selling_price')
67-
elif key == 'count':
68-
self.aggregates[key] = Count('id')
69-
elif key == 'avg_price':
70-
self.aggregates[key] = Avg('selling_price')
71-
elif key == 'max_price':
72-
self.aggregates[key] = Max('selling_price')
73-
elif key == 'min_price':
74-
self.aggregates[key] = Min('selling_price')
75-
self.aggregate_result = {}
66+
self.aggregate_expressions = aggregate_expressions
7667

7768

7869
QuerysetWrapperType = Union[
@@ -220,16 +211,23 @@ def _get_compiler_from_queryset(self, queryset: QuerysetWrapperType) -> Any:
220211

221212
return compiler
222213

223-
def _get_sanitized_sql_param(self, param: str) -> str:
214+
def _get_sanitized_sql_param(self, param) -> str:
215+
if param is None:
216+
return "NULL"
217+
if isinstance(param, (int, float)):
218+
return str(param)
219+
if isinstance(param, bool):
220+
return "TRUE" if param else "FALSE"
221+
222+
param_str = str(param)
223+
224224
try:
225225
from psycopg import sql
226-
227-
return sql.quote(param)
226+
return sql.quote(param_str)
228227
except ImportError:
229228
try:
230229
from psycopg2.extensions import QuotedString
231-
232-
return QuotedString(param).getquoted().decode("utf-8")
230+
return QuotedString(param_str).getquoted().decode("utf-8")
233231
except ImportError:
234232
raise ImportError("psycopg or psycopg2 not installed")
235233

@@ -270,7 +268,46 @@ def _get_django_sql_for_queryset(self, queryset: QuerysetWrapperType) -> str:
270268
django_sql = sql % quoted_params
271269

272270
if isinstance(queryset, QuerysetAggregateWrapper):
273-
return ""
271+
272+
compiler = self._get_compiler_from_queryset(queryset.queryset)
273+
sql, params = compiler.as_sql()
274+
275+
if isinstance(params, dict):
276+
quoted_params = {}
277+
for key, value in params.items():
278+
quoted_params[key] = self._get_sanitized_sql_param(value)
279+
base_sql = sql % quoted_params
280+
else:
281+
quoted_params = []
282+
for value in params:
283+
quoted_params.append(self._get_sanitized_sql_param(value))
284+
base_sql = sql % tuple(quoted_params)
285+
286+
aggregate_sql_parts = []
287+
for key, value in queryset.aggregate_expressions.items():
288+
if isinstance(value, Sum):
289+
field = value.source_expressions[0].name
290+
aggregate_sql_parts.append(f"'{key}', SUM(subquery.{field})")
291+
elif isinstance(value, Count):
292+
field = value.source_expressions[0].name
293+
if field == '*':
294+
aggregate_sql_parts.append(f"'{key}', COUNT(*)")
295+
else:
296+
aggregate_sql_parts.append(f"'{key}', COUNT(subquery.{field})")
297+
elif isinstance(value, Avg):
298+
field = value.source_expressions[0].name
299+
aggregate_sql_parts.append(f"'{key}', AVG(subquery.{field})")
300+
elif isinstance(value, Max):
301+
field = value.source_expressions[0].name
302+
aggregate_sql_parts.append(f"'{key}', MAX(subquery.{field})")
303+
elif isinstance(value, Min):
304+
field = value.source_expressions[0].name
305+
aggregate_sql_parts.append(f"'{key}', MIN(subquery.{field})")
306+
307+
if aggregate_sql_parts:
308+
return f"(SELECT array_to_json(array[row(json_build_object({', '.join(aggregate_sql_parts)}))]) FROM ({base_sql}) AS subquery)"
309+
else:
310+
return "(SELECT array_to_json(array[row('{}'::jsonb)]))"
274311
else:
275312
return f"(SELECT COALESCE(json_agg(item), '[]') FROM ({django_sql}) item)"
276313

@@ -398,10 +435,14 @@ def _convert_raw_results_to_final_queryset_results(
398435
if isinstance(queryset, QuerysetCountWrapper):
399436
queryset_results = queryset_raw_results[0]["__count"]
400437
elif isinstance(queryset, QuerysetAggregateWrapper):
401-
if queryset_raw_results:
402-
queryset_results = queryset_raw_results[0]
438+
if queryset_raw_results and len(queryset_raw_results) > 0:
439+
nested_result = queryset_raw_results[0].get('f1', {})
440+
queryset_results = nested_result
403441
else:
404-
queryset_results = queryset.queryset.aggregate(**queryset.aggregates)
442+
queryset_results = {key: None for key in queryset.aggregate_expressions.keys()}
443+
for key, value in queryset.aggregate_expressions.items():
444+
if isinstance(value, Count):
445+
queryset_results[key] = 0
405446
else:
406447
if isinstance(queryset, QuerysetGetOrNoneWrapper):
407448
django_queryset = queryset.queryset

testapp/tests/test_aggregate_wrapper_for_postgres.py

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -30,55 +30,57 @@ def setUp(self) -> None:
3030
def test_simple_aggregate(self):
3131
"""Test simple aggregate with Sum"""
3232
queryset = StoreProduct.objects.filter()
33-
aggregate_queryset = queryset.aggregate(total_price=Sum("selling_price"))
33+
aggregate_expressions = {'total_price': Sum("selling_price")}
34+
expected_result = queryset.aggregate(**aggregate_expressions)
3435

3536
with self.assertNumQueries(1):
3637
results = QuerysetsSingleQueryFetch(
37-
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_queryset)]
38+
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_expressions)]
3839
).execute()
3940

4041
self.assertEqual(len(results), 1)
4142
aggregate_result = results[0]
4243

43-
self.assertEqual(len(aggregate_result), len(aggregate_queryset))
44+
self.assertEqual(len(aggregate_result), len(expected_result))
4445
self.assertIn('total_price', aggregate_result)
4546
self.assertAlmostEqual(
4647
float(aggregate_result['total_price']),
47-
float(aggregate_queryset['total_price']),
48+
float(expected_result['total_price']),
4849
places=2
4950
)
5051

5152
def test_multiple_aggregates(self):
5253
"""Test multiple aggregates in a single query"""
5354
queryset = StoreProduct.objects.filter()
54-
aggregate_queryset = queryset.aggregate(
55-
total_price=Sum("selling_price"),
56-
count=Count("id"),
57-
avg_price=Avg("selling_price"),
58-
max_price=Max("selling_price"),
59-
min_price=Min("selling_price"),
60-
)
55+
aggregate_expressions = {
56+
'total_price': Sum("selling_price"),
57+
'count': Count("id"),
58+
'avg_price': Avg("selling_price"),
59+
'max_price': Max("selling_price"),
60+
'min_price': Min("selling_price"),
61+
}
62+
expected_result = queryset.aggregate(**aggregate_expressions)
6163

6264
with self.assertNumQueries(1):
6365
results = QuerysetsSingleQueryFetch(
64-
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_queryset)]
66+
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_expressions)]
6567
).execute()
6668

6769
self.assertEqual(len(results), 1)
6870
aggregate_result = results[0]
6971

70-
self.assertEqual(len(aggregate_result), len(aggregate_queryset))
72+
self.assertEqual(len(aggregate_result), len(expected_result))
7173

72-
for key in aggregate_queryset.keys():
74+
for key in expected_result.keys():
7375
self.assertIn(key, aggregate_result)
74-
if isinstance(aggregate_queryset[key], Decimal) or isinstance(aggregate_result[key], (int, float, Decimal)):
76+
if isinstance(expected_result[key], Decimal) or isinstance(aggregate_result[key], (int, float, Decimal)):
7577
self.assertAlmostEqual(
7678
float(aggregate_result[key]),
77-
float(aggregate_queryset[key]),
79+
float(expected_result[key]),
7880
places=2
7981
)
8082
else:
81-
self.assertEqual(aggregate_result[key], aggregate_queryset[key])
83+
self.assertEqual(aggregate_result[key], expected_result[key])
8284

8385
self.assertEqual(aggregate_result['count'], 4)
8486
self.assertAlmostEqual(
@@ -90,31 +92,32 @@ def test_multiple_aggregates(self):
9092
def test_filtered_aggregate(self):
9193
"""Test aggregate with filter"""
9294
queryset = StoreProduct.objects.filter(category=self.category1)
93-
aggregate_queryset = queryset.aggregate(
94-
total_price=Sum("selling_price"),
95-
count=Count("id"),
96-
)
95+
aggregate_expressions = {
96+
'total_price': Sum("selling_price"),
97+
'count': Count("id"),
98+
}
99+
expected_result = queryset.aggregate(**aggregate_expressions)
97100

98101
with self.assertNumQueries(1):
99102
results = QuerysetsSingleQueryFetch(
100-
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_queryset)]
103+
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_expressions)]
101104
).execute()
102105

103106
self.assertEqual(len(results), 1)
104107
aggregate_result = results[0]
105108

106-
self.assertEqual(len(aggregate_result), len(aggregate_queryset))
109+
self.assertEqual(len(aggregate_result), len(expected_result))
107110

108-
for key in aggregate_queryset.keys():
111+
for key in expected_result.keys():
109112
self.assertIn(key, aggregate_result)
110-
if isinstance(aggregate_queryset[key], Decimal) or isinstance(aggregate_result[key], (int, float, Decimal)):
113+
if isinstance(expected_result[key], Decimal) or isinstance(aggregate_result[key], (int, float, Decimal)):
111114
self.assertAlmostEqual(
112115
float(aggregate_result[key]),
113-
float(aggregate_queryset[key]),
116+
float(expected_result[key]),
114117
places=2
115118
)
116119
else:
117-
self.assertEqual(aggregate_result[key], aggregate_queryset[key])
120+
self.assertEqual(aggregate_result[key], expected_result[key])
118121

119122
self.assertEqual(aggregate_result['count'], 2) # Only products in category1
120123
self.assertAlmostEqual(
@@ -126,40 +129,43 @@ def test_filtered_aggregate(self):
126129
def test_empty_aggregate(self):
127130
"""Test aggregate on empty queryset"""
128131
queryset = StoreProduct.objects.filter(id=-1) # No matches
129-
aggregate_queryset = queryset.aggregate(
130-
total_price=Sum("selling_price"),
131-
count=Count("id"),
132-
)
132+
aggregate_expressions = {
133+
'total_price': Sum("selling_price"),
134+
'count': Count("id"),
135+
}
136+
expected_result = queryset.aggregate(**aggregate_expressions)
133137

134138
with self.assertNumQueries(1):
135139
results = QuerysetsSingleQueryFetch(
136-
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_queryset)]
140+
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_expressions)]
137141
).execute()
138142

139143
self.assertEqual(len(results), 1)
140144
aggregate_result = results[0]
141145

142-
self.assertEqual(len(aggregate_result), len(aggregate_queryset))
146+
self.assertEqual(len(aggregate_result), len(expected_result))
143147

144-
for key in aggregate_queryset.keys():
148+
for key in expected_result.keys():
145149
self.assertIn(key, aggregate_result)
146-
self.assertEqual(aggregate_result[key], aggregate_queryset[key])
150+
self.assertEqual(aggregate_result[key], expected_result[key])
147151

148152
self.assertEqual(aggregate_result['count'], 0)
149153
self.assertIsNone(aggregate_result['total_price'])
150154

151155
def test_mix_with_other_querysets(self):
152156
"""Test mixture of aggregate wrapper and other querysets"""
153-
aggregate_queryset = StoreProduct.objects.filter().aggregate(
154-
total_price=Sum("selling_price"),
155-
count=Count("id"),
156-
)
157+
queryset = StoreProduct.objects.filter()
158+
aggregate_expressions = {
159+
'total_price': Sum("selling_price"),
160+
'count': Count("id"),
161+
}
162+
expected_result = queryset.aggregate(**aggregate_expressions)
157163
regular_queryset = StoreProductCategory.objects.filter()
158164

159165
with self.assertNumQueries(1):
160166
results = QuerysetsSingleQueryFetch(
161167
querysets=[
162-
QuerysetAggregateWrapper(queryset=StoreProduct.objects.filter(), **aggregate_queryset),
168+
QuerysetAggregateWrapper(queryset=queryset, **aggregate_expressions),
163169
regular_queryset
164170
]
165171
).execute()
@@ -168,18 +174,18 @@ def test_mix_with_other_querysets(self):
168174
aggregate_result = results[0]
169175
categories = results[1]
170176

171-
self.assertEqual(len(aggregate_result), len(aggregate_queryset))
177+
self.assertEqual(len(aggregate_result), len(expected_result))
172178

173-
for key in aggregate_queryset.keys():
179+
for key in expected_result.keys():
174180
self.assertIn(key, aggregate_result)
175-
if isinstance(aggregate_queryset[key], Decimal) or isinstance(aggregate_result[key], (int, float, Decimal)):
181+
if isinstance(expected_result[key], Decimal) or isinstance(aggregate_result[key], (int, float, Decimal)):
176182
self.assertAlmostEqual(
177183
float(aggregate_result[key]),
178-
float(aggregate_queryset[key]),
184+
float(expected_result[key]),
179185
places=2
180186
)
181187
else:
182-
self.assertEqual(aggregate_result[key], aggregate_queryset[key])
188+
self.assertEqual(aggregate_result[key], expected_result[key])
183189

184190
regular_categories = list(regular_queryset)
185191
self.assertEqual(len(categories), len(regular_categories))

0 commit comments

Comments
 (0)