465 lines
14 KiB
Python
465 lines
14 KiB
Python
#
|
|
# Copyright (C) 2009-2020 the sqlparse authors and contributors
|
|
# <see AUTHORS file>
|
|
#
|
|
# This module is part of python-sqlparse and is released under
|
|
# the BSD License: https://opensource.org/licenses/BSD-3-Clause
|
|
|
|
from sqlparse import sql
|
|
from sqlparse import tokens as T
|
|
from sqlparse.utils import recurse, imt
|
|
|
|
T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float)
|
|
T_STRING = (T.String, T.String.Single, T.String.Symbol)
|
|
T_NAME = (T.Name, T.Name.Placeholder)
|
|
|
|
|
|
def _group_matching(tlist, cls):
|
|
"""Groups Tokens that have beginning and end."""
|
|
opens = []
|
|
tidx_offset = 0
|
|
for idx, token in enumerate(list(tlist)):
|
|
tidx = idx - tidx_offset
|
|
|
|
if token.is_whitespace:
|
|
# ~50% of tokens will be whitespace. Will checking early
|
|
# for them avoid 3 comparisons, but then add 1 more comparison
|
|
# for the other ~50% of tokens...
|
|
continue
|
|
|
|
if token.is_group and not isinstance(token, cls):
|
|
# Check inside previously grouped (i.e. parenthesis) if group
|
|
# of different type is inside (i.e., case). though ideally should
|
|
# should check for all open/close tokens at once to avoid recursion
|
|
_group_matching(token, cls)
|
|
continue
|
|
|
|
if token.match(*cls.M_OPEN):
|
|
opens.append(tidx)
|
|
|
|
elif token.match(*cls.M_CLOSE):
|
|
try:
|
|
open_idx = opens.pop()
|
|
except IndexError:
|
|
# this indicates invalid sql and unbalanced tokens.
|
|
# instead of break, continue in case other "valid" groups exist
|
|
continue
|
|
close_idx = tidx
|
|
tlist.group_tokens(cls, open_idx, close_idx)
|
|
tidx_offset += close_idx - open_idx
|
|
|
|
|
|
def group_brackets(tlist):
|
|
_group_matching(tlist, sql.SquareBrackets)
|
|
|
|
|
|
def group_parenthesis(tlist):
|
|
_group_matching(tlist, sql.Parenthesis)
|
|
|
|
|
|
def group_case(tlist):
|
|
_group_matching(tlist, sql.Case)
|
|
|
|
|
|
def group_if(tlist):
|
|
_group_matching(tlist, sql.If)
|
|
|
|
|
|
def group_for(tlist):
|
|
_group_matching(tlist, sql.For)
|
|
|
|
|
|
def group_begin(tlist):
|
|
_group_matching(tlist, sql.Begin)
|
|
|
|
|
|
def group_typecasts(tlist):
|
|
def match(token):
|
|
return token.match(T.Punctuation, '::')
|
|
|
|
def valid(token):
|
|
return token is not None
|
|
|
|
def post(tlist, pidx, tidx, nidx):
|
|
return pidx, nidx
|
|
|
|
valid_prev = valid_next = valid
|
|
_group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
|
|
|
|
|
|
def group_tzcasts(tlist):
|
|
def match(token):
|
|
return token.ttype == T.Keyword.TZCast
|
|
|
|
def valid_prev(token):
|
|
return token is not None
|
|
|
|
def valid_next(token):
|
|
return token is not None and (
|
|
token.is_whitespace
|
|
or token.match(T.Keyword, 'AS')
|
|
or token.match(*sql.TypedLiteral.M_CLOSE)
|
|
)
|
|
|
|
def post(tlist, pidx, tidx, nidx):
|
|
return pidx, nidx
|
|
|
|
_group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
|
|
|
|
|
|
def group_typed_literal(tlist):
|
|
# definitely not complete, see e.g.:
|
|
# https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literal-syntax
|
|
# https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literals
|
|
# https://www.postgresql.org/docs/9.1/datatype-datetime.html
|
|
# https://www.postgresql.org/docs/9.1/functions-datetime.html
|
|
def match(token):
|
|
return imt(token, m=sql.TypedLiteral.M_OPEN)
|
|
|
|
def match_to_extend(token):
|
|
return isinstance(token, sql.TypedLiteral)
|
|
|
|
def valid_prev(token):
|
|
return token is not None
|
|
|
|
def valid_next(token):
|
|
return token is not None and token.match(*sql.TypedLiteral.M_CLOSE)
|
|
|
|
def valid_final(token):
|
|
return token is not None and token.match(*sql.TypedLiteral.M_EXTEND)
|
|
|
|
def post(tlist, pidx, tidx, nidx):
|
|
return tidx, nidx
|
|
|
|
_group(tlist, sql.TypedLiteral, match, valid_prev, valid_next,
|
|
post, extend=False)
|
|
_group(tlist, sql.TypedLiteral, match_to_extend, valid_prev, valid_final,
|
|
post, extend=True)
|
|
|
|
|
|
def group_period(tlist):
|
|
def match(token):
|
|
return token.match(T.Punctuation, '.')
|
|
|
|
def valid_prev(token):
|
|
sqlcls = sql.SquareBrackets, sql.Identifier
|
|
ttypes = T.Name, T.String.Symbol
|
|
return imt(token, i=sqlcls, t=ttypes)
|
|
|
|
def valid_next(token):
|
|
# issue261, allow invalid next token
|
|
return True
|
|
|
|
def post(tlist, pidx, tidx, nidx):
|
|
# next_ validation is being performed here. issue261
|
|
sqlcls = sql.SquareBrackets, sql.Function
|
|
ttypes = T.Name, T.String.Symbol, T.Wildcard
|
|
next_ = tlist[nidx] if nidx is not None else None
|
|
valid_next = imt(next_, i=sqlcls, t=ttypes)
|
|
|
|
return (pidx, nidx) if valid_next else (pidx, tidx)
|
|
|
|
_group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
|
|
|
|
|
|
def group_as(tlist):
|
|
def match(token):
|
|
return token.is_keyword and token.normalized == 'AS'
|
|
|
|
def valid_prev(token):
|
|
return token.normalized == 'NULL' or not token.is_keyword
|
|
|
|
def valid_next(token):
|
|
ttypes = T.DML, T.DDL, T.CTE
|
|
return not imt(token, t=ttypes) and token is not None
|
|
|
|
def post(tlist, pidx, tidx, nidx):
|
|
return pidx, nidx
|
|
|
|
_group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
|
|
|
|
|
|
def group_assignment(tlist):
|
|
def match(token):
|
|
return token.match(T.Assignment, ':=')
|
|
|
|
def valid(token):
|
|
return token is not None and token.ttype not in (T.Keyword)
|
|
|
|
def post(tlist, pidx, tidx, nidx):
|
|
m_semicolon = T.Punctuation, ';'
|
|
snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx)
|
|
nidx = snidx or nidx
|
|
return pidx, nidx
|
|
|
|
valid_prev = valid_next = valid
|
|
_group(tlist, sql.Assignment, match, valid_prev, valid_next, post)
|
|
|
|
|
|
def group_comparison(tlist):
|
|
sqlcls = (sql.Parenthesis, sql.Function, sql.Identifier,
|
|
sql.Operation, sql.TypedLiteral)
|
|
ttypes = T_NUMERICAL + T_STRING + T_NAME
|
|
|
|
def match(token):
|
|
return token.ttype == T.Operator.Comparison
|
|
|
|
def valid(token):
|
|
if imt(token, t=ttypes, i=sqlcls):
|
|
return True
|
|
elif token and token.is_keyword and token.normalized == 'NULL':
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def post(tlist, pidx, tidx, nidx):
|
|
return pidx, nidx
|
|
|
|
valid_prev = valid_next = valid
|
|
_group(tlist, sql.Comparison, match,
|
|
valid_prev, valid_next, post, extend=False)
|
|
|
|
|
|
@recurse(sql.Identifier)
|
|
def group_identifier(tlist):
|
|
ttypes = (T.String.Symbol, T.Name)
|
|
|
|
tidx, token = tlist.token_next_by(t=ttypes)
|
|
while token:
|
|
tlist.group_tokens(sql.Identifier, tidx, tidx)
|
|
tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
|
|
|
|
|
|
def group_arrays(tlist):
|
|
sqlcls = sql.SquareBrackets, sql.Identifier, sql.Function
|
|
ttypes = T.Name, T.String.Symbol
|
|
|
|
def match(token):
|
|
return isinstance(token, sql.SquareBrackets)
|
|
|
|
def valid_prev(token):
|
|
return imt(token, i=sqlcls, t=ttypes)
|
|
|
|
def valid_next(token):
|
|
return True
|
|
|
|
def post(tlist, pidx, tidx, nidx):
|
|
return pidx, tidx
|
|
|
|
_group(tlist, sql.Identifier, match,
|
|
valid_prev, valid_next, post, extend=True, recurse=False)
|
|
|
|
|
|
def group_operator(tlist):
|
|
ttypes = T_NUMERICAL + T_STRING + T_NAME
|
|
sqlcls = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
|
|
sql.Identifier, sql.Operation, sql.TypedLiteral)
|
|
|
|
def match(token):
|
|
return imt(token, t=(T.Operator, T.Wildcard))
|
|
|
|
def valid(token):
|
|
return imt(token, i=sqlcls, t=ttypes) \
|
|
or (token and token.match(
|
|
T.Keyword,
|
|
('CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP')))
|
|
|
|
def post(tlist, pidx, tidx, nidx):
|
|
tlist[tidx].ttype = T.Operator
|
|
return pidx, nidx
|
|
|
|
valid_prev = valid_next = valid
|
|
_group(tlist, sql.Operation, match,
|
|
valid_prev, valid_next, post, extend=False)
|
|
|
|
|
|
def group_identifier_list(tlist):
|
|
m_role = T.Keyword, ('null', 'role')
|
|
sqlcls = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
|
|
sql.IdentifierList, sql.Operation)
|
|
ttypes = (T_NUMERICAL + T_STRING + T_NAME
|
|
+ (T.Keyword, T.Comment, T.Wildcard))
|
|
|
|
def match(token):
|
|
return token.match(T.Punctuation, ',')
|
|
|
|
def valid(token):
|
|
return imt(token, i=sqlcls, m=m_role, t=ttypes)
|
|
|
|
def post(tlist, pidx, tidx, nidx):
|
|
return pidx, nidx
|
|
|
|
valid_prev = valid_next = valid
|
|
_group(tlist, sql.IdentifierList, match,
|
|
valid_prev, valid_next, post, extend=True)
|
|
|
|
|
|
@recurse(sql.Comment)
|
|
def group_comments(tlist):
|
|
tidx, token = tlist.token_next_by(t=T.Comment)
|
|
while token:
|
|
eidx, end = tlist.token_not_matching(
|
|
lambda tk: imt(tk, t=T.Comment) or tk.is_whitespace, idx=tidx)
|
|
if end is not None:
|
|
eidx, end = tlist.token_prev(eidx, skip_ws=False)
|
|
tlist.group_tokens(sql.Comment, tidx, eidx)
|
|
|
|
tidx, token = tlist.token_next_by(t=T.Comment, idx=tidx)
|
|
|
|
|
|
@recurse(sql.Where)
|
|
def group_where(tlist):
|
|
tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN)
|
|
while token:
|
|
eidx, end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=tidx)
|
|
|
|
if end is None:
|
|
end = tlist._groupable_tokens[-1]
|
|
else:
|
|
end = tlist.tokens[eidx - 1]
|
|
# TODO: convert this to eidx instead of end token.
|
|
# i think above values are len(tlist) and eidx-1
|
|
eidx = tlist.token_index(end)
|
|
tlist.group_tokens(sql.Where, tidx, eidx)
|
|
tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=tidx)
|
|
|
|
|
|
@recurse()
|
|
def group_aliased(tlist):
|
|
I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier,
|
|
sql.Operation, sql.Comparison)
|
|
|
|
tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number)
|
|
while token:
|
|
nidx, next_ = tlist.token_next(tidx)
|
|
if isinstance(next_, sql.Identifier):
|
|
tlist.group_tokens(sql.Identifier, tidx, nidx, extend=True)
|
|
tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx)
|
|
|
|
|
|
@recurse(sql.Function)
|
|
def group_functions(tlist):
|
|
has_create = False
|
|
has_table = False
|
|
has_as = False
|
|
for tmp_token in tlist.tokens:
|
|
if tmp_token.value.upper() == 'CREATE':
|
|
has_create = True
|
|
if tmp_token.value.upper() == 'TABLE':
|
|
has_table = True
|
|
if tmp_token.value == 'AS':
|
|
has_as = True
|
|
if has_create and has_table and not has_as:
|
|
return
|
|
|
|
tidx, token = tlist.token_next_by(t=T.Name)
|
|
while token:
|
|
nidx, next_ = tlist.token_next(tidx)
|
|
if isinstance(next_, sql.Parenthesis):
|
|
tlist.group_tokens(sql.Function, tidx, nidx)
|
|
tidx, token = tlist.token_next_by(t=T.Name, idx=tidx)
|
|
|
|
|
|
def group_order(tlist):
|
|
"""Group together Identifier and Asc/Desc token"""
|
|
tidx, token = tlist.token_next_by(t=T.Keyword.Order)
|
|
while token:
|
|
pidx, prev_ = tlist.token_prev(tidx)
|
|
if imt(prev_, i=sql.Identifier, t=T.Number):
|
|
tlist.group_tokens(sql.Identifier, pidx, tidx)
|
|
tidx = pidx
|
|
tidx, token = tlist.token_next_by(t=T.Keyword.Order, idx=tidx)
|
|
|
|
|
|
@recurse()
|
|
def align_comments(tlist):
|
|
tidx, token = tlist.token_next_by(i=sql.Comment)
|
|
while token:
|
|
pidx, prev_ = tlist.token_prev(tidx)
|
|
if isinstance(prev_, sql.TokenList):
|
|
tlist.group_tokens(sql.TokenList, pidx, tidx, extend=True)
|
|
tidx = pidx
|
|
tidx, token = tlist.token_next_by(i=sql.Comment, idx=tidx)
|
|
|
|
|
|
def group_values(tlist):
|
|
tidx, token = tlist.token_next_by(m=(T.Keyword, 'VALUES'))
|
|
start_idx = tidx
|
|
end_idx = -1
|
|
while token:
|
|
if isinstance(token, sql.Parenthesis):
|
|
end_idx = tidx
|
|
tidx, token = tlist.token_next(tidx)
|
|
if end_idx != -1:
|
|
tlist.group_tokens(sql.Values, start_idx, end_idx, extend=True)
|
|
|
|
|
|
def group(stmt):
|
|
for func in [
|
|
group_comments,
|
|
|
|
# _group_matching
|
|
group_brackets,
|
|
group_parenthesis,
|
|
group_case,
|
|
group_if,
|
|
group_for,
|
|
group_begin,
|
|
|
|
group_functions,
|
|
group_where,
|
|
group_period,
|
|
group_arrays,
|
|
group_identifier,
|
|
group_order,
|
|
group_typecasts,
|
|
group_tzcasts,
|
|
group_typed_literal,
|
|
group_operator,
|
|
group_comparison,
|
|
group_as,
|
|
group_aliased,
|
|
group_assignment,
|
|
|
|
align_comments,
|
|
group_identifier_list,
|
|
group_values,
|
|
]:
|
|
func(stmt)
|
|
return stmt
|
|
|
|
|
|
def _group(tlist, cls, match,
|
|
valid_prev=lambda t: True,
|
|
valid_next=lambda t: True,
|
|
post=None,
|
|
extend=True,
|
|
recurse=True
|
|
):
|
|
"""Groups together tokens that are joined by a middle token. i.e. x < y"""
|
|
|
|
tidx_offset = 0
|
|
pidx, prev_ = None, None
|
|
for idx, token in enumerate(list(tlist)):
|
|
tidx = idx - tidx_offset
|
|
if tidx < 0: # tidx shouldn't get negative
|
|
continue
|
|
|
|
if token.is_whitespace:
|
|
continue
|
|
|
|
if recurse and token.is_group and not isinstance(token, cls):
|
|
_group(token, cls, match, valid_prev, valid_next, post, extend)
|
|
|
|
if match(token):
|
|
nidx, next_ = tlist.token_next(tidx)
|
|
if prev_ and valid_prev(prev_) and valid_next(next_):
|
|
from_idx, to_idx = post(tlist, pidx, tidx, nidx)
|
|
grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend)
|
|
|
|
tidx_offset += to_idx - from_idx
|
|
pidx, prev_ = from_idx, grp
|
|
continue
|
|
|
|
pidx, prev_ = tidx, token
|