mysteriendrama/lib/python3.11/site-packages/psycopg/_queries.py
2023-07-26 21:33:29 +02:00

376 lines
11 KiB
Python

"""
Utility module to manipulate queries
"""
# Copyright (C) 2020 The Psycopg Team
import re
from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional
from typing import Sequence, Tuple, Union, TYPE_CHECKING
from functools import lru_cache
from . import pq
from . import errors as e
from .sql import Composable
from .abc import Buffer, Query, Params
from ._enums import PyFormat
from ._encodings import conn_encoding
if TYPE_CHECKING:
from .abc import Transformer
class QueryPart(NamedTuple):
pre: bytes
item: Union[int, str]
format: PyFormat
class PostgresQuery:
"""
Helper to convert a Python query and parameters into Postgres format.
"""
__slots__ = """
query params types formats
_tx _want_formats _parts _encoding _order
""".split()
def __init__(self, transformer: "Transformer"):
self._tx = transformer
self.params: Optional[Sequence[Optional[Buffer]]] = None
# these are tuples so they can be used as keys e.g. in prepared stmts
self.types: Tuple[int, ...] = ()
# The format requested by the user and the ones to really pass Postgres
self._want_formats: Optional[List[PyFormat]] = None
self.formats: Optional[Sequence[pq.Format]] = None
self._encoding = conn_encoding(transformer.connection)
self._parts: List[QueryPart]
self.query = b""
self._order: Optional[List[str]] = None
def convert(self, query: Query, vars: Optional[Params]) -> None:
"""
Set up the query and parameters to convert.
The results of this function can be obtained accessing the object
attributes (`query`, `params`, `types`, `formats`).
"""
if isinstance(query, str):
bquery = query.encode(self._encoding)
elif isinstance(query, Composable):
bquery = query.as_bytes(self._tx)
else:
bquery = query
if vars is not None:
(
self.query,
self._want_formats,
self._order,
self._parts,
) = _query2pg(bquery, self._encoding)
else:
self.query = bquery
self._want_formats = self._order = None
self.dump(vars)
def dump(self, vars: Optional[Params]) -> None:
"""
Process a new set of variables on the query processed by `convert()`.
This method updates `params` and `types`.
"""
if vars is not None:
params = _validate_and_reorder_params(self._parts, vars, self._order)
assert self._want_formats is not None
self.params = self._tx.dump_sequence(params, self._want_formats)
self.types = self._tx.types or ()
self.formats = self._tx.formats
else:
self.params = None
self.types = ()
self.formats = None
class PostgresClientQuery(PostgresQuery):
"""
PostgresQuery subclass merging query and arguments client-side.
"""
__slots__ = ("template",)
def convert(self, query: Query, vars: Optional[Params]) -> None:
"""
Set up the query and parameters to convert.
The results of this function can be obtained accessing the object
attributes (`query`, `params`, `types`, `formats`).
"""
if isinstance(query, str):
bquery = query.encode(self._encoding)
elif isinstance(query, Composable):
bquery = query.as_bytes(self._tx)
else:
bquery = query
if vars is not None:
(self.template, self._order, self._parts) = _query2pg_client(
bquery, self._encoding
)
else:
self.query = bquery
self._order = None
self.dump(vars)
def dump(self, vars: Optional[Params]) -> None:
"""
Process a new set of variables on the query processed by `convert()`.
This method updates `params` and `types`.
"""
if vars is not None:
params = _validate_and_reorder_params(self._parts, vars, self._order)
self.params = tuple(
self._tx.as_literal(p) if p is not None else b"NULL" for p in params
)
self.query = self.template % self.params
else:
self.params = None
@lru_cache()
def _query2pg(
query: bytes, encoding: str
) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]:
"""
Convert Python query and params into something Postgres understands.
- Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
format (``$1``, ``$2``)
- placeholders can be %s, %t, or %b (auto, text or binary)
- return ``query`` (bytes), ``formats`` (list of formats) ``order``
(sequence of names used in the query, in the position they appear)
``parts`` (splits of queries and placeholders).
"""
parts = _split_query(query, encoding)
order: Optional[List[str]] = None
chunks: List[bytes] = []
formats = []
if isinstance(parts[0].item, int):
for part in parts[:-1]:
assert isinstance(part.item, int)
chunks.append(part.pre)
chunks.append(b"$%d" % (part.item + 1))
formats.append(part.format)
elif isinstance(parts[0].item, str):
seen: Dict[str, Tuple[bytes, PyFormat]] = {}
order = []
for part in parts[:-1]:
assert isinstance(part.item, str)
chunks.append(part.pre)
if part.item not in seen:
ph = b"$%d" % (len(seen) + 1)
seen[part.item] = (ph, part.format)
order.append(part.item)
chunks.append(ph)
formats.append(part.format)
else:
if seen[part.item][1] != part.format:
raise e.ProgrammingError(
f"placeholder '{part.item}' cannot have different formats"
)
chunks.append(seen[part.item][0])
# last part
chunks.append(parts[-1].pre)
return b"".join(chunks), formats, order, parts
@lru_cache()
def _query2pg_client(
query: bytes, encoding: str
) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]:
"""
Convert Python query and params into a template to perform client-side binding
"""
parts = _split_query(query, encoding, collapse_double_percent=False)
order: Optional[List[str]] = None
chunks: List[bytes] = []
if isinstance(parts[0].item, int):
for part in parts[:-1]:
assert isinstance(part.item, int)
chunks.append(part.pre)
chunks.append(b"%s")
elif isinstance(parts[0].item, str):
seen: Dict[str, Tuple[bytes, PyFormat]] = {}
order = []
for part in parts[:-1]:
assert isinstance(part.item, str)
chunks.append(part.pre)
if part.item not in seen:
ph = b"%s"
seen[part.item] = (ph, part.format)
order.append(part.item)
chunks.append(ph)
else:
chunks.append(seen[part.item][0])
order.append(part.item)
# last part
chunks.append(parts[-1].pre)
return b"".join(chunks), order, parts
def _validate_and_reorder_params(
parts: List[QueryPart], vars: Params, order: Optional[List[str]]
) -> Sequence[Any]:
"""
Verify the compatibility between a query and a set of params.
"""
# Try concrete types, then abstract types
t = type(vars)
if t is list or t is tuple:
sequence = True
elif t is dict:
sequence = False
elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
sequence = True
elif isinstance(vars, Mapping):
sequence = False
else:
raise TypeError(
"query parameters should be a sequence or a mapping,"
f" got {type(vars).__name__}"
)
if sequence:
if len(vars) != len(parts) - 1:
raise e.ProgrammingError(
f"the query has {len(parts) - 1} placeholders but"
f" {len(vars)} parameters were passed"
)
if vars and not isinstance(parts[0].item, int):
raise TypeError("named placeholders require a mapping of parameters")
return vars # type: ignore[return-value]
else:
if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
raise TypeError(
"positional placeholders (%s) require a sequence of parameters"
)
try:
return [vars[item] for item in order or ()] # type: ignore[call-overload]
except KeyError:
raise e.ProgrammingError(
"query parameter missing:"
f" {', '.join(sorted(i for i in order or () if i not in vars))}"
)
_re_placeholder = re.compile(
rb"""(?x)
% # a literal %
(?:
(?:
\( ([^)]+) \) # or a name in (braces)
. # followed by a format
)
|
(?:.) # or any char, really
)
"""
)
def _split_query(
query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True
) -> List[QueryPart]:
parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
cur = 0
# pairs [(fragment, match], with the last match None
m = None
for m in _re_placeholder.finditer(query):
pre = query[cur : m.span(0)[0]]
parts.append((pre, m))
cur = m.span(0)[1]
if m:
parts.append((query[cur:], None))
else:
parts.append((query, None))
rv = []
# drop the "%%", validate
i = 0
phtype = None
while i < len(parts):
pre, m = parts[i]
if m is None:
# last part
rv.append(QueryPart(pre, 0, PyFormat.AUTO))
break
ph = m.group(0)
if ph == b"%%":
# unescape '%%' to '%' if necessary, then merge the parts
if collapse_double_percent:
ph = b"%"
pre1, m1 = parts[i + 1]
parts[i + 1] = (pre + ph + pre1, m1)
del parts[i]
continue
if ph == b"%(":
raise e.ProgrammingError(
"incomplete placeholder:"
f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'"
)
elif ph == b"% ":
# explicit messasge for a typical error
raise e.ProgrammingError(
"incomplete placeholder: '%'; if you want to use '%' as an"
" operator you can double it up, i.e. use '%%'"
)
elif ph[-1:] not in b"sbt":
raise e.ProgrammingError(
"only '%s', '%b', '%t' are allowed as placeholders, got"
f" '{m.group(0).decode(encoding)}'"
)
# Index or name
item: Union[int, str]
item = m.group(1).decode(encoding) if m.group(1) else i
if not phtype:
phtype = type(item)
elif phtype is not type(item):
raise e.ProgrammingError(
"positional and named placeholders cannot be mixed"
)
format = _ph_to_fmt[ph[-1:]]
rv.append(QueryPart(pre, item, format))
i += 1
return rv
_ph_to_fmt = {
b"s": PyFormat.AUTO,
b"t": PyFormat.TEXT,
b"b": PyFormat.BINARY,
}