1#
2# Copyright (C) 2009-2020 the sqlparse authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of python-sqlparse and is released under
6# the BSD License: https://opensource.org/licenses/BSD-3-Clause
7
8from sqlparse import sql
9from sqlparse import tokens as T
10from sqlparse.exceptions import SQLParseError
11from sqlparse.utils import imt, recurse
12
13# Maximum recursion depth for grouping operations to prevent DoS attacks
14# Set to None to disable limit (not recommended for untrusted input)
15MAX_GROUPING_DEPTH = 100
16
17# Maximum number of tokens to process in one grouping operation to prevent
18# DoS attacks.
19# Set to None to disable limit (not recommended for untrusted input)
20MAX_GROUPING_TOKENS = 10000
21
22T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float)
23T_STRING = (T.String, T.String.Single, T.String.Symbol)
24T_NAME = (T.Name, T.Name.Placeholder)
25
26
27def _group_matching(tlist, cls, depth=0):
28 """Groups Tokens that have beginning and end."""
29 if MAX_GROUPING_DEPTH is not None and depth > MAX_GROUPING_DEPTH:
30 raise SQLParseError(
31 f"Maximum grouping depth exceeded ({MAX_GROUPING_DEPTH})."
32 )
33
34 # Limit the number of tokens to prevent DoS attacks
35 if MAX_GROUPING_TOKENS is not None \
36 and len(tlist.tokens) > MAX_GROUPING_TOKENS:
37 raise SQLParseError(
38 f"Maximum number of tokens exceeded ({MAX_GROUPING_TOKENS})."
39 )
40
41 opens = []
42 tidx_offset = 0
43 token_list = list(tlist)
44
45 for idx, token in enumerate(token_list):
46 tidx = idx - tidx_offset
47
48 if token.is_whitespace:
49 # ~50% of tokens will be whitespace. Will checking early
50 # for them avoid 3 comparisons, but then add 1 more comparison
51 # for the other ~50% of tokens...
52 continue
53
54 if token.is_group and not isinstance(token, cls):
55 # Check inside previously grouped (i.e. parenthesis) if group
56 # of different type is inside (i.e., case). though ideally should
57 # should check for all open/close tokens at once to avoid recursion
58 _group_matching(token, cls, depth + 1)
59 continue
60
61 if token.match(*cls.M_OPEN):
62 opens.append(tidx)
63
64 elif token.match(*cls.M_CLOSE):
65 try:
66 open_idx = opens.pop()
67 except IndexError:
68 # this indicates invalid sql and unbalanced tokens.
69 # instead of break, continue in case other "valid" groups exist
70 continue
71 close_idx = tidx
72 tlist.group_tokens(cls, open_idx, close_idx)
73 tidx_offset += close_idx - open_idx
74
75
76def group_brackets(tlist):
77 _group_matching(tlist, sql.SquareBrackets)
78
79
80def group_parenthesis(tlist):
81 _group_matching(tlist, sql.Parenthesis)
82
83
84def group_case(tlist):
85 _group_matching(tlist, sql.Case)
86
87
88def group_if(tlist):
89 _group_matching(tlist, sql.If)
90
91
92def group_for(tlist):
93 _group_matching(tlist, sql.For)
94
95
96def group_begin(tlist):
97 _group_matching(tlist, sql.Begin)
98
99
100def group_typecasts(tlist):
101 def match(token):
102 return token.match(T.Punctuation, '::')
103
104 def valid(token):
105 return token is not None
106
107 def post(tlist, pidx, tidx, nidx):
108 return pidx, nidx
109
110 valid_prev = valid_next = valid
111 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
112
113
114def group_tzcasts(tlist):
115 def match(token):
116 return token.ttype == T.Keyword.TZCast
117
118 def valid_prev(token):
119 return token is not None
120
121 def valid_next(token):
122 return token is not None and (
123 token.is_whitespace
124 or token.match(T.Keyword, 'AS')
125 or token.match(*sql.TypedLiteral.M_CLOSE)
126 )
127
128 def post(tlist, pidx, tidx, nidx):
129 return pidx, nidx
130
131 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
132
133
134def group_typed_literal(tlist):
135 # definitely not complete, see e.g.:
136 # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literal-syntax
137 # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literals
138 # https://www.postgresql.org/docs/9.1/datatype-datetime.html
139 # https://www.postgresql.org/docs/9.1/functions-datetime.html
140 def match(token):
141 return imt(token, m=sql.TypedLiteral.M_OPEN)
142
143 def match_to_extend(token):
144 return isinstance(token, sql.TypedLiteral)
145
146 def valid_prev(token):
147 return token is not None
148
149 def valid_next(token):
150 return token is not None and token.match(*sql.TypedLiteral.M_CLOSE)
151
152 def valid_final(token):
153 return token is not None and token.match(*sql.TypedLiteral.M_EXTEND)
154
155 def post(tlist, pidx, tidx, nidx):
156 return tidx, nidx
157
158 _group(tlist, sql.TypedLiteral, match, valid_prev, valid_next,
159 post, extend=False)
160 _group(tlist, sql.TypedLiteral, match_to_extend, valid_prev, valid_final,
161 post, extend=True)
162
163
164def group_period(tlist):
165 def match(token):
166 for ttype, value in ((T.Punctuation, '.'),
167 (T.Operator, '->'),
168 (T.Operator, '->>')):
169 if token.match(ttype, value):
170 return True
171 return False
172
173 def valid_prev(token):
174 sqlcls = sql.SquareBrackets, sql.Identifier
175 ttypes = T.Name, T.String.Symbol
176 return imt(token, i=sqlcls, t=ttypes)
177
178 def valid_next(token):
179 # issue261, allow invalid next token
180 return True
181
182 def post(tlist, pidx, tidx, nidx):
183 # next_ validation is being performed here. issue261
184 sqlcls = sql.SquareBrackets, sql.Function
185 ttypes = T.Name, T.String.Symbol, T.Wildcard, T.String.Single
186 next_ = tlist[nidx] if nidx is not None else None
187 valid_next = imt(next_, i=sqlcls, t=ttypes)
188
189 return (pidx, nidx) if valid_next else (pidx, tidx)
190
191 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
192
193
194def group_as(tlist):
195 def match(token):
196 return token.is_keyword and token.normalized == 'AS'
197
198 def valid_prev(token):
199 return token.normalized == 'NULL' or not token.is_keyword
200
201 def valid_next(token):
202 ttypes = T.DML, T.DDL, T.CTE
203 return not imt(token, t=ttypes) and token is not None
204
205 def post(tlist, pidx, tidx, nidx):
206 return pidx, nidx
207
208 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
209
210
211def group_assignment(tlist):
212 def match(token):
213 return token.match(T.Assignment, ':=')
214
215 def valid(token):
216 return token is not None and token.ttype not in (T.Keyword,)
217
218 def post(tlist, pidx, tidx, nidx):
219 m_semicolon = T.Punctuation, ';'
220 snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx)
221 nidx = snidx or nidx
222 return pidx, nidx
223
224 valid_prev = valid_next = valid
225 _group(tlist, sql.Assignment, match, valid_prev, valid_next, post)
226
227
228def group_comparison(tlist):
229 sqlcls = (sql.Parenthesis, sql.Function, sql.Identifier,
230 sql.Operation, sql.TypedLiteral)
231 ttypes = T_NUMERICAL + T_STRING + T_NAME
232
233 def match(token):
234 return token.ttype == T.Operator.Comparison
235
236 def valid(token):
237 return bool(imt(token, t=ttypes, i=sqlcls) or (token and token.is_keyword and token.normalized == 'NULL'))
238
239 def post(tlist, pidx, tidx, nidx):
240 return pidx, nidx
241
242 valid_prev = valid_next = valid
243 _group(tlist, sql.Comparison, match,
244 valid_prev, valid_next, post, extend=False)
245
246
247@recurse(sql.Identifier)
248def group_identifier(tlist):
249 ttypes = (T.String.Symbol, T.Name)
250
251 tidx, token = tlist.token_next_by(t=ttypes)
252 while token:
253 tlist.group_tokens(sql.Identifier, tidx, tidx)
254 tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
255
256
257@recurse(sql.Over)
258def group_over(tlist):
259 tidx, token = tlist.token_next_by(m=sql.Over.M_OPEN)
260 while token:
261 nidx, next_ = tlist.token_next(tidx)
262 if imt(next_, i=sql.Parenthesis, t=T.Name):
263 tlist.group_tokens(sql.Over, tidx, nidx)
264 tidx, token = tlist.token_next_by(m=sql.Over.M_OPEN, idx=tidx)
265
266
267def group_arrays(tlist):
268 sqlcls = sql.SquareBrackets, sql.Identifier, sql.Function
269 ttypes = T.Name, T.String.Symbol
270
271 def match(token):
272 return isinstance(token, sql.SquareBrackets)
273
274 def valid_prev(token):
275 return imt(token, i=sqlcls, t=ttypes)
276
277 def valid_next(token):
278 return True
279
280 def post(tlist, pidx, tidx, nidx):
281 return pidx, tidx
282
283 _group(tlist, sql.Identifier, match,
284 valid_prev, valid_next, post, extend=True, recurse=False)
285
286
287def group_operator(tlist):
288 ttypes = T_NUMERICAL + T_STRING + T_NAME
289 sqlcls = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
290 sql.Identifier, sql.Operation, sql.TypedLiteral)
291
292 def match(token):
293 return imt(token, t=(T.Operator, T.Wildcard))
294
295 def valid(token):
296 return imt(token, i=sqlcls, t=ttypes) \
297 or (token and token.match(
298 T.Keyword,
299 ('CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP')))
300
301 def post(tlist, pidx, tidx, nidx):
302 tlist[tidx].ttype = T.Operator
303 return pidx, nidx
304
305 valid_prev = valid_next = valid
306 _group(tlist, sql.Operation, match,
307 valid_prev, valid_next, post, extend=False)
308
309
310def group_identifier_list(tlist):
311 m_role = T.Keyword, ('null', 'role')
312 sqlcls = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
313 sql.IdentifierList, sql.Operation)
314 ttypes = (T_NUMERICAL + T_STRING + T_NAME
315 + (T.Keyword, T.Comment, T.Wildcard))
316
317 def match(token):
318 return token.match(T.Punctuation, ',')
319
320 def valid(token):
321 return imt(token, i=sqlcls, m=m_role, t=ttypes)
322
323 def post(tlist, pidx, tidx, nidx):
324 return pidx, nidx
325
326 valid_prev = valid_next = valid
327 _group(tlist, sql.IdentifierList, match,
328 valid_prev, valid_next, post, extend=True)
329
330
331@recurse(sql.Comment)
332def group_comments(tlist):
333 tidx, token = tlist.token_next_by(t=T.Comment)
334 while token:
335 eidx, end = tlist.token_not_matching(
336 lambda tk: imt(tk, t=T.Comment) or tk.is_newline, idx=tidx)
337 if end is not None:
338 eidx, end = tlist.token_prev(eidx, skip_ws=False)
339 tlist.group_tokens(sql.Comment, tidx, eidx)
340
341 tidx, token = tlist.token_next_by(t=T.Comment, idx=tidx)
342
343
344@recurse(sql.Where)
345def group_where(tlist):
346 tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN)
347 while token:
348 eidx, end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=tidx)
349
350 if end is None:
351 end = tlist._groupable_tokens[-1]
352 else:
353 end = tlist.tokens[eidx - 1]
354 # TODO: convert this to eidx instead of end token.
355 # i think above values are len(tlist) and eidx-1
356 eidx = tlist.token_index(end)
357 tlist.group_tokens(sql.Where, tidx, eidx)
358 tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=tidx)
359
360
361@recurse()
362def group_aliased(tlist):
363 I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier,
364 sql.Operation, sql.Comparison)
365
366 tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number)
367 while token:
368 nidx, next_ = tlist.token_next(tidx)
369 if isinstance(next_, sql.Identifier):
370 tlist.group_tokens(sql.Identifier, tidx, nidx, extend=True)
371 tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx)
372
373
374@recurse(sql.Function)
375def group_functions(tlist):
376 has_create = False
377 has_table = False
378 has_as = False
379 for tmp_token in tlist.tokens:
380 if tmp_token.value.upper() == 'CREATE':
381 has_create = True
382 if tmp_token.value.upper() == 'TABLE':
383 has_table = True
384 if tmp_token.value == 'AS':
385 has_as = True
386 if has_create and has_table and not has_as:
387 return
388
389 tidx, token = tlist.token_next_by(t=T.Name)
390 while token:
391 nidx, next_ = tlist.token_next(tidx)
392 if isinstance(next_, sql.Parenthesis):
393 over_idx, over = tlist.token_next(nidx)
394 if over and isinstance(over, sql.Over):
395 eidx = over_idx
396 else:
397 eidx = nidx
398 tlist.group_tokens(sql.Function, tidx, eidx)
399 tidx, token = tlist.token_next_by(t=T.Name, idx=tidx)
400
401
402@recurse(sql.Identifier)
403def group_order(tlist):
404 """Group together Identifier and Asc/Desc token"""
405 tidx, token = tlist.token_next_by(t=T.Keyword.Order)
406 while token:
407 pidx, prev_ = tlist.token_prev(tidx)
408 if imt(prev_, i=sql.Identifier, t=T.Number):
409 tlist.group_tokens(sql.Identifier, pidx, tidx)
410 tidx = pidx
411 tidx, token = tlist.token_next_by(t=T.Keyword.Order, idx=tidx)
412
413
414@recurse()
415def align_comments(tlist):
416 tidx, token = tlist.token_next_by(i=sql.Comment)
417 while token:
418 pidx, prev_ = tlist.token_prev(tidx)
419 if isinstance(prev_, sql.TokenList):
420 tlist.group_tokens(sql.TokenList, pidx, tidx, extend=True)
421 tidx = pidx
422 tidx, token = tlist.token_next_by(i=sql.Comment, idx=tidx)
423
424
425def group_values(tlist):
426 tidx, token = tlist.token_next_by(m=(T.Keyword, 'VALUES'))
427 start_idx = tidx
428 end_idx = -1
429 while token:
430 if isinstance(token, sql.Parenthesis):
431 end_idx = tidx
432 tidx, token = tlist.token_next(tidx)
433 if end_idx != -1:
434 tlist.group_tokens(sql.Values, start_idx, end_idx, extend=True)
435
436
437def group(stmt):
438 for func in [
439 group_comments,
440
441 # _group_matching
442 group_brackets,
443 group_parenthesis,
444 group_case,
445 group_if,
446 group_for,
447 group_begin,
448
449 group_over,
450 group_functions,
451 group_where,
452 group_period,
453 group_arrays,
454 group_identifier,
455 group_order,
456 group_typecasts,
457 group_tzcasts,
458 group_typed_literal,
459 group_operator,
460 group_comparison,
461 group_as,
462 group_aliased,
463 group_assignment,
464
465 align_comments,
466 group_identifier_list,
467 group_values,
468 ]:
469 func(stmt)
470 return stmt
471
472
473def _group(tlist, cls, match,
474 valid_prev=lambda t: True,
475 valid_next=lambda t: True,
476 post=None,
477 extend=True,
478 recurse=True,
479 depth=0
480 ):
481 """Groups together tokens that are joined by a middle token. i.e. x < y"""
482 if MAX_GROUPING_DEPTH is not None and depth > MAX_GROUPING_DEPTH:
483 raise SQLParseError(
484 f"Maximum grouping depth exceeded ({MAX_GROUPING_DEPTH})."
485 )
486
487 # Limit the number of tokens to prevent DoS attacks
488 if MAX_GROUPING_TOKENS is not None \
489 and len(tlist.tokens) > MAX_GROUPING_TOKENS:
490 raise SQLParseError(
491 f"Maximum number of tokens exceeded ({MAX_GROUPING_TOKENS})."
492 )
493
494 tidx_offset = 0
495 pidx, prev_ = None, None
496 token_list = list(tlist)
497
498 for idx, token in enumerate(token_list):
499 tidx = idx - tidx_offset
500 if tidx < 0: # tidx shouldn't get negative
501 continue
502
503 if token.is_whitespace:
504 continue
505
506 if recurse and token.is_group and not isinstance(token, cls):
507 _group(token, cls, match, valid_prev, valid_next,
508 post, extend, True, depth + 1)
509
510 if match(token):
511 nidx, next_ = tlist.token_next(tidx)
512 if prev_ and valid_prev(prev_) and valid_next(next_):
513 from_idx, to_idx = post(tlist, pidx, tidx, nidx)
514 grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend)
515
516 tidx_offset += to_idx - from_idx
517 pidx, prev_ = from_idx, grp
518 continue
519
520 pidx, prev_ = tidx, token