496 lines
13 KiB
Python
496 lines
13 KiB
Python
|
"""
|
||
|
Adapers for numeric types.
|
||
|
"""
|
||
|
|
||
|
# Copyright (C) 2020 The Psycopg Team
|
||
|
|
||
|
import struct
|
||
|
from math import log
|
||
|
from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast
|
||
|
from decimal import Decimal, DefaultContext, Context
|
||
|
|
||
|
from .. import postgres
|
||
|
from .. import errors as e
|
||
|
from ..pq import Format
|
||
|
from ..abc import AdaptContext
|
||
|
from ..adapt import Buffer, Dumper, Loader, PyFormat
|
||
|
from .._struct import pack_int2, pack_uint2, unpack_int2
|
||
|
from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4
|
||
|
from .._struct import pack_int8, unpack_int8
|
||
|
from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8
|
||
|
|
||
|
# Exposed here
|
||
|
from .._wrappers import (
|
||
|
Int2 as Int2,
|
||
|
Int4 as Int4,
|
||
|
Int8 as Int8,
|
||
|
IntNumeric as IntNumeric,
|
||
|
Oid as Oid,
|
||
|
Float4 as Float4,
|
||
|
Float8 as Float8,
|
||
|
)
|
||
|
|
||
|
|
||
|
class _IntDumper(Dumper):
|
||
|
def dump(self, obj: Any) -> Buffer:
|
||
|
t = type(obj)
|
||
|
if t is not int:
|
||
|
# Convert to int in order to dump IntEnum correctly
|
||
|
if issubclass(t, int):
|
||
|
obj = int(obj)
|
||
|
else:
|
||
|
raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
|
||
|
|
||
|
return str(obj).encode()
|
||
|
|
||
|
def quote(self, obj: Any) -> Buffer:
|
||
|
value = self.dump(obj)
|
||
|
return value if obj >= 0 else b" " + value
|
||
|
|
||
|
|
||
|
class _SpecialValuesDumper(Dumper):
|
||
|
_special: Dict[bytes, bytes] = {}
|
||
|
|
||
|
def dump(self, obj: Any) -> bytes:
|
||
|
return str(obj).encode()
|
||
|
|
||
|
def quote(self, obj: Any) -> bytes:
|
||
|
value = self.dump(obj)
|
||
|
|
||
|
if value in self._special:
|
||
|
return self._special[value]
|
||
|
|
||
|
return value if obj >= 0 else b" " + value
|
||
|
|
||
|
|
||
|
class FloatDumper(_SpecialValuesDumper):
|
||
|
oid = postgres.types["float8"].oid
|
||
|
|
||
|
_special = {
|
||
|
b"inf": b"'Infinity'::float8",
|
||
|
b"-inf": b"'-Infinity'::float8",
|
||
|
b"nan": b"'NaN'::float8",
|
||
|
}
|
||
|
|
||
|
|
||
|
class Float4Dumper(FloatDumper):
|
||
|
oid = postgres.types["float4"].oid
|
||
|
|
||
|
|
||
|
class FloatBinaryDumper(Dumper):
|
||
|
format = Format.BINARY
|
||
|
oid = postgres.types["float8"].oid
|
||
|
|
||
|
def dump(self, obj: float) -> bytes:
|
||
|
return pack_float8(obj)
|
||
|
|
||
|
|
||
|
class Float4BinaryDumper(FloatBinaryDumper):
|
||
|
oid = postgres.types["float4"].oid
|
||
|
|
||
|
def dump(self, obj: float) -> bytes:
|
||
|
return pack_float4(obj)
|
||
|
|
||
|
|
||
|
class DecimalDumper(_SpecialValuesDumper):
|
||
|
oid = postgres.types["numeric"].oid
|
||
|
|
||
|
def dump(self, obj: Decimal) -> bytes:
|
||
|
if obj.is_nan():
|
||
|
# cover NaN and sNaN
|
||
|
return b"NaN"
|
||
|
else:
|
||
|
return str(obj).encode()
|
||
|
|
||
|
_special = {
|
||
|
b"Infinity": b"'Infinity'::numeric",
|
||
|
b"-Infinity": b"'-Infinity'::numeric",
|
||
|
b"NaN": b"'NaN'::numeric",
|
||
|
}
|
||
|
|
||
|
|
||
|
class Int2Dumper(_IntDumper):
|
||
|
oid = postgres.types["int2"].oid
|
||
|
|
||
|
|
||
|
class Int4Dumper(_IntDumper):
|
||
|
oid = postgres.types["int4"].oid
|
||
|
|
||
|
|
||
|
class Int8Dumper(_IntDumper):
|
||
|
oid = postgres.types["int8"].oid
|
||
|
|
||
|
|
||
|
class IntNumericDumper(_IntDumper):
|
||
|
oid = postgres.types["numeric"].oid
|
||
|
|
||
|
|
||
|
class OidDumper(_IntDumper):
|
||
|
oid = postgres.types["oid"].oid
|
||
|
|
||
|
|
||
|
class IntDumper(Dumper):
|
||
|
def dump(self, obj: Any) -> bytes:
|
||
|
raise TypeError(
|
||
|
f"{type(self).__name__} is a dispatcher to other dumpers:"
|
||
|
" dump() is not supposed to be called"
|
||
|
)
|
||
|
|
||
|
def get_key(self, obj: int, format: PyFormat) -> type:
|
||
|
return self.upgrade(obj, format).cls
|
||
|
|
||
|
_int2_dumper = Int2Dumper(Int2)
|
||
|
_int4_dumper = Int4Dumper(Int4)
|
||
|
_int8_dumper = Int8Dumper(Int8)
|
||
|
_int_numeric_dumper = IntNumericDumper(IntNumeric)
|
||
|
|
||
|
def upgrade(self, obj: int, format: PyFormat) -> Dumper:
|
||
|
if -(2**31) <= obj < 2**31:
|
||
|
if -(2**15) <= obj < 2**15:
|
||
|
return self._int2_dumper
|
||
|
else:
|
||
|
return self._int4_dumper
|
||
|
else:
|
||
|
if -(2**63) <= obj < 2**63:
|
||
|
return self._int8_dumper
|
||
|
else:
|
||
|
return self._int_numeric_dumper
|
||
|
|
||
|
|
||
|
class Int2BinaryDumper(Int2Dumper):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def dump(self, obj: int) -> bytes:
|
||
|
return pack_int2(obj)
|
||
|
|
||
|
|
||
|
class Int4BinaryDumper(Int4Dumper):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def dump(self, obj: int) -> bytes:
|
||
|
return pack_int4(obj)
|
||
|
|
||
|
|
||
|
class Int8BinaryDumper(Int8Dumper):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def dump(self, obj: int) -> bytes:
|
||
|
return pack_int8(obj)
|
||
|
|
||
|
|
||
|
# Ratio between number of bits required to store a number and number of pg
|
||
|
# decimal digits required.
|
||
|
BIT_PER_PGDIGIT = log(2) / log(10_000)
|
||
|
|
||
|
|
||
|
class IntNumericBinaryDumper(IntNumericDumper):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def dump(self, obj: int) -> Buffer:
|
||
|
return dump_int_to_numeric_binary(obj)
|
||
|
|
||
|
|
||
|
class OidBinaryDumper(OidDumper):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def dump(self, obj: int) -> bytes:
|
||
|
return pack_uint4(obj)
|
||
|
|
||
|
|
||
|
class IntBinaryDumper(IntDumper):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
_int2_dumper = Int2BinaryDumper(Int2)
|
||
|
_int4_dumper = Int4BinaryDumper(Int4)
|
||
|
_int8_dumper = Int8BinaryDumper(Int8)
|
||
|
_int_numeric_dumper = IntNumericBinaryDumper(IntNumeric)
|
||
|
|
||
|
|
||
|
class IntLoader(Loader):
|
||
|
def load(self, data: Buffer) -> int:
|
||
|
# it supports bytes directly
|
||
|
return int(data)
|
||
|
|
||
|
|
||
|
class Int2BinaryLoader(Loader):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def load(self, data: Buffer) -> int:
|
||
|
return unpack_int2(data)[0]
|
||
|
|
||
|
|
||
|
class Int4BinaryLoader(Loader):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def load(self, data: Buffer) -> int:
|
||
|
return unpack_int4(data)[0]
|
||
|
|
||
|
|
||
|
class Int8BinaryLoader(Loader):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def load(self, data: Buffer) -> int:
|
||
|
return unpack_int8(data)[0]
|
||
|
|
||
|
|
||
|
class OidBinaryLoader(Loader):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def load(self, data: Buffer) -> int:
|
||
|
return unpack_uint4(data)[0]
|
||
|
|
||
|
|
||
|
class FloatLoader(Loader):
|
||
|
def load(self, data: Buffer) -> float:
|
||
|
# it supports bytes directly
|
||
|
return float(data)
|
||
|
|
||
|
|
||
|
class Float4BinaryLoader(Loader):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def load(self, data: Buffer) -> float:
|
||
|
return unpack_float4(data)[0]
|
||
|
|
||
|
|
||
|
class Float8BinaryLoader(Loader):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def load(self, data: Buffer) -> float:
|
||
|
return unpack_float8(data)[0]
|
||
|
|
||
|
|
||
|
class NumericLoader(Loader):
|
||
|
def load(self, data: Buffer) -> Decimal:
|
||
|
if isinstance(data, memoryview):
|
||
|
data = bytes(data)
|
||
|
return Decimal(data.decode())
|
||
|
|
||
|
|
||
|
DEC_DIGITS = 4 # decimal digits per Postgres "digit"
|
||
|
NUMERIC_POS = 0x0000
|
||
|
NUMERIC_NEG = 0x4000
|
||
|
NUMERIC_NAN = 0xC000
|
||
|
NUMERIC_PINF = 0xD000
|
||
|
NUMERIC_NINF = 0xF000
|
||
|
|
||
|
_decimal_special = {
|
||
|
NUMERIC_NAN: Decimal("NaN"),
|
||
|
NUMERIC_PINF: Decimal("Infinity"),
|
||
|
NUMERIC_NINF: Decimal("-Infinity"),
|
||
|
}
|
||
|
|
||
|
|
||
|
class _ContextMap(DefaultDict[int, Context]):
|
||
|
"""
|
||
|
Cache for decimal contexts to use when the precision requires it.
|
||
|
|
||
|
Note: if the default context is used (prec=28) you can get an invalid
|
||
|
operation or a rounding to 0:
|
||
|
|
||
|
- Decimal(1000).shift(24) = Decimal('1000000000000000000000000000')
|
||
|
- Decimal(1000).shift(25) = Decimal('0')
|
||
|
- Decimal(1000).shift(30) raises InvalidOperation
|
||
|
"""
|
||
|
|
||
|
def __missing__(self, key: int) -> Context:
|
||
|
val = Context(prec=key)
|
||
|
self[key] = val
|
||
|
return val
|
||
|
|
||
|
|
||
|
_contexts = _ContextMap()
|
||
|
for i in range(DefaultContext.prec):
|
||
|
_contexts[i] = DefaultContext
|
||
|
|
||
|
_unpack_numeric_head = cast(
|
||
|
Callable[[Buffer], Tuple[int, int, int, int]],
|
||
|
struct.Struct("!HhHH").unpack_from,
|
||
|
)
|
||
|
_pack_numeric_head = cast(
|
||
|
Callable[[int, int, int, int], bytes],
|
||
|
struct.Struct("!HhHH").pack,
|
||
|
)
|
||
|
|
||
|
|
||
|
class NumericBinaryLoader(Loader):
|
||
|
format = Format.BINARY
|
||
|
|
||
|
def load(self, data: Buffer) -> Decimal:
|
||
|
ndigits, weight, sign, dscale = _unpack_numeric_head(data)
|
||
|
if sign == NUMERIC_POS or sign == NUMERIC_NEG:
|
||
|
val = 0
|
||
|
for i in range(8, len(data), 2):
|
||
|
val = val * 10_000 + data[i] * 0x100 + data[i + 1]
|
||
|
|
||
|
shift = dscale - (ndigits - weight - 1) * DEC_DIGITS
|
||
|
ctx = _contexts[(weight + 2) * DEC_DIGITS + dscale]
|
||
|
return (
|
||
|
Decimal(val if sign == NUMERIC_POS else -val)
|
||
|
.scaleb(-dscale, ctx)
|
||
|
.shift(shift, ctx)
|
||
|
)
|
||
|
else:
|
||
|
try:
|
||
|
return _decimal_special[sign]
|
||
|
except KeyError:
|
||
|
raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None
|
||
|
|
||
|
|
||
|
NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0)
|
||
|
NUMERIC_PINF_BIN = _pack_numeric_head(0, 0, NUMERIC_PINF, 0)
|
||
|
NUMERIC_NINF_BIN = _pack_numeric_head(0, 0, NUMERIC_NINF, 0)
|
||
|
|
||
|
|
||
|
class DecimalBinaryDumper(Dumper):
|
||
|
format = Format.BINARY
|
||
|
oid = postgres.types["numeric"].oid
|
||
|
|
||
|
def dump(self, obj: Decimal) -> Buffer:
|
||
|
return dump_decimal_to_numeric_binary(obj)
|
||
|
|
||
|
|
||
|
class NumericDumper(DecimalDumper):
|
||
|
def dump(self, obj: Union[Decimal, int]) -> bytes:
|
||
|
if isinstance(obj, int):
|
||
|
return str(obj).encode()
|
||
|
else:
|
||
|
return super().dump(obj)
|
||
|
|
||
|
|
||
|
class NumericBinaryDumper(Dumper):
|
||
|
format = Format.BINARY
|
||
|
oid = postgres.types["numeric"].oid
|
||
|
|
||
|
def dump(self, obj: Union[Decimal, int]) -> Buffer:
|
||
|
if isinstance(obj, int):
|
||
|
return dump_int_to_numeric_binary(obj)
|
||
|
else:
|
||
|
return dump_decimal_to_numeric_binary(obj)
|
||
|
|
||
|
|
||
|
def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]:
|
||
|
sign, digits, exp = obj.as_tuple()
|
||
|
if exp == "n" or exp == "N":
|
||
|
return NUMERIC_NAN_BIN
|
||
|
elif exp == "F":
|
||
|
return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN
|
||
|
|
||
|
# Weights of py digits into a pg digit according to their positions.
|
||
|
# Starting with an index wi != 0 is equivalent to prepending 0's to
|
||
|
# the digits tuple, but without really changing it.
|
||
|
weights = (1000, 100, 10, 1)
|
||
|
wi = 0
|
||
|
|
||
|
ndigits = nzdigits = len(digits)
|
||
|
|
||
|
# Find the last nonzero digit
|
||
|
while nzdigits > 0 and digits[nzdigits - 1] == 0:
|
||
|
nzdigits -= 1
|
||
|
|
||
|
if exp <= 0:
|
||
|
dscale = -exp
|
||
|
else:
|
||
|
dscale = 0
|
||
|
# align the py digits to the pg digits if there's some py exponent
|
||
|
ndigits += exp % DEC_DIGITS
|
||
|
|
||
|
if not nzdigits:
|
||
|
return _pack_numeric_head(0, 0, NUMERIC_POS, dscale)
|
||
|
|
||
|
# Equivalent of 0-padding left to align the py digits to the pg digits
|
||
|
# but without changing the digits tuple.
|
||
|
mod = (ndigits - dscale) % DEC_DIGITS
|
||
|
if mod:
|
||
|
wi = DEC_DIGITS - mod
|
||
|
ndigits += wi
|
||
|
|
||
|
tmp = nzdigits + wi
|
||
|
out = bytearray(
|
||
|
_pack_numeric_head(
|
||
|
tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1), # ndigits
|
||
|
(ndigits + exp) // DEC_DIGITS - 1, # weight
|
||
|
NUMERIC_NEG if sign else NUMERIC_POS, # sign
|
||
|
dscale,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
pgdigit = 0
|
||
|
for i in range(nzdigits):
|
||
|
pgdigit += weights[wi] * digits[i]
|
||
|
wi += 1
|
||
|
if wi >= DEC_DIGITS:
|
||
|
out += pack_uint2(pgdigit)
|
||
|
pgdigit = wi = 0
|
||
|
|
||
|
if pgdigit:
|
||
|
out += pack_uint2(pgdigit)
|
||
|
|
||
|
return out
|
||
|
|
||
|
|
||
|
def dump_int_to_numeric_binary(obj: int) -> bytearray:
|
||
|
ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1
|
||
|
out = bytearray(b"\x00\x00" * (ndigits + 4))
|
||
|
if obj < 0:
|
||
|
sign = NUMERIC_NEG
|
||
|
obj = -obj
|
||
|
else:
|
||
|
sign = NUMERIC_POS
|
||
|
|
||
|
out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0)
|
||
|
i = 8 + (ndigits - 1) * 2
|
||
|
while obj:
|
||
|
rem = obj % 10_000
|
||
|
obj //= 10_000
|
||
|
out[i : i + 2] = pack_uint2(rem)
|
||
|
i -= 2
|
||
|
|
||
|
return out
|
||
|
|
||
|
|
||
|
def register_default_adapters(context: AdaptContext) -> None:
|
||
|
adapters = context.adapters
|
||
|
adapters.register_dumper(int, IntDumper)
|
||
|
adapters.register_dumper(int, IntBinaryDumper)
|
||
|
adapters.register_dumper(float, FloatDumper)
|
||
|
adapters.register_dumper(float, FloatBinaryDumper)
|
||
|
adapters.register_dumper(Int2, Int2Dumper)
|
||
|
adapters.register_dumper(Int4, Int4Dumper)
|
||
|
adapters.register_dumper(Int8, Int8Dumper)
|
||
|
adapters.register_dumper(IntNumeric, IntNumericDumper)
|
||
|
adapters.register_dumper(Oid, OidDumper)
|
||
|
|
||
|
# The binary dumper is currently some 30% slower, so default to text
|
||
|
# (see tests/scripts/testdec.py for a rough benchmark)
|
||
|
# Also, must be after IntNumericDumper
|
||
|
adapters.register_dumper("decimal.Decimal", DecimalBinaryDumper)
|
||
|
adapters.register_dumper("decimal.Decimal", DecimalDumper)
|
||
|
|
||
|
# Used only by oid, can take both int and Decimal as input
|
||
|
adapters.register_dumper(None, NumericBinaryDumper)
|
||
|
adapters.register_dumper(None, NumericDumper)
|
||
|
|
||
|
adapters.register_dumper(Float4, Float4Dumper)
|
||
|
adapters.register_dumper(Float8, FloatDumper)
|
||
|
adapters.register_dumper(Int2, Int2BinaryDumper)
|
||
|
adapters.register_dumper(Int4, Int4BinaryDumper)
|
||
|
adapters.register_dumper(Int8, Int8BinaryDumper)
|
||
|
adapters.register_dumper(Oid, OidBinaryDumper)
|
||
|
adapters.register_dumper(Float4, Float4BinaryDumper)
|
||
|
adapters.register_dumper(Float8, FloatBinaryDumper)
|
||
|
adapters.register_loader("int2", IntLoader)
|
||
|
adapters.register_loader("int4", IntLoader)
|
||
|
adapters.register_loader("int8", IntLoader)
|
||
|
adapters.register_loader("oid", IntLoader)
|
||
|
adapters.register_loader("int2", Int2BinaryLoader)
|
||
|
adapters.register_loader("int4", Int4BinaryLoader)
|
||
|
adapters.register_loader("int8", Int8BinaryLoader)
|
||
|
adapters.register_loader("oid", OidBinaryLoader)
|
||
|
adapters.register_loader("float4", FloatLoader)
|
||
|
adapters.register_loader("float8", FloatLoader)
|
||
|
adapters.register_loader("float4", Float4BinaryLoader)
|
||
|
adapters.register_loader("float8", Float8BinaryLoader)
|
||
|
adapters.register_loader("numeric", NumericLoader)
|
||
|
adapters.register_loader("numeric", NumericBinaryLoader)
|