501 lines
14 KiB
Python
501 lines
14 KiB
Python
|
"""
|
||
|
Information about PostgreSQL types
|
||
|
|
||
|
These types allow to read information from the system catalog and provide
|
||
|
information to the adapters if needed.
|
||
|
"""
|
||
|
|
||
|
# Copyright (C) 2020 The Psycopg Team
|
||
|
from enum import Enum
|
||
|
from typing import Any, Dict, Iterator, Optional, overload
|
||
|
from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
|
||
|
from typing_extensions import TypeAlias
|
||
|
|
||
|
from . import errors as e
|
||
|
from .abc import AdaptContext, Query
|
||
|
from .rows import dict_row
|
||
|
from ._encodings import conn_encoding
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from .connection import BaseConnection, Connection
|
||
|
from .connection_async import AsyncConnection
|
||
|
from .sql import Identifier, SQL
|
||
|
|
||
|
T = TypeVar("T", bound="TypeInfo")
|
||
|
RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]]
|
||
|
|
||
|
|
||
|
class TypeInfo:
|
||
|
"""
|
||
|
Hold information about a PostgreSQL base type.
|
||
|
"""
|
||
|
|
||
|
__module__ = "psycopg.types"
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
name: str,
|
||
|
oid: int,
|
||
|
array_oid: int,
|
||
|
*,
|
||
|
regtype: str = "",
|
||
|
delimiter: str = ",",
|
||
|
):
|
||
|
self.name = name
|
||
|
self.oid = oid
|
||
|
self.array_oid = array_oid
|
||
|
self.regtype = regtype or name
|
||
|
self.delimiter = delimiter
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return (
|
||
|
f"<{self.__class__.__qualname__}:"
|
||
|
f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>"
|
||
|
)
|
||
|
|
||
|
@overload
|
||
|
@classmethod
|
||
|
def fetch(
|
||
|
cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"]
|
||
|
) -> Optional[T]:
|
||
|
...
|
||
|
|
||
|
@overload
|
||
|
@classmethod
|
||
|
async def fetch(
|
||
|
cls: Type[T], conn: "AsyncConnection[Any]", name: Union[str, "Identifier"]
|
||
|
) -> Optional[T]:
|
||
|
...
|
||
|
|
||
|
@classmethod
|
||
|
def fetch(
|
||
|
cls: Type[T], conn: "BaseConnection[Any]", name: Union[str, "Identifier"]
|
||
|
) -> Any:
|
||
|
"""Query a system catalog to read information about a type."""
|
||
|
from .sql import Composable
|
||
|
from .connection import Connection
|
||
|
from .connection_async import AsyncConnection
|
||
|
|
||
|
if isinstance(name, Composable):
|
||
|
name = name.as_string(conn)
|
||
|
|
||
|
if isinstance(conn, Connection):
|
||
|
return cls._fetch(conn, name)
|
||
|
elif isinstance(conn, AsyncConnection):
|
||
|
return cls._fetch_async(conn, name)
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
f"expected Connection or AsyncConnection, got {type(conn).__name__}"
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def _fetch(cls: Type[T], conn: "Connection[Any]", name: str) -> Optional[T]:
|
||
|
# This might result in a nested transaction. What we want is to leave
|
||
|
# the function with the connection in the state we found (either idle
|
||
|
# or intrans)
|
||
|
try:
|
||
|
with conn.transaction():
|
||
|
if conn_encoding(conn) == "ascii":
|
||
|
conn.execute("set local client_encoding to utf8")
|
||
|
with conn.cursor(row_factory=dict_row) as cur:
|
||
|
cur.execute(cls._get_info_query(conn), {"name": name})
|
||
|
recs = cur.fetchall()
|
||
|
except e.UndefinedObject:
|
||
|
return None
|
||
|
|
||
|
return cls._from_records(name, recs)
|
||
|
|
||
|
@classmethod
|
||
|
async def _fetch_async(
|
||
|
cls: Type[T], conn: "AsyncConnection[Any]", name: str
|
||
|
) -> Optional[T]:
|
||
|
try:
|
||
|
async with conn.transaction():
|
||
|
if conn_encoding(conn) == "ascii":
|
||
|
await conn.execute("set local client_encoding to utf8")
|
||
|
async with conn.cursor(row_factory=dict_row) as cur:
|
||
|
await cur.execute(cls._get_info_query(conn), {"name": name})
|
||
|
recs = await cur.fetchall()
|
||
|
except e.UndefinedObject:
|
||
|
return None
|
||
|
|
||
|
return cls._from_records(name, recs)
|
||
|
|
||
|
@classmethod
|
||
|
def _from_records(
|
||
|
cls: Type[T], name: str, recs: Sequence[Dict[str, Any]]
|
||
|
) -> Optional[T]:
|
||
|
if len(recs) == 1:
|
||
|
return cls(**recs[0])
|
||
|
elif not recs:
|
||
|
return None
|
||
|
else:
|
||
|
raise e.ProgrammingError(f"found {len(recs)} different types named {name}")
|
||
|
|
||
|
def register(self, context: Optional[AdaptContext] = None) -> None:
|
||
|
"""
|
||
|
Register the type information, globally or in the specified `!context`.
|
||
|
"""
|
||
|
if context:
|
||
|
types = context.adapters.types
|
||
|
else:
|
||
|
from . import postgres
|
||
|
|
||
|
types = postgres.types
|
||
|
|
||
|
types.add(self)
|
||
|
|
||
|
if self.array_oid:
|
||
|
from .types.array import register_array
|
||
|
|
||
|
register_array(self, context)
|
||
|
|
||
|
@classmethod
|
||
|
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
|
||
|
from .sql import SQL
|
||
|
|
||
|
return SQL(
|
||
|
"""\
|
||
|
SELECT
|
||
|
typname AS name, oid, typarray AS array_oid,
|
||
|
oid::regtype::text AS regtype, typdelim AS delimiter
|
||
|
FROM pg_type t
|
||
|
WHERE t.oid = {regtype}
|
||
|
ORDER BY t.oid
|
||
|
"""
|
||
|
).format(regtype=cls._to_regtype(conn))
|
||
|
|
||
|
@classmethod
|
||
|
def _has_to_regtype_function(cls, conn: "BaseConnection[Any]") -> bool:
|
||
|
# to_regtype() introduced in PostgreSQL 9.4 and CockroachDB 22.2
|
||
|
info = conn.info
|
||
|
if info.vendor == "PostgreSQL":
|
||
|
return info.server_version >= 90400
|
||
|
elif info.vendor == "CockroachDB":
|
||
|
return info.server_version >= 220200
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
@classmethod
|
||
|
def _to_regtype(cls, conn: "BaseConnection[Any]") -> "SQL":
|
||
|
# `to_regtype()` returns the type oid or NULL, unlike the :: operator,
|
||
|
# which returns the type or raises an exception, which requires
|
||
|
# a transaction rollback and leaves traces in the server logs.
|
||
|
|
||
|
from .sql import SQL
|
||
|
|
||
|
if cls._has_to_regtype_function(conn):
|
||
|
return SQL("to_regtype(%(name)s)")
|
||
|
else:
|
||
|
return SQL("%(name)s::regtype")
|
||
|
|
||
|
def _added(self, registry: "TypesRegistry") -> None:
|
||
|
"""Method called by the `!registry` when the object is added there."""
|
||
|
pass
|
||
|
|
||
|
|
||
|
class RangeInfo(TypeInfo):
|
||
|
"""Manage information about a range type."""
|
||
|
|
||
|
__module__ = "psycopg.types.range"
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
name: str,
|
||
|
oid: int,
|
||
|
array_oid: int,
|
||
|
*,
|
||
|
regtype: str = "",
|
||
|
subtype_oid: int,
|
||
|
):
|
||
|
super().__init__(name, oid, array_oid, regtype=regtype)
|
||
|
self.subtype_oid = subtype_oid
|
||
|
|
||
|
@classmethod
|
||
|
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
|
||
|
from .sql import SQL
|
||
|
|
||
|
return SQL(
|
||
|
"""\
|
||
|
SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
|
||
|
t.oid::regtype::text AS regtype,
|
||
|
r.rngsubtype AS subtype_oid
|
||
|
FROM pg_type t
|
||
|
JOIN pg_range r ON t.oid = r.rngtypid
|
||
|
WHERE t.oid = {regtype}
|
||
|
"""
|
||
|
).format(regtype=cls._to_regtype(conn))
|
||
|
|
||
|
def _added(self, registry: "TypesRegistry") -> None:
|
||
|
# Map ranges subtypes to info
|
||
|
registry._registry[RangeInfo, self.subtype_oid] = self
|
||
|
|
||
|
|
||
|
class MultirangeInfo(TypeInfo):
|
||
|
"""Manage information about a multirange type."""
|
||
|
|
||
|
__module__ = "psycopg.types.multirange"
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
name: str,
|
||
|
oid: int,
|
||
|
array_oid: int,
|
||
|
*,
|
||
|
regtype: str = "",
|
||
|
range_oid: int,
|
||
|
subtype_oid: int,
|
||
|
):
|
||
|
super().__init__(name, oid, array_oid, regtype=regtype)
|
||
|
self.range_oid = range_oid
|
||
|
self.subtype_oid = subtype_oid
|
||
|
|
||
|
@classmethod
|
||
|
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
|
||
|
from .sql import SQL
|
||
|
|
||
|
if conn.info.server_version < 140000:
|
||
|
raise e.NotSupportedError(
|
||
|
"multirange types are only available from PostgreSQL 14"
|
||
|
)
|
||
|
|
||
|
return SQL(
|
||
|
"""\
|
||
|
SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
|
||
|
t.oid::regtype::text AS regtype,
|
||
|
r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid
|
||
|
FROM pg_type t
|
||
|
JOIN pg_range r ON t.oid = r.rngmultitypid
|
||
|
WHERE t.oid = {regtype}
|
||
|
"""
|
||
|
).format(regtype=cls._to_regtype(conn))
|
||
|
|
||
|
def _added(self, registry: "TypesRegistry") -> None:
|
||
|
# Map multiranges ranges and subtypes to info
|
||
|
registry._registry[MultirangeInfo, self.range_oid] = self
|
||
|
registry._registry[MultirangeInfo, self.subtype_oid] = self
|
||
|
|
||
|
|
||
|
class CompositeInfo(TypeInfo):
|
||
|
"""Manage information about a composite type."""
|
||
|
|
||
|
__module__ = "psycopg.types.composite"
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
name: str,
|
||
|
oid: int,
|
||
|
array_oid: int,
|
||
|
*,
|
||
|
regtype: str = "",
|
||
|
field_names: Sequence[str],
|
||
|
field_types: Sequence[int],
|
||
|
):
|
||
|
super().__init__(name, oid, array_oid, regtype=regtype)
|
||
|
self.field_names = field_names
|
||
|
self.field_types = field_types
|
||
|
# Will be set by register() if the `factory` is a type
|
||
|
self.python_type: Optional[type] = None
|
||
|
|
||
|
@classmethod
|
||
|
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
|
||
|
from .sql import SQL
|
||
|
|
||
|
return SQL(
|
||
|
"""\
|
||
|
SELECT
|
||
|
t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
|
||
|
t.oid::regtype::text AS regtype,
|
||
|
coalesce(a.fnames, '{{}}') AS field_names,
|
||
|
coalesce(a.ftypes, '{{}}') AS field_types
|
||
|
FROM pg_type t
|
||
|
LEFT JOIN (
|
||
|
SELECT
|
||
|
attrelid,
|
||
|
array_agg(attname) AS fnames,
|
||
|
array_agg(atttypid) AS ftypes
|
||
|
FROM (
|
||
|
SELECT a.attrelid, a.attname, a.atttypid
|
||
|
FROM pg_attribute a
|
||
|
JOIN pg_type t ON t.typrelid = a.attrelid
|
||
|
WHERE t.oid = {regtype}
|
||
|
AND a.attnum > 0
|
||
|
AND NOT a.attisdropped
|
||
|
ORDER BY a.attnum
|
||
|
) x
|
||
|
GROUP BY attrelid
|
||
|
) a ON a.attrelid = t.typrelid
|
||
|
WHERE t.oid = {regtype}
|
||
|
"""
|
||
|
).format(regtype=cls._to_regtype(conn))
|
||
|
|
||
|
|
||
|
class EnumInfo(TypeInfo):
|
||
|
"""Manage information about an enum type."""
|
||
|
|
||
|
__module__ = "psycopg.types.enum"
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
name: str,
|
||
|
oid: int,
|
||
|
array_oid: int,
|
||
|
labels: Sequence[str],
|
||
|
):
|
||
|
super().__init__(name, oid, array_oid)
|
||
|
self.labels = labels
|
||
|
# Will be set by register_enum()
|
||
|
self.enum: Optional[Type[Enum]] = None
|
||
|
|
||
|
@classmethod
|
||
|
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
|
||
|
from .sql import SQL
|
||
|
|
||
|
return SQL(
|
||
|
"""\
|
||
|
SELECT name, oid, array_oid, array_agg(label) AS labels
|
||
|
FROM (
|
||
|
SELECT
|
||
|
t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
|
||
|
e.enumlabel AS label
|
||
|
FROM pg_type t
|
||
|
LEFT JOIN pg_enum e
|
||
|
ON e.enumtypid = t.oid
|
||
|
WHERE t.oid = {regtype}
|
||
|
ORDER BY e.enumsortorder
|
||
|
) x
|
||
|
GROUP BY name, oid, array_oid
|
||
|
"""
|
||
|
).format(regtype=cls._to_regtype(conn))
|
||
|
|
||
|
|
||
|
class TypesRegistry:
|
||
|
"""
|
||
|
Container for the information about types in a database.
|
||
|
"""
|
||
|
|
||
|
__module__ = "psycopg.types"
|
||
|
|
||
|
def __init__(self, template: Optional["TypesRegistry"] = None):
|
||
|
self._registry: Dict[RegistryKey, TypeInfo]
|
||
|
|
||
|
# Make a shallow copy: it will become a proper copy if the registry
|
||
|
# is edited.
|
||
|
if template:
|
||
|
self._registry = template._registry
|
||
|
self._own_state = False
|
||
|
template._own_state = False
|
||
|
else:
|
||
|
self.clear()
|
||
|
|
||
|
def clear(self) -> None:
|
||
|
self._registry = {}
|
||
|
self._own_state = True
|
||
|
|
||
|
def add(self, info: TypeInfo) -> None:
|
||
|
self._ensure_own_state()
|
||
|
if info.oid:
|
||
|
self._registry[info.oid] = info
|
||
|
if info.array_oid:
|
||
|
self._registry[info.array_oid] = info
|
||
|
self._registry[info.name] = info
|
||
|
|
||
|
if info.regtype and info.regtype not in self._registry:
|
||
|
self._registry[info.regtype] = info
|
||
|
|
||
|
# Allow info to customise further their relation with the registry
|
||
|
info._added(self)
|
||
|
|
||
|
def __iter__(self) -> Iterator[TypeInfo]:
|
||
|
seen = set()
|
||
|
for t in self._registry.values():
|
||
|
if id(t) not in seen:
|
||
|
seen.add(id(t))
|
||
|
yield t
|
||
|
|
||
|
@overload
|
||
|
def __getitem__(self, key: Union[str, int]) -> TypeInfo:
|
||
|
...
|
||
|
|
||
|
@overload
|
||
|
def __getitem__(self, key: Tuple[Type[T], int]) -> T:
|
||
|
...
|
||
|
|
||
|
def __getitem__(self, key: RegistryKey) -> TypeInfo:
|
||
|
"""
|
||
|
Return info about a type, specified by name or oid
|
||
|
|
||
|
:param key: the name or oid of the type to look for.
|
||
|
|
||
|
Raise KeyError if not found.
|
||
|
"""
|
||
|
if isinstance(key, str):
|
||
|
if key.endswith("[]"):
|
||
|
key = key[:-2]
|
||
|
elif not isinstance(key, (int, tuple)):
|
||
|
raise TypeError(f"the key must be an oid or a name, got {type(key)}")
|
||
|
try:
|
||
|
return self._registry[key]
|
||
|
except KeyError:
|
||
|
raise KeyError(f"couldn't find the type {key!r} in the types registry")
|
||
|
|
||
|
@overload
|
||
|
def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
|
||
|
...
|
||
|
|
||
|
@overload
|
||
|
def get(self, key: Tuple[Type[T], int]) -> Optional[T]:
|
||
|
...
|
||
|
|
||
|
def get(self, key: RegistryKey) -> Optional[TypeInfo]:
|
||
|
"""
|
||
|
Return info about a type, specified by name or oid
|
||
|
|
||
|
:param key: the name or oid of the type to look for.
|
||
|
|
||
|
Unlike `__getitem__`, return None if not found.
|
||
|
"""
|
||
|
try:
|
||
|
return self[key]
|
||
|
except KeyError:
|
||
|
return None
|
||
|
|
||
|
def get_oid(self, name: str) -> int:
|
||
|
"""
|
||
|
Return the oid of a PostgreSQL type by name.
|
||
|
|
||
|
:param key: the name of the type to look for.
|
||
|
|
||
|
Return the array oid if the type ends with "``[]``"
|
||
|
|
||
|
Raise KeyError if the name is unknown.
|
||
|
"""
|
||
|
t = self[name]
|
||
|
if name.endswith("[]"):
|
||
|
return t.array_oid
|
||
|
else:
|
||
|
return t.oid
|
||
|
|
||
|
def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]:
|
||
|
"""
|
||
|
Return info about a `TypeInfo` subclass by its element name or oid.
|
||
|
|
||
|
:param cls: the subtype of `!TypeInfo` to look for. Currently
|
||
|
supported are `~psycopg.types.range.RangeInfo` and
|
||
|
`~psycopg.types.multirange.MultirangeInfo`.
|
||
|
:param subtype: The name or OID of the subtype of the element to look for.
|
||
|
:return: The `!TypeInfo` object of class `!cls` whose subtype is
|
||
|
`!subtype`. `!None` if the element or its range are not found.
|
||
|
"""
|
||
|
try:
|
||
|
info = self[subtype]
|
||
|
except KeyError:
|
||
|
return None
|
||
|
return self.get((cls, info.oid))
|
||
|
|
||
|
def _ensure_own_state(self) -> None:
|
||
|
# Time to write! so, copy.
|
||
|
if not self._own_state:
|
||
|
self._registry = self._registry.copy()
|
||
|
self._own_state = True
|