Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/django/db/models/functions/text.py: 67%

169 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 06:13 +0000

1from django.db import NotSupportedError 

2from django.db.models.expressions import Func, Value 

3from django.db.models.fields import CharField, IntegerField, TextField 

4from django.db.models.functions import Cast, Coalesce 

5from django.db.models.lookups import Transform 

6 

7 

8class MySQLSHA2Mixin: 

9 def as_mysql(self, compiler, connection, **extra_context): 

10 return super().as_sql( 

11 compiler, 

12 connection, 

13 template="SHA2(%%(expressions)s, %s)" % self.function[3:], 

14 **extra_context, 

15 ) 

16 

17 

18class OracleHashMixin: 

19 def as_oracle(self, compiler, connection, **extra_context): 

20 return super().as_sql( 

21 compiler, 

22 connection, 

23 template=( 

24 "LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW(" 

25 "%(expressions)s, 'AL32UTF8'), '%(function)s')))" 

26 ), 

27 **extra_context, 

28 ) 

29 

30 

31class PostgreSQLSHAMixin: 

32 def as_postgresql(self, compiler, connection, **extra_context): 

33 return super().as_sql( 

34 compiler, 

35 connection, 

36 template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')", 

37 function=self.function.lower(), 

38 **extra_context, 

39 ) 

40 

41 

42class Chr(Transform): 

43 function = "CHR" 

44 lookup_name = "chr" 

45 

46 def as_mysql(self, compiler, connection, **extra_context): 

47 return super().as_sql( 

48 compiler, 

49 connection, 

50 function="CHAR", 

51 template="%(function)s(%(expressions)s USING utf16)", 

52 **extra_context, 

53 ) 

54 

55 def as_oracle(self, compiler, connection, **extra_context): 

56 return super().as_sql( 

57 compiler, 

58 connection, 

59 template="%(function)s(%(expressions)s USING NCHAR_CS)", 

60 **extra_context, 

61 ) 

62 

63 def as_sqlite(self, compiler, connection, **extra_context): 

64 return super().as_sql(compiler, connection, function="CHAR", **extra_context) 

65 

66 

67class ConcatPair(Func): 

68 """ 

69 Concatenate two arguments together. This is used by `Concat` because not 

70 all backend databases support more than two arguments. 

71 """ 

72 

73 function = "CONCAT" 

74 

75 def as_sqlite(self, compiler, connection, **extra_context): 

76 coalesced = self.coalesce() 

77 return super(ConcatPair, coalesced).as_sql( 

78 compiler, 

79 connection, 

80 template="%(expressions)s", 

81 arg_joiner=" || ", 

82 **extra_context, 

83 ) 

84 

85 def as_postgresql(self, compiler, connection, **extra_context): 

86 copy = self.copy() 

87 copy.set_source_expressions( 

88 [ 

89 Cast(expression, TextField()) 

90 for expression in copy.get_source_expressions() 

91 ] 

92 ) 

93 return super(ConcatPair, copy).as_sql( 

94 compiler, 

95 connection, 

96 **extra_context, 

97 ) 

98 

99 def as_mysql(self, compiler, connection, **extra_context): 

100 # Use CONCAT_WS with an empty separator so that NULLs are ignored. 

101 return super().as_sql( 

102 compiler, 

103 connection, 

104 function="CONCAT_WS", 

105 template="%(function)s('', %(expressions)s)", 

106 **extra_context, 

107 ) 

108 

109 def coalesce(self): 

110 # null on either side results in null for expression, wrap with coalesce 

111 c = self.copy() 

112 c.set_source_expressions( 

113 [ 

114 Coalesce(expression, Value("")) 

115 for expression in c.get_source_expressions() 

116 ] 

117 ) 

118 return c 

119 

120 

121class Concat(Func): 

122 """ 

123 Concatenate text fields together. Backends that result in an entire 

124 null expression when any arguments are null will wrap each argument in 

125 coalesce functions to ensure a non-null result. 

126 """ 

127 

128 function = None 

129 template = "%(expressions)s" 

130 

131 def __init__(self, *expressions, **extra): 

132 if len(expressions) < 2: 

133 raise ValueError("Concat must take at least two expressions") 

134 paired = self._paired(expressions) 

135 super().__init__(paired, **extra) 

136 

137 def _paired(self, expressions): 

138 # wrap pairs of expressions in successive concat functions 

139 # exp = [a, b, c, d] 

140 # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d)))) 

141 if len(expressions) == 2: 

142 return ConcatPair(*expressions) 

143 return ConcatPair(expressions[0], self._paired(expressions[1:])) 

144 

145 

146class Left(Func): 

147 function = "LEFT" 

148 arity = 2 

149 output_field = CharField() 

150 

151 def __init__(self, expression, length, **extra): 

152 """ 

153 expression: the name of a field, or an expression returning a string 

154 length: the number of characters to return from the start of the string 

155 """ 

156 if not hasattr(length, "resolve_expression"): 

157 if length < 1: 

158 raise ValueError("'length' must be greater than 0.") 

159 super().__init__(expression, length, **extra) 

160 

161 def get_substr(self): 

162 return Substr(self.source_expressions[0], Value(1), self.source_expressions[1]) 

163 

164 def as_oracle(self, compiler, connection, **extra_context): 

165 return self.get_substr().as_oracle(compiler, connection, **extra_context) 

166 

167 def as_sqlite(self, compiler, connection, **extra_context): 

168 return self.get_substr().as_sqlite(compiler, connection, **extra_context) 

169 

170 

171class Length(Transform): 

172 """Return the number of characters in the expression.""" 

173 

174 function = "LENGTH" 

175 lookup_name = "length" 

176 output_field = IntegerField() 

177 

178 def as_mysql(self, compiler, connection, **extra_context): 

179 return super().as_sql( 

180 compiler, connection, function="CHAR_LENGTH", **extra_context 

181 ) 

182 

183 

184class Lower(Transform): 

185 function = "LOWER" 

186 lookup_name = "lower" 

187 

188 

189class LPad(Func): 

190 function = "LPAD" 

191 output_field = CharField() 

192 

193 def __init__(self, expression, length, fill_text=Value(" "), **extra): 

194 if ( 

195 not hasattr(length, "resolve_expression") 

196 and length is not None 

197 and length < 0 

198 ): 

199 raise ValueError("'length' must be greater or equal to 0.") 

200 super().__init__(expression, length, fill_text, **extra) 

201 

202 

203class LTrim(Transform): 

204 function = "LTRIM" 

205 lookup_name = "ltrim" 

206 

207 

208class MD5(OracleHashMixin, Transform): 

209 function = "MD5" 

210 lookup_name = "md5" 

211 

212 

213class Ord(Transform): 

214 function = "ASCII" 

215 lookup_name = "ord" 

216 output_field = IntegerField() 

217 

218 def as_mysql(self, compiler, connection, **extra_context): 

219 return super().as_sql(compiler, connection, function="ORD", **extra_context) 

220 

221 def as_sqlite(self, compiler, connection, **extra_context): 

222 return super().as_sql(compiler, connection, function="UNICODE", **extra_context) 

223 

224 

225class Repeat(Func): 

226 function = "REPEAT" 

227 output_field = CharField() 

228 

229 def __init__(self, expression, number, **extra): 

230 if ( 

231 not hasattr(number, "resolve_expression") 

232 and number is not None 

233 and number < 0 

234 ): 

235 raise ValueError("'number' must be greater or equal to 0.") 

236 super().__init__(expression, number, **extra) 

237 

238 def as_oracle(self, compiler, connection, **extra_context): 

239 expression, number = self.source_expressions 

240 length = None if number is None else Length(expression) * number 

241 rpad = RPad(expression, length, expression) 

242 return rpad.as_sql(compiler, connection, **extra_context) 

243 

244 

245class Replace(Func): 

246 function = "REPLACE" 

247 

248 def __init__(self, expression, text, replacement=Value(""), **extra): 

249 super().__init__(expression, text, replacement, **extra) 

250 

251 

252class Reverse(Transform): 

253 function = "REVERSE" 

254 lookup_name = "reverse" 

255 

256 def as_oracle(self, compiler, connection, **extra_context): 

257 # REVERSE in Oracle is undocumented and doesn't support multi-byte 

258 # strings. Use a special subquery instead. 

259 return super().as_sql( 

260 compiler, 

261 connection, 

262 template=( 

263 "(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM " 

264 "(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s " 

265 "FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) " 

266 "GROUP BY %(expressions)s)" 

267 ), 

268 **extra_context, 

269 ) 

270 

271 

272class Right(Left): 

273 function = "RIGHT" 

274 

275 def get_substr(self): 

276 return Substr( 

277 self.source_expressions[0], self.source_expressions[1] * Value(-1) 

278 ) 

279 

280 

281class RPad(LPad): 

282 function = "RPAD" 

283 

284 

285class RTrim(Transform): 

286 function = "RTRIM" 

287 lookup_name = "rtrim" 

288 

289 

290class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform): 

291 function = "SHA1" 

292 lookup_name = "sha1" 

293 

294 

295class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform): 

296 function = "SHA224" 

297 lookup_name = "sha224" 

298 

299 def as_oracle(self, compiler, connection, **extra_context): 

300 raise NotSupportedError("SHA224 is not supported on Oracle.") 

301 

302 

303class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): 

304 function = "SHA256" 

305 lookup_name = "sha256" 

306 

307 

308class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): 

309 function = "SHA384" 

310 lookup_name = "sha384" 

311 

312 

313class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform): 

314 function = "SHA512" 

315 lookup_name = "sha512" 

316 

317 

318class StrIndex(Func): 

319 """ 

320 Return a positive integer corresponding to the 1-indexed position of the 

321 first occurrence of a substring inside another string, or 0 if the 

322 substring is not found. 

323 """ 

324 

325 function = "INSTR" 

326 arity = 2 

327 output_field = IntegerField() 

328 

329 def as_postgresql(self, compiler, connection, **extra_context): 

330 return super().as_sql(compiler, connection, function="STRPOS", **extra_context) 

331 

332 

333class Substr(Func): 

334 function = "SUBSTRING" 

335 output_field = CharField() 

336 

337 def __init__(self, expression, pos, length=None, **extra): 

338 """ 

339 expression: the name of a field, or an expression returning a string 

340 pos: an integer > 0, or an expression returning an integer 

341 length: an optional number of characters to return 

342 """ 

343 if not hasattr(pos, "resolve_expression"): 

344 if pos < 1: 

345 raise ValueError("'pos' must be greater than 0") 

346 expressions = [expression, pos] 

347 if length is not None: 

348 expressions.append(length) 

349 super().__init__(*expressions, **extra) 

350 

351 def as_sqlite(self, compiler, connection, **extra_context): 

352 return super().as_sql(compiler, connection, function="SUBSTR", **extra_context) 

353 

354 def as_oracle(self, compiler, connection, **extra_context): 

355 return super().as_sql(compiler, connection, function="SUBSTR", **extra_context) 

356 

357 

358class Trim(Transform): 

359 function = "TRIM" 

360 lookup_name = "trim" 

361 

362 

363class Upper(Transform): 

364 function = "UPPER" 

365 lookup_name = "upper"