""" Helper object to transform values between Python and PostgreSQL """ # Copyright (C) 2020 The Psycopg Team from typing import Any, Dict, List, Optional, Sequence, Tuple from typing import DefaultDict, TYPE_CHECKING from collections import defaultdict from typing_extensions import TypeAlias from . import pq from . import postgres from . import errors as e from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType from .rows import Row, RowMaker from .postgres import INVALID_OID, TEXT_OID from ._encodings import pgconn_encoding if TYPE_CHECKING: from .abc import Dumper, Loader from .adapt import AdaptersMap from .pq.abc import PGresult from .connection import BaseConnection DumperCache: TypeAlias = Dict[DumperKey, "Dumper"] OidDumperCache: TypeAlias = Dict[int, "Dumper"] LoaderCache: TypeAlias = Dict[int, "Loader"] TEXT = pq.Format.TEXT PY_TEXT = PyFormat.TEXT class Transformer(AdaptContext): """ An object that can adapt efficiently between Python and PostgreSQL. The life cycle of the object is the query, so it is assumed that attributes such as the server version or the connection encoding will not change. The object have its state so adapting several values of the same type can be optimised. """ __module__ = "psycopg.adapt" __slots__ = """ types formats _conn _adapters _pgresult _dumpers _loaders _encoding _none_oid _oid_dumpers _oid_types _row_dumpers _row_loaders """.split() types: Optional[Tuple[int, ...]] formats: Optional[List[pq.Format]] _adapters: "AdaptersMap" _pgresult: Optional["PGresult"] _none_oid: int def __init__(self, context: Optional[AdaptContext] = None): self._pgresult = self.types = self.formats = None # WARNING: don't store context, or you'll create a loop with the Cursor if context: self._adapters = context.adapters self._conn = context.connection else: self._adapters = postgres.adapters self._conn = None # mapping fmt, class -> Dumper instance self._dumpers: DefaultDict[PyFormat, DumperCache] self._dumpers = defaultdict(dict) # mapping fmt, oid -> Dumper instance # Not often used, so create it only if needed. self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]] self._oid_dumpers = None # mapping fmt, oid -> Loader instance self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {}) self._row_dumpers: Optional[List["Dumper"]] = None # sequence of load functions from value to python # the length of the result columns self._row_loaders: List[LoadFunc] = [] # mapping oid -> type sql representation self._oid_types: Dict[int, bytes] = {} self._encoding = "" @classmethod def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": """ Return a Transformer from an AdaptContext. If the context is a Transformer instance, just return it. """ if isinstance(context, Transformer): return context else: return cls(context) @property def connection(self) -> Optional["BaseConnection[Any]"]: return self._conn @property def encoding(self) -> str: if not self._encoding: conn = self.connection self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8" return self._encoding @property def adapters(self) -> "AdaptersMap": return self._adapters @property def pgresult(self) -> Optional["PGresult"]: return self._pgresult def set_pgresult( self, result: Optional["PGresult"], *, set_loaders: bool = True, format: Optional[pq.Format] = None, ) -> None: self._pgresult = result if not result: self._nfields = self._ntuples = 0 if set_loaders: self._row_loaders = [] return self._ntuples = result.ntuples nf = self._nfields = result.nfields if not set_loaders: return if not nf: self._row_loaders = [] return fmt: pq.Format fmt = result.fformat(0) if format is None else format # type: ignore self._row_loaders = [ self.get_loader(result.ftype(i), fmt).load for i in range(nf) ] def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types] self.types = tuple(types) self.formats = [format] * len(types) def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: self._row_loaders = [self.get_loader(oid, format).load for oid in types] def dump_sequence( self, params: Sequence[Any], formats: Sequence[PyFormat] ) -> Sequence[Optional[Buffer]]: nparams = len(params) out: List[Optional[Buffer]] = [None] * nparams # If we have dumpers, it means set_dumper_types had been called, in # which case self.types and self.formats are set to sequences of the # right size. if self._row_dumpers: for i in range(nparams): param = params[i] if param is not None: out[i] = self._row_dumpers[i].dump(param) return out types = [self._get_none_oid()] * nparams pqformats = [TEXT] * nparams for i in range(nparams): param = params[i] if param is None: continue dumper = self.get_dumper(param, formats[i]) out[i] = dumper.dump(param) types[i] = dumper.oid pqformats[i] = dumper.format self.types = tuple(types) self.formats = pqformats return out def as_literal(self, obj: Any) -> bytes: dumper = self.get_dumper(obj, PY_TEXT) rv = dumper.quote(obj) # If the result is quoted, and the oid not unknown or text, # add an explicit type cast. # Check the last char because the first one might be 'E'. oid = dumper.oid if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID: try: type_sql = self._oid_types[oid] except KeyError: ti = self.adapters.types.get(oid) if ti: if oid < 8192: # builtin: prefer "timestamptz" to "timestamp with time zone" type_sql = ti.name.encode(self.encoding) else: type_sql = ti.regtype.encode(self.encoding) if oid == ti.array_oid: type_sql += b"[]" else: type_sql = b"" self._oid_types[oid] = type_sql if type_sql: rv = b"%s::%s" % (rv, type_sql) if not isinstance(rv, bytes): rv = bytes(rv) return rv def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper": """ Return a Dumper instance to dump `!obj`. """ # Normally, the type of the object dictates how to dump it key = type(obj) # Reuse an existing Dumper class for objects of the same type cache = self._dumpers[format] try: dumper = cache[key] except KeyError: # If it's the first time we see this type, look for a dumper # configured for it. dcls = self.adapters.get_dumper(key, format) cache[key] = dumper = dcls(key, self) # Check if the dumper requires an upgrade to handle this specific value key1 = dumper.get_key(obj, format) if key1 is key: return dumper # If it does, ask the dumper to create its own upgraded version try: return cache[key1] except KeyError: dumper = cache[key1] = dumper.upgrade(obj, format) return dumper def _get_none_oid(self) -> int: try: return self._none_oid except AttributeError: pass try: rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid except KeyError: raise e.InterfaceError("None dumper not found") return rv def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper": """ Return a Dumper to dump an object to the type with given oid. """ if not self._oid_dumpers: self._oid_dumpers = ({}, {}) # Reuse an existing Dumper class for objects of the same type cache = self._oid_dumpers[format] try: return cache[oid] except KeyError: # If it's the first time we see this type, look for a dumper # configured for it. dcls = self.adapters.get_dumper_by_oid(oid, format) cache[oid] = dumper = dcls(NoneType, self) return dumper def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: res = self._pgresult if not res: raise e.InterfaceError("result not set") if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples): raise e.InterfaceError( f"rows must be included between 0 and {self._ntuples}" ) records = [] for row in range(row0, row1): record: List[Any] = [None] * self._nfields for col in range(self._nfields): val = res.get_value(row, col) if val is not None: record[col] = self._row_loaders[col](val) records.append(make_row(record)) return records def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: res = self._pgresult if not res: return None if not 0 <= row < self._ntuples: return None record: List[Any] = [None] * self._nfields for col in range(self._nfields): val = res.get_value(row, col) if val is not None: record[col] = self._row_loaders[col](val) return make_row(record) def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: if len(self._row_loaders) != len(record): raise e.ProgrammingError( f"cannot load sequence of {len(record)} items:" f" {len(self._row_loaders)} loaders registered" ) return tuple( (self._row_loaders[i](val) if val is not None else None) for i, val in enumerate(record) ) def get_loader(self, oid: int, format: pq.Format) -> "Loader": try: return self._loaders[format][oid] except KeyError: pass loader_cls = self._adapters.get_loader(oid, format) if not loader_cls: loader_cls = self._adapters.get_loader(INVALID_OID, format) if not loader_cls: raise e.InterfaceError("unknown oid loader not found") loader = self._loaders[format][oid] = loader_cls(oid, self) return loader