1#
2# Copyright (C) 2009-2020 the sqlparse authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of python-sqlparse and is released under
6# the BSD License: https://opensource.org/licenses/BSD-3-Clause
7
8from sqlparse import sql
9from sqlparse import tokens as T
10from sqlparse.utils import indent, offset
11
12
13class AlignedIndentFilter:
14 join_words = (r'((LEFT\s+|RIGHT\s+|FULL\s+)?'
15 r'(INNER\s+|OUTER\s+|STRAIGHT\s+)?|'
16 r'(CROSS\s+|NATURAL\s+)?)?JOIN\b')
17 by_words = r'(GROUP|ORDER)\s+BY\b'
18 split_words = ('FROM',
19 join_words, 'ON', by_words,
20 'WHERE', 'AND', 'OR',
21 'HAVING', 'LIMIT',
22 'UNION', 'VALUES',
23 'SET', 'BETWEEN', 'EXCEPT')
24
25 def __init__(self, char=' ', n='\n'):
26 self.n = n
27 self.offset = 0
28 self.indent = 0
29 self.char = char
30 self._max_kwd_len = len('select')
31
32 def nl(self, offset=1):
33 # offset = 1 represent a single space after SELECT
34 offset = -len(offset) if not isinstance(offset, int) else offset
35 # add two for the space and parenthesis
36 indent = self.indent * (2 + self._max_kwd_len)
37
38 return sql.Token(T.Whitespace, self.n + self.char * (
39 self._max_kwd_len + offset + indent + self.offset))
40
41 def _process_statement(self, tlist):
42 if len(tlist.tokens) > 0 and tlist.tokens[0].is_whitespace \
43 and self.indent == 0:
44 tlist.tokens.pop(0)
45
46 # process the main query body
47 self._process(sql.TokenList(tlist.tokens))
48
49 def _process_parenthesis(self, tlist):
50 # if this isn't a subquery, don't re-indent
51 _, token = tlist.token_next_by(m=(T.DML, 'SELECT'))
52 if token is not None:
53 with indent(self):
54 tlist.insert_after(tlist[0], self.nl('SELECT'))
55 # process the inside of the parenthesis
56 self._process_default(tlist)
57
58 # de-indent last parenthesis
59 tlist.insert_before(tlist[-1], self.nl())
60
61 def _process_identifierlist(self, tlist):
62 # columns being selected
63 identifiers = list(tlist.get_identifiers())
64 identifiers.pop(0)
65 [tlist.insert_before(token, self.nl()) for token in identifiers]
66 self._process_default(tlist)
67
68 def _process_case(self, tlist):
69 offset_ = len('case ') + len('when ')
70 cases = tlist.get_cases(skip_ws=True)
71 # align the end as well
72 end_token = tlist.token_next_by(m=(T.Keyword, 'END'))[1]
73 cases.append((None, [end_token]))
74
75 condition_width = [len(' '.join(map(str, cond))) if cond else 0
76 for cond, _ in cases]
77 max_cond_width = max(condition_width)
78
79 for i, (cond, value) in enumerate(cases):
80 # cond is None when 'else or end'
81 stmt = cond[0] if cond else value[0]
82
83 if i > 0:
84 tlist.insert_before(stmt, self.nl(offset_ - len(str(stmt))))
85 if cond:
86 ws = sql.Token(T.Whitespace, self.char * (
87 max_cond_width - condition_width[i]))
88 tlist.insert_after(cond[-1], ws)
89
90 def _next_token(self, tlist, idx=-1):
91 split_words = T.Keyword, self.split_words, True
92 tidx, token = tlist.token_next_by(m=split_words, idx=idx)
93 # treat "BETWEEN x and y" as a single statement
94 if token and token.normalized == 'BETWEEN':
95 tidx, token = self._next_token(tlist, tidx)
96 if token and token.normalized == 'AND':
97 tidx, token = self._next_token(tlist, tidx)
98 return tidx, token
99
100 def _split_kwds(self, tlist):
101 tidx, token = self._next_token(tlist)
102 while token:
103 # joins, group/order by are special case. only consider the first
104 # word as aligner
105 if (
106 token.match(T.Keyword, self.join_words, regex=True)
107 or token.match(T.Keyword, self.by_words, regex=True)
108 ):
109 token_indent = token.value.split()[0]
110 else:
111 token_indent = str(token)
112 tlist.insert_before(token, self.nl(token_indent))
113 tidx += 1
114 tidx, token = self._next_token(tlist, tidx)
115
116 def _process_default(self, tlist):
117 self._split_kwds(tlist)
118 # process any sub-sub statements
119 for sgroup in tlist.get_sublists():
120 idx = tlist.token_index(sgroup)
121 pidx, prev_ = tlist.token_prev(idx)
122 # HACK: make "group/order by" work. Longer than max_len.
123 offset_ = 3 if (
124 prev_ and prev_.match(T.Keyword, self.by_words, regex=True)
125 ) else 0
126 with offset(self, offset_):
127 self._process(sgroup)
128
129 def _process(self, tlist):
130 func_name = f'_process_{type(tlist).__name__}'
131 func = getattr(self, func_name.lower(), self._process_default)
132 func(tlist)
133
134 def process(self, stmt):
135 self._process(stmt)
136 return stmt