8000 Fixed #36210 -- Allowed Subquery usage in further lookups against com… · django/django@57bbeaf · GitHub
[go: up one dir, main page]

Skip to content

Commit 57bbeaf

Browse files
Fixed #36210 -- Allowed Subquery usage in further lookups against composite pks.
Follow-up to 8561100. co-authored-by: Simon Charette <charette.s@gmail.com>
1 parent 926e2a2 commit 57bbeaf

File tree

4 files changed

+55
-5
lines changed

4 files changed

+55
-5
lines changed

django/db/models/expressions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,7 @@ def __init__(self, queryset, output_field=None, **extra):
17561756
# Allow the usage of both QuerySet and sql.Query objects.
17571757
self.query = getattr(queryset, "query", queryset).clone()
17581758
self.query.subquery = True
1759+
self.template = extra.pop("template", self.template)
17591760
self.extra = extra
17601761
super().__init__(output_field)
17611762

@@ -1768,6 +1769,14 @@ def set_source_expressions(self, exprs):
17681769
def _resolve_output_field(self):
17691770
return self.query.output_field
17701771

1772+
def resolve_expression(self, *args, **kwargs):
1773+
resolved = super().resolve_expression(*args, **kwargs)
1774+
if self.template == Subquery.template:
1775+
# Subquery is an unnecessary shim for a resolved query as it
1776+
# complexifies the lookup's right-hand-side introspection.
1777+
return resolved.query
1778+
return resolved
1779+
17711780
def copy(self):
17721781
clone = super().copy()
17731782
clone.query = clone.query.clone()

django/db/models/fields/tuple_lookups.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22

33
from django.core.exceptions import EmptyResultSet
4+
from django.db import NotSupportedError
45
from django.db.models import Field
56
from django.db.models.expressions import (
67
ColPairs,
@@ -129,6 +130,20 @@ def as_sql(self, compiler, connection):
129130
return super().as_sql(compiler, connection)
130131

131132

133+
class TupleComparisonMixin:
134+
def as_oracle(self, compiler, connection):
135+
"""
136+
Regardless of whether a subquery has been limited to return
137+
a single row, this error is raised for <, <=, >, and >=:
138+
ORA-01796: this operator cannot be used with lists
139+
"""
140+
if isinstance(self.rhs, Query) and self.rhs.subquery:
141+
lookup = self.lookup_name
142+
msg = f"{lookup} for composite fields is not supported on this backend"
143+
raise NotSupportedError(msg)
144+
return super().as_oracle(compiler, connection)
145+
146+
132147
class TupleExact(TupleLookupMixin, Exact):
133148
def get_fallback_sql(self, compiler, connection):
134149
if isinstance(self.rhs, Query):
@@ -165,7 +180,7 @@ def as_sql(self, compiler, connection):
165180
return root.as_sql(compiler, connection)
166181

167182

168-
class TupleGreaterThan(TupleLookupMixin, GreaterThan):
183+
class TupleGreaterThan(TupleLookupMixin, TupleComparisonMixin, GreaterThan):
169184
def get_fallback_sql(self, compiler, connection):
170185
# Process right-hand-side to trigger sanitization.
171186
self.process_rhs(compiler, connection)
@@ -193,7 +208,9 @@ def get_fallback_sql(self, compiler, connection):
193208
return root.as_sql(compiler, connection)
194209

195210

196-
class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
211+
class TupleGreaterThanOrEqual(
212+
TupleLookupMixin, TupleComparisonMixin, GreaterThanOrEqual
213+
):
197214
def get_fallback_sql(self, compiler, connection):
198215
# Process right-hand-side to trigger sanitization.
199216
self.process_rhs(compiler, connection)
@@ -221,7 +238,7 @@ def get_fallback_sql(self, compiler, connection):
221238
return root.as_sql(compiler, connection)
222239

223240

224-
class TupleLessThan(TupleLookupMixin, LessThan):
241+
class TupleLessThan(TupleLookupMixin, TupleComparisonMixin, LessThan):
225242
def get_fallback_sql(self, compiler, connection):
226243
# Process right-hand-side to trigger sanitization.
227244
self.process_rhs(compiler, connection)
@@ -249,7 +266,7 @@ def get_fallback_sql(self, compiler, connection):
249266
return root.as_sql(compiler, connection)
250267

251268

252-
class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
269+
class TupleLessThanOrEqual(TupleLookupMixin, TupleComparisonMixin, LessThanOrEqual):
253270
def get_fallback_sql(self, compiler, connection):
254271
# Process right-hand-side to trigger sanitization.
255272
self.process_rhs(compiler, connection)

django/db/models/sql/query.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ class Query(BaseExpression):
235235

236236
filter_is_sticky = False
237237
subquery = False
238+
contains_subquery = True
238239

239240
# SQL-related attributes.
240241
# Select and related select clauses are expressions to use in the SELECT
@@ -1226,7 +1227,7 @@ def add_annotation(self, annotation, alias, select=True):
12261227

12271228
@property
12281229
def _subquery_fields_len(self):
1229-
if self.has_select_fields:
1230+
if self.has_select_fields and tuple(self.selected) != ("pk",):
12301231
return len(self.selected)
12311232
return len(self.model._meta.pk_fields)
12321233

tests/composite_pk/test_filter.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from django.db import NotSupportedError, connection
12
from django.db.models import (
23
Case,
34
F,
@@ -476,6 +477,28 @@ def test_outer_ref_pk(self):
476477
queryset = Comment.objects.filter(**{f"id{lookup}": subquery})
477478
self.assertEqual(queryset.count(), expected_count)
478479

480+
def test_outer_ref_pk_filter_on_pk(self):
481+
subquery = Subquery(User.objects.filter(pk=OuterRef("pk")).values("pk")[:1])
482+
tests = [
483+
("exact", 2),
484+
("gt", 0),
485+
("gte", 2),
486+
("lt", 0),
487+
("lte", 2),
488+
]
489+
for lookup, expected_count in tests:
490+
with self.subTest(f"pk__{lookup}"):
491+
qs = Comment.objects.filter(**{f"pk__{lookup}": subquery})
492+
if lookup != "exact" and connection.vendor == "oracle":
493+
with self.assertRaisesMessage(
494+
NotSupportedError,
495+
f"{lookup} "
496+
"for composite fields is not supported on this backend",
497+
):
498+
qs.count()
499+
else:
500+
self.assertEqual(qs.count(), expected_count)
501+
479502
def test_unsupported_rhs(self):
480503
pk = Exact(F("tenant_id"), 1)
481504
msg = (

0 commit comments

Comments
 (0)
0