Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/django/db/models/functions/text.py: 66%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

173 statements  

1from django.db import NotSupportedError 

2from django.db.models.expressions import Func, Value 

3from django.db.models.fields import CharField, IntegerField, TextField 

4from django.db.models.functions import Cast, Coalesce 

5from django.db.models.lookups import Transform 

6 

7 

8class MySQLSHA2Mixin: 

9 def as_mysql(self, compiler, connection, **extra_context): 

10 return super().as_sql( 

11 compiler, 

12 connection, 

13 template="SHA2(%%(expressions)s, %s)" % self.function[3:], 

14 **extra_context, 

15 ) 

16 

17 

18class OracleHashMixin: 

19 def as_oracle(self, compiler, connection, **extra_context): 

20 return super().as_sql( 

21 compiler, 

22 connection, 

23 template=( 

24 "LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW(" 

25 "%(expressions)s, 'AL32UTF8'), '%(function)s')))" 

26 ), 

27 **extra_context, 

28 ) 

29 

30 

31class PostgreSQLSHAMixin: 

32 def as_postgresql(self, compiler, connection, **extra_context): 

33 return super().as_sql( 

34 compiler, 

35 connection, 

36 template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')", 

37 function=self.function.lower(), 

38 **extra_context, 

39 ) 

40 

41 

42class Chr(Transform): 

43 function = "CHR" 

44 lookup_name = "chr" 

45 output_field = CharField() 

46 

47 def as_mysql(self, compiler, connection, **extra_context): 

48 return super().as_sql( 

49 compiler, 

50 connection, 

51 function="CHAR", 

52 template="%(function)s(%(expressions)s USING utf16)", 

53 **extra_context, 

54 ) 

55 

56 def as_oracle(self, compiler, connection, **extra_context): 

57 return super().as_sql( 

58 compiler, 

59 connection, 

60 template="%(function)s(%(expressions)s USING NCHAR_CS)", 

61 **extra_context, 

62 ) 

63 

64 def as_sqlite(self, compiler, connection, **extra_context): 

65 return super().as_sql(compiler, connection, function="CHAR", **extra_context) 

66 

67 

68class ConcatPair(Func): 

69 """ 

70 Concatenate two arguments together. This is used by `Concat` because not 

71 all backend databases support more than two arguments. 

72 """ 

73 

74 function = "CONCAT" 

75 

76 def pipes_concat_sql(self, compiler, connection, **extra_context): 

77 coalesced = self.coalesce() 

78 return super(ConcatPair, coalesced).as_sql( 

79 compiler, 

80 connection, 

81 template="(%(expressions)s)", 

82 arg_joiner=" || ", 

83 **extra_context, 

84 ) 

85 

86 as_sqlite = pipes_concat_sql 

87 

88 def as_postgresql(self, compiler, connection, **extra_context): 

89 c = self.copy() 

90 c.set_source_expressions( 

91 [ 

92 ( 

93 expression 

94 if isinstance(expression.output_field, (CharField, TextField)) 

95 else Cast(expression, TextField()) 

96 ) 

97 for expression in c.get_source_expressions() 

98 ] 

99 ) 

100 return c.pipes_concat_sql(compiler, connection, **extra_context) 

101 

102 def as_mysql(self, compiler, connection, **extra_context): 

103 # Use CONCAT_WS with an empty separator so that NULLs are ignored. 

104 return super().as_sql( 

105 compiler, 

106 connection, 

107 function="CONCAT_WS", 

108 template="%(function)s('', %(expressions)s)", 

109 **extra_context, 

110 ) 

111 

112 def coalesce(self): 

113 # null on either side results in null for expression, wrap with coalesce 

114 c = self.copy() 

115 c.set_source_expressions( 

116 [ 

117 Coalesce(expression, Value("")) 

118 for expression in c.get_source_expressions() 

119 ] 

120 ) 

121 return c 

122 

123 

124class Concat(Func): 

