mysteriendrama/lib/python3.11/site-packages/psycopg/_transform.py
2023-07-26 21:33:29 +02:00

351 lines
11 KiB
Python

"""
Helper object to transform values between Python and PostgreSQL
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import DefaultDict, TYPE_CHECKING
from collections import defaultdict
from typing_extensions import TypeAlias
from . import pq
from . import postgres
from . import errors as e
from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType
from .rows import Row, RowMaker
from .postgres import INVALID_OID, TEXT_OID
from ._encodings import pgconn_encoding
if TYPE_CHECKING:
from .abc import Dumper, Loader
from .adapt import AdaptersMap
from .pq.abc import PGresult
from .connection import BaseConnection
DumperCache: TypeAlias = Dict[DumperKey, "Dumper"]
OidDumperCache: TypeAlias = Dict[int, "Dumper"]
LoaderCache: TypeAlias = Dict[int, "Loader"]
TEXT = pq.Format.TEXT
PY_TEXT = PyFormat.TEXT
class Transformer(AdaptContext):
"""
An object that can adapt efficiently between Python and PostgreSQL.
The life cycle of the object is the query, so it is assumed that attributes
such as the server version or the connection encoding will not change. The
object have its state so adapting several values of the same type can be
optimised.
"""
__module__ = "psycopg.adapt"
__slots__ = """
types formats
_conn _adapters _pgresult _dumpers _loaders _encoding _none_oid
_oid_dumpers _oid_types _row_dumpers _row_loaders
""".split()
types: Optional[Tuple[int, ...]]
formats: Optional[List[pq.Format]]
_adapters: "AdaptersMap"
_pgresult: Optional["PGresult"]
_none_oid: int
def __init__(self, context: Optional[AdaptContext] = None):
self._pgresult = self.types = self.formats = None
# WARNING: don't store context, or you'll create a loop with the Cursor
if context:
self._adapters = context.adapters
self._conn = context.connection
else:
self._adapters = postgres.adapters
self._conn = None
# mapping fmt, class -> Dumper instance
self._dumpers: DefaultDict[PyFormat, DumperCache]
self._dumpers = defaultdict(dict)
# mapping fmt, oid -> Dumper instance
# Not often used, so create it only if needed.
self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]]
self._oid_dumpers = None
# mapping fmt, oid -> Loader instance
self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {})
self._row_dumpers: Optional[List["Dumper"]] = None
# sequence of load functions from value to python
# the length of the result columns
self._row_loaders: List[LoadFunc] = []
# mapping oid -> type sql representation
self._oid_types: Dict[int, bytes] = {}
self._encoding = ""
@classmethod
def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
"""
Return a Transformer from an AdaptContext.
If the context is a Transformer instance, just return it.
"""
if isinstance(context, Transformer):
return context
else:
return cls(context)
@property
def connection(self) -> Optional["BaseConnection[Any]"]:
return self._conn
@property
def encoding(self) -> str:
if not self._encoding:
conn = self.connection
self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8"
return self._encoding
@property
def adapters(self) -> "AdaptersMap":
return self._adapters
@property
def pgresult(self) -> Optional["PGresult"]:
return self._pgresult
def set_pgresult(
self,
result: Optional["PGresult"],
*,
set_loaders: bool = True,
format: Optional[pq.Format] = None,
) -> None:
self._pgresult = result
if not result:
self._nfields = self._ntuples = 0
if set_loaders:
self._row_loaders = []
return
self._ntuples = result.ntuples
nf = self._nfields = result.nfields
if not set_loaders:
return
if not nf:
self._row_loaders = []
return
fmt: pq.Format
fmt = result.fformat(0) if format is None else format # type: ignore
self._row_loaders = [
self.get_loader(result.ftype(i), fmt).load for i in range(nf)
]
def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types]
self.types = tuple(types)
self.formats = [format] * len(types)
def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
self._row_loaders = [self.get_loader(oid, format).load for oid in types]
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
) -> Sequence[Optional[Buffer]]:
nparams = len(params)
out: List[Optional[Buffer]] = [None] * nparams
# If we have dumpers, it means set_dumper_types had been called, in
# which case self.types and self.formats are set to sequences of the
# right size.
if self._row_dumpers:
for i in range(nparams):
param = params[i]
if param is not None:
out[i] = self._row_dumpers[i].dump(param)
return out
types = [self._get_none_oid()] * nparams
pqformats = [TEXT] * nparams
for i in range(nparams):
param = params[i]
if param is None:
continue
dumper = self.get_dumper(param, formats[i])
out[i] = dumper.dump(param)
types[i] = dumper.oid
pqformats[i] = dumper.format
self.types = tuple(types)
self.formats = pqformats
return out
def as_literal(self, obj: Any) -> bytes:
dumper = self.get_dumper(obj, PY_TEXT)
rv = dumper.quote(obj)
# If the result is quoted, and the oid not unknown or text,
# add an explicit type cast.
# Check the last char because the first one might be 'E'.
oid = dumper.oid
if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID:
try:
type_sql = self._oid_types[oid]
except KeyError:
ti = self.adapters.types.get(oid)
if ti:
if oid < 8192:
# builtin: prefer "timestamptz" to "timestamp with time zone"
type_sql = ti.name.encode(self.encoding)
else:
type_sql = ti.regtype.encode(self.encoding)
if oid == ti.array_oid:
type_sql += b"[]"
else:
type_sql = b""
self._oid_types[oid] = type_sql
if type_sql:
rv = b"%s::%s" % (rv, type_sql)
if not isinstance(rv, bytes):
rv = bytes(rv)
return rv
def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
"""
Return a Dumper instance to dump `!obj`.
"""
# Normally, the type of the object dictates how to dump it
key = type(obj)
# Reuse an existing Dumper class for objects of the same type
cache = self._dumpers[format]
try:
dumper = cache[key]
except KeyError:
# If it's the first time we see this type, look for a dumper
# configured for it.
dcls = self.adapters.get_dumper(key, format)
cache[key] = dumper = dcls(key, self)
# Check if the dumper requires an upgrade to handle this specific value
key1 = dumper.get_key(obj, format)
if key1 is key:
return dumper
# If it does, ask the dumper to create its own upgraded version
try:
return cache[key1]
except KeyError:
dumper = cache[key1] = dumper.upgrade(obj, format)
return dumper
def _get_none_oid(self) -> int:
try:
return self._none_oid
except AttributeError:
pass
try:
rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid
except KeyError:
raise e.InterfaceError("None dumper not found")
return rv
def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper":
"""
Return a Dumper to dump an object to the type with given oid.
"""
if not self._oid_dumpers:
self._oid_dumpers = ({}, {})
# Reuse an existing Dumper class for objects of the same type
cache = self._oid_dumpers[format]
try:
return cache[oid]
except KeyError:
# If it's the first time we see this type, look for a dumper
# configured for it.
dcls = self.adapters.get_dumper_by_oid(oid, format)
cache[oid] = dumper = dcls(NoneType, self)
return dumper
def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]:
res = self._pgresult
if not res:
raise e.InterfaceError("result not set")
if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
raise e.InterfaceError(
f"rows must be included between 0 and {self._ntuples}"
)
records = []
for row in range(row0, row1):
record: List[Any] = [None] * self._nfields
for col in range(self._nfields):
val = res.get_value(row, col)
if val is not None:
record[col] = self._row_loaders[col](val)
records.append(make_row(record))
return records
def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]:
res = self._pgresult
if not res:
return None
if not 0 <= row < self._ntuples:
return None
record: List[Any] = [None] * self._nfields
for col in range(self._nfields):
val = res.get_value(row, col)
if val is not None:
record[col] = self._row_loaders[col](val)
return make_row(record)
def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
if len(self._row_loaders) != len(record):
raise e.ProgrammingError(
f"cannot load sequence of {len(record)} items:"
f" {len(self._row_loaders)} loaders registered"
)
return tuple(
(self._row_loaders[i](val) if val is not None else None)
for i, val in enumerate(record)
)
def get_loader(self, oid: int, format: pq.Format) -> "Loader":
try:
return self._loaders[format][oid]
except KeyError:
pass
loader_cls = self._adapters.get_loader(oid, format)
if not loader_cls:
loader_cls = self._adapters.get_loader(INVALID_OID, format)
if not loader_cls:
raise e.InterfaceError("unknown oid loader not found")
loader = self._loaders[format][oid] = loader_cls(oid, self)
return loader