1"""
2Useful auxiliary data structures for query construction. Not useful outside
3the SQL domain.
4"""
5
6from django.core.exceptions import FullResultSet
7from django.db.models.sql.constants import INNER, LOUTER
8
9
10class MultiJoin(Exception):
11 """
12 Used by join construction code to indicate the point at which a
13 multi-valued join was attempted (if the caller wants to treat that
14 exceptionally).
15 """
16
17 def __init__(self, names_pos, path_with_names):
18 self.level = names_pos
19 # The path travelled, this includes the path to the multijoin.
20 self.names_with_path = path_with_names
21
22
23class Empty:
24 pass
25
26
27class Join:
28 """
29 Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
30 FROM entry. For example, the SQL generated could be
31 LEFT OUTER JOIN "sometable" T1
32 ON ("othertable"."sometable_id" = "sometable"."id")
33
34 This class is primarily used in Query.alias_map. All entries in alias_map
35 must be Join compatible by providing the following attributes and methods:
36 - table_name (string)
37 - table_alias (possible alias for the table, can be None)
38 - join_type (can be None for those entries that aren't joined from
39 anything)
40 - parent_alias (which table is this join's parent, can be None similarly
41 to join_type)
42 - as_sql()
43 - relabeled_clone()
44 """
45
46 def __init__(
47 self,
48 table_name,
49 parent_alias,
50 table_alias,
51 join_type,
52 join_field,
53 nullable,
54 filtered_relation=None,
55 ):
56 # Join table
57 self.table_name = table_name
58 self.parent_alias = parent_alias
59 # Note: table_alias is not necessarily known at instantiation time.
60 self.table_alias = table_alias
61 # LOUTER or INNER
62 self.join_type = join_type
63 # A list of 2-tuples to use in the ON clause of the JOIN.
64 # Each 2-tuple will create one join condition in the ON clause.
65 self.join_fields = join_field.get_joining_fields()
66 self.join_cols = tuple(
67 (lhs_field.column, rhs_field.column)
68 for lhs_field, rhs_field in self.join_fields
69 )
70 # Along which field (or ForeignObjectRel in the reverse join case)
71 self.join_field = join_field
72 # Is this join nullabled?
73 self.nullable = nullable
74 self.filtered_relation = filtered_relation
75
76 def as_sql(self, compiler, connection):
77 """
78 Generate the full
79 LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
80 clause for this join.
81 """
82 join_conditions = []
83 params = []
84 qn = compiler.quote_name_unless_alias
85 # Add a join condition for each pair of joining columns.
86 for lhs, rhs in self.join_fields:
87 lhs, rhs = connection.ops.prepare_join_on_clause(
88 self.parent_alias, lhs, self.table_alias, rhs
89 )
90 lhs_sql, lhs_params = compiler.compile(lhs)
91 lhs_full_name = lhs_sql % lhs_params
92 rhs_sql, rhs_params = compiler.compile(rhs)
93 rhs_full_name = rhs_sql % rhs_params
94 join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
95
96 # Add a single condition inside parentheses for whatever
97 # get_extra_restriction() returns.
98 extra_cond = self.join_field.get_extra_restriction(
99 self.table_alias, self.parent_alias
100 )
101 if extra_cond:
102 extra_sql, extra_params = compiler.compile(extra_cond)
103 join_conditions.append("(%s)" % extra_sql)
104 params.extend(extra_params)
105 if self.filtered_relation:
106 try:
107 extra_sql, extra_params = compiler.compile(self.filtered_relation)
108 except FullResultSet:
109 pass
110 else:
111 join_conditions.append("(%s)" % extra_sql)
112 params.extend(extra_params)
113 if not join_conditions:
114 # This might be a rel on the other end of an actual declared field.
115 declared_field = getattr(self.join_field, "field", self.join_field)
116 raise ValueError(
117 "Join generated an empty ON clause. %s did not yield either "
118 "joining columns or extra restrictions." % declared_field.__class__
119 )
120 on_clause_sql = " AND ".join(join_conditions)
121 alias_str = (
122 "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
123 )
124 sql = "%s %s%s ON (%s)" % (
125 self.join_type,
126 qn(self.table_name),
127 alias_str,
128 on_clause_sql,
129 )
130 return sql, params
131
132 def relabeled_clone(self, change_map):
133 new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
134 new_table_alias = change_map.get(self.table_alias, self.table_alias)
135 if self.filtered_relation is not None:
136 filtered_relation = self.filtered_relation.relabeled_clone(change_map)
137 else:
138 filtered_relation = None
139 return self.__class__(
140 self.table_name,
141 new_parent_alias,
142 new_table_alias,
143 self.join_type,
144 self.join_field,
145 self.nullable,
146 filtered_relation=filtered_relation,
147 )
148
149 @property
150 def identity(self):
151 return (
152 self.__class__,
153 self.table_name,
154 self.parent_alias,
155 self.join_field,
156 self.filtered_relation,
157 )
158
159 def __eq__(self, other):
160 if not isinstance(other, Join):
161 return NotImplemented
162 return self.identity == other.identity
163
164 def __hash__(self):
165 return hash(self.identity)
166
167 def demote(self):
168 new = self.relabeled_clone({})
169 new.join_type = INNER
170 return new
171
172 def promote(self):
173 new = self.relabeled_clone({})
174 new.join_type = LOUTER
175 return new
176
177
178class BaseTable:
179 """
180 The BaseTable class is used for base table references in FROM clause. For
181 example, the SQL "foo" in
182 SELECT * FROM "foo" WHERE somecond
183 could be generated by this class.
184 """
185
186 join_type = None
187 parent_alias = None
188 filtered_relation = None
189
190 def __init__(self, table_name, alias):
191 self.table_name = table_name
192 self.table_alias = alias
193
194 def as_sql(self, compiler, connection):
195 alias_str = (
196 "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
197 )
198 base_sql = compiler.quote_name_unless_alias(self.table_name)
199 return base_sql + alias_str, []
200
201 def relabeled_clone(self, change_map):
202 return self.__class__(
203 self.table_name, change_map.get(self.table_alias, self.table_alias)
204 )
205
206 @property
207 def identity(self):
208 return self.__class__, self.table_name, self.table_alias
209
210 def __eq__(self, other):
211 if not isinstance(other, BaseTable):
212 return NotImplemented
213 return self.identity == other.identity
214
215 def __hash__(self):
216 return hash(self.identity)