512 lines
16 KiB
Python
512 lines
16 KiB
Python
"""
|
|
Support for multirange types adaptation.
|
|
"""
|
|
|
|
# Copyright (C) 2021 The Psycopg Team
|
|
|
|
from decimal import Decimal
|
|
from typing import Any, Generic, List, Iterable
|
|
from typing import MutableSequence, Optional, Type, Union, overload
|
|
from datetime import date, datetime
|
|
|
|
from .. import errors as e
|
|
from .. import postgres
|
|
from ..pq import Format
|
|
from ..abc import AdaptContext, Buffer, Dumper, DumperKey
|
|
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
|
|
from .._struct import pack_len, unpack_len
|
|
from ..postgres import INVALID_OID, TEXT_OID
|
|
from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here
|
|
|
|
from .range import Range, T, load_range_text, load_range_binary
|
|
from .range import dump_range_text, dump_range_binary, fail_dump
|
|
|
|
|
|
class Multirange(MutableSequence[Range[T]]):
|
|
"""Python representation for a PostgreSQL multirange type.
|
|
|
|
:param items: Sequence of ranges to initialise the object.
|
|
"""
|
|
|
|
def __init__(self, items: Iterable[Range[T]] = ()):
|
|
self._ranges: List[Range[T]] = list(map(self._check_type, items))
|
|
|
|
def _check_type(self, item: Any) -> Range[Any]:
|
|
if not isinstance(item, Range):
|
|
raise TypeError(
|
|
f"Multirange is a sequence of Range, got {type(item).__name__}"
|
|
)
|
|
return item
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}({self._ranges!r})"
|
|
|
|
def __str__(self) -> str:
|
|
return f"{{{', '.join(map(str, self._ranges))}}}"
|
|
|
|
@overload
|
|
def __getitem__(self, index: int) -> Range[T]:
|
|
...
|
|
|
|
@overload
|
|
def __getitem__(self, index: slice) -> "Multirange[T]":
|
|
...
|
|
|
|
def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]":
|
|
if isinstance(index, int):
|
|
return self._ranges[index]
|
|
else:
|
|
return Multirange(self._ranges[index])
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._ranges)
|
|
|
|
@overload
|
|
def __setitem__(self, index: int, value: Range[T]) -> None:
|
|
...
|
|
|
|
@overload
|
|
def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None:
|
|
...
|
|
|
|
def __setitem__(
|
|
self,
|
|
index: Union[int, slice],
|
|
value: Union[Range[T], Iterable[Range[T]]],
|
|
) -> None:
|
|
if isinstance(index, int):
|
|
self._check_type(value)
|
|
self._ranges[index] = self._check_type(value)
|
|
elif not isinstance(value, Iterable):
|
|
raise TypeError("can only assign an iterable")
|
|
else:
|
|
value = map(self._check_type, value)
|
|
self._ranges[index] = value
|
|
|
|
def __delitem__(self, index: Union[int, slice]) -> None:
|
|
del self._ranges[index]
|
|
|
|
def insert(self, index: int, value: Range[T]) -> None:
|
|
self._ranges.insert(index, self._check_type(value))
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
if not isinstance(other, Multirange):
|
|
return False
|
|
return self._ranges == other._ranges
|
|
|
|
# Order is arbitrary but consistent
|
|
|
|
def __lt__(self, other: Any) -> bool:
|
|
if not isinstance(other, Multirange):
|
|
return NotImplemented
|
|
return self._ranges < other._ranges
|
|
|
|
def __le__(self, other: Any) -> bool:
|
|
return self == other or self < other # type: ignore
|
|
|
|
def __gt__(self, other: Any) -> bool:
|
|
if not isinstance(other, Multirange):
|
|
return NotImplemented
|
|
return self._ranges > other._ranges
|
|
|
|
def __ge__(self, other: Any) -> bool:
|
|
return self == other or self > other # type: ignore
|
|
|
|
|
|
# Subclasses to specify a specific subtype. Usually not needed
|
|
|
|
|
|
class Int4Multirange(Multirange[int]):
|
|
pass
|
|
|
|
|
|
class Int8Multirange(Multirange[int]):
|
|
pass
|
|
|
|
|
|
class NumericMultirange(Multirange[Decimal]):
|
|
pass
|
|
|
|
|
|
class DateMultirange(Multirange[date]):
|
|
pass
|
|
|
|
|
|
class TimestampMultirange(Multirange[datetime]):
|
|
pass
|
|
|
|
|
|
class TimestamptzMultirange(Multirange[datetime]):
|
|
pass
|
|
|
|
|
|
class BaseMultirangeDumper(RecursiveDumper):
|
|
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
|
|
super().__init__(cls, context)
|
|
self.sub_dumper: Optional[Dumper] = None
|
|
self._adapt_format = PyFormat.from_pq(self.format)
|
|
|
|
def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey:
|
|
# If we are a subclass whose oid is specified we don't need upgrade
|
|
if self.cls is not Multirange:
|
|
return self.cls
|
|
|
|
item = self._get_item(obj)
|
|
if item is not None:
|
|
sd = self._tx.get_dumper(item, self._adapt_format)
|
|
return (self.cls, sd.get_key(item, format))
|
|
else:
|
|
return (self.cls,)
|
|
|
|
def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper":
|
|
# If we are a subclass whose oid is specified we don't need upgrade
|
|
if self.cls is not Multirange:
|
|
return self
|
|
|
|
item = self._get_item(obj)
|
|
if item is None:
|
|
return MultirangeDumper(self.cls)
|
|
|
|
dumper: BaseMultirangeDumper
|
|
if type(item) is int:
|
|
# postgres won't cast int4range -> int8range so we must use
|
|
# text format and unknown oid here
|
|
sd = self._tx.get_dumper(item, PyFormat.TEXT)
|
|
dumper = MultirangeDumper(self.cls, self._tx)
|
|
dumper.sub_dumper = sd
|
|
dumper.oid = INVALID_OID
|
|
return dumper
|
|
|
|
sd = self._tx.get_dumper(item, format)
|
|
dumper = type(self)(self.cls, self._tx)
|
|
dumper.sub_dumper = sd
|
|
if sd.oid == INVALID_OID and isinstance(item, str):
|
|
# Work around the normal mapping where text is dumped as unknown
|
|
dumper.oid = self._get_multirange_oid(TEXT_OID)
|
|
else:
|
|
dumper.oid = self._get_multirange_oid(sd.oid)
|
|
|
|
return dumper
|
|
|
|
def _get_item(self, obj: Multirange[Any]) -> Any:
|
|
"""
|
|
Return a member representative of the multirange
|
|
"""
|
|
for r in obj:
|
|
if r.lower is not None:
|
|
return r.lower
|
|
if r.upper is not None:
|
|
return r.upper
|
|
return None
|
|
|
|
def _get_multirange_oid(self, sub_oid: int) -> int:
|
|
"""
|
|
Return the oid of the range from the oid of its elements.
|
|
"""
|
|
info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid)
|
|
return info.oid if info else INVALID_OID
|
|
|
|
|
|
class MultirangeDumper(BaseMultirangeDumper):
|
|
"""
|
|
Dumper for multirange types.
|
|
|
|
The dumper can upgrade to one specific for a different range type.
|
|
"""
|
|
|
|
def dump(self, obj: Multirange[Any]) -> Buffer:
|
|
if not obj:
|
|
return b"{}"
|
|
|
|
item = self._get_item(obj)
|
|
if item is not None:
|
|
dump = self._tx.get_dumper(item, self._adapt_format).dump
|
|
else:
|
|
dump = fail_dump
|
|
|
|
out: List[Buffer] = [b"{"]
|
|
for r in obj:
|
|
out.append(dump_range_text(r, dump))
|
|
out.append(b",")
|
|
out[-1] = b"}"
|
|
return b"".join(out)
|
|
|
|
|
|
class MultirangeBinaryDumper(BaseMultirangeDumper):
|
|
format = Format.BINARY
|
|
|
|
def dump(self, obj: Multirange[Any]) -> Buffer:
|
|
item = self._get_item(obj)
|
|
if item is not None:
|
|
dump = self._tx.get_dumper(item, self._adapt_format).dump
|
|
else:
|
|
dump = fail_dump
|
|
|
|
out: List[Buffer] = [pack_len(len(obj))]
|
|
for r in obj:
|
|
data = dump_range_binary(r, dump)
|
|
out.append(pack_len(len(data)))
|
|
out.append(data)
|
|
return b"".join(out)
|
|
|
|
|
|
class BaseMultirangeLoader(RecursiveLoader, Generic[T]):
|
|
subtype_oid: int
|
|
|
|
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
|
|
super().__init__(oid, context)
|
|
self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
|
|
|
|
|
|
class MultirangeLoader(BaseMultirangeLoader[T]):
|
|
def load(self, data: Buffer) -> Multirange[T]:
|
|
if not data or data[0] != _START_INT:
|
|
raise e.DataError(
|
|
"malformed multirange starting with"
|
|
f" {bytes(data[:1]).decode('utf8', 'replace')}"
|
|
)
|
|
|
|
out = Multirange[T]()
|
|
if data == b"{}":
|
|
return out
|
|
|
|
pos = 1
|
|
data = data[pos:]
|
|
try:
|
|
while True:
|
|
r, pos = load_range_text(data, self._load)
|
|
out.append(r)
|
|
|
|
sep = data[pos] # can raise IndexError
|
|
if sep == _SEP_INT:
|
|
data = data[pos + 1 :]
|
|
continue
|
|
elif sep == _END_INT:
|
|
if len(data) == pos + 1:
|
|
return out
|
|
else:
|
|
raise e.DataError(
|
|
"malformed multirange: data after closing brace"
|
|
)
|
|
else:
|
|
raise e.DataError(
|
|
f"malformed multirange: found unexpected {chr(sep)}"
|
|
)
|
|
|
|
except IndexError:
|
|
raise e.DataError("malformed multirange: separator missing")
|
|
|
|
return out
|
|
|
|
|
|
_SEP_INT = ord(",")
|
|
_START_INT = ord("{")
|
|
_END_INT = ord("}")
|
|
|
|
|
|
class MultirangeBinaryLoader(BaseMultirangeLoader[T]):
|
|
format = Format.BINARY
|
|
|
|
def load(self, data: Buffer) -> Multirange[T]:
|
|
nelems = unpack_len(data, 0)[0]
|
|
pos = 4
|
|
out = Multirange[T]()
|
|
for i in range(nelems):
|
|
length = unpack_len(data, pos)[0]
|
|
pos += 4
|
|
out.append(load_range_binary(data[pos : pos + length], self._load))
|
|
pos += length
|
|
|
|
if pos != len(data):
|
|
raise e.DataError("unexpected trailing data in multirange")
|
|
|
|
return out
|
|
|
|
|
|
def register_multirange(
|
|
info: MultirangeInfo, context: Optional[AdaptContext] = None
|
|
) -> None:
|
|
"""Register the adapters to load and dump a multirange type.
|
|
|
|
:param info: The object with the information about the range to register.
|
|
:param context: The context where to register the adapters. If `!None`,
|
|
register it globally.
|
|
|
|
Register loaders so that loading data of this type will result in a `Range`
|
|
with bounds parsed as the right subtype.
|
|
|
|
.. note::
|
|
|
|
Registering the adapters doesn't affect objects already created, even
|
|
if they are children of the registered context. For instance,
|
|
registering the adapter globally doesn't affect already existing
|
|
connections.
|
|
"""
|
|
# A friendly error warning instead of an AttributeError in case fetch()
|
|
# failed and it wasn't noticed.
|
|
if not info:
|
|
raise TypeError("no info passed. Is the requested multirange available?")
|
|
|
|
# Register arrays and type info
|
|
info.register(context)
|
|
|
|
adapters = context.adapters if context else postgres.adapters
|
|
|
|
# generate and register a customized text loader
|
|
loader: Type[MultirangeLoader[Any]] = type(
|
|
f"{info.name.title()}Loader",
|
|
(MultirangeLoader,),
|
|
{"subtype_oid": info.subtype_oid},
|
|
)
|
|
adapters.register_loader(info.oid, loader)
|
|
|
|
# generate and register a customized binary loader
|
|
bloader: Type[MultirangeBinaryLoader[Any]] = type(
|
|
f"{info.name.title()}BinaryLoader",
|
|
(MultirangeBinaryLoader,),
|
|
{"subtype_oid": info.subtype_oid},
|
|
)
|
|
adapters.register_loader(info.oid, bloader)
|
|
|
|
|
|
# Text dumpers for builtin multirange types wrappers
|
|
# These are registered on specific subtypes so that the upgrade mechanism
|
|
# doesn't kick in.
|
|
|
|
|
|
class Int4MultirangeDumper(MultirangeDumper):
|
|
oid = postgres.types["int4multirange"].oid
|
|
|
|
|
|
class Int8MultirangeDumper(MultirangeDumper):
|
|
oid = postgres.types["int8multirange"].oid
|
|
|
|
|
|
class NumericMultirangeDumper(MultirangeDumper):
|
|
oid = postgres.types["nummultirange"].oid
|
|
|
|
|
|
class DateMultirangeDumper(MultirangeDumper):
|
|
oid = postgres.types["datemultirange"].oid
|
|
|
|
|
|
class TimestampMultirangeDumper(MultirangeDumper):
|
|
oid = postgres.types["tsmultirange"].oid
|
|
|
|
|
|
class TimestamptzMultirangeDumper(MultirangeDumper):
|
|
oid = postgres.types["tstzmultirange"].oid
|
|
|
|
|
|
# Binary dumpers for builtin multirange types wrappers
|
|
# These are registered on specific subtypes so that the upgrade mechanism
|
|
# doesn't kick in.
|
|
|
|
|
|
class Int4MultirangeBinaryDumper(MultirangeBinaryDumper):
|
|
oid = postgres.types["int4multirange"].oid
|
|
|
|
|
|
class Int8MultirangeBinaryDumper(MultirangeBinaryDumper):
|
|
oid = postgres.types["int8multirange"].oid
|
|
|
|
|
|
class NumericMultirangeBinaryDumper(MultirangeBinaryDumper):
|
|
oid = postgres.types["nummultirange"].oid
|
|
|
|
|
|
class DateMultirangeBinaryDumper(MultirangeBinaryDumper):
|
|
oid = postgres.types["datemultirange"].oid
|
|
|
|
|
|
class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper):
|
|
oid = postgres.types["tsmultirange"].oid
|
|
|
|
|
|
class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper):
|
|
oid = postgres.types["tstzmultirange"].oid
|
|
|
|
|
|
# Text loaders for builtin multirange types
|
|
|
|
|
|
class Int4MultirangeLoader(MultirangeLoader[int]):
|
|
subtype_oid = postgres.types["int4"].oid
|
|
|
|
|
|
class Int8MultirangeLoader(MultirangeLoader[int]):
|
|
subtype_oid = postgres.types["int8"].oid
|
|
|
|
|
|
class NumericMultirangeLoader(MultirangeLoader[Decimal]):
|
|
subtype_oid = postgres.types["numeric"].oid
|
|
|
|
|
|
class DateMultirangeLoader(MultirangeLoader[date]):
|
|
subtype_oid = postgres.types["date"].oid
|
|
|
|
|
|
class TimestampMultirangeLoader(MultirangeLoader[datetime]):
|
|
subtype_oid = postgres.types["timestamp"].oid
|
|
|
|
|
|
class TimestampTZMultirangeLoader(MultirangeLoader[datetime]):
|
|
subtype_oid = postgres.types["timestamptz"].oid
|
|
|
|
|
|
# Binary loaders for builtin multirange types
|
|
|
|
|
|
class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
|
|
subtype_oid = postgres.types["int4"].oid
|
|
|
|
|
|
class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
|
|
subtype_oid = postgres.types["int8"].oid
|
|
|
|
|
|
class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]):
|
|
subtype_oid = postgres.types["numeric"].oid
|
|
|
|
|
|
class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]):
|
|
subtype_oid = postgres.types["date"].oid
|
|
|
|
|
|
class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
|
|
subtype_oid = postgres.types["timestamp"].oid
|
|
|
|
|
|
class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
|
|
subtype_oid = postgres.types["timestamptz"].oid
|
|
|
|
|
|
def register_default_adapters(context: AdaptContext) -> None:
|
|
adapters = context.adapters
|
|
adapters.register_dumper(Multirange, MultirangeBinaryDumper)
|
|
adapters.register_dumper(Multirange, MultirangeDumper)
|
|
adapters.register_dumper(Int4Multirange, Int4MultirangeDumper)
|
|
adapters.register_dumper(Int8Multirange, Int8MultirangeDumper)
|
|
adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
|
|
adapters.register_dumper(DateMultirange, DateMultirangeDumper)
|
|
adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
|
|
adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper)
|
|
adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
|
|
adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
|
|
adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
|
|
adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
|
|
adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper)
|
|
adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper)
|
|
adapters.register_loader("int4multirange", Int4MultirangeLoader)
|
|
adapters.register_loader("int8multirange", Int8MultirangeLoader)
|
|
adapters.register_loader("nummultirange", NumericMultirangeLoader)
|
|
adapters.register_loader("datemultirange", DateMultirangeLoader)
|
|
adapters.register_loader("tsmultirange", TimestampMultirangeLoader)
|
|
adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader)
|
|
adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader)
|
|
adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader)
|
|
adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
|
|
adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
|
|
adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
|
|
adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader)
|