mysteriendrama/lib/python3.11/site-packages/psycopg/types/enum.py
2023-07-26 21:33:29 +02:00

178 lines
5.0 KiB
Python

"""
Adapters for the enum type.
"""
from enum import Enum
from typing import Any, Dict, Generic, Optional, Mapping, Sequence
from typing import Tuple, Type, TypeVar, Union, cast
from typing_extensions import TypeAlias
from .. import postgres
from .. import errors as e
from ..pq import Format
from ..abc import AdaptContext
from ..adapt import Buffer, Dumper, Loader
from .._encodings import conn_encoding
from .._typeinfo import EnumInfo as EnumInfo # exported here
E = TypeVar("E", bound=Enum)
EnumDumpMap: TypeAlias = Dict[E, bytes]
EnumLoadMap: TypeAlias = Dict[bytes, E]
EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None]
class _BaseEnumLoader(Loader, Generic[E]):
"""
Loader for a specific Enum class
"""
enum: Type[E]
_load_map: EnumLoadMap[E]
def load(self, data: Buffer) -> E:
if not isinstance(data, bytes):
data = bytes(data)
try:
return self._load_map[data]
except KeyError:
enc = conn_encoding(self.connection)
label = data.decode(enc, "replace")
raise e.DataError(
f"bad member for enum {self.enum.__qualname__}: {label!r}"
)
class _BaseEnumDumper(Dumper, Generic[E]):
"""
Dumper for a specific Enum class
"""
enum: Type[E]
_dump_map: EnumDumpMap[E]
def dump(self, value: E) -> Buffer:
return self._dump_map[value]
class EnumDumper(Dumper):
"""
Dumper for a generic Enum class
"""
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
self._encoding = conn_encoding(self.connection)
def dump(self, value: E) -> Buffer:
return value.name.encode(self._encoding)
class EnumBinaryDumper(EnumDumper):
format = Format.BINARY
def register_enum(
info: EnumInfo,
context: Optional[AdaptContext] = None,
enum: Optional[Type[E]] = None,
*,
mapping: EnumMapping[E] = None,
) -> None:
"""Register the adapters to load and dump a enum type.
:param info: The object with the information about the enum to register.
:param context: The context where to register the adapters. If `!None`,
register it globally.
:param enum: Python enum type matching to the PostgreSQL one. If `!None`,
a new enum will be generated and exposed as `EnumInfo.enum`.
:param mapping: Override the mapping between `!enum` members and `!info`
labels.
"""
if not info:
raise TypeError("no info passed. Is the requested enum available?")
if enum is None:
enum = cast(Type[E], Enum(info.name.title(), info.labels, module=__name__))
info.enum = enum
adapters = context.adapters if context else postgres.adapters
info.register(context)
load_map = _make_load_map(info, enum, mapping, context)
attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map}
name = f"{info.name.title()}Loader"
loader = type(name, (_BaseEnumLoader,), attribs)
adapters.register_loader(info.oid, loader)
name = f"{info.name.title()}BinaryLoader"
loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY})
adapters.register_loader(info.oid, loader)
dump_map = _make_dump_map(info, enum, mapping, context)
attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map}
name = f"{enum.__name__}Dumper"
dumper = type(name, (_BaseEnumDumper,), attribs)
adapters.register_dumper(info.enum, dumper)
name = f"{enum.__name__}BinaryDumper"
dumper = type(name, (_BaseEnumDumper,), {**attribs, "format": Format.BINARY})
adapters.register_dumper(info.enum, dumper)
def _make_load_map(
info: EnumInfo,
enum: Type[E],
mapping: EnumMapping[E],
context: Optional[AdaptContext],
) -> EnumLoadMap[E]:
enc = conn_encoding(context.connection if context else None)
rv: EnumLoadMap[E] = {}
for label in info.labels:
try:
member = enum[label]
except KeyError:
# tolerate a missing enum, assuming it won't be used. If it is we
# will get a DataError on fetch.
pass
else:
rv[label.encode(enc)] = member
if mapping:
if isinstance(mapping, Mapping):
mapping = list(mapping.items())
for member, label in mapping:
rv[label.encode(enc)] = member
return rv
def _make_dump_map(
info: EnumInfo,
enum: Type[E],
mapping: EnumMapping[E],
context: Optional[AdaptContext],
) -> EnumDumpMap[E]:
enc = conn_encoding(context.connection if context else None)
rv: EnumDumpMap[E] = {}
for member in enum:
rv[member] = member.name.encode(enc)
if mapping:
if isinstance(mapping, Mapping):
mapping = list(mapping.items())
for member, label in mapping:
rv[member] = label.encode(enc)
return rv
def register_default_adapters(context: AdaptContext) -> None:
context.adapters.register_dumper(Enum, EnumBinaryDumper)
context.adapters.register_dumper(Enum, EnumDumper)