125 """ 

126 Concatenate text fields together. Backends that result in an entire 

127 null expression when any arguments are null will wrap each argument in 

128 coalesce functions to ensure a non-null result. 

129 """ 

130 

131 function = None 

132 template = "%(expressions)s" 

133 

134 def __init__(self, *expressions, **extra): 

135 if len(expressions) < 2: 

136 raise ValueError("Concat must take at least two expressions") 

137 paired = self._paired(expressions, output_field=extra.get("output_field")) 

138 super().__init__(paired, **extra) 

139 

140 def _paired(self, expressions, output_field): 

141 # wrap pairs of expressions in successive concat functions 

142 # exp = [a, b, c, d] 

143 # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d)))) 

144 if len(expressions) == 2: 

145 return ConcatPair(*expressions, output_field=output_field) 

146 return ConcatPair( 

147 expressions[0], 

148 self._paired(expressions[1:], output_field=output_field), 

149 output_field=output_field, 

150 ) 

151 

152 

153class Left(Func): 

154 function = "LEFT" 

155 arity = 2 

156 output_field = CharField() 

157 

158 def __init__(self, expression, length, **extra): 

159 """ 

160 expression: the name of a field, or an expression returning a string 

161 length: the number of characters to return from the start of the string 

162 """ 

163 if not hasattr(length, "resolve_expression"): 

164 if length < 1: 

165 raise ValueError("'length' must be greater than 0.") 

166 super().__init__(expression, length, **extra) 

167 

168 def get_substr(self): 

169 return Substr(self.source_expressions[0], Value(1), self.source_expressions[1]) 

170 

171 def as_oracle(self, compiler, connection, **extra_context): 

172 return self.get_substr().as_oracle(compiler, connection, **extra_context) 

173 

174 def as_sqlite(self, compiler, connection, **extra_context): 

175 return self.get_substr().as_sqlite(compiler, connection, **extra_context) 

176 

177 

178class Length(Transform): 

179 """Return the number of characters in the expression.""" 

180 

181 function = "LENGTH" 

182 lookup_name = "length" 

183 output_field = IntegerField() 

184 

185 def as_mysql(self, compiler, connection, **extra_context): 

186 return super().as_sql( 

187 compiler, connection, function="CHAR_LENGTH", **extra_context 

188 ) 

189 

190 

191class Lower(Transform): 

192 function = "LOWER" 

193 lookup_name = "lower" 

194 

195 

196class LPad(Func): 

197 function = "LPAD" 

198 output_field = CharField() 

199 

200 def __init__(self, expression, length, fill_text=Value(" "), **extra): 

201 if ( 

202 not hasattr(length, "resolve_expression") 

203 and length is not None 

204 and length < 0 

205 ): 

206 raise ValueError("'length' must be greater or equal to 0.") 

207 super().__init__(expression, length, fill_text, **extra) 

208 

209 

210class LTrim(Transform): 

211 function = "LTRIM" 

212 lookup_name = "ltrim" 

213 

214 

215class MD5(OracleHashMixin, Transform): 

216 function = "MD5" 

217 lookup_name = "md5" 

218 

219 

220class Ord(Transform): 

221 function = "ASCII" 

222 lookup_name = "ord" 

223 output_field = IntegerField() 

224 

225 def as_mysql(self, compiler, connection, **extra_context): 

226 return super().as_sql(compiler, connection, function="ORD", **extra_context) 

227 

228 def as_sqlite(self, compiler, connection, **extra_context): 

229 return super().as_sql(compiler, connection, function="UNICODE", **extra_context) 

230 

231 

232class Repeat(Func): 

233 function = "REPEAT" 

234 output_field = CharField() 

235 

236 def __init__(self, expression, number, **extra): 

237 if ( 

238 not hasattr(number, "resolve_expression") 

239 and number is not None 

240 and number < 0 

241 ): 

242 raise ValueError("'number' must be greater or equal to 0.") 

243 super().__init__(expression, number, **extra) 

244 

245 def as_oracle(self, compiler, connection, **extra_context): 

246 expression, number = self.source_expressions 

247 length = None if number is None else Length(expression) * number 

248 rpad = RPad(expression, length, expression) 

249 return rpad.as_sql(compiler, connection, **extra_context) 

250 

251 

252class Replace(Func): 

253 function = "REPLACE" 

254 

255 def __init__(self, expression, text, replacement=Value(""), **extra): 

256 super().__init__(expression, text, replacement, **extra) 

257 

258 

259class Reverse(Transform): 

260 function = "REVERSE" 

261 lookup_name = "reverse" 

262 

263 def as_oracle(self, compiler, connection, **extra_context): 

264 # REVERSE in Oracle is undocumented and doesn't support multi-byte 

265 # strings. Use a special subquery instead. 

266 suffix = connection.features.bare_select_suffix 

267 sql, params = super().as_sql( 

268 compiler, 

269 connection, 

270 template=( 

271 "(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM " 

272 f"(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s{suffix} " 

273 "CONNECT BY LEVEL <= LENGTH(%(expressions)s)) " 

274 "GROUP BY %(expressions)s)" 

275 ), 

276 **extra_context, 

277 ) 

278 return sql, params * 3 

279 

280 

281class Right(Left): 

282 function = "RIGHT" 

283 

284 def get_substr(self): 

285 return Substr( 

286 self.source_expressions[0], 

287 self.source_expressions[1] * Value(-1), 

288 self.source_expressions[1], 

289 ) 

290 

291 

292class RPad(LPad): 

293 function = "RPAD" 

294 

295 

296class RTrim(Transform): 

297 function = "RTRIM" 

298 lookup_name = "rtrim" 

299 

300 

301class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform): 

302 function = "SHA1" 

303 lookup_name = "sha1" 

304 

305 

306class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform): 

307 function = "SHA224" 

308 lookup_name = "sha224" 

309 

310 def as_oracle(self, compiler, connection, **extra_context): 

311 raise NotSupportedError("SHA224 is not supported on Oracle.") 

312 

313 

314class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): 

315 function = "SHA256" 

316 lookup_name = "sha256" 

317 

318 

319class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): 

320 function = "SHA384" 

321 lookup_name = "sha384" 

322 

323 

324class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): 

325 function = "SHA512" 

326 lookup_name = "sha512" 

327 

328 

329class StrIndex(Func): 

330 """ 

331 Return a positive integer corresponding to the 1-indexed position of the 

332 first occurrence of a substring inside another string, or 0 if the 

333 substring is not found. 

334 """ 

335 

336 function = "INSTR" 

337 arity = 2 

338 output_field = IntegerField() 

339 

340 def as_postgresql(self, compiler, connection, **extra_context): 

341 return super().as_sql(compiler, connection, function="STRPOS", **extra_context) 

342 

343 

344class Substr(Func): 

345 function = "SUBSTRING" 

346 output_field = CharField() 

347 

348 def __init__(self, expression, pos, length=None, **extra): 

349 """ 

350 expression: the name of a field, or an expression returning a string 

351 pos: an integer > 0, or an expression returning an integer 

352 length: an optional number of characters to return 

353 """ 

354 if not hasattr(pos, "resolve_expression"): 

355 if pos < 1: 

356 raise ValueError("'pos' must be greater than 0") 

357 expressions = [expression, pos] 

358 if length is not None: 

359 expressions.append(length) 

360 super().__init__(*expressions, **extra) 

361 

362 def as_sqlite(self, compiler, connection, **extra_context): 

363 return super().as_sql(compiler, connection, function="SUBSTR", **extra_context) 

364 

365 def as_oracle(self, compiler, connection, **extra_context): 

366 return super().as_sql(compiler, connection, function="SUBSTR", **extra_context) 

367 

368 

369class Trim(Transform): 

370 function = "TRIM" 

371 lookup_name = "trim" 

372 

373 

374class Upper(Transform): 

375 function = "UPPER" 

376 lookup_name = "upper"