1#
2# Copyright (C) 2009-2020 the sqlparse authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of python-sqlparse and is released under
6# the BSD License: https://opensource.org/licenses/BSD-3-Clause
7
8from sqlparse import sql, tokens as T
9from sqlparse.utils import offset, indent
10
11
12class ReindentFilter:
13 def __init__(self, width=2, char=' ', wrap_after=0, n='\n',
14 comma_first=False, indent_after_first=False,
15 indent_columns=False, compact=False):
16 self.n = n
17 self.width = width
18 self.char = char
19 self.indent = 1 if indent_after_first else 0
20 self.offset = 0
21 self.wrap_after = wrap_after
22 self.comma_first = comma_first
23 self.indent_columns = indent_columns
24 self.compact = compact
25 self._curr_stmt = None
26 self._last_stmt = None
27 self._last_func = None
28
29 def _flatten_up_to_token(self, token):
30 """Yields all tokens up to token but excluding current."""
31 if token.is_group:
32 token = next(token.flatten())
33
34 for t in self._curr_stmt.flatten():
35 if t == token:
36 break
37 yield t
38
39 @property
40 def leading_ws(self):
41 return self.offset + self.indent * self.width
42
43 def _get_offset(self, token):
44 raw = ''.join(map(str, self._flatten_up_to_token(token)))
45 line = (raw or '\n').splitlines()[-1]
46 # Now take current offset into account and return relative offset.
47 return len(line) - len(self.char * self.leading_ws)
48
49 def nl(self, offset=0):
50 return sql.Token(
51 T.Whitespace,
52 self.n + self.char * max(0, self.leading_ws + offset))
53
54 def _next_token(self, tlist, idx=-1):
55 split_words = ('FROM', 'STRAIGHT_JOIN$', 'JOIN$', 'AND', 'OR',
56 'GROUP BY', 'ORDER BY', 'UNION', 'VALUES',
57 'SET', 'BETWEEN', 'EXCEPT', 'HAVING', 'LIMIT')
58 m_split = T.Keyword, split_words, True
59 tidx, token = tlist.token_next_by(m=m_split, idx=idx)
60
61 if token and token.normalized == 'BETWEEN':
62 tidx, token = self._next_token(tlist, tidx)
63
64 if token and token.normalized == 'AND':
65 tidx, token = self._next_token(tlist, tidx)
66
67 return tidx, token
68
69 def _split_kwds(self, tlist):
70 tidx, token = self._next_token(tlist)
71 while token:
72 pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
73 uprev = str(prev_)
74
75 if prev_ and prev_.is_whitespace:
76 del tlist.tokens[pidx]
77 tidx -= 1
78
79 if not (uprev.endswith('\n') or uprev.endswith('\r')):
80 tlist.insert_before(tidx, self.nl())
81 tidx += 1
82
83 tidx, token = self._next_token(tlist, tidx)
84
85 def _split_statements(self, tlist):
86 ttypes = T.Keyword.DML, T.Keyword.DDL
87 tidx, token = tlist.token_next_by(t=ttypes)
88 while token:
89 pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
90 if prev_ and prev_.is_whitespace:
91 del tlist.tokens[pidx]
92 tidx -= 1
93 # only break if it's not the first token
94 if prev_:
95 tlist.insert_before(tidx, self.nl())
96 tidx += 1
97 tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
98
99 def _process(self, tlist):
100 func_name = f'_process_{type(tlist).__name__}'
101 func = getattr(self, func_name.lower(), self._process_default)
102 func(tlist)
103
104 def _process_where(self, tlist):
105 tidx, token = tlist.token_next_by(m=(T.Keyword, 'WHERE'))
106 if not token:
107 return
108 # issue121, errors in statement fixed??
109 tlist.insert_before(tidx, self.nl())
110 with indent(self):
111 self._process_default(tlist)
112
113 def _process_parenthesis(self, tlist):
114 ttypes = T.Keyword.DML, T.Keyword.DDL
115 _, is_dml_dll = tlist.token_next_by(t=ttypes)
116 fidx, first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN)
117 if first is None:
118 return
119
120 with indent(self, 1 if is_dml_dll else 0):
121 tlist.tokens.insert(0, self.nl()) if is_dml_dll else None
122 with offset(self, self._get_offset(first) + 1):
123 self._process_default(tlist, not is_dml_dll)
124
125 def _process_function(self, tlist):
126 self._last_func = tlist[0]
127 self._process_default(tlist)
128
129 def _process_identifierlist(self, tlist):
130 identifiers = list(tlist.get_identifiers())
131 if self.indent_columns:
132 first = next(identifiers[0].flatten())
133 num_offset = 1 if self.char == '\t' else self.width
134 else:
135 first = next(identifiers.pop(0).flatten())
136 num_offset = 1 if self.char == '\t' else self._get_offset(first)
137
138 if not tlist.within(sql.Function) and not tlist.within(sql.Values):
139 with offset(self, num_offset):
140 position = 0
141 for token in identifiers:
142 # Add 1 for the "," separator
143 position += len(token.value) + 1
144 if position > (self.wrap_after - self.offset):
145 adjust = 0
146 if self.comma_first:
147 adjust = -2
148 _, comma = tlist.token_prev(
149 tlist.token_index(token))
150 if comma is None:
151 continue
152 token = comma
153 tlist.insert_before(token, self.nl(offset=adjust))
154 if self.comma_first:
155 _, ws = tlist.token_next(
156 tlist.token_index(token), skip_ws=False)
157 if (ws is not None
158 and ws.ttype is not T.Text.Whitespace):
159 tlist.insert_after(
160 token, sql.Token(T.Whitespace, ' '))
161 position = 0
162 else:
163 # ensure whitespace
164 for token in tlist:
165 _, next_ws = tlist.token_next(
166 tlist.token_index(token), skip_ws=False)
167 if token.value == ',' and not next_ws.is_whitespace:
168 tlist.insert_after(
169 token, sql.Token(T.Whitespace, ' '))
170
171 end_at = self.offset + sum(len(i.value) + 1 for i in identifiers)
172 adjusted_offset = 0
173 if (self.wrap_after > 0
174 and end_at > (self.wrap_after - self.offset)
175 and self._last_func):
176 adjusted_offset = -len(self._last_func.value) - 1
177
178 with offset(self, adjusted_offset), indent(self):
179 if adjusted_offset < 0:
180 tlist.insert_before(identifiers[0], self.nl())
181 position = 0
182 for token in identifiers:
183 # Add 1 for the "," separator
184 position += len(token.value) + 1
185 if (self.wrap_after > 0
186 and position > (self.wrap_after - self.offset)):
187 adjust = 0
188 tlist.insert_before(token, self.nl(offset=adjust))
189 position = 0
190 self._process_default(tlist)
191
192 def _process_case(self, tlist):
193 iterable = iter(tlist.get_cases())
194 cond, _ = next(iterable)
195 first = next(cond[0].flatten())
196
197 with offset(self, self._get_offset(tlist[0])):
198 with offset(self, self._get_offset(first)):
199 for cond, value in iterable:
200 str_cond = ''.join(str(x) for x in cond or [])
201 str_value = ''.join(str(x) for x in value)
202 end_pos = self.offset + 1 + len(str_cond) + len(str_value)
203 if (not self.compact and end_pos > self.wrap_after):
204 token = value[0] if cond is None else cond[0]
205 tlist.insert_before(token, self.nl())
206
207 # Line breaks on group level are done. let's add an offset of
208 # len "when ", "then ", "else "
209 with offset(self, len("WHEN ")):
210 self._process_default(tlist)
211 end_idx, end = tlist.token_next_by(m=sql.Case.M_CLOSE)
212 if end_idx is not None and not self.compact:
213 tlist.insert_before(end_idx, self.nl())
214
215 def _process_values(self, tlist):
216 tlist.insert_before(0, self.nl())
217 tidx, token = tlist.token_next_by(i=sql.Parenthesis)
218 first_token = token
219 while token:
220 ptidx, ptoken = tlist.token_next_by(m=(T.Punctuation, ','),
221 idx=tidx)
222 if ptoken:
223 if self.comma_first:
224 adjust = -2
225 offset = self._get_offset(first_token) + adjust
226 tlist.insert_before(ptoken, self.nl(offset))
227 else:
228 tlist.insert_after(ptoken,
229 self.nl(self._get_offset(token)))
230 tidx, token = tlist.token_next_by(i=sql.Parenthesis, idx=tidx)
231
232 def _process_default(self, tlist, stmts=True):
233 self._split_statements(tlist) if stmts else None
234 self._split_kwds(tlist)
235 for sgroup in tlist.get_sublists():
236 self._process(sgroup)
237
238 def process(self, stmt):
239 self._curr_stmt = stmt
240 self._process(stmt)
241
242 if self._last_stmt is not None:
243 nl = '\n' if str(self._last_stmt).endswith('\n') else '\n\n'
244 stmt.tokens.insert(0, sql.Token(T.Whitespace, nl))
245
246 self._last_stmt = stmt
247 return stmt