243 lines
9.3 KiB
Python
243 lines
9.3 KiB
Python
#
|
|
# Copyright (C) 2009-2020 the sqlparse authors and contributors
|
|
# <see AUTHORS file>
|
|
#
|
|
# This module is part of python-sqlparse and is released under
|
|
# the BSD License: https://opensource.org/licenses/BSD-3-Clause
|
|
|
|
from sqlparse import sql, tokens as T
|
|
from sqlparse.utils import offset, indent
|
|
|
|
|
|
class ReindentFilter:
|
|
def __init__(self, width=2, char=' ', wrap_after=0, n='\n',
|
|
comma_first=False, indent_after_first=False,
|
|
indent_columns=False):
|
|
self.n = n
|
|
self.width = width
|
|
self.char = char
|
|
self.indent = 1 if indent_after_first else 0
|
|
self.offset = 0
|
|
self.wrap_after = wrap_after
|
|
self.comma_first = comma_first
|
|
self.indent_columns = indent_columns
|
|
self._curr_stmt = None
|
|
self._last_stmt = None
|
|
self._last_func = None
|
|
|
|
def _flatten_up_to_token(self, token):
|
|
"""Yields all tokens up to token but excluding current."""
|
|
if token.is_group:
|
|
token = next(token.flatten())
|
|
|
|
for t in self._curr_stmt.flatten():
|
|
if t == token:
|
|
break
|
|
yield t
|
|
|
|
@property
|
|
def leading_ws(self):
|
|
return self.offset + self.indent * self.width
|
|
|
|
def _get_offset(self, token):
|
|
raw = ''.join(map(str, self._flatten_up_to_token(token)))
|
|
line = (raw or '\n').splitlines()[-1]
|
|
# Now take current offset into account and return relative offset.
|
|
return len(line) - len(self.char * self.leading_ws)
|
|
|
|
def nl(self, offset=0):
|
|
return sql.Token(
|
|
T.Whitespace,
|
|
self.n + self.char * max(0, self.leading_ws + offset))
|
|
|
|
def _next_token(self, tlist, idx=-1):
|
|
split_words = ('FROM', 'STRAIGHT_JOIN$', 'JOIN$', 'AND', 'OR',
|
|
'GROUP BY', 'ORDER BY', 'UNION', 'VALUES',
|
|
'SET', 'BETWEEN', 'EXCEPT', 'HAVING', 'LIMIT')
|
|
m_split = T.Keyword, split_words, True
|
|
tidx, token = tlist.token_next_by(m=m_split, idx=idx)
|
|
|
|
if token and token.normalized == 'BETWEEN':
|
|
tidx, token = self._next_token(tlist, tidx)
|
|
|
|
if token and token.normalized == 'AND':
|
|
tidx, token = self._next_token(tlist, tidx)
|
|
|
|
return tidx, token
|
|
|
|
def _split_kwds(self, tlist):
|
|
tidx, token = self._next_token(tlist)
|
|
while token:
|
|
pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
|
|
uprev = str(prev_)
|
|
|
|
if prev_ and prev_.is_whitespace:
|
|
del tlist.tokens[pidx]
|
|
tidx -= 1
|
|
|
|
if not (uprev.endswith('\n') or uprev.endswith('\r')):
|
|
tlist.insert_before(tidx, self.nl())
|
|
tidx += 1
|
|
|
|
tidx, token = self._next_token(tlist, tidx)
|
|
|
|
def _split_statements(self, tlist):
|
|
ttypes = T.Keyword.DML, T.Keyword.DDL
|
|
tidx, token = tlist.token_next_by(t=ttypes)
|
|
while token:
|
|
pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
|
|
if prev_ and prev_.is_whitespace:
|
|
del tlist.tokens[pidx]
|
|
tidx -= 1
|
|
# only break if it's not the first token
|
|
if prev_:
|
|
tlist.insert_before(tidx, self.nl())
|
|
tidx += 1
|
|
tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
|
|
|
|
def _process(self, tlist):
|
|
func_name = '_process_{cls}'.format(cls=type(tlist).__name__)
|
|
func = getattr(self, func_name.lower(), self._process_default)
|
|
func(tlist)
|
|
|
|
def _process_where(self, tlist):
|
|
tidx, token = tlist.token_next_by(m=(T.Keyword, 'WHERE'))
|
|
if not token:
|
|
return
|
|
# issue121, errors in statement fixed??
|
|
tlist.insert_before(tidx, self.nl())
|
|
with indent(self):
|
|
self._process_default(tlist)
|
|
|
|
def _process_parenthesis(self, tlist):
|
|
ttypes = T.Keyword.DML, T.Keyword.DDL
|
|
_, is_dml_dll = tlist.token_next_by(t=ttypes)
|
|
fidx, first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN)
|
|
if first is None:
|
|
return
|
|
|
|
with indent(self, 1 if is_dml_dll else 0):
|
|
tlist.tokens.insert(0, self.nl()) if is_dml_dll else None
|
|
with offset(self, self._get_offset(first) + 1):
|
|
self._process_default(tlist, not is_dml_dll)
|
|
|
|
def _process_function(self, tlist):
|
|
self._last_func = tlist[0]
|
|
self._process_default(tlist)
|
|
|
|
def _process_identifierlist(self, tlist):
|
|
identifiers = list(tlist.get_identifiers())
|
|
if self.indent_columns:
|
|
first = next(identifiers[0].flatten())
|
|
num_offset = 1 if self.char == '\t' else self.width
|
|
else:
|
|
first = next(identifiers.pop(0).flatten())
|
|
num_offset = 1 if self.char == '\t' else self._get_offset(first)
|
|
|
|
if not tlist.within(sql.Function) and not tlist.within(sql.Values):
|
|
with offset(self, num_offset):
|
|
position = 0
|
|
for token in identifiers:
|
|
# Add 1 for the "," separator
|
|
position += len(token.value) + 1
|
|
if position > (self.wrap_after - self.offset):
|
|
adjust = 0
|
|
if self.comma_first:
|
|
adjust = -2
|
|
_, comma = tlist.token_prev(
|
|
tlist.token_index(token))
|
|
if comma is None:
|
|
continue
|
|
token = comma
|
|
tlist.insert_before(token, self.nl(offset=adjust))
|
|
if self.comma_first:
|
|
_, ws = tlist.token_next(
|
|
tlist.token_index(token), skip_ws=False)
|
|
if (ws is not None
|
|
and ws.ttype is not T.Text.Whitespace):
|
|
tlist.insert_after(
|
|
token, sql.Token(T.Whitespace, ' '))
|
|
position = 0
|
|
else:
|
|
# ensure whitespace
|
|
for token in tlist:
|
|
_, next_ws = tlist.token_next(
|
|
tlist.token_index(token), skip_ws=False)
|
|
if token.value == ',' and not next_ws.is_whitespace:
|
|
tlist.insert_after(
|
|
token, sql.Token(T.Whitespace, ' '))
|
|
|
|
end_at = self.offset + sum(len(i.value) + 1 for i in identifiers)
|
|
adjusted_offset = 0
|
|
if (self.wrap_after > 0
|
|
and end_at > (self.wrap_after - self.offset)
|
|
and self._last_func):
|
|
adjusted_offset = -len(self._last_func.value) - 1
|
|
|
|
with offset(self, adjusted_offset), indent(self):
|
|
if adjusted_offset < 0:
|
|
tlist.insert_before(identifiers[0], self.nl())
|
|
position = 0
|
|
for token in identifiers:
|
|
# Add 1 for the "," separator
|
|
position += len(token.value) + 1
|
|
if (self.wrap_after > 0
|
|
and position > (self.wrap_after - self.offset)):
|
|
adjust = 0
|
|
tlist.insert_before(token, self.nl(offset=adjust))
|
|
position = 0
|
|
self._process_default(tlist)
|
|
|
|
def _process_case(self, tlist):
|
|
iterable = iter(tlist.get_cases())
|
|
cond, _ = next(iterable)
|
|
first = next(cond[0].flatten())
|
|
|
|
with offset(self, self._get_offset(tlist[0])):
|
|
with offset(self, self._get_offset(first)):
|
|
for cond, value in iterable:
|
|
token = value[0] if cond is None else cond[0]
|
|
tlist.insert_before(token, self.nl())
|
|
|
|
# Line breaks on group level are done. let's add an offset of
|
|
# len "when ", "then ", "else "
|
|
with offset(self, len("WHEN ")):
|
|
self._process_default(tlist)
|
|
end_idx, end = tlist.token_next_by(m=sql.Case.M_CLOSE)
|
|
if end_idx is not None:
|
|
tlist.insert_before(end_idx, self.nl())
|
|
|
|
def _process_values(self, tlist):
|
|
tlist.insert_before(0, self.nl())
|
|
tidx, token = tlist.token_next_by(i=sql.Parenthesis)
|
|
first_token = token
|
|
while token:
|
|
ptidx, ptoken = tlist.token_next_by(m=(T.Punctuation, ','),
|
|
idx=tidx)
|
|
if ptoken:
|
|
if self.comma_first:
|
|
adjust = -2
|
|
offset = self._get_offset(first_token) + adjust
|
|
tlist.insert_before(ptoken, self.nl(offset))
|
|
else:
|
|
tlist.insert_after(ptoken,
|
|
self.nl(self._get_offset(token)))
|
|
tidx, token = tlist.token_next_by(i=sql.Parenthesis, idx=tidx)
|
|
|
|
def _process_default(self, tlist, stmts=True):
|
|
self._split_statements(tlist) if stmts else None
|
|
self._split_kwds(tlist)
|
|
for sgroup in tlist.get_sublists():
|
|
self._process(sgroup)
|
|
|
|
def process(self, stmt):
|
|
self._curr_stmt = stmt
|
|
self._process(stmt)
|
|
|
|
if self._last_stmt is not None:
|
|
nl = '\n' if str(self._last_stmt).endswith('\n') else '\n\n'
|
|
stmt.tokens.insert(0, sql.Token(T.Whitespace, nl))
|
|
|
|
self._last_stmt = stmt
|
|
return stmt
|