8000 Fixed #24887 -- Removed one-arg limit from models.aggregate · gtossou/django@4a66a69 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4a66a69

Browse files
gchptimgraham
authored andcommitted
Fixed #24887 -- Removed one-arg limit from models.aggregate
1 parent 6c592e7 commit 4a66a69

File tree

4 files changed

+49
-21
lines changed

4 files changed

+49
-21
lines changed

django/contrib/gis/db/models/aggregates.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ def as_oracle(self, compiler, connection):
2525

2626
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
2727
c = super(GeoAggregate, self).resolve_expression(query, allow_joins, reuse, summarize, for_save)
28-
if not hasattr(c.input_field.field, 'geom_type'):
29-
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
28+
for expr in c.get_source_expressions():
29+
if not hasattr(expr.field, 'geom_type&# 10000 39;):
30+
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
3031
return c
3132

3233
def convert_value(self, value, expression, connection, context):

django/db/backends/sqlite3/operations.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,19 @@ def check_expression_support(self, expression):
3535
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
3636
bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev)
3737
if isinstance(expression, bad_aggregates):
38-
try:
39-
output_field = expression.input_field.output_field
40-
if isinstance(output_field, bad_fields):
41-
raise NotImplementedError(
42-
'You cannot use Sum, Avg, StdDev and Variance aggregations '
43-
'on date/time fields in sqlite3 '
44-
'since date/time is saved as text.')
45-
except FieldError:
46-
# not every sub-expression has an output_field which is fine to
47-
# ignore
48-
pass
38+
for expr in expression.get_source_expressions():
39+
try:
40+
output_field = expr.output_field
41+
if isinstance(output_field, bad_fields):
42+
raise NotImplementedError(
43+
'You cannot use Sum, Avg, StdDev, and Variance '
44+
'aggregations on date/time fields in sqlite3 '
45+
'since date/time is saved as text.'
46+
)
47+
except FieldError:
48+
# Not every subexpression has an output_field which is fine
49+
# to ignore.
50+
pass
4951

5052
def date_extract_sql(self, lookup_type, field_name):
5153
# sqlite doesn't support extract, so we fake it with the user-defined

django/db/models/aggregates.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@ class Aggregate(Func):
1515
name = None
1616

1717
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
18-
assert len(self.source_expressions) == 1
1918
# Aggregates are not allowed in UPDATE queries, so ignore for_save
2019
c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize)
21-
if c.source_expressions[0].contains_aggregate and not summarize:
22-
name = self.source_expressions[0].name
23-
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
24-
c.name, name, name))
20+
if not summarize:
21+
expressions = c.get_source_expressions()
22+
for index, expr in enumerate(expressions):
23+
if expr.contains_aggregate:
24+
before_resolved = self.get_source_expressions()[index]
25+
name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
26+
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
2527
c._patch_aggregate(query) # backward-compatibility support
2628
return c
2729

@@ -31,8 +33,9 @@ def input_field(self):
3133

3234
@property
3335
def default_alias(self):
34-
if hasattr(self.source_expressions[0], 'name'):
35-
return '%s__%s' % (self.source_expressions[0].name, self.name.lower())
36+
expressions = self.get_source_expressions()
37+
if len(expressions) == 1 and hasattr(expressions[0], 'name'):
38+
return '%s__%s' % (expressions[0].name, self.name.lower())
3639
raise TypeError("Complex expressions require an alias")
3740

3841
def get_group_by_cols(self):

tests/aggregation/tests.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,9 +985,31 @@ def test_annotate_over_annotate(self):
985985
self.assertEqual(author.sum_age, other_author.sum_age)
986986

987987
def test_annotated_aggregate_over_annotated_aggregate(self):
988-
with six.assertRaisesRegex(self, FieldError, "Cannot compute Sum\('id__max'\): 'id__max' is an aggregate"):
988+
with self.assertRaisesMessage(FieldError, "Cannot compute Sum('id__max'): 'id__max' is an aggregate"):
989989
Book.objects.annotate(Max('id')).annotate(Sum('id__max'))
990990

991+
class MyMax(Max):
992+
def as_sql(self, compiler, connection):
993+
self.set_source_expressions(self.get_source_expressions()[0:1])
994+
return super(MyMax, self).as_sql(compiler, connection)
995+
996+
with self.assertRaisesMessage(FieldError, "Cannot compute Max('id__max'): 'id__max' is an aggregate"):
997+
Book.objects.annotate(Max('id')).annotate(my_max=MyMax('id__max', 'price'))
998+
999+
def test_multi_arg_aggregate(self):
1000+
class MyMax(Max):
1001+
def as_sql(self, compiler, connection):
1002+
self.set_source_expressions(self.get_source_expressions()[0:1])
1003+
return super(MyMax, self).as_sql(compiler, connection)
1004+
1005+
with self.assertRaisesMessage(TypeError, 'Complex aggregates require an alias'):
1006+
Book.objects.aggregate(MyMax('pages', 'price'))
1007+
1008+
with self.assertRaisesMessage(TypeError, 'Complex annotations require an alias'):
1009+
Book.objects.annotate(MyMax('pages', 'price'))
1010+
1011+
Book.objects.aggregate(max_field=MyMax('pages', 'price'))
1012+
9911013
def test_add_implementation(self):
9921014
class MySum(Sum):
9931015
pass

0 commit comments

Comments
 (0)
0