1from django.db.models.expressions import ColPairs
2from django.db.models.fields import composite
3from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups
4from django.db.models.lookups import (
5 Exact,
6 GreaterThan,
7 GreaterThanOrEqual,
8 In,
9 IsNull,
10 LessThan,
11 LessThanOrEqual,
12)
13
14
15def get_normalized_value(value, lhs):
16 from django.db.models import Model
17
18 if isinstance(value, Model):
19 if not value._is_pk_set():
20 raise ValueError("Model instances passed to related filters must be saved.")
21 value_list = []
22 sources = composite.unnest(lhs.output_field.path_infos[-1].target_fields)
23 for source in sources:
24 while not isinstance(value, source.model) and source.remote_field:
25 source = source.remote_field.model._meta.get_field(
26 source.remote_field.field_name
27 )
28 try:
29 value_list.append(getattr(value, source.attname))
30 except AttributeError:
31 # A case like Restaurant.objects.filter(place=restaurant_instance),
32 # where place is a OneToOneField and the primary key of Restaurant.
33 pk = value.pk
34 return pk if isinstance(pk, tuple) else (pk,)
35 return tuple(value_list)
36 if not isinstance(value, tuple):
37 return (value,)
38 return value
39
40
41class RelatedIn(In):
42 def get_prep_lookup(self):
43 if not isinstance(self.lhs, ColPairs):
44 if self.rhs_is_direct_value():
45 # If we get here, we are dealing with single-column relations.
46 self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
47 # We need to run the related field's get_prep_value(). Consider
48 # case ForeignKey to IntegerField given value 'abc'. The
49 # ForeignKey itself doesn't have validation for non-integers,
50 # so we must run validation using the target field.
51 if hasattr(self.lhs.output_field, "path_infos"):
52 # Run the target field's get_prep_value. We can safely
53 # assume there is only one as we don't get to the direct
54 # value branch otherwise.
55 target_field = self.lhs.output_field.path_infos[-1].target_fields[
56 -1
57 ]
58 self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
59 elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
60 self.lhs.field.target_field, "primary_key", False
61 ):
62 if (
63 getattr(self.lhs.output_field, "primary_key", False)
64 and self.lhs.output_field.model == self.rhs.model
65 ):
66 # A case like
67 # Restaurant.objects.filter(place__in=restaurant_qs), where
68 # place is a OneToOneField and the primary key of
69 # Restaurant.
70 target_field = self.lhs.field.name
71 else:
72 target_field = self.lhs.field.target_field.name
73 self.rhs.set_values([target_field])
74 return super().get_prep_lookup()
75
76 def as_sql(self, compiler, connection):
77 if isinstance(self.lhs, ColPairs):
78 from django.db.models.sql.where import SubqueryConstraint
79
80 if self.rhs_is_direct_value():
81 values = [get_normalized_value(value, self.lhs) for value in self.rhs]
82 lookup = TupleIn(self.lhs, values)
83 return compiler.compile(lookup)
84 else:
85 return compiler.compile(
86 SubqueryConstraint(
87 self.lhs.alias,
88 [target.column for target in self.lhs.targets],
89 [source.name for source in self.lhs.sources],
90 self.rhs,
91 ),
92 )
93
94 return super().as_sql(compiler, connection)
95
96
97class RelatedLookupMixin:
98 def get_prep_lookup(self):
99 if not isinstance(self.lhs, ColPairs) and not hasattr(
100 self.rhs, "resolve_expression"
101 ):
102 # If we get here, we are dealing with single-column relations.
103 self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
104 # We need to run the related field's get_prep_value(). Consider case
105 # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
106 # doesn't have validation for non-integers, so we must run validation
107 # using the target field.
108 if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
109 # Get the target field. We can safely assume there is only one
110 # as we don't get to the direct value branch otherwise.
111 target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
112 self.rhs = target_field.get_prep_value(self.rhs)
113
114 return super().get_prep_lookup()
115
116 def as_sql(self, compiler, connection):
117 if isinstance(self.lhs, ColPairs):
118 if not self.rhs_is_direct_value():
119 raise ValueError(
120 f"'{self.lookup_name}' doesn't support multi-column subqueries."
121 )
122 self.rhs = get_normalized_value(self.rhs, self.lhs)
123 lookup_class = tuple_lookups[self.lookup_name]
124 lookup = lookup_class(self.lhs, self.rhs)
125 return compiler.compile(lookup)
126
127 return super().as_sql(compiler, connection)
128
129
130class RelatedExact(RelatedLookupMixin, Exact):
131 pass
132
133
134class RelatedLessThan(RelatedLookupMixin, LessThan):
135 pass
136
137
138class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
139 pass
140
141
142class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
143 pass
144
145
146class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
147 pass
148
149
150class RelatedIsNull(RelatedLookupMixin, IsNull):
151 pass