1#
2# Copyright (C) 2009-2020 the sqlparse authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of python-sqlparse and is released under
6# the BSD License: https://opensource.org/licenses/BSD-3-Clause
7
8from sqlparse import sql
9from sqlparse import tokens as T
10from sqlparse.utils import recurse, imt
11
12T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float)
13T_STRING = (T.String, T.String.Single, T.String.Symbol)
14T_NAME = (T.Name, T.Name.Placeholder)
15
16
17def _group_matching(tlist, cls):
18 """Groups Tokens that have beginning and end."""
19 opens = []
20 tidx_offset = 0
21 for idx, token in enumerate(list(tlist)):
22 tidx = idx - tidx_offset
23
24 if token.is_whitespace:
25 # ~50% of tokens will be whitespace. Will checking early
26 # for them avoid 3 comparisons, but then add 1 more comparison
27 # for the other ~50% of tokens...
28 continue
29
30 if token.is_group and not isinstance(token, cls):
31 # Check inside previously grouped (i.e. parenthesis) if group
32 # of different type is inside (i.e., case). though ideally should
33 # should check for all open/close tokens at once to avoid recursion
34 _group_matching(token, cls)
35 continue
36
37 if token.match(*cls.M_OPEN):
38 opens.append(tidx)
39
40 elif token.match(*cls.M_CLOSE):
41 try:
42 open_idx = opens.pop()
43 except IndexError:
44 # this indicates invalid sql and unbalanced tokens.
45 # instead of break, continue in case other "valid" groups exist
46 continue
47 close_idx = tidx
48 tlist.group_tokens(cls, open_idx, close_idx)
49 tidx_offset += close_idx - open_idx
50
51
52def group_brackets(tlist):
53 _group_matching(tlist, sql.SquareBrackets)
54
55
56def group_parenthesis(tlist):
57 _group_matching(tlist, sql.Parenthesis)
58
59
60def group_case(tlist):
61 _group_matching(tlist, sql.Case)
62
63
64def group_if(tlist):
65 _group_matching(tlist, sql.If)
66
67
68def group_for(tlist):
69 _group_matching(tlist, sql.For)
70
71
72def group_begin(tlist):
73 _group_matching(tlist, sql.Begin)
74
75
76def group_typecasts(tlist):
77 def match(token):
78 return token.match(T.Punctuation, '::')
79
80 def valid(token):
81 return token is not None
82
83 def post(tlist, pidx, tidx, nidx):
84 return pidx, nidx
85
86 valid_prev = valid_next = valid
87 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
88
89
90def group_tzcasts(tlist):
91 def match(token):
92 return token.ttype == T.Keyword.TZCast
93
94 def valid_prev(token):
95 return token is not None
96
97 def valid_next(token):
98 return token is not None and (
99 token.is_whitespace
100 or token.match(T.Keyword, 'AS')
101 or token.match(*sql.TypedLiteral.M_CLOSE)
102 )
103
104 def post(tlist, pidx, tidx, nidx):
105 return pidx, nidx
106
107 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
108
109
110def group_typed_literal(tlist):
111 # definitely not complete, see e.g.:
112 # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literal-syntax
113 # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literals
114 # https://www.postgresql.org/docs/9.1/datatype-datetime.html
115 # https://www.postgresql.org/docs/9.1/functions-datetime.html
116 def match(token):
117 return imt(token, m=sql.TypedLiteral.M_OPEN)
118
119 def match_to_extend(token):
120 return isinstance(token, sql.TypedLiteral)
121
122 def valid_prev(token):
123 return token is not None
124
125 def valid_next(token):
126 return token is not None and token.match(*sql.TypedLiteral.M_CLOSE)
127
128 def valid_final(token):
129 return token is not None and token.match(*sql.TypedLiteral.M_EXTEND)
130
131 def post(tlist, pidx, tidx, nidx):
132 return tidx, nidx
133
134 _group(tlist, sql.TypedLiteral, match, valid_prev, valid_next,
135 post, extend=False)
136 _group(tlist, sql.TypedLiteral, match_to_extend, valid_prev, valid_final,
137 post, extend=True)
138
139
140def group_period(tlist):
141 def match(token):
142 for ttype, value in ((T.Punctuation, '.'),
143 (T.Operator, '->'),
144 (T.Operator, '->>')):
145 if token.match(ttype, value):
146 return True
147 return False
148
149 def valid_prev(token):
150 sqlcls = sql.SquareBrackets, sql.Identifier
151 ttypes = T.Name, T.String.Symbol
152 return imt(token, i=sqlcls, t=ttypes)
153
154 def valid_next(token):
155 # issue261, allow invalid next token
156 return True
157
158 def post(tlist, pidx, tidx, nidx):
159 # next_ validation is being performed here. issue261
160 sqlcls = sql.SquareBrackets, sql.Function
161 ttypes = T.Name, T.String.Symbol, T.Wildcard, T.String.Single
162 next_ = tlist[nidx] if nidx is not None else None
163 valid_next = imt(next_, i=sqlcls, t=ttypes)
164
165 return (pidx, nidx) if valid_next else (pidx, tidx)
166
167 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
168
169
170def group_as(tlist):
171 def match(token):
172 return token.is_keyword and token.normalized == 'AS'
173
174 def valid_prev(token):
175 return token.normalized == 'NULL' or not token.is_keyword
176
177 def valid_next(token):
178 ttypes = T.DML, T.DDL, T.CTE
179 return not imt(token, t=ttypes) and token is not None
180
181 def post(tlist, pidx, tidx, nidx):
182 return pidx, nidx
183
184 _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
185
186
187def group_assignment(tlist):
188 def match(token):
189 return token.match(T.Assignment, ':=')
190
191 def valid(token):
192 return token is not None and token.ttype not in (T.Keyword,)
193
194 def post(tlist, pidx, tidx, nidx):
195 m_semicolon = T.Punctuation, ';'
196 snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx)
197 nidx = snidx or nidx
198 return pidx, nidx
199
200 valid_prev = valid_next = valid
201 _group(tlist, sql.Assignment, match, valid_prev, valid_next, post)
202
203
204def group_comparison(tlist):
205 sqlcls = (sql.Parenthesis, sql.Function, sql.Identifier,
206 sql.Operation, sql.TypedLiteral)
207 ttypes = T_NUMERICAL + T_STRING + T_NAME
208
209 def match(token):
210 return token.ttype == T.Operator.Comparison
211
212 def valid(token):
213 if imt(token, t=ttypes, i=sqlcls):
214 return True
215 elif token and token.is_keyword and token.normalized == 'NULL':
216 return True
217 else:
218 return False
219
220 def post(tlist, pidx, tidx, nidx):
221 return pidx, nidx
222
223 valid_prev = valid_next = valid
224 _group(tlist, sql.Comparison, match,
225 valid_prev, valid_next, post, extend=False)
226
227
228@recurse(sql.Identifier)
229def group_identifier(tlist):
230 ttypes = (T.String.Symbol, T.Name)
231
232 tidx, token = tlist.token_next_by(t=ttypes)
233 while token:
234 tlist.group_tokens(sql.Identifier, tidx, tidx)
235 tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
236
237
238@recurse(sql.Over)
239def group_over(tlist):
240 tidx, token = tlist.token_next_by(m=sql.Over.M_OPEN)
241 while token:
242 nidx, next_ = tlist.token_next(tidx)
243 if imt(next_, i=sql.Parenthesis, t=T.Name):
244 tlist.group_tokens(sql.Over, tidx, nidx)
245 tidx, token = tlist.token_next_by(m=sql.Over.M_OPEN, idx=tidx)
246
247
248def group_arrays(tlist):
249 sqlcls = sql.SquareBrackets, sql.Identifier, sql.Function
250 ttypes = T.Name, T.String.Symbol
251
252 def match(token):
253 return isinstance(token, sql.SquareBrackets)
254
255 def valid_prev(token):
256 return imt(token, i=sqlcls, t=ttypes)
257
258 def valid_next(token):
259 return True
260
261 def post(tlist, pidx, tidx, nidx):
262 return pidx, tidx
263
264 _group(tlist, sql.Identifier, match,
265 valid_prev, valid_next, post, extend=True, recurse=False)
266
267
268def group_operator(tlist):
269 ttypes = T_NUMERICAL + T_STRING + T_NAME
270 sqlcls = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
271 sql.Identifier, sql.Operation, sql.TypedLiteral)
272
273 def match(token):
274 return imt(token, t=(T.Operator, T.Wildcard))
275
276 def valid(token):
277 return imt(token, i=sqlcls, t=ttypes) \
278 or (token and token.match(
279 T.Keyword,
280 ('CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP')))
281
282 def post(tlist, pidx, tidx, nidx):
283 tlist[tidx].ttype = T.Operator
284 return pidx, nidx
285
286 valid_prev = valid_next = valid
287 _group(tlist, sql.Operation, match,
288 valid_prev, valid_next, post, extend=False)
289
290
291def group_identifier_list(tlist):
292 m_role = T.Keyword, ('null', 'role')
293 sqlcls = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
294 sql.IdentifierList, sql.Operation)
295 ttypes = (T_NUMERICAL + T_STRING + T_NAME
296 + (T.Keyword, T.Comment, T.Wildcard))
297
298 def match(token):
299 return token.match(T.Punctuation, ',')
300
301 def valid(token):
302 return imt(token, i=sqlcls, m=m_role, t=ttypes)
303
304 def post(tlist, pidx, tidx, nidx):
305 return pidx, nidx
306
307 valid_prev = valid_next = valid
308 _group(tlist, sql.IdentifierList, match,
309 valid_prev, valid_next, post, extend=True)
310
311
312@recurse(sql.Comment)
313def group_comments(tlist):
314 tidx, token = tlist.token_next_by(t=T.Comment)
315 while token:
316 eidx, end = tlist.token_not_matching(
317 lambda tk: imt(tk, t=T.Comment) or tk.is_newline, idx=tidx)
318 if end is not None:
319 eidx, end = tlist.token_prev(eidx, skip_ws=False)
320 tlist.group_tokens(sql.Comment, tidx, eidx)
321
322 tidx, token = tlist.token_next_by(t=T.Comment, idx=tidx)
323
324
325@recurse(sql.Where)
326def group_where(tlist):
327 tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN)
328 while token:
329 eidx, end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=tidx)
330
331 if end is None:
332 end = tlist._groupable_tokens[-1]
333 else:
334 end = tlist.tokens[eidx - 1]
335 # TODO: convert this to eidx instead of end token.
336 # i think above values are len(tlist) and eidx-1
337 eidx = tlist.token_index(end)
338 tlist.group_tokens(sql.Where, tidx, eidx)
339 tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=tidx)
340
341
342@recurse()
343def group_aliased(tlist):
344 I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier,
345 sql.Operation, sql.Comparison)
346
347 tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number)
348 while token:
349 nidx, next_ = tlist.token_next(tidx)
350 if isinstance(next_, sql.Identifier):
351 tlist.group_tokens(sql.Identifier, tidx, nidx, extend=True)
352 tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx)
353
354
355@recurse(sql.Function)
356def group_functions(tlist):
357 has_create = False
358 has_table = False
359 has_as = False
360 for tmp_token in tlist.tokens:
361 if tmp_token.value.upper() == 'CREATE':
362 has_create = True
363 if tmp_token.value.upper() == 'TABLE':
364 has_table = True
365 if tmp_token.value == 'AS':
366 has_as = True
367 if has_create and has_table and not has_as:
368 return
369
370 tidx, token = tlist.token_next_by(t=T.Name)
371 while token:
372 nidx, next_ = tlist.token_next(tidx)
373 if isinstance(next_, sql.Parenthesis):
374 over_idx, over = tlist.token_next(nidx)
375 if over and isinstance(over, sql.Over):
376 eidx = over_idx
377 else:
378 eidx = nidx
379 tlist.group_tokens(sql.Function, tidx, eidx)
380 tidx, token = tlist.token_next_by(t=T.Name, idx=tidx)
381
382
383@recurse(sql.Identifier)
384def group_order(tlist):
385 """Group together Identifier and Asc/Desc token"""
386 tidx, token = tlist.token_next_by(t=T.Keyword.Order)
387 while token:
388 pidx, prev_ = tlist.token_prev(tidx)
389 if imt(prev_, i=sql.Identifier, t=T.Number):
390 tlist.group_tokens(sql.Identifier, pidx, tidx)
391 tidx = pidx
392 tidx, token = tlist.token_next_by(t=T.Keyword.Order, idx=tidx)
393
394
395@recurse()
396def align_comments(tlist):
397 tidx, token = tlist.token_next_by(i=sql.Comment)
398 while token:
399 pidx, prev_ = tlist.token_prev(tidx)
400 if isinstance(prev_, sql.TokenList):
401 tlist.group_tokens(sql.TokenList, pidx, tidx, extend=True)
402 tidx = pidx
403 tidx, token = tlist.token_next_by(i=sql.Comment, idx=tidx)
404
405
406def group_values(tlist):
407 tidx, token = tlist.token_next_by(m=(T.Keyword, 'VALUES'))
408 start_idx = tidx
409 end_idx = -1
410 while token:
411 if isinstance(token, sql.Parenthesis):
412 end_idx = tidx
413 tidx, token = tlist.token_next(tidx)
414 if end_idx != -1:
415 tlist.group_tokens(sql.Values, start_idx, end_idx, extend=True)
416
417
418def group(stmt):
419 for func in [
420 group_comments,
421
422 # _group_matching
423 group_brackets,
424 group_parenthesis,
425 group_case,
426 group_if,
427 group_for,
428 group_begin,
429
430 group_over,
431 group_functions,
432 group_where,
433 group_period,
434 group_arrays,
435 group_identifier,
436 group_order,
437 group_typecasts,
438 group_tzcasts,
439 group_typed_literal,
440 group_operator,
441 group_comparison,
442 group_as,
443 group_aliased,
444 group_assignment,
445
446 align_comments,
447 group_identifier_list,
448 group_values,
449 ]:
450 func(stmt)
451 return stmt
452
453
454def _group(tlist, cls, match,
455 valid_prev=lambda t: True,
456 valid_next=lambda t: True,
457 post=None,
458 extend=True,
459 recurse=True
460 ):
461 """Groups together tokens that are joined by a middle token. i.e. x < y"""
462
463 tidx_offset = 0
464 pidx, prev_ = None, None
465 for idx, token in enumerate(list(tlist)):
466 tidx = idx - tidx_offset
467 if tidx < 0: # tidx shouldn't get negative
468 continue
469
470 if token.is_whitespace:
471 continue
472
473 if recurse and token.is_group and not isinstance(token, cls):
474 _group(token, cls, match, valid_prev, valid_next, post, extend)
475
476 if match(token):
477 nidx, next_ = tlist.token_next(tidx)
478 if prev_ and valid_prev(prev_) and valid_next(next_):
479 from_idx, to_idx = post(tlist, pidx, tidx, nidx)
480 grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend)
481
482 tidx_offset += to_idx - from_idx
483 pidx, prev_ = from_idx, grp
484 continue
485
486 pidx, prev_ = tidx, token