8000 Fixed #36210 -- Allowed Subquery usage in further lookups against com… · django/django@f5fe954 · GitHub
[go: up one dir, main page]

Skip to content

Commit f5fe954

Browse files
Fixed #36210 -- Allowed Subquery usage in further lookups against composite pks.
Follow-up to 8561100. co-authored-by: Simon Charette <charette.s@gmail.com>
1 parent 926e2a2 commit f5fe954

File tree

6 files changed

+69
-6
lines changed

6 files changed

+69
-6
lines changed

django/db/backends/base/features.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,9 @@ class BaseDatabaseFeatures:
376376
# Does the backend support native tuple lookups (=, >, <, IN)?
377377
supports_tuple_lookups = True
378378

379+
# Does the backend support native tuple gt(e), lt(e) comparisons in subqueries?
380+
supports_tuple_comparison_in_subquery = True
381+
379382
# Collation names for use by the Django test suite.
380383
test_collations = {
381384
"ci": None, # Case-insensitive.

django/db/backends/oracle/features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
2121
can_return_columns_from_insert = True
2222
supports_subqueries_in_group_by = False
2323
ignores_unnecessary_order_by_in_subqueries = False
24+
supports_tuple_comparison_in_subquery = False
2425
supports_transactions = True
2526
supports_timezones = False
2627
has_native_duration_field = True

django/db/models/expressions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,7 @@ def __init__(self, queryset, output_field=None, **extra):
17561756
# Allow the usage of both QuerySet and sql.Query objects.
17571757
self.query = getattr(queryset, "query", queryset).clone()
17581758
self.query.subquery = True
1759+
self.template = extra.pop("template", self.template)
17591760
self.extra = extra
17601761
super().__init__(output_field)
17611762

@@ -1768,6 +1769,14 @@ def set_source_expressions(self, exprs):
17681769
def _resolve_output_field(self):
17691770
return self.query.output_field
17701771

1772+
def resolve_expression(self, *args, **kwargs):
1773+
resolved = super().resolve_expression(*args, **kwargs)
1774+
if self.template == Subquery.template:
1775+
# Subquery is an unnecessary shim for a resolved query as it
1776+
# complexifies the lookup's right-hand-side introspection.
1777+
return resolved.query
1778+
return resolved
1779+
17711780
def copy(self):
17721781
clone = super().copy()
17731782
clone.query = clone.query.clone()

django/db/models/fields/tuple_lookups.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22

33
from django.core.exceptions import EmptyResultSet
4+
from django.db import NotSupportedError
45
from django.db.models import Field
56
from django.db.models.expressions import (
67
ColPairs,
@@ -129,6 +130,19 @@ def as_sql(self, compiler, connection):
129130
return super().as_sql(compiler, connection)
130131

131132

133+
class TupleComparisonMixin:
134+
def as_sql(self, compiler, connection):
135+
if (
136+
not connection.features.supports_tuple_comparison_in_subquery
137+
and isinstance(self.rhs, Query)
138+
and self.rhs.subquery
139+
):
140+
lookup = self.lookup_name
141+
msg = f"{lookup} for composite fields is not supported on this backend"
142+
raise NotSupportedError(msg)
143+
return super().as_sql(compiler, connection)
144+
145+
132146
class TupleExact(TupleLookupMixin, Exact):
133147
def get_fallback_sql(self, compiler, connection):
134148
if isinstance(self.rhs, Query):
@@ -165,7 +179,7 @@ def as_sql(self, compiler, connection):
165179
return root.as_sql(compiler, connection)
166180

167181

168-
class TupleGreaterThan(TupleLookupMixin, GreaterThan):
182+
class TupleGreaterThan(TupleLookupMixin, TupleComparisonMixin, GreaterThan):
169183
def get_fallback_sql(self, compiler, connection):
170184
# Process right-hand-side to trigger sanitization.
171185
self.process_rhs(compiler, connection)
@@ -193,7 +207,9 @@ def get_fallback_sql(self, compiler, connection):
193207
return root.as_sql(compiler, connection)
194208

195209

196-
class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
210+
class TupleGreaterThanOrEqual(
211+
TupleLookupMixin, TupleComparisonMixin, GreaterThanOrEqual
212+
):
197213
def get_fallback_sql(self, compiler, connection):
198214
# Process right-hand-side to trigger sanitization.
199215
self.process_rhs(compiler, connection)
@@ -221,7 +237,7 @@ def get_fallback_sql(self, compiler, connection):
221237
return root.as_sql(compiler, connection)
222238

223239

224-
class TupleLessThan(TupleLookupMixin, LessThan):
240+
class TupleLessThan(TupleLookupMixin, TupleComparisonMixin, LessThan):
225241
def get_fallback_sql(self, compiler, connection):
226242
# Process right-hand-side to trigger sanitization.
227243
self.process_rhs(compiler, connection)
@@ -249,7 +265,7 @@ def get_fallback_sql(self, compiler, connection):
249265
return root.as_sql(compiler, connection)
250266

251267

252-
class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
268+
class TupleLessThanOrEqual(TupleLookupMixin, TupleComparisonMixin, LessThanOrEqual):
253269
def get_fallback_sql(self, compiler, connection):
254270
# Process right-hand-side to trigger sanitization.
255271
self.process_rhs(compiler, connection)

django/db/models/sql/query.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ class Query(BaseExpression):
235235

236236
filter_is_sticky = False
237237
subquery = False
238+
contains_subquery = True
238239

239240
# SQL-related attributes.
240241
# Select and related select clauses are expressions to use in the SELECT
@@ -1226,7 +1227,7 @@ def add_annotation(self, annotation, alias, select=True):
12261227

12271228
@property
12281229
def _subquery_fields_len(self):
1229-
if self.has_select_fields:
1230+
if self.has_select_fields and tuple(self.selected) != ("pk",):
12301231
return len(self.selected)
12311232
return len(self.model._meta.pk_fields)
12321233

tests/composite_pk/test_filter.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from django.db import NotSupportedError
12
from django.db.models import (
23
Case,
34
F,
@@ -10,7 +11,7 @@
1011
)
1112
from django.db.models.functions import Cast
1213
from django.db.models.lookups import Exact
13-
from django.test import TestCase, skipUnlessDBFeature
14+
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
1415

1516
from .models import Comment, Tenant, User
1617

@@ -476,6 +477,38 @@ def test_outer_ref_pk(self):
476477
queryset = Comment.objects.filter(**{f"id{lookup}": subquery})
477478
self.assertEqual(queryset.count(), expected_count)
478479

480+
def test_outer_ref_pk_filter_on_pk_exact(self):
481+
subquery = Subquery(User.objects.filter(pk=OuterRef("pk")).values("pk")[:1])
482+
qs = Comment.objects.filter(pk=subquery)
483+
self.assertEqual(qs.count(), 2)
484+
485+
@skipUnlessDBFeature("supports_tuple_comparison_in_subquery")
486+
def test_outer_ref_pk_filter_on_pk_comparison(self):
487+
subquery = Subquery(User.objects.filter(pk=OuterRef("pk")).values("pk")[:1])
488+
tests = [
489+
("gt", 0),
490+
("gte", 2),
491+
("lt", 0),
492+
("lte", 2),
493+
]
494+
for lookup, expected_count in tests:
495+
with self.subTest(f"pk__{lookup}"):
496+
qs = Comment.objects.filter(**{f"pk__{lookup}": subquery})
497+
self.assertEqual(qs.count(), expected_count)
498+
499+
@skipIfDBFeature("supports_tuple_comparison_in_subquery")
500+
def test_outer_ref_pk_filter_on_pk_comparison_unsupported(self):
501+
subquery = Subquery(User.objects.filter(pk=OuterRef("pk")).values("pk")[:1])
502+
tests = ["gt", "gte", "lt", "lte"]
503+
for lookup in tests:
504+
with self.subTest(f"pk__{lookup}"):
505+
qs = Comment.objects.filter(**{f"pk__{lookup}": subquery})
506+
with self.assertRaisesMessage(
507+
NotSupportedError,
508+
f"{lookup} for composite fields is not supported on this backend",
509+
):
510+
qs.count()
511+
479512
def test_unsupported_rhs(self):
480513
pk = Exact(F("tenant_id"), 1)
481514
msg = (

0 commit comments

Comments
 (0)
0