""" Support for multirange types adaptation. """ # Copyright (C) 2021 The Psycopg Team from decimal import Decimal from typing import Any, Generic, List, Iterable from typing import MutableSequence, Optional, Type, Union, overload from datetime import date, datetime from .. import errors as e from .. import postgres from ..pq import Format from ..abc import AdaptContext, Buffer, Dumper, DumperKey from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat from .._struct import pack_len, unpack_len from ..postgres import INVALID_OID, TEXT_OID from .._typeinfo import MultirangeInfo as MultirangeInfo # exported here from .range import Range, T, load_range_text, load_range_binary from .range import dump_range_text, dump_range_binary, fail_dump class Multirange(MutableSequence[Range[T]]): """Python representation for a PostgreSQL multirange type. :param items: Sequence of ranges to initialise the object. """ def __init__(self, items: Iterable[Range[T]] = ()): self._ranges: List[Range[T]] = list(map(self._check_type, items)) def _check_type(self, item: Any) -> Range[Any]: if not isinstance(item, Range): raise TypeError( f"Multirange is a sequence of Range, got {type(item).__name__}" ) return item def __repr__(self) -> str: return f"{self.__class__.__name__}({self._ranges!r})" def __str__(self) -> str: return f"{{{', '.join(map(str, self._ranges))}}}" @overload def __getitem__(self, index: int) -> Range[T]: ... @overload def __getitem__(self, index: slice) -> "Multirange[T]": ... def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]": if isinstance(index, int): return self._ranges[index] else: return Multirange(self._ranges[index]) def __len__(self) -> int: return len(self._ranges) @overload def __setitem__(self, index: int, value: Range[T]) -> None: ... @overload def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None: ... def __setitem__( self, index: Union[int, slice], value: Union[Range[T], Iterable[Range[T]]], ) -> None: if isinstance(index, int): self._check_type(value) self._ranges[index] = self._check_type(value) elif not isinstance(value, Iterable): raise TypeError("can only assign an iterable") else: value = map(self._check_type, value) self._ranges[index] = value def __delitem__(self, index: Union[int, slice]) -> None: del self._ranges[index] def insert(self, index: int, value: Range[T]) -> None: self._ranges.insert(index, self._check_type(value)) def __eq__(self, other: Any) -> bool: if not isinstance(other, Multirange): return False return self._ranges == other._ranges # Order is arbitrary but consistent def __lt__(self, other: Any) -> bool: if not isinstance(other, Multirange): return NotImplemented return self._ranges < other._ranges def __le__(self, other: Any) -> bool: return self == other or self < other # type: ignore def __gt__(self, other: Any) -> bool: if not isinstance(other, Multirange): return NotImplemented return self._ranges > other._ranges def __ge__(self, other: Any) -> bool: return self == other or self > other # type: ignore # Subclasses to specify a specific subtype. Usually not needed class Int4Multirange(Multirange[int]): pass class Int8Multirange(Multirange[int]): pass class NumericMultirange(Multirange[Decimal]): pass class DateMultirange(Multirange[date]): pass class TimestampMultirange(Multirange[datetime]): pass class TimestamptzMultirange(Multirange[datetime]): pass class BaseMultirangeDumper(RecursiveDumper): def __init__(self, cls: type, context: Optional[AdaptContext] = None): super().__init__(cls, context) self.sub_dumper: Optional[Dumper] = None self._adapt_format = PyFormat.from_pq(self.format) def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey: # If we are a subclass whose oid is specified we don't need upgrade if self.cls is not Multirange: return self.cls item = self._get_item(obj) if item is not None: sd = self._tx.get_dumper(item, self._adapt_format) return (self.cls, sd.get_key(item, format)) else: return (self.cls,) def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper": # If we are a subclass whose oid is specified we don't need upgrade if self.cls is not Multirange: return self item = self._get_item(obj) if item is None: return MultirangeDumper(self.cls) dumper: BaseMultirangeDumper if type(item) is int: # postgres won't cast int4range -> int8range so we must use # text format and unknown oid here sd = self._tx.get_dumper(item, PyFormat.TEXT) dumper = MultirangeDumper(self.cls, self._tx) dumper.sub_dumper = sd dumper.oid = INVALID_OID return dumper sd = self._tx.get_dumper(item, format) dumper = type(self)(self.cls, self._tx) dumper.sub_dumper = sd if sd.oid == INVALID_OID and isinstance(item, str): # Work around the normal mapping where text is dumped as unknown dumper.oid = self._get_multirange_oid(TEXT_OID) else: dumper.oid = self._get_multirange_oid(sd.oid) return dumper def _get_item(self, obj: Multirange[Any]) -> Any: """ Return a member representative of the multirange """ for r in obj: if r.lower is not None: return r.lower if r.upper is not None: return r.upper return None def _get_multirange_oid(self, sub_oid: int) -> int: """ Return the oid of the range from the oid of its elements. """ info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid) return info.oid if info else INVALID_OID class MultirangeDumper(BaseMultirangeDumper): """ Dumper for multirange types. The dumper can upgrade to one specific for a different range type. """ def dump(self, obj: Multirange[Any]) -> Buffer: if not obj: return b"{}" item = self._get_item(obj) if item is not None: dump = self._tx.get_dumper(item, self._adapt_format).dump else: dump = fail_dump out: List[Buffer] = [b"{"] for r in obj: out.append(dump_range_text(r, dump)) out.append(b",") out[-1] = b"}" return b"".join(out) class MultirangeBinaryDumper(BaseMultirangeDumper): format = Format.BINARY def dump(self, obj: Multirange[Any]) -> Buffer: item = self._get_item(obj) if item is not None: dump = self._tx.get_dumper(item, self._adapt_format).dump else: dump = fail_dump out: List[Buffer] = [pack_len(len(obj))] for r in obj: data = dump_range_binary(r, dump) out.append(pack_len(len(data))) out.append(data) return b"".join(out) class BaseMultirangeLoader(RecursiveLoader, Generic[T]): subtype_oid: int def __init__(self, oid: int, context: Optional[AdaptContext] = None): super().__init__(oid, context) self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load class MultirangeLoader(BaseMultirangeLoader[T]): def load(self, data: Buffer) -> Multirange[T]: if not data or data[0] != _START_INT: raise e.DataError( "malformed multirange starting with" f" {bytes(data[:1]).decode('utf8', 'replace')}" ) out = Multirange[T]() if data == b"{}": return out pos = 1 data = data[pos:] try: while True: r, pos = load_range_text(data, self._load) out.append(r) sep = data[pos] # can raise IndexError if sep == _SEP_INT: data = data[pos + 1 :] continue elif sep == _END_INT: if len(data) == pos + 1: return out else: raise e.DataError( "malformed multirange: data after closing brace" ) else: raise e.DataError( f"malformed multirange: found unexpected {chr(sep)}" ) except IndexError: raise e.DataError("malformed multirange: separator missing") return out _SEP_INT = ord(",") _START_INT = ord("{") _END_INT = ord("}") class MultirangeBinaryLoader(BaseMultirangeLoader[T]): format = Format.BINARY def load(self, data: Buffer) -> Multirange[T]: nelems = unpack_len(data, 0)[0] pos = 4 out = Multirange[T]() for i in range(nelems): length = unpack_len(data, pos)[0] pos += 4 out.append(load_range_binary(data[pos : pos + length], self._load)) pos += length if pos != len(data): raise e.DataError("unexpected trailing data in multirange") return out def register_multirange( info: MultirangeInfo, context: Optional[AdaptContext] = None ) -> None: """Register the adapters to load and dump a multirange type. :param info: The object with the information about the range to register. :param context: The context where to register the adapters. If `!None`, register it globally. Register loaders so that loading data of this type will result in a `Range` with bounds parsed as the right subtype. .. note:: Registering the adapters doesn't affect objects already created, even if they are children of the registered context. For instance, registering the adapter globally doesn't affect already existing connections. """ # A friendly error warning instead of an AttributeError in case fetch() # failed and it wasn't noticed. if not info: raise TypeError("no info passed. Is the requested multirange available?") # Register arrays and type info info.register(context) adapters = context.adapters if context else postgres.adapters # generate and register a customized text loader loader: Type[MultirangeLoader[Any]] = type( f"{info.name.title()}Loader", (MultirangeLoader,), {"subtype_oid": info.subtype_oid}, ) adapters.register_loader(info.oid, loader) # generate and register a customized binary loader bloader: Type[MultirangeBinaryLoader[Any]] = type( f"{info.name.title()}BinaryLoader", (MultirangeBinaryLoader,), {"subtype_oid": info.subtype_oid}, ) adapters.register_loader(info.oid, bloader) # Text dumpers for builtin multirange types wrappers # These are registered on specific subtypes so that the upgrade mechanism # doesn't kick in. class Int4MultirangeDumper(MultirangeDumper): oid = postgres.types["int4multirange"].oid class Int8MultirangeDumper(MultirangeDumper): oid = postgres.types["int8multirange"].oid class NumericMultirangeDumper(MultirangeDumper): oid = postgres.types["nummultirange"].oid class DateMultirangeDumper(MultirangeDumper): oid = postgres.types["datemultirange"].oid class TimestampMultirangeDumper(MultirangeDumper): oid = postgres.types["tsmultirange"].oid class TimestamptzMultirangeDumper(MultirangeDumper): oid = postgres.types["tstzmultirange"].oid # Binary dumpers for builtin multirange types wrappers # These are registered on specific subtypes so that the upgrade mechanism # doesn't kick in. class Int4MultirangeBinaryDumper(MultirangeBinaryDumper): oid = postgres.types["int4multirange"].oid class Int8MultirangeBinaryDumper(MultirangeBinaryDumper): oid = postgres.types["int8multirange"].oid class NumericMultirangeBinaryDumper(MultirangeBinaryDumper): oid = postgres.types["nummultirange"].oid class DateMultirangeBinaryDumper(MultirangeBinaryDumper): oid = postgres.types["datemultirange"].oid class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper): oid = postgres.types["tsmultirange"].oid class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper): oid = postgres.types["tstzmultirange"].oid # Text loaders for builtin multirange types class Int4MultirangeLoader(MultirangeLoader[int]): subtype_oid = postgres.types["int4"].oid class Int8MultirangeLoader(MultirangeLoader[int]): subtype_oid = postgres.types["int8"].oid class NumericMultirangeLoader(MultirangeLoader[Decimal]): subtype_oid = postgres.types["numeric"].oid class DateMultirangeLoader(MultirangeLoader[date]): subtype_oid = postgres.types["date"].oid class TimestampMultirangeLoader(MultirangeLoader[datetime]): subtype_oid = postgres.types["timestamp"].oid class TimestampTZMultirangeLoader(MultirangeLoader[datetime]): subtype_oid = postgres.types["timestamptz"].oid # Binary loaders for builtin multirange types class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]): subtype_oid = postgres.types["int4"].oid class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]): subtype_oid = postgres.types["int8"].oid class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]): subtype_oid = postgres.types["numeric"].oid class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]): subtype_oid = postgres.types["date"].oid class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]): subtype_oid = postgres.types["timestamp"].oid class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]): subtype_oid = postgres.types["timestamptz"].oid def register_default_adapters(context: AdaptContext) -> None: adapters = context.adapters adapters.register_dumper(Multirange, MultirangeBinaryDumper) adapters.register_dumper(Multirange, MultirangeDumper) adapters.register_dumper(Int4Multirange, Int4MultirangeDumper) adapters.register_dumper(Int8Multirange, Int8MultirangeDumper) adapters.register_dumper(NumericMultirange, NumericMultirangeDumper) adapters.register_dumper(DateMultirange, DateMultirangeDumper) adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper) adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper) adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper) adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper) adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper) adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper) adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper) adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper) adapters.register_loader("int4multirange", Int4MultirangeLoader) adapters.register_loader("int8multirange", Int8MultirangeLoader) adapters.register_loader("nummultirange", NumericMultirangeLoader) adapters.register_loader("datemultirange", DateMultirangeLoader) adapters.register_loader("tsmultirange", TimestampMultirangeLoader) adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader) adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader) adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader) adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader) adapters.register_loader("datemultirange", DateMultirangeBinaryLoader) adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader) adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader)