1from django.db import NotSupportedError
2from django.db.models.expressions import Func, Value
3from django.db.models.fields import CharField, IntegerField, TextField
4from django.db.models.functions import Cast, Coalesce
5from django.db.models.lookups import Transform
6
7
8class MySQLSHA2Mixin:
9 def as_mysql(self, compiler, connection, **extra_context):
10 return super().as_sql(
11 compiler,
12 connection,
13 template="SHA2(%%(expressions)s, %s)" % self.function[3:],
14 **extra_context,
15 )
16
17
18class OracleHashMixin:
19 def as_oracle(self, compiler, connection, **extra_context):
20 return super().as_sql(
21 compiler,
22 connection,
23 template=(
24 "LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW("
25 "%(expressions)s, 'AL32UTF8'), '%(function)s')))"
26 ),
27 **extra_context,
28 )
29
30
31class PostgreSQLSHAMixin:
32 def as_postgresql(self, compiler, connection, **extra_context):
33 return super().as_sql(
34 compiler,
35 connection,
36 template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
37 function=self.function.lower(),
38 **extra_context,
39 )
40
41
42class Chr(Transform):
43 function = "CHR"
44 lookup_name = "chr"
45 output_field = CharField()
46
47 def as_mysql(self, compiler, connection, **extra_context):
48 return super().as_sql(
49 compiler,
50 connection,
51 function="CHAR",
52 template="%(function)s(%(expressions)s USING utf16)",
53 **extra_context,
54 )
55
56 def as_oracle(self, compiler, connection, **extra_context):
57 return super().as_sql(
58 compiler,
59 connection,
60 template="%(function)s(%(expressions)s USING NCHAR_CS)",
61 **extra_context,
62 )
63
64 def as_sqlite(self, compiler, connection, **extra_context):
65 return super().as_sql(compiler, connection, function="CHAR", **extra_context)
66
67
68class ConcatPair(Func):
69 """
70 Concatenate two arguments together. This is used by `Concat` because not
71 all backend databases support more than two arguments.
72 """
73
74 function = "CONCAT"
75
76 def pipes_concat_sql(self, compiler, connection, **extra_context):
77 coalesced = self.coalesce()
78 return super(ConcatPair, coalesced).as_sql(
79 compiler,
80 connection,
81 template="(%(expressions)s)",
82 arg_joiner=" || ",
83 **extra_context,
84 )
85
86 as_sqlite = pipes_concat_sql
87
88 def as_postgresql(self, compiler, connection, **extra_context):
89 c = self.copy()
90 c.set_source_expressions(
91 [
92 (
93 expression
94 if isinstance(expression.output_field, (CharField, TextField))
95 else Cast(expression, TextField())
96 )
97 for expression in c.get_source_expressions()
98 ]
99 )
100 return c.pipes_concat_sql(compiler, connection, **extra_context)
101
102 def as_mysql(self, compiler, connection, **extra_context):
103 # Use CONCAT_WS with an empty separator so that NULLs are ignored.
104 return super().as_sql(
105 compiler,
106 connection,
107 function="CONCAT_WS",
108 template="%(function)s('', %(expressions)s)",
109 **extra_context,
110 )
111
112 def coalesce(self):
113 # null on either side results in null for expression, wrap with coalesce
114 c = self.copy()
115 c.set_source_expressions(
116 [
117 Coalesce(expression, Value(""))
118 for expression in c.get_source_expressions()
119 ]
120 )
121 return c
122
123
124class Concat(Func):
125 """
126 Concatenate text fields together. Backends that result in an entire
127 null expression when any arguments are null will wrap each argument in
128 coalesce functions to ensure a non-null result.
129 """
130
131 function = None
132 template = "%(expressions)s"
133
134 def __init__(self, *expressions, **extra):
135 if len(expressions) < 2:
136 raise ValueError("Concat must take at least two expressions")
137 paired = self._paired(expressions, output_field=extra.get("output_field"))
138 super().__init__(paired, **extra)
139
140 def _paired(self, expressions, output_field):
141 # wrap pairs of expressions in successive concat functions
142 # exp = [a, b, c, d]
143 # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
144 if len(expressions) == 2:
145 return ConcatPair(*expressions, output_field=output_field)
146 return ConcatPair(
147 expressions[0],
148 self._paired(expressions[1:], output_field=output_field),
149 output_field=output_field,
150 )
151
152
153class Left(Func):
154 function = "LEFT"
155 arity = 2
156 output_field = CharField()
157
158 def __init__(self, expression, length, **extra):
159 """
160 expression: the name of a field, or an expression returning a string
161 length: the number of characters to return from the start of the string
162 """
163 if not hasattr(length, "resolve_expression"):
164 if length < 1:
165 raise ValueError("'length' must be greater than 0.")
166 super().__init__(expression, length, **extra)
167
168 def get_substr(self):
169 return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
170
171 def as_oracle(self, compiler, connection, **extra_context):
172 return self.get_substr().as_oracle(compiler, connection, **extra_context)
173
174 def as_sqlite(self, compiler, connection, **extra_context):
175 return self.get_substr().as_sqlite(compiler, connection, **extra_context)
176
177
178class Length(Transform):
179 """Return the number of characters in the expression."""
180
181 function = "LENGTH"
182 lookup_name = "length"
183 output_field = IntegerField()
184
185 def as_mysql(self, compiler, connection, **extra_context):
186 return super().as_sql(
187 compiler, connection, function="CHAR_LENGTH", **extra_context
188 )
189
190
191class Lower(Transform):
192 function = "LOWER"
193 lookup_name = "lower"
194
195
196class LPad(Func):
197 function = "LPAD"
198 output_field = CharField()
199
200 def __init__(self, expression, length, fill_text=Value(" "), **extra):
201 if (
202 not hasattr(length, "resolve_expression")
203 and length is not None
204 and length < 0
205 ):
206 raise ValueError("'length' must be greater or equal to 0.")
207 super().__init__(expression, length, fill_text, **extra)
208
209
210class LTrim(Transform):
211 function = "LTRIM"
212 lookup_name = "ltrim"
213
214
215class MD5(OracleHashMixin, Transform):
216 function = "MD5"
217 lookup_name = "md5"
218
219
220class Ord(Transform):
221 function = "ASCII"
222 lookup_name = "ord"
223 output_field = IntegerField()
224
225 def as_mysql(self, compiler, connection, **extra_context):
226 return super().as_sql(compiler, connection, function="ORD", **extra_context)
227
228 def as_sqlite(self, compiler, connection, **extra_context):
229 return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
230
231
232class Repeat(Func):
233 function = "REPEAT"
234 output_field = CharField()
235
236 def __init__(self, expression, number, **extra):
237 if (
238 not hasattr(number, "resolve_expression")
239 and number is not None
240 and number < 0
241 ):
242 raise ValueError("'number' must be greater or equal to 0.")
243 super().__init__(expression, number, **extra)
244
245 def as_oracle(self, compiler, connection, **extra_context):
246 expression, number = self.source_expressions
247 length = None if number is None else Length(expression) * number
248 rpad = RPad(expression, length, expression)
249 return rpad.as_sql(compiler, connection, **extra_context)
250
251
252class Replace(Func):
253 function = "REPLACE"
254
255 def __init__(self, expression, text, replacement=Value(""), **extra):
256 super().__init__(expression, text, replacement, **extra)
257
258
259class Reverse(Transform):
260 function = "REVERSE"
261 lookup_name = "reverse"
262
263 def as_oracle(self, compiler, connection, **extra_context):
264 # REVERSE in Oracle is undocumented and doesn't support multi-byte
265 # strings. Use a special subquery instead.
266 suffix = connection.features.bare_select_suffix
267 sql, params = super().as_sql(
268 compiler,
269 connection,
270 template=(
271 "(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM "
272 f"(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s{suffix} "
273 "CONNECT BY LEVEL <= LENGTH(%(expressions)s)) "
274 "GROUP BY %(expressions)s)"
275 ),
276 **extra_context,
277 )
278 return sql, params * 3
279
280
281class Right(Left):
282 function = "RIGHT"
283
284 def get_substr(self):
285 return Substr(
286 self.source_expressions[0],
287 self.source_expressions[1] * Value(-1),
288 self.source_expressions[1],
289 )
290
291
292class RPad(LPad):
293 function = "RPAD"
294
295
296class RTrim(Transform):
297 function = "RTRIM"
298 lookup_name = "rtrim"
299
300
301class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
302 function = "SHA1"
303 lookup_name = "sha1"
304
305
306class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
307 function = "SHA224"
308 lookup_name = "sha224"
309
310 def as_oracle(self, compiler, connection, **extra_context):
311 raise NotSupportedError("SHA224 is not supported on Oracle.")
312
313
314class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
315 function = "SHA256"
316 lookup_name = "sha256"
317
318
319class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
320 function = "SHA384"
321 lookup_name = "sha384"
322
323
324class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
325 function = "SHA512"
326 lookup_name = "sha512"
327
328
329class StrIndex(Func):
330 """
331 Return a positive integer corresponding to the 1-indexed position of the
332 first occurrence of a substring inside another string, or 0 if the
333 substring is not found.
334 """
335
336 function = "INSTR"
337 arity = 2
338 output_field = IntegerField()
339
340 def as_postgresql(self, compiler, connection, **extra_context):
341 return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
342
343
344class Substr(Func):
345 function = "SUBSTRING"
346 output_field = CharField()
347
348 def __init__(self, expression, pos, length=None, **extra):
349 """
350 expression: the name of a field, or an expression returning a string
351 pos: an integer > 0, or an expression returning an integer
352 length: an optional number of characters to return
353 """
354 if not hasattr(pos, "resolve_expression"):
355 if pos < 1:
356 raise ValueError("'pos' must be greater than 0")
357 expressions = [expression, pos]
358 if length is not None:
359 expressions.append(length)
360 super().__init__(*expressions, **extra)
361
362 def as_sqlite(self, compiler, connection, **extra_context):
363 return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
364
365 def as_oracle(self, compiler, connection, **extra_context):
366 return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
367
368
369class Trim(Transform):
370 function = "TRIM"
371 lookup_name = "trim"
372
373
374class Upper(Transform):
375 function = "UPPER"
376 lookup_name = "upper"