1#
2# Copyright (C) 2009-2020 the sqlparse authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of python-sqlparse and is released under
6# the BSD License: https://opensource.org/licenses/BSD-3-Clause
7
8from sqlparse import sql
9from sqlparse import tokens as T
10from sqlparse.exceptions import SQLParseError
11from sqlparse.utils import recurse, imt
12
13# Maximum recursion depth for grouping operations to prevent DoS attacks
14# Set to None to disable limit (not recommended for untrusted input)
15MAX_GROUPING_DEPTH = 100
16
17# Maximum number of tokens to process in one grouping operation to prevent
18# DoS attacks.
19# Set to None to disable limit (not recommended for untrusted input)
20MAX_GROUPING_TOKENS = 10000
21
22T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float)
23T_STRING = (T.String, T.String.Single, T.String.Symbol)
24T_NAME = (T.Name, T.Name.Placeholder)
25
26
27def _group_matching(tlist, cls, depth=0):
28 """Groups Tokens that have beginning and end."""
29 if MAX_GROUPING_DEPTH is not None and depth > MAX_GROUPING_DEPTH:
30 raise SQLParseError(
31 f"Maximum grouping depth exceeded ({MAX_GROUPING_DEPTH})."
32 )
33
34 # Limit the number of tokens to prevent DoS attacks
35 if MAX_GROUPING_TOKENS is not None \
36 and len(tlist.tokens) > MAX_GROUPING_TOKENS:
37 raise SQLParseError(
38 f"Maximum number of tokens exceeded ({MAX_GROUPING_TOKENS})."
39 )
40
41 opens = []
42 tidx_offset = 0
43 token_list = list(tlist)
44
45 for idx, token in enumerate(token_list):
46 tidx = idx - tidx_offset
47
48 if token.is_whitespace:
49 # ~50% of tokens will be whitespace. Will checking early
50 # for them avoid 3 comparisons, but then add 1 more comparison
51 # for the other ~50% of tokens...
52 continue
53
54 if token.is_group and not isinstance(token, cls):
55 # Check inside previously grouped (i.e. parenthesis) if group
56 # of different type is inside (i.e., case). though ideally should
57 # should check for all open/close tokens at once to avoid recursion
58 _group_matching(token, cls, depth + 1)
59 continue
60
61 if token.match(*cls.M_OPEN):
62 opens.append(tidx)
63
64 elif token.match(*cls.M_CLOSE):
65 try:
66 open_idx = opens.pop()
67 except IndexError:
68 # this indicates invalid sql and unbalanced tokens.
69 # instead of break, continue in case other "valid" groups exist
70 continue
71 close_idx = tidx
72 tlist.group_tokens(cls, open_idx, close_idx)
73 tidx_offset += close_idx - open_idx
74
75
76def group_brackets(tlist):
77 _group_matching(tlist, sql.SquareBrackets)
78
79
80def group_parenthesis(tlist):
81 _group_matching(tlist, sql.Parenthesis)
82
83
84def group_case(tlist):
85 _group_matching(tlist, sql.Case)
86
87
88def group_if(tlist):
89 _group_matching(tlist, sql.If)
90
91
92def group_for(tlist):
93 _group_matching(tlist, sql.For)
94
95
96def group_begin(tlist):
97 _group_matching(tlist, sql.Begin)
98
99
100def group_typecasts(tlist):
101 def match(token):
102 return token.match(T.Punctuation, '::')
103
104 def valid(token):
105 return token is not None
106
107 def post(tlist, pidx, tidx, nidx):
108 return pidx, nidx
109
110 valid_prev = valid_next = valid
111 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
112
113
114def group_tzcasts(tlist):
115 def match(token):
116 return token.ttype == T.Keyword.TZCast
117
118 def valid_prev(token):
119 return token is not None
120
121 def valid_next(token):
122 return token is not None and (
123 token.is_whitespace
124 or token.match(T.Keyword, 'AS')
125 or token.match(*sql.TypedLiteral.M_CLOSE)
126 )
127
128 def post(tlist, pidx, tidx, nidx):
129 return pidx, nidx
130
131 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
132
133
134def group_typed_literal(tlist):
135 # definitely not complete, see e.g.:
136 # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literal-syntax
137 # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literals
138 # https://www.postgresql.org/docs/9.1/datatype-datetime.html
139 # https://www.postgresql.org/docs/9.1/functions-datetime.html
140 def match(token):
141 return imt(token, m=sql.TypedLiteral.M_OPEN)
142
143 def match_to_extend(token):
144 return isinstance(token, sql.TypedLiteral)
145
146 def valid_prev(token):
147 return token is not None
148
149 def valid_next(token):
150 return token is not None and token.match(*sql.TypedLiteral.M_CLOSE)
151
152 def valid_final(token):
153 return token is not None and token.match(*sql.TypedLiteral.M_EXTEND)
154
155 def post(tlist, pidx, tidx, nidx):
156 return tidx, nidx
157
158 _group(tlist, sql.TypedLiteral, match, valid_prev, valid_next,
159 post, extend=False)
160 _group(tlist, sql.TypedLiteral, match_to_extend, valid_prev, valid_final,
161 post, extend=True)
162
163
164def group_period(tlist):
165 def match(token):
166 for ttype, value in ((T.Punctuation, '.'),
167 (T.Operator, '->'),
168 (T.Operator, '->>')):
169 if token.match(ttype, value):
170 return True
171 return False
172
173 def valid_prev(token):
174 sqlcls = sql.SquareBrackets, sql.Identifier
175 ttypes = T.Name, T.String.Symbol
176 return imt(token, i=sqlcls, t=ttypes)
177
178 def valid_next(token):
179 # issue261, allow invalid next token
180 return True
181
182 def post(tlist, pidx, tidx, nidx):
183 # next_ validation is being performed here. issue261
184 sqlcls = sql.SquareBrackets, sql.Function
185 ttypes = T.Name, T.String.Symbol, T.Wildcard, T.String.Single
186 next_ = tlist[nidx] if nidx is not None else None
187 valid_next = imt(next_, i=sqlcls, t=ttypes)
188
189 return (pidx, nidx) if valid_next else (pidx, tidx)
190
191 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
192
193
194def group_as(tlist):
195 def match(token):
196 return token.is_keyword and token.normalized == 'AS'
197
198 def valid_prev(token):
199 return token.normalized == 'NULL' or not token.is_keyword
200
201 def valid_next(token):
202 ttypes = T.DML, T.DDL, T.CTE
203 return not imt(token, t=ttypes) and token is not None
204
205 def post(tlist, pidx, tidx, nidx):
206 return pidx, nidx
207
208 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
209
210
211def group_assignment(tlist):
212 def match(token):
213 return token.match(T.Assignment, ':=')
214
215 def valid(token):
216 return token is not None and token.ttype not in (T.Keyword,)
217
218 def post(tlist, pidx, tidx, nidx):
219 m_semicolon = T.Punctuation, ';'
220 snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx)
221 nidx = snidx or nidx
222 return pidx, nidx
223
224 valid_prev = valid_next = valid
225 _group(tlist, sql.Assignment, match, valid_prev, valid_next, post)
226
227
228def group_comparison(tlist):
229 sqlcls = (sql.Parenthesis, sql.Function, sql.Identifier,
230 sql.Operation, sql.TypedLiteral)
231 ttypes = T_NUMERICAL + T_STRING + T_NAME
232
233 def match(token):
234 return token.ttype == T.Operator.Comparison
235
236 def valid(token):
237 if imt(token, t=ttypes, i=sqlcls):
238 return True
239 elif token and token.is_keyword and token.normalized == 'NULL':
240 return True
241 else:
242 return False
243
244 def post(tlist, pidx, tidx, nidx):
245 return pidx, nidx
246
247 valid_prev = valid_next = valid
248 _group(tlist, sql.Comparison, match,
249 valid_prev, valid_next, post, extend=False)
250
251
252@recurse(sql.Identifier)
253def group_identifier(tlist):
254 ttypes = (T.String.Symbol, T.Name)
255
256 tidx, token = tlist.token_next_by(t=ttypes)
257 while token:
258 tlist.group_tokens(sql.Identifier, tidx, tidx)
259 tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
260
261
262@recurse(sql.Over)
263def group_over(tlist):
264 tidx, token = tlist.token_next_by(m=sql.Over.M_OPEN)
265 while token:
266 nidx, next_ = tlist.token_next(tidx)
267 if imt(next_, i=sql.Parenthesis, t=T.Name):
268 tlist.group_tokens(sql.Over, tidx, nidx)
269 tidx, token = tlist.token_next_by(m=sql.Over.M_OPEN, idx=tidx)
270
271
272def group_arrays(tlist):
273 sqlcls = sql.SquareBrackets, sql.Identifier, sql.Function
274 ttypes = T.Name, T.String.Symbol
275
276 def match(token):
277 return isinstance(token, sql.SquareBrackets)
278
279 def valid_prev(token):
280 return imt(token, i=sqlcls, t=ttypes)
281
282 def valid_next(token):
283 return True
284
285 def post(tlist, pidx, tidx, nidx):
286 return pidx, tidx
287
288 _group(tlist, sql.Identifier, match,
289 valid_prev, valid_next, post, extend=True, recurse=False)
290
291
292def group_operator(tlist):
293 ttypes = T_NUMERICAL + T_STRING + T_NAME
294 sqlcls = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
295 sql.Identifier, sql.Operation, sql.TypedLiteral)
296
297 def match(token):
298 return imt(token, t=(T.Operator, T.Wildcard))
299
300 def valid(token):
301 return imt(token, i=sqlcls, t=ttypes) \
302 or (token and token.match(
303 T.Keyword,
304 ('CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP')))
305
306 def post(tlist, pidx, tidx, nidx):
307 tlist[tidx].ttype = T.Operator
308 return pidx, nidx
309
310 valid_prev = valid_next = valid
311 _group(tlist, sql.Operation, match,
312 valid_prev, valid_next, post, extend=False)
313
314
315def group_identifier_list(tlist):
316 m_role = T.Keyword, ('null', 'role')
317 sqlcls = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
318 sql.IdentifierList, sql.Operation)
319 ttypes = (T_NUMERICAL + T_STRING + T_NAME
320 + (T.Keyword, T.Comment, T.Wildcard))
321
322 def match(token):
323 return token.match(T.Punctuation, ',')
324
325 def valid(token):
326 return imt(token, i=sqlcls, m=m_role, t=ttypes)
327
328 def post(tlist, pidx, tidx, nidx):
329 return pidx, nidx
330
331 valid_prev = valid_next = valid
332 _group(tlist, sql.IdentifierList, match,
333 valid_prev, valid_next, post, extend=True)
334
335
336@recurse(sql.Comment)
337def group_comments(tlist):
338 tidx, token = tlist.token_next_by(t=T.Comment)
339 while token:
340 eidx, end = tlist.token_not_matching(
341 lambda tk: imt(tk, t=T.Comment) or tk.is_newline, idx=tidx)
342 if end is not None:
343 eidx, end = tlist.token_prev(eidx, skip_ws=False)
344 tlist.group_tokens(sql.Comment, tidx, eidx)
345
346 tidx, token = tlist.token_next_by(t=T.Comment, idx=tidx)
347
348
349@recurse(sql.Where)
350def group_where(tlist):
351 tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN)
352 while token:
353 eidx, end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=tidx)
354
355 if end is None:
356 end = tlist._groupable_tokens[-1]
357 else:
358 end = tlist.tokens[eidx - 1]
359 # TODO: convert this to eidx instead of end token.
360 # i think above values are len(tlist) and eidx-1
361 eidx = tlist.token_index(end)
362 tlist.group_tokens(sql.Where, tidx, eidx)
363 tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=tidx)
364
365
366@recurse()
367def group_aliased(tlist):
368 I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier,
369 sql.Operation, sql.Comparison)
370
371 tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number)
372 while token:
373 nidx, next_ = tlist.token_next(tidx)
374 if isinstance(next_, sql.Identifier):
375 tlist.group_tokens(sql.Identifier, tidx, nidx, extend=True)
376 tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx)
377
378
379@recurse(sql.Function)
380def group_functions(tlist):
381 has_create = False
382 has_table = False
383 has_as = False
384 for tmp_token in tlist.tokens:
385 if tmp_token.value.upper() == 'CREATE':
386 has_create = True
387 if tmp_token.value.upper() == 'TABLE':
388 has_table = True
389 if tmp_token.value == 'AS':
390 has_as = True
391 if has_create and has_table and not has_as:
392 return
393
394 tidx, token = tlist.token_next_by(t=T.Name)
395 while token:
396 nidx, next_ = tlist.token_next(tidx)
397 if isinstance(next_, sql.Parenthesis):
398 over_idx, over = tlist.token_next(nidx)
399 if over and isinstance(over, sql.Over):
400 eidx = over_idx
401 else:
402 eidx = nidx
403 tlist.group_tokens(sql.Function, tidx, eidx)
404 tidx, token = tlist.token_next_by(t=T.Name, idx=tidx)
405
406
407@recurse(sql.Identifier)
408def group_order(tlist):
409 """Group together Identifier and Asc/Desc token"""
410 tidx, token = tlist.token_next_by(t=T.Keyword.Order)
411 while token:
412 pidx, prev_ = tlist.token_prev(tidx)
413 if imt(prev_, i=sql.Identifier, t=T.Number):
414 tlist.group_tokens(sql.Identifier, pidx, tidx)
415 tidx = pidx
416 tidx, token = tlist.token_next_by(t=T.Keyword.Order, idx=tidx)
417
418
419@recurse()
420def align_comments(tlist):
421 tidx, token = tlist.token_next_by(i=sql.Comment)
422 while token:
423 pidx, prev_ = tlist.token_prev(tidx)
424 if isinstance(prev_, sql.TokenList):
425 tlist.group_tokens(sql.TokenList, pidx, tidx, extend=True)
426 tidx = pidx
427 tidx, token = tlist.token_next_by(i=sql.Comment, idx=tidx)
428
429
430def group_values(tlist):
431 tidx, token = tlist.token_next_by(m=(T.Keyword, 'VALUES'))
432 start_idx = tidx
433 end_idx = -1
434 while token:
435 if isinstance(token, sql.Parenthesis):
436 end_idx = tidx
437 tidx, token = tlist.token_next(tidx)
438 if end_idx != -1:
439 tlist.group_tokens(sql.Values, start_idx, end_idx, extend=True)
440
441
442def group(stmt):
443 for func in [
444 group_comments,
445
446 # _group_matching
447 group_brackets,
448 group_parenthesis,
449 group_case,
450 group_if,
451 group_for,
452 group_begin,
453
454 group_over,
455 group_functions,
456 group_where,
457 group_period,
458 group_arrays,
459 group_identifier,
460 group_order,
461 group_typecasts,
462 group_tzcasts,
463 group_typed_literal,
464 group_operator,
465 group_comparison,
466 group_as,
467 group_aliased,
468 group_assignment,
469
470 align_comments,
471 group_identifier_list,
472 group_values,
473 ]:
474 func(stmt)
475 return stmt
476
477
478def _group(tlist, cls, match,
479 valid_prev=lambda t: True,
480 valid_next=lambda t: True,
481 post=None,
482 extend=True,
483 recurse=True,
484 depth=0
485 ):
486 """Groups together tokens that are joined by a middle token. i.e. x < y"""
487 if MAX_GROUPING_DEPTH is not None and depth > MAX_GROUPING_DEPTH:
488 raise SQLParseError(
489 f"Maximum grouping depth exceeded ({MAX_GROUPING_DEPTH})."
490 )
491
492 # Limit the number of tokens to prevent DoS attacks
493 if MAX_GROUPING_TOKENS is not None \
494 and len(tlist.tokens) > MAX_GROUPING_TOKENS:
495 raise SQLParseError(
496 f"Maximum number of tokens exceeded ({MAX_GROUPING_TOKENS})."
497 )
498
499 tidx_offset = 0
500 pidx, prev_ = None, None
501 token_list = list(tlist)
502
503 for idx, token in enumerate(token_list):
504 tidx = idx - tidx_offset
505 if tidx < 0: # tidx shouldn't get negative
506 continue
507
508 if token.is_whitespace:
509 continue
510
511 if recurse and token.is_group and not isinstance(token, cls):
512 _group(token, cls, match, valid_prev, valid_next,
513 post, extend, True, depth + 1)
514
515 if match(token):
516 nidx, next_ = tlist.token_next(tidx)
517 if prev_ and valid_prev(prev_) and valid_next(next_):
518 from_idx, to_idx = post(tlist, pidx, tidx, nidx)
519 grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend)
520
521 tidx_offset += to_idx - from_idx
522 pidx, prev_ = from_idx, grp
523 continue
524
525 pidx, prev_ = tidx, token