461 lines
14 KiB
Python
461 lines
14 KiB
Python
"""
|
|
Adapters for arrays
|
|
"""
|
|
|
|
# Copyright (C) 2020 The Psycopg Team
|
|
|
|
import re
|
|
import struct
|
|
from typing import Any, cast, Callable, List, Optional, Pattern, Set, Tuple, Type
|
|
|
|
from .. import pq
|
|
from .. import errors as e
|
|
from .. import postgres
|
|
from ..abc import AdaptContext, Buffer, Dumper, DumperKey, NoneType, Loader, Transformer
|
|
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
|
|
from .._compat import cache, prod
|
|
from .._struct import pack_len, unpack_len
|
|
from .._cmodule import _psycopg
|
|
from ..postgres import TEXT_OID, INVALID_OID
|
|
from .._typeinfo import TypeInfo
|
|
|
|
_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid
|
|
_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
|
|
_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from)
|
|
_struct_dim = struct.Struct("!II") # dim, lower bound
|
|
_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
|
|
_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from)
|
|
|
|
TEXT_ARRAY_OID = postgres.types["text"].array_oid
|
|
|
|
PY_TEXT = PyFormat.TEXT
|
|
PQ_BINARY = pq.Format.BINARY
|
|
|
|
|
|
class BaseListDumper(RecursiveDumper):
|
|
element_oid = 0
|
|
|
|
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
|
|
if cls is NoneType:
|
|
cls = list
|
|
|
|
super().__init__(cls, context)
|
|
self.sub_dumper: Optional[Dumper] = None
|
|
if self.element_oid and context:
|
|
sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format)
|
|
self.sub_dumper = sdclass(NoneType, context)
|
|
|
|
def _find_list_element(self, L: List[Any], format: PyFormat) -> Any:
|
|
"""
|
|
Find the first non-null element of an eventually nested list
|
|
"""
|
|
items = list(self._flatiter(L, set()))
|
|
types = {type(item): item for item in items}
|
|
if not types:
|
|
return None
|
|
|
|
if len(types) == 1:
|
|
t, v = types.popitem()
|
|
else:
|
|
# More than one type in the list. It might be still good, as long
|
|
# as they dump with the same oid (e.g. IPv4Network, IPv6Network).
|
|
dumpers = [self._tx.get_dumper(item, format) for item in types.values()]
|
|
oids = set(d.oid for d in dumpers)
|
|
if len(oids) == 1:
|
|
t, v = types.popitem()
|
|
else:
|
|
raise e.DataError(
|
|
"cannot dump lists of mixed types;"
|
|
f" got: {', '.join(sorted(t.__name__ for t in types))}"
|
|
)
|
|
|
|
# Checking for precise type. If the type is a subclass (e.g. Int4)
|
|
# we assume the user knows what type they are passing.
|
|
if t is not int:
|
|
return v
|
|
|
|
# If we got an int, let's see what is the biggest one in order to
|
|
# choose the smallest OID and allow Postgres to do the right cast.
|
|
imax: int = max(items)
|
|
imin: int = min(items)
|
|
if imin >= 0:
|
|
return imax
|
|
else:
|
|
return max(imax, -imin - 1)
|
|
|
|
def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
|
|
if id(L) in seen:
|
|
raise e.DataError("cannot dump a recursive list")
|
|
|
|
seen.add(id(L))
|
|
|
|
for item in L:
|
|
if type(item) is list:
|
|
yield from self._flatiter(item, seen)
|
|
elif item is not None:
|
|
yield item
|
|
|
|
return None
|
|
|
|
def _get_base_type_info(self, base_oid: int) -> TypeInfo:
|
|
"""
|
|
Return info about the base type.
|
|
|
|
Return text info as fallback.
|
|
"""
|
|
if base_oid:
|
|
info = self._tx.adapters.types.get(base_oid)
|
|
if info:
|
|
return info
|
|
|
|
return self._tx.adapters.types["text"]
|
|
|
|
|
|
class ListDumper(BaseListDumper):
|
|
delimiter = b","
|
|
|
|
def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
|
|
if self.oid:
|
|
return self.cls
|
|
|
|
item = self._find_list_element(obj, format)
|
|
if item is None:
|
|
return self.cls
|
|
|
|
sd = self._tx.get_dumper(item, format)
|
|
return (self.cls, sd.get_key(item, format))
|
|
|
|
def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
|
|
# If we have an oid we don't need to upgrade
|
|
if self.oid:
|
|
return self
|
|
|
|
item = self._find_list_element(obj, format)
|
|
if item is None:
|
|
# Empty lists can only be dumped as text if the type is unknown.
|
|
return self
|
|
|
|
sd = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
|
|
dumper = type(self)(self.cls, self._tx)
|
|
dumper.sub_dumper = sd
|
|
|
|
# We consider an array of unknowns as unknown, so we can dump empty
|
|
# lists or lists containing only None elements.
|
|
if sd.oid != INVALID_OID:
|
|
info = self._get_base_type_info(sd.oid)
|
|
dumper.oid = info.array_oid or TEXT_ARRAY_OID
|
|
dumper.delimiter = info.delimiter.encode()
|
|
else:
|
|
dumper.oid = INVALID_OID
|
|
|
|
return dumper
|
|
|
|
# Double quotes and backslashes embedded in element values will be
|
|
# backslash-escaped.
|
|
_re_esc = re.compile(rb'(["\\])')
|
|
|
|
def dump(self, obj: List[Any]) -> bytes:
|
|
tokens: List[Buffer] = []
|
|
needs_quotes = _get_needs_quotes_regexp(self.delimiter).search
|
|
|
|
def dump_list(obj: List[Any]) -> None:
|
|
if not obj:
|
|
tokens.append(b"{}")
|
|
return
|
|
|
|
tokens.append(b"{")
|
|
for item in obj:
|
|
if isinstance(item, list):
|
|
dump_list(item)
|
|
elif item is not None:
|
|
ad = self._dump_item(item)
|
|
if needs_quotes(ad):
|
|
if not isinstance(ad, bytes):
|
|
ad = bytes(ad)
|
|
ad = b'"' + self._re_esc.sub(rb"\\\1", ad) + b'"'
|
|
tokens.append(ad)
|
|
else:
|
|
tokens.append(b"NULL")
|
|
|
|
tokens.append(self.delimiter)
|
|
|
|
tokens[-1] = b"}"
|
|
|
|
dump_list(obj)
|
|
|
|
return b"".join(tokens)
|
|
|
|
def _dump_item(self, item: Any) -> Buffer:
|
|
if self.sub_dumper:
|
|
return self.sub_dumper.dump(item)
|
|
else:
|
|
return self._tx.get_dumper(item, PY_TEXT).dump(item)
|
|
|
|
|
|
@cache
|
|
def _get_needs_quotes_regexp(delimiter: bytes) -> Pattern[bytes]:
|
|
"""Return a regexp to recognise when a value needs quotes
|
|
|
|
from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
|
|
|
|
The array output routine will put double quotes around element values if
|
|
they are empty strings, contain curly braces, delimiter characters,
|
|
double quotes, backslashes, or white space, or match the word NULL.
|
|
"""
|
|
return re.compile(
|
|
rb"""(?xi)
|
|
^$ # the empty string
|
|
| ["{}%s\\\s] # or a char to escape
|
|
| ^null$ # or the word NULL
|
|
"""
|
|
% delimiter
|
|
)
|
|
|
|
|
|
class ListBinaryDumper(BaseListDumper):
|
|
format = pq.Format.BINARY
|
|
|
|
def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
|
|
if self.oid:
|
|
return self.cls
|
|
|
|
item = self._find_list_element(obj, format)
|
|
if item is None:
|
|
return (self.cls,)
|
|
|
|
sd = self._tx.get_dumper(item, format)
|
|
return (self.cls, sd.get_key(item, format))
|
|
|
|
def upgrade(self, obj: List[Any], format: PyFormat) -> "BaseListDumper":
|
|
# If we have an oid we don't need to upgrade
|
|
if self.oid:
|
|
return self
|
|
|
|
item = self._find_list_element(obj, format)
|
|
if item is None:
|
|
return ListDumper(self.cls, self._tx)
|
|
|
|
sd = self._tx.get_dumper(item, format.from_pq(self.format))
|
|
dumper = type(self)(self.cls, self._tx)
|
|
dumper.sub_dumper = sd
|
|
info = self._get_base_type_info(sd.oid)
|
|
dumper.oid = info.array_oid or TEXT_ARRAY_OID
|
|
|
|
return dumper
|
|
|
|
def dump(self, obj: List[Any]) -> bytes:
|
|
# Postgres won't take unknown for element oid: fall back on text
|
|
sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID
|
|
|
|
if not obj:
|
|
return _pack_head(0, 0, sub_oid)
|
|
|
|
data: List[Buffer] = [b"", b""] # placeholders to avoid a resize
|
|
dims: List[int] = []
|
|
hasnull = 0
|
|
|
|
def calc_dims(L: List[Any]) -> None:
|
|
if isinstance(L, self.cls):
|
|
if not L:
|
|
raise e.DataError("lists cannot contain empty lists")
|
|
dims.append(len(L))
|
|
calc_dims(L[0])
|
|
|
|
calc_dims(obj)
|
|
|
|
def dump_list(L: List[Any], dim: int) -> None:
|
|
nonlocal hasnull
|
|
if len(L) != dims[dim]:
|
|
raise e.DataError("nested lists have inconsistent lengths")
|
|
|
|
if dim == len(dims) - 1:
|
|
for item in L:
|
|
if item is not None:
|
|
# If we get here, the sub_dumper must have been set
|
|
ad = self.sub_dumper.dump(item) # type: ignore[union-attr]
|
|
data.append(pack_len(len(ad)))
|
|
data.append(ad)
|
|
else:
|
|
hasnull = 1
|
|
data.append(b"\xff\xff\xff\xff")
|
|
else:
|
|
for item in L:
|
|
if not isinstance(item, self.cls):
|
|
raise e.DataError("nested lists have inconsistent depths")
|
|
dump_list(item, dim + 1) # type: ignore
|
|
|
|
dump_list(obj, 0)
|
|
|
|
data[0] = _pack_head(len(dims), hasnull, sub_oid)
|
|
data[1] = b"".join(_pack_dim(dim, 1) for dim in dims)
|
|
return b"".join(data)
|
|
|
|
|
|
class ArrayLoader(RecursiveLoader):
|
|
delimiter = b","
|
|
base_oid: int
|
|
|
|
def load(self, data: Buffer) -> List[Any]:
|
|
loader = self._tx.get_loader(self.base_oid, self.format)
|
|
return _load_text(data, loader, self.delimiter)
|
|
|
|
|
|
class ArrayBinaryLoader(RecursiveLoader):
|
|
format = pq.Format.BINARY
|
|
|
|
def load(self, data: Buffer) -> List[Any]:
|
|
return _load_binary(data, self._tx)
|
|
|
|
|
|
def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
|
|
if not info.array_oid:
|
|
raise ValueError(f"the type info {info} doesn't describe an array")
|
|
|
|
base: Type[Any]
|
|
adapters = context.adapters if context else postgres.adapters
|
|
|
|
base = getattr(_psycopg, "ArrayLoader", ArrayLoader)
|
|
name = f"{info.name.title()}{base.__name__}"
|
|
attribs = {
|
|
"base_oid": info.oid,
|
|
"delimiter": info.delimiter.encode(),
|
|
}
|
|
loader = type(name, (base,), attribs)
|
|
adapters.register_loader(info.array_oid, loader)
|
|
|
|
loader = getattr(_psycopg, "ArrayBinaryLoader", ArrayBinaryLoader)
|
|
adapters.register_loader(info.array_oid, loader)
|
|
|
|
base = ListDumper
|
|
name = f"{info.name.title()}{base.__name__}"
|
|
attribs = {
|
|
"oid": info.array_oid,
|
|
"element_oid": info.oid,
|
|
"delimiter": info.delimiter.encode(),
|
|
}
|
|
dumper = type(name, (base,), attribs)
|
|
adapters.register_dumper(None, dumper)
|
|
|
|
base = ListBinaryDumper
|
|
name = f"{info.name.title()}{base.__name__}"
|
|
attribs = {
|
|
"oid": info.array_oid,
|
|
"element_oid": info.oid,
|
|
}
|
|
dumper = type(name, (base,), attribs)
|
|
adapters.register_dumper(None, dumper)
|
|
|
|
|
|
def register_default_adapters(context: AdaptContext) -> None:
|
|
# The text dumper is more flexible as it can handle lists of mixed type,
|
|
# so register it later.
|
|
context.adapters.register_dumper(list, ListBinaryDumper)
|
|
context.adapters.register_dumper(list, ListDumper)
|
|
|
|
|
|
def register_all_arrays(context: AdaptContext) -> None:
|
|
"""
|
|
Associate the array oid of all the types in Loader.globals.
|
|
|
|
This function is designed to be called once at import time, after having
|
|
registered all the base loaders.
|
|
"""
|
|
for t in context.adapters.types:
|
|
if t.array_oid:
|
|
t.register(context)
|
|
|
|
|
|
def _load_text(
|
|
data: Buffer,
|
|
loader: Loader,
|
|
delimiter: bytes = b",",
|
|
__re_unescape: Pattern[bytes] = re.compile(rb"\\(.)"),
|
|
) -> List[Any]:
|
|
rv = None
|
|
stack: List[Any] = []
|
|
a: List[Any] = []
|
|
rv = a
|
|
load = loader.load
|
|
|
|
# Remove the dimensions information prefix (``[...]=``)
|
|
if data and data[0] == b"["[0]:
|
|
if isinstance(data, memoryview):
|
|
data = bytes(data)
|
|
idx = data.find(b"=")
|
|
if idx == -1:
|
|
raise e.DataError("malformed array: no '=' after dimension information")
|
|
data = data[idx + 1 :]
|
|
|
|
re_parse = _get_array_parse_regexp(delimiter)
|
|
for m in re_parse.finditer(data):
|
|
t = m.group(1)
|
|
if t == b"{":
|
|
if stack:
|
|
stack[-1].append(a)
|
|
stack.append(a)
|
|
a = []
|
|
|
|
elif t == b"}":
|
|
if not stack:
|
|
raise e.DataError("malformed array: unexpected '}'")
|
|
rv = stack.pop()
|
|
|
|
else:
|
|
if not stack:
|
|
wat = t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
|
|
raise e.DataError(f"malformed array: unexpected '{wat}'")
|
|
if t == b"NULL":
|
|
v = None
|
|
else:
|
|
if t.startswith(b'"'):
|
|
t = __re_unescape.sub(rb"\1", t[1:-1])
|
|
v = load(t)
|
|
|
|
stack[-1].append(v)
|
|
|
|
assert rv is not None
|
|
return rv
|
|
|
|
|
|
@cache
|
|
def _get_array_parse_regexp(delimiter: bytes) -> Pattern[bytes]:
|
|
"""
|
|
Return a regexp to tokenize an array representation into item and brackets
|
|
"""
|
|
return re.compile(
|
|
rb"""(?xi)
|
|
( [{}] # open or closed bracket
|
|
| " (?: [^"\\] | \\. )* " # or a quoted string
|
|
| [^"{}%s\\]+ # or an unquoted non-empty string
|
|
) ,?
|
|
"""
|
|
% delimiter
|
|
)
|
|
|
|
|
|
def _load_binary(data: Buffer, tx: Transformer) -> List[Any]:
|
|
ndims, hasnull, oid = _unpack_head(data)
|
|
load = tx.get_loader(oid, PQ_BINARY).load
|
|
|
|
if not ndims:
|
|
return []
|
|
|
|
p = 12 + 8 * ndims
|
|
dims = [_unpack_dim(data, i)[0] for i in range(12, p, 8)]
|
|
nelems = prod(dims)
|
|
|
|
out: List[Any] = [None] * nelems
|
|
for i in range(nelems):
|
|
size = unpack_len(data, p)[0]
|
|
p += 4
|
|
if size == -1:
|
|
continue
|
|
out[i] = load(data[p : p + size])
|
|
p += size
|
|
|
|
# fon ndims > 1 we have to aggregate the array into sub-arrays
|
|
for dim in dims[-1:0:-1]:
|
|
out = [out[i : i + dim] for i in range(0, len(out), dim)]
|
|
|
|
return out
|