mysteriendrama/lib/python3.11/site-packages/psycopg/types/array.py
2023-07-26 21:33:29 +02:00

461 lines
14 KiB
Python

"""
Adapters for arrays
"""
# Copyright (C) 2020 The Psycopg Team
import re
import struct
from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type
from .. import pq
from .. import errors as e
from .. import postgres
from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, Loader, Transformer
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
from .._compat import cache, prod
from .._struct import pack_len, unpack_len
from .._cmodule import _psycopg
from ..postgres import TEXT_OID, INVALID_OID
from .._typeinfo import TypeInfo
_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid
_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from)
_struct_dim = struct.Struct("!II") # dim, lower bound
_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from)
TEXT_ARRAY_OID = postgres.types["text"].array_oid
PY_TEXT = PyFormat.TEXT
PQ_BINARY = pq.Format.BINARY
class BaseListDumper(RecursiveDumper):
element_oid = 0
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
if cls is NoneType:
cls = list
super().__init__(cls, context)
self.sub_dumper: Optional[Dumper] = None
if self.element_oid and context:
sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format)
self.sub_dumper = sdclass(NoneType, context)
def _find_list_element(self, L: List[Any], format: PyFormat) -> Any:
"""
Find the first non-null element of an eventually nested list
"""
items = list(self._flatiter(L, set()))
types = {type(item): item for item in items}
if not types:
return None
if len(types) == 1:
t, v = types.popitem()
else:
# More than one type in the list. It might be still good, as long
# as they dump with the same oid (e.g. IPv4Network, IPv6Network).
dumpers = [self._tx.get_dumper(item, format) for item in types.values()]
oids = set(d.oid for d in dumpers)
if len(oids) == 1:
t, v = types.popitem()
else:
raise e.DataError(
"cannot dump lists of mixed types;"
f" got: {', '.join(sorted(t.__name__ for t in types))}"
)
# Checking for precise type. If the type is a subclass (e.g. Int4)
# we assume the user knows what type they are passing.
if t is not int:
return v
# If we got an int, let's see what is the biggest one in order to
# choose the smallest OID and allow Postgres to do the right cast.
imax: int = max(items)
imin: int = min(items)
if imin >= 0:
return imax
else:
return max(imax, -imin - 1)
def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
if id(L) in seen:
raise e.DataError("cannot dump a recursive list")
seen.add(id(L))
for item in L:
if type(item) is list:
yield from self._flatiter(item, seen)
elif item is not None:
yield item
return None
def _get_base_type_info(self, base_oid: int) -> TypeInfo:
"""
Return info about the base type.
Return text info as fallback.
"""
if base_oid:
info = self._tx.adapters.types.get(base_oid)
if info:
return info
return self._tx.adapters.types["text"]
class ListDumper(BaseListDumper):
delimiter = b","
def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
if self.oid:
return self.cls
item = self._find_list_element(obj, format)
if item is None:
return self.cls
sd = self._tx.get_dumper(item, format)
return (self.cls, sd.get_key(item, format))
def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
# If we have an oid we don't need to upgrade
if self.oid:
return self
item = self._find_list_element(obj, format)
if item is None:
# Empty lists can only be dumped as text if the type is unknown.
return self
sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
dumper = type(self)(self.cls, self._tx)
dumper.sub_dumper = sd
# We consider an array of unknowns as unknown, so we can dump empty
# lists or lists containing only None elements.
if sd.oid != INVALID_OID:
info = self._get_base_type_info(sd.oid)
dumper.oid = info.array_oid or TEXT_ARRAY_OID
dumper.delimiter = info.delimiter.encode()
else:
dumper.oid = INVALID_OID
return dumper
# Double quotes and backslashes embedded in element values will be
# backslash-escaped.
_re_esc = re.compile(rb'(["\\])')
def dump(self, obj: List[Any]) -> bytes:
tokens: List[Buffer] = []
needs_quotes = _get_needs_quotes_regexp(self.delimiter).search
def dump_list(obj: List[Any]) -> None:
if not obj:
tokens.append(b"{}")
return
tokens.append(b"{")
for item in obj:
if isinstance(item, list):
dump_list(item)
elif item is not None:
ad = self._dump_item(item)
if needs_quotes(ad):
if not isinstance(ad, bytes):
ad = bytes(ad)
ad = b'"' + self._re_esc.sub(rb"\\\1", ad) + b'"'
tokens.append(ad)
else:
tokens.append(b"NULL")
tokens.append(self.delimiter)
tokens[-1] = b"}"
dump_list(obj)
return b"".join(tokens)
def _dump_item(self, item: Any) -> Buffer:
if self.sub_dumper:
return self.sub_dumper.dump(item)
else:
return self._tx.get_dumper(item, PY_TEXT).dump(item)
@cache
def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]:
"""Return a regexp to recognise when a value needs quotes
from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
The array output routine will put double quotes around element values if
they are empty strings, contain curly braces, delimiter characters,
double quotes, backslashes, or white space, or match the word NULL.
"""
return re.compile(
rb"""(?xi)
^$ # the empty string
| ["{}%s\\\s] # or a char to escape
| ^null$ # or the word NULL
"""
% delimiter
)
class ListBinaryDumper(BaseListDumper):
format = pq.Format.BINARY
def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
if self.oid:
return self.cls
item = self._find_list_element(obj, format)
if item is None:
return (self.cls,)
sd = self._tx.get_dumper(item, format)
return (self.cls, sd.get_key(item, format))
def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
# If we have an oid we don't need to upgrade
if self.oid:
return self
item = self._find_list_element(obj, format)
if item is None:
return ListDumper(self.cls, self._tx)
sd = self._tx.get_dumper(item, format.from_pq(self.format))
dumper = type(self)(self.cls, self._tx)
dumper.sub_dumper = sd
info = self._get_base_type_info(sd.oid)
dumper.oid = info.array_oid or TEXT_ARRAY_OID
return dumper
def dump(self, obj: List[Any]) -> bytes:
# Postgres won't take unknown for element oid: fall back on text
sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID
if not obj:
return _pack_head(0, 0, sub_oid)
data: List[Buffer] = [b"", b""] # placeholders to avoid a resize
dims: List[int] = []
hasnull = 0
def calc_dims(L: List[Any]) -> None:
if isinstance(L, self.cls):
if not L:
raise e.DataError("lists cannot contain empty lists")
dims.append(len(L))
calc_dims(L[0])
calc_dims(obj)
def dump_list(L: List[Any], dim: int) -> None:
nonlocal hasnull
if len(L) != dims[dim]:
raise e.DataError("nested lists have inconsistent lengths")
if dim == len(dims) - 1:
for item in L:
if item is not None:
# If we get here, the sub_dumper must have been set
ad = self.sub_dumper.dump(item) # type: ignore[union-attr]
data.append(pack_len(len(ad)))
data.append(ad)
else:
hasnull = 1
data.append(b"\xff\xff\xff\xff")
else:
for item in L:
if not isinstance(item, self.cls):
raise e.DataError("nested lists have inconsistent depths")
dump_list(item, dim + 1) # type: ignore
dump_list(obj, 0)
data[0] = _pack_head(len(dims), hasnull, sub_oid)
data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
return b"".join(data)
class ArrayLoader(RecursiveLoader):
delimiter = b","
base_oid: int
def load(self, data: Buffer) -> List[Any]:
loader = self._tx.get_loader(self.base_oid, self.format)
return _load_text(data, loader, self.delimiter)
class ArrayBinaryLoader(RecursiveLoader):
format = pq.Format.BINARY
def load(self, data: Buffer) -> List[Any]:
return _load_binary(data, self._tx)
def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
if not info.array_oid:
raise ValueError(f"the type info {info} doesn't describe an array")
base: Type[Any]
adapters = context.adapters if context else postgres.adapters
base = getattr(_psycopg, "ArrayLoader", ArrayLoader)
name = f"{info.name.title()}{base.__name__}"
attribs = {
"base_oid": info.oid,
"delimiter": info.delimiter.encode(),
}
loader = type(name, (base,), attribs)
adapters.register_loader(info.array_oid, loader)
loader = getattr(_psycopg, "ArrayBinaryLoader", ArrayBinaryLoader)
adapters.register_loader(info.array_oid, loader)
base = ListDumper
name = f"{info.name.title()}{base.__name__}"
attribs = {
"oid": info.array_oid,
"element_oid": info.oid,
"delimiter": info.delimiter.encode(),
}
dumper = type(name, (base,), attribs)
adapters.register_dumper(None, dumper)
base = ListBinaryDumper
name = f"{info.name.title()}{base.__name__}"
attribs = {
"oid": info.array_oid,
"element_oid": info.oid,
}
dumper = type(name, (base,), attribs)
adapters.register_dumper(None, dumper)
def register_default_adapters(context: AdaptContext) -> None:
# The text dumper is more flexible as it can handle lists of mixed type,
# so register it later.
context.adapters.register_dumper(list, ListBinaryDumper)
context.adapters.register_dumper(list, ListDumper)
def register_all_arrays(context: AdaptContext) -> None:
"""
Associate the array oid of all the types in Loader.globals.
This function is designed to be called once at import time, after having
registered all the base loaders.
"""
for t in context.adapters.types:
if t.array_oid:
t.register(context)
def _load_text(
data: Buffer,
loader: Loader,
delimiter: bytes = b",",
__re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"),
) -> List[Any]:
rv = None
stack: List[Any] = []
a: List[Any] = []
rv = a
load = loader.load
# Remove the dimensions information prefix (``[...]=``)
if data and data[0] == b"["[0]:
if isinstance(data, memoryview):
data = bytes(data)
idx = data.find(b"=")
if idx == -1:
raise e.DataError("malformed array: no '=' after dimension information")
data = data[idx + 1 :]
re_parse = _get_array_parse_regexp(delimiter)
for m in re_parse.finditer(data):
t = m.group(1)
if t == b"{":
if stack:
stack[-1].append(a)
stack.append(a)
a = []
elif t == b"}":
if not stack:
raise e.DataError("malformed array: unexpected '}'")
rv = stack.pop()
else:
if not stack:
wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
raise e.DataError(f"malformed array: unexpected '{wat}'")
if t == b"NULL":
v = None
else:
if t.startswith(b'"'):
t = __re_unescape.sub(rb"\1", t[1:-1])
v = load(t)
stack[-1].append(v)
assert rv is not None
return rv
@cache
def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
"""
Return a regexp to tokenize an array representation into item and brackets
"""
return re.compile(
rb"""(?xi)
( [{}] # open or closed bracket
| " (?: [^"\\] | \\. )* " # or a quoted string
| [^"{}%s\\]+ # or an unquoted non-empty string
) ,?
"""
% delimiter
)
def _load_binary(data: Buffer, tx: Transformer) -> List[Any]:
ndims, hasnull, oid = _unpack_head(data)
load = tx.get_loader(oid, PQ_BINARY).load
if not ndims:
return []
p = 12 + 8 * ndims
dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)]
nelems = prod(dims)
out: List[Any] = [None] * nelems
for i in range(nelems):
size = unpack_len(data, p)[0]
p += 4
if size == -1:
continue
out[i] = load(data[p : p + size])
p += size
# fon ndims > 1 we have to aggregate the array into sub-arrays
for dim in dims[-1:0:-1]:
out = [out[i : i + dim] for i in range(0, len(out), dim)]
return out