1#
2# Copyright (C) 2009-2020 the sqlparse authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of python-sqlparse and is released under
6# the BSD License: https://opensource.org/licenses/BSD-3-Clause
7
8from sqlparse import sql
9from sqlparse import tokens as T
10from sqlparse.utils import indent, offset
11
12
13class ReindentFilter:
14 def __init__(self, width=2, char=' ', wrap_after=0, n='\n',
15 comma_first=False, indent_after_first=False,
16 indent_columns=False, compact=False):
17 self.n = n
18 self.width = width
19 self.char = char
20 self.indent = 1 if indent_after_first else 0
21 self.offset = 0
22 self.wrap_after = wrap_after
23 self.comma_first = comma_first
24 self.indent_columns = indent_columns
25 self.compact = compact
26 self._curr_stmt = None
27 self._last_stmt = None
28 self._last_func = None
29
30 def _flatten_up_to_token(self, token):
31 """Yields all tokens up to token but excluding current."""
32 if token.is_group:
33 token = next(token.flatten())
34
35 for t in self._curr_stmt.flatten():
36 if t == token:
37 break
38 yield t
39
40 @property
41 def leading_ws(self):
42 return self.offset + self.indent * self.width
43
44 def _get_offset(self, token):
45 raw = ''.join(map(str, self._flatten_up_to_token(token)))
46 line = (raw or '\n').splitlines()[-1]
47 # Now take current offset into account and return relative offset.
48 return len(line) - len(self.char * self.leading_ws)
49
50 def nl(self, offset=0):
51 return sql.Token(
52 T.Whitespace,
53 self.n + self.char * max(0, self.leading_ws + offset))
54
55 def _next_token(self, tlist, idx=-1):
56 split_words = ('FROM', 'STRAIGHT_JOIN$', 'JOIN$', 'AND', 'OR',
57 'GROUP BY', 'ORDER BY', 'UNION', 'VALUES',
58 'SET', 'BETWEEN', 'EXCEPT', 'HAVING', 'LIMIT')
59 m_split = T.Keyword, split_words, True
60 tidx, token = tlist.token_next_by(m=m_split, idx=idx)
61
62 if token and token.normalized == 'BETWEEN':
63 tidx, token = self._next_token(tlist, tidx)
64
65 if token and token.normalized == 'AND':
66 tidx, token = self._next_token(tlist, tidx)
67
68 return tidx, token
69
70 def _split_kwds(self, tlist):
71 tidx, token = self._next_token(tlist)
72 while token:
73 pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
74 uprev = str(prev_)
75
76 if prev_ and prev_.is_whitespace:
77 del tlist.tokens[pidx]
78 tidx -= 1
79
80 if not (uprev.endswith('\n') or uprev.endswith('\r')):
81 tlist.insert_before(tidx, self.nl())
82 tidx += 1
83
84 tidx, token = self._next_token(tlist, tidx)
85
86 def _split_statements(self, tlist):
87 ttypes = T.Keyword.DML, T.Keyword.DDL
88 tidx, token = tlist.token_next_by(t=ttypes)
89 while token:
90 pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
91 if prev_ and prev_.is_whitespace:
92 del tlist.tokens[pidx]
93 tidx -= 1
94 # only break if it's not the first token
95 if prev_:
96 tlist.insert_before(tidx, self.nl())
97 tidx += 1
98 tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
99
100 def _process(self, tlist):
101 func_name = f'_process_{type(tlist).__name__}'
102 func = getattr(self, func_name.lower(), self._process_default)
103 func(tlist)
104
105 def _process_where(self, tlist):
106 tidx, token = tlist.token_next_by(m=(T.Keyword, 'WHERE'))
107 if not token:
108 return
109 # issue121, errors in statement fixed??
110 tlist.insert_before(tidx, self.nl())
111 with indent(self):
112 self._process_default(tlist)
113
114 def _process_parenthesis(self, tlist):
115 ttypes = T.Keyword.DML, T.Keyword.DDL
116 _, is_dml_dll = tlist.token_next_by(t=ttypes)
117 fidx, first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN)
118 if first is None:
119 return
120
121 with indent(self, 1 if is_dml_dll else 0):
122 tlist.tokens.insert(0, self.nl()) if is_dml_dll else None
123 with offset(self, self._get_offset(first) + 1):
124 self._process_default(tlist, not is_dml_dll)
125
126 def _process_function(self, tlist):
127 self._last_func = tlist[0]
128 self._process_default(tlist)
129
130 def _process_identifierlist(self, tlist):
131 identifiers = list(tlist.get_identifiers())
132 if self.indent_columns:
133 first = next(identifiers[0].flatten())
134 num_offset = 1 if self.char == '\t' else self.width
135 else:
136 first = next(identifiers.pop(0).flatten())
137 num_offset = 1 if self.char == '\t' else self._get_offset(first)
138
139 if not tlist.within(sql.Function) and not tlist.within(sql.Values):
140 with offset(self, num_offset):
141 position = 0
142 for token in identifiers:
143 # Add 1 for the "," separator
144 position += len(token.value) + 1
145 if position > (self.wrap_after - self.offset):
146 adjust = 0
147 if self.comma_first:
148 adjust = -2
149 _, comma = tlist.token_prev(
150 tlist.token_index(token))
151 if comma is None:
152 continue
153 token = comma
154 tlist.insert_before(token, self.nl(offset=adjust))
155 if self.comma_first:
156 _, ws = tlist.token_next(
157 tlist.token_index(token), skip_ws=False)
158 if (ws is not None
159 and ws.ttype is not T.Text.Whitespace):
160 tlist.insert_after(
161 token, sql.Token(T.Whitespace, ' '))
162 position = 0
163 else:
164 # ensure whitespace
165 for token in tlist:
166 _, next_ws = tlist.token_next(
167 tlist.token_index(token), skip_ws=False)
168 if token.value == ',' and not next_ws.is_whitespace:
169 tlist.insert_after(
170 token, sql.Token(T.Whitespace, ' '))
171
172 end_at = self.offset + sum(len(i.value) + 1 for i in identifiers)
173 adjusted_offset = 0
174 if (self.wrap_after > 0
175 and end_at > (self.wrap_after - self.offset)
176 and self._last_func):
177 adjusted_offset = -len(self._last_func.value) - 1
178
179 with offset(self, adjusted_offset), indent(self):
180 if adjusted_offset < 0:
181 tlist.insert_before(identifiers[0], self.nl())
182 position = 0
183 for token in identifiers:
184 # Add 1 for the "," separator
185 position += len(token.value) + 1
186 if (self.wrap_after > 0
187 and position > (self.wrap_after - self.offset)):
188 adjust = 0
189 tlist.insert_before(token, self.nl(offset=adjust))
190 position = 0
191 self._process_default(tlist)
192
193 def _process_case(self, tlist):
194 iterable = iter(tlist.get_cases())
195 cond, _ = next(iterable)
196 first = next(cond[0].flatten())
197
198 with offset(self, self._get_offset(tlist[0])):
199 with offset(self, self._get_offset(first)):
200 for cond, value in iterable:
201 str_cond = ''.join(str(x) for x in cond or [])
202 str_value = ''.join(str(x) for x in value)
203 end_pos = self.offset + 1 + len(str_cond) + len(str_value)
204 if (not self.compact and end_pos > self.wrap_after):
205 token = value[0] if cond is None else cond[0]
206 tlist.insert_before(token, self.nl())
207
208 # Line breaks on group level are done. let's add an offset of
209 # len "when ", "then ", "else "
210 with offset(self, len("WHEN ")):
211 self._process_default(tlist)
212 end_idx, end = tlist.token_next_by(m=sql.Case.M_CLOSE)
213 if end_idx is not None and not self.compact:
214 tlist.insert_before(end_idx, self.nl())
215
216 def _process_values(self, tlist):
217 tlist.insert_before(0, self.nl())
218 tidx, token = tlist.token_next_by(i=sql.Parenthesis)
219 first_token = token
220 while token:
221 ptidx, ptoken = tlist.token_next_by(m=(T.Punctuation, ','),
222 idx=tidx)
223 if ptoken:
224 if self.comma_first:
225 adjust = -2
226 offset = self._get_offset(first_token) + adjust
227 tlist.insert_before(ptoken, self.nl(offset))
228 else:
229 tlist.insert_after(ptoken,
230 self.nl(self._get_offset(token)))
231 tidx, token = tlist.token_next_by(i=sql.Parenthesis, idx=tidx)
232
233 def _process_default(self, tlist, stmts=True):
234 self._split_statements(tlist) if stmts else None
235 self._split_kwds(tlist)
236 for sgroup in tlist.get_sublists():
237 self._process(sgroup)
238
239 def process(self, stmt):
240 self._curr_stmt = stmt
241 self._process(stmt)
242
243 if self._last_stmt is not None:
244 nl = '\n' if str(self._last_stmt).endswith('\n') else '\n\n'
245 stmt.tokens.insert(0, sql.Token(T.Whitespace, nl))
246
247 self._last_stmt = stmt
248 return stmt