mysteriendrama/lib/python3.11/site-packages/psycopg/types/multirange.py
2023-07-26 21:33:29 +02:00

512 lines
16 KiB
Python

"""
Support for multirange types adaptation.
"""
# Copyright (C) 2021 The Psycopg Team
from decimal import Decimal
from typing import Any, Generic, List, Iterable
from typing import MutableSequence, Optional, Type, Union, overload
from datetime import date, datetime
from .. import errors as e
from .. import postgres
from ..pq import Format
from ..abc import AdaptContext, Buffer, Dumper, DumperKey
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
from .._struct import pack_len, unpack_len
from ..postgres import INVALID_OID, TEXT_OID
from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here
from .range import Range, T, load_range_text, load_range_binary
from .range import dump_range_text, dump_range_binary, fail_dump
class Multirange(MutableSequence[Range[T]]):
"""Python representation for a PostgreSQL multirange type.
:param items: Sequence of ranges to initialise the object.
"""
def __init__(self, items: Iterable[Range[T]] = ()):
self._ranges: List[Range[T]] = list(map(self._check_type, items))
def _check_type(self, item: Any) -> Range[Any]:
if not isinstance(item, Range):
raise TypeError(
f"Multirange is a sequence of Range, got {type(item).__name__}"
)
return item
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._ranges!r})"
def __str__(self) -> str:
return f"{{{', '.join(map(str, self._ranges))}}}"
@overload
def __getitem__(self, index: int) -> Range[T]:
...
@overload
def __getitem__(self, index: slice) -> "Multirange[T]":
...
def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]":
if isinstance(index, int):
return self._ranges[index]
else:
return Multirange(self._ranges[index])
def __len__(self) -> int:
return len(self._ranges)
@overload
def __setitem__(self, index: int, value: Range[T]) -> None:
...
@overload
def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None:
...
def __setitem__(
self,
index: Union[int, slice],
value: Union[Range[T], Iterable[Range[T]]],
) -> None:
if isinstance(index, int):
self._check_type(value)
self._ranges[index] = self._check_type(value)
elif not isinstance(value, Iterable):
raise TypeError("can only assign an iterable")
else:
value = map(self._check_type, value)
self._ranges[index] = value
def __delitem__(self, index: Union[int, slice]) -> None:
del self._ranges[index]
def insert(self, index: int, value: Range[T]) -> None:
self._ranges.insert(index, self._check_type(value))
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Multirange):
return False
return self._ranges == other._ranges
# Order is arbitrary but consistent
def __lt__(self, other: Any) -> bool:
if not isinstance(other, Multirange):
return NotImplemented
return self._ranges < other._ranges
def __le__(self, other: Any) -> bool:
return self == other or self < other # type: ignore
def __gt__(self, other: Any) -> bool:
if not isinstance(other, Multirange):
return NotImplemented
return self._ranges > other._ranges
def __ge__(self, other: Any) -> bool:
return self == other or self > other # type: ignore
# Subclasses to specify a specific subtype. Usually not needed
class Int4Multirange(Multirange[int]):
pass
class Int8Multirange(Multirange[int]):
pass
class NumericMultirange(Multirange[Decimal]):
pass
class DateMultirange(Multirange[date]):
pass
class TimestampMultirange(Multirange[datetime]):
pass
class TimestamptzMultirange(Multirange[datetime]):
pass
class BaseMultirangeDumper(RecursiveDumper):
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
self.sub_dumper: Optional[Dumper] = None
self._adapt_format = PyFormat.from_pq(self.format)
def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey:
# If we are a subclass whose oid is specified we don't need upgrade
if self.cls is not Multirange:
return self.cls
item = self._get_item(obj)
if item is not None:
sd = self._tx.get_dumper(item, self._adapt_format)
return (self.cls, sd.get_key(item, format))
else:
return (self.cls,)
def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper":
# If we are a subclass whose oid is specified we don't need upgrade
if self.cls is not Multirange:
return self
item = self._get_item(obj)
if item is None:
return MultirangeDumper(self.cls)
dumper: BaseMultirangeDumper
if type(item) is int:
# postgres won't cast int4range -> int8range so we must use
# text format and unknown oid here
sd = self._tx.get_dumper(item, PyFormat.TEXT)
dumper = MultirangeDumper(self.cls, self._tx)
dumper.sub_dumper = sd
dumper.oid = INVALID_OID
return dumper
sd = self._tx.get_dumper(item, format)
dumper = type(self)(self.cls, self._tx)
dumper.sub_dumper = sd
if sd.oid == INVALID_OID and isinstance(item, str):
# Work around the normal mapping where text is dumped as unknown
dumper.oid = self._get_multirange_oid(TEXT_OID)
else:
dumper.oid = self._get_multirange_oid(sd.oid)
return dumper
def _get_item(self, obj: Multirange[Any]) -> Any:
"""
Return a member representative of the multirange
"""
for r in obj:
if r.lower is not None:
return r.lower
if r.upper is not None:
return r.upper
return None
def _get_multirange_oid(self, sub_oid: int) -> int:
"""
Return the oid of the range from the oid of its elements.
"""
info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid)
return info.oid if info else INVALID_OID
class MultirangeDumper(BaseMultirangeDumper):
"""
Dumper for multirange types.
The dumper can upgrade to one specific for a different range type.
"""
def dump(self, obj: Multirange[Any]) -> Buffer:
if not obj:
return b"{}"
item = self._get_item(obj)
if item is not None:
dump = self._tx.get_dumper(item, self._adapt_format).dump
else:
dump = fail_dump
out: List[Buffer] = [b"{"]
for r in obj:
out.append(dump_range_text(r, dump))
out.append(b",")
out[-1] = b"}"
return b"".join(out)
class MultirangeBinaryDumper(BaseMultirangeDumper):
format = Format.BINARY
def dump(self, obj: Multirange[Any]) -> Buffer:
item = self._get_item(obj)
if item is not None:
dump = self._tx.get_dumper(item, self._adapt_format).dump
else:
dump = fail_dump
out: List[Buffer] = [pack_len(len(obj))]
for r in obj:
data = dump_range_binary(r, dump)
out.append(pack_len(len(data)))
out.append(data)
return b"".join(out)
class BaseMultirangeLoader(RecursiveLoader, Generic[T]):
subtype_oid: int
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
class MultirangeLoader(BaseMultirangeLoader[T]):
def load(self, data: Buffer) -> Multirange[T]:
if not data or data[0] != _START_INT:
raise e.DataError(
"malformed multirange starting with"
f" {bytes(data[:1]).decode('utf8', 'replace')}"
)
out = Multirange[T]()
if data == b"{}":
return out
pos = 1
data = data[pos:]
try:
while True:
r, pos = load_range_text(data, self._load)
out.append(r)
sep = data[pos] # can raise IndexError
if sep == _SEP_INT:
data = data[pos + 1 :]
continue
elif sep == _END_INT:
if len(data) == pos + 1:
return out
else:
raise e.DataError(
"malformed multirange: data after closing brace"
)
else:
raise e.DataError(
f"malformed multirange: found unexpected {chr(sep)}"
)
except IndexError:
raise e.DataError("malformed multirange: separator missing")
return out
_SEP_INT = ord(",")
_START_INT = ord("{")
_END_INT = ord("}")
class MultirangeBinaryLoader(BaseMultirangeLoader[T]):
format = Format.BINARY
def load(self, data: Buffer) -> Multirange[T]:
nelems = unpack_len(data, 0)[0]
pos = 4
out = Multirange[T]()
for i in range(nelems):
length = unpack_len(data, pos)[0]
pos += 4
out.append(load_range_binary(data[pos : pos + length], self._load))
pos += length
if pos != len(data):
raise e.DataError("unexpected trailing data in multirange")
return out
def register_multirange(
info: MultirangeInfo, context: Optional[AdaptContext] = None
) -> None:
"""Register the adapters to load and dump a multirange type.
:param info: The object with the information about the range to register.
:param context: The context where to register the adapters. If `!None`,
register it globally.
Register loaders so that loading data of this type will result in a `Range`
with bounds parsed as the right subtype.
.. note::
Registering the adapters doesn't affect objects already created, even
if they are children of the registered context. For instance,
registering the adapter globally doesn't affect already existing
connections.
"""
# A friendly error warning instead of an AttributeError in case fetch()
# failed and it wasn't noticed.
if not info:
raise TypeError("no info passed. Is the requested multirange available?")
# Register arrays and type info
info.register(context)
adapters = context.adapters if context else postgres.adapters
# generate and register a customized text loader
loader: Type[MultirangeLoader[Any]] = type(
f"{info.name.title()}Loader",
(MultirangeLoader,),
{"subtype_oid": info.subtype_oid},
)
adapters.register_loader(info.oid, loader)
# generate and register a customized binary loader
bloader: Type[MultirangeBinaryLoader[Any]] = type(
f"{info.name.title()}BinaryLoader",
(MultirangeBinaryLoader,),
{"subtype_oid": info.subtype_oid},
)
adapters.register_loader(info.oid, bloader)
# Text dumpers for builtin multirange types wrappers
# These are registered on specific subtypes so that the upgrade mechanism
# doesn't kick in.
class Int4MultirangeDumper(MultirangeDumper):
oid = postgres.types["int4multirange"].oid
class Int8MultirangeDumper(MultirangeDumper):
oid = postgres.types["int8multirange"].oid
class NumericMultirangeDumper(MultirangeDumper):
oid = postgres.types["nummultirange"].oid
class DateMultirangeDumper(MultirangeDumper):
oid = postgres.types["datemultirange"].oid
class TimestampMultirangeDumper(MultirangeDumper):
oid = postgres.types["tsmultirange"].oid
class TimestamptzMultirangeDumper(MultirangeDumper):
oid = postgres.types["tstzmultirange"].oid
# Binary dumpers for builtin multirange types wrappers
# These are registered on specific subtypes so that the upgrade mechanism
# doesn't kick in.
class Int4MultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["int4multirange"].oid
class Int8MultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["int8multirange"].oid
class NumericMultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["nummultirange"].oid
class DateMultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["datemultirange"].oid
class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["tsmultirange"].oid
class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper):
oid = postgres.types["tstzmultirange"].oid
# Text loaders for builtin multirange types
class Int4MultirangeLoader(MultirangeLoader[int]):
subtype_oid = postgres.types["int4"].oid
class Int8MultirangeLoader(MultirangeLoader[int]):
subtype_oid = postgres.types["int8"].oid
class NumericMultirangeLoader(MultirangeLoader[Decimal]):
subtype_oid = postgres.types["numeric"].oid
class DateMultirangeLoader(MultirangeLoader[date]):
subtype_oid = postgres.types["date"].oid
class TimestampMultirangeLoader(MultirangeLoader[datetime]):
subtype_oid = postgres.types["timestamp"].oid
class TimestampTZMultirangeLoader(MultirangeLoader[datetime]):
subtype_oid = postgres.types["timestamptz"].oid
# Binary loaders for builtin multirange types
class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
subtype_oid = postgres.types["int4"].oid
class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
subtype_oid = postgres.types["int8"].oid
class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]):
subtype_oid = postgres.types["numeric"].oid
class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]):
subtype_oid = postgres.types["date"].oid
class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
subtype_oid = postgres.types["timestamp"].oid
class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
subtype_oid = postgres.types["timestamptz"].oid
def register_default_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper(Multirange, MultirangeBinaryDumper)
adapters.register_dumper(Multirange, MultirangeDumper)
adapters.register_dumper(Int4Multirange, Int4MultirangeDumper)
adapters.register_dumper(Int8Multirange, Int8MultirangeDumper)
adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
adapters.register_dumper(DateMultirange, DateMultirangeDumper)
adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper)
adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper)
adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper)
adapters.register_loader("int4multirange", Int4MultirangeLoader)
adapters.register_loader("int8multirange", Int8MultirangeLoader)
adapters.register_loader("nummultirange", NumericMultirangeLoader)
adapters.register_loader("datemultirange", DateMultirangeLoader)
adapters.register_loader("tsmultirange", TimestampMultirangeLoader)
adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader)
adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader)
adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader)
adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader)