mysteriendrama/lib/python3.11/site-packages/psycopg/types/numeric.py
2023-07-26 21:33:29 +02:00

496 lines
13 KiB
Python

"""
Adapers for numeric types.
"""
# Copyright (C) 2020 The Psycopg Team
import struct
from math import log
from typing import Any, Callable, DefaultDict, Dict, Tuple, Union, cast
from decimal import Decimal, DefaultContext, Context
from .. import postgres
from .. import errors as e
from ..pq import Format
from ..abc import AdaptContext
from ..adapt import Buffer, Dumper, Loader, PyFormat
from .._struct import pack_int2, pack_uint2, unpack_int2
from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4
from .._struct import pack_int8, unpack_int8
from .._struct import pack_float4, pack_float8, unpack_float4, unpack_float8
# Exposed here
from .._wrappers import (
Int2 as Int2,
Int4 as Int4,
Int8 as Int8,
IntNumeric as IntNumeric,
Oid as Oid,
Float4 as Float4,
Float8 as Float8,
)
class _IntDumper(Dumper):
def dump(self, obj: Any) -> Buffer:
t = type(obj)
if t is not int:
# Convert to int in order to dump IntEnum correctly
if issubclass(t, int):
obj = int(obj)
else:
raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
return str(obj).encode()
def quote(self, obj: Any) -> Buffer:
value = self.dump(obj)
return value if obj >= 0 else b" " + value
class _SpecialValuesDumper(Dumper):
_special: Dict[bytes, bytes] = {}
def dump(self, obj: Any) -> bytes:
return str(obj).encode()
def quote(self, obj: Any) -> bytes:
value = self.dump(obj)
if value in self._special:
return self._special[value]
return value if obj >= 0 else b" " + value
class FloatDumper(_SpecialValuesDumper):
oid = postgres.types["float8"].oid
_special = {
b"inf": b"'Infinity'::float8",
b"-inf": b"'-Infinity'::float8",
b"nan": b"'NaN'::float8",
}
class Float4Dumper(FloatDumper):
oid = postgres.types["float4"].oid
class FloatBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["float8"].oid
def dump(self, obj: float) -> bytes:
return pack_float8(obj)
class Float4BinaryDumper(FloatBinaryDumper):
oid = postgres.types["float4"].oid
def dump(self, obj: float) -> bytes:
return pack_float4(obj)
class DecimalDumper(_SpecialValuesDumper):
oid = postgres.types["numeric"].oid
def dump(self, obj: Decimal) -> bytes:
if obj.is_nan():
# cover NaN and sNaN
return b"NaN"
else:
return str(obj).encode()
_special = {
b"Infinity": b"'Infinity'::numeric",
b"-Infinity": b"'-Infinity'::numeric",
b"NaN": b"'NaN'::numeric",
}
class Int2Dumper(_IntDumper):
oid = postgres.types["int2"].oid
class Int4Dumper(_IntDumper):
oid = postgres.types["int4"].oid
class Int8Dumper(_IntDumper):
oid = postgres.types["int8"].oid
class IntNumericDumper(_IntDumper):
oid = postgres.types["numeric"].oid
class OidDumper(_IntDumper):
oid = postgres.types["oid"].oid
class IntDumper(Dumper):
def dump(self, obj: Any) -> bytes:
raise TypeError(
f"{type(self).__name__} is a dispatcher to other dumpers:"
" dump() is not supposed to be called"
)
def get_key(self, obj: int, format: PyFormat) -> type:
return self.upgrade(obj, format).cls
_int2_dumper = Int2Dumper(Int2)
_int4_dumper = Int4Dumper(Int4)
_int8_dumper = Int8Dumper(Int8)
_int_numeric_dumper = IntNumericDumper(IntNumeric)
def upgrade(self, obj: int, format: PyFormat) -> Dumper:
if -(2**31) <= obj < 2**31:
if -(2**15) <= obj < 2**15:
return self._int2_dumper
else:
return self._int4_dumper
else:
if -(2**63) <= obj < 2**63:
return self._int8_dumper
else:
return self._int_numeric_dumper
class Int2BinaryDumper(Int2Dumper):
format = Format.BINARY
def dump(self, obj: int) -> bytes:
return pack_int2(obj)
class Int4BinaryDumper(Int4Dumper):
format = Format.BINARY
def dump(self, obj: int) -> bytes:
return pack_int4(obj)
class Int8BinaryDumper(Int8Dumper):
format = Format.BINARY
def dump(self, obj: int) -> bytes:
return pack_int8(obj)
# Ratio between number of bits required to store a number and number of pg
# decimal digits required.
BIT_PER_PGDIGIT = log(2) / log(10_000)
class IntNumericBinaryDumper(IntNumericDumper):
format = Format.BINARY
def dump(self, obj: int) -> Buffer:
return dump_int_to_numeric_binary(obj)
class OidBinaryDumper(OidDumper):
format = Format.BINARY
def dump(self, obj: int) -> bytes:
return pack_uint4(obj)
class IntBinaryDumper(IntDumper):
format = Format.BINARY
_int2_dumper = Int2BinaryDumper(Int2)
_int4_dumper = Int4BinaryDumper(Int4)
_int8_dumper = Int8BinaryDumper(Int8)
_int_numeric_dumper = IntNumericBinaryDumper(IntNumeric)
class IntLoader(Loader):
def load(self, data: Buffer) -> int:
# it supports bytes directly
return int(data)
class Int2BinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> int:
return unpack_int2(data)[0]
class Int4BinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> int:
return unpack_int4(data)[0]
class Int8BinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> int:
return unpack_int8(data)[0]
class OidBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> int:
return unpack_uint4(data)[0]
class FloatLoader(Loader):
def load(self, data: Buffer) -> float:
# it supports bytes directly
return float(data)
class Float4BinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> float:
return unpack_float4(data)[0]
class Float8BinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> float:
return unpack_float8(data)[0]
class NumericLoader(Loader):
def load(self, data: Buffer) -> Decimal:
if isinstance(data, memoryview):
data = bytes(data)
return Decimal(data.decode())
DEC_DIGITS = 4 # decimal digits per Postgres "digit"
NUMERIC_POS = 0x0000
NUMERIC_NEG = 0x4000
NUMERIC_NAN = 0xC000
NUMERIC_PINF = 0xD000
NUMERIC_NINF = 0xF000
_decimal_special = {
NUMERIC_NAN: Decimal("NaN"),
NUMERIC_PINF: Decimal("Infinity"),
NUMERIC_NINF: Decimal("-Infinity"),
}
class _ContextMap(DefaultDict[int, Context]):
"""
Cache for decimal contexts to use when the precision requires it.
Note: if the default context is used (prec=28) you can get an invalid
operation or a rounding to 0:
- Decimal(1000).shift(24) = Decimal('1000000000000000000000000000')
- Decimal(1000).shift(25) = Decimal('0')
- Decimal(1000).shift(30) raises InvalidOperation
"""
def __missing__(self, key: int) -> Context:
val = Context(prec=key)
self[key] = val
return val
_contexts = _ContextMap()
for i in range(DefaultContext.prec):
_contexts[i] = DefaultContext
_unpack_numeric_head = cast(
Callable[[Buffer], Tuple[int, int, int, int]],
struct.Struct("!HhHH").unpack_from,
)
_pack_numeric_head = cast(
Callable[[int, int, int, int], bytes],
struct.Struct("!HhHH").pack,
)
class NumericBinaryLoader(Loader):
format = Format.BINARY
def load(self, data: Buffer) -> Decimal:
ndigits, weight, sign, dscale = _unpack_numeric_head(data)
if sign == NUMERIC_POS or sign == NUMERIC_NEG:
val = 0
for i in range(8, len(data), 2):
val = val * 10_000 + data[i] * 0x100 + data[i + 1]
shift = dscale - (ndigits - weight - 1) * DEC_DIGITS
ctx = _contexts[(weight + 2) * DEC_DIGITS + dscale]
return (
Decimal(val if sign == NUMERIC_POS else -val)
.scaleb(-dscale, ctx)
.shift(shift, ctx)
)
else:
try:
return _decimal_special[sign]
except KeyError:
raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None
NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0)
NUMERIC_PINF_BIN = _pack_numeric_head(0, 0, NUMERIC_PINF, 0)
NUMERIC_NINF_BIN = _pack_numeric_head(0, 0, NUMERIC_NINF, 0)
class DecimalBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["numeric"].oid
def dump(self, obj: Decimal) -> Buffer:
return dump_decimal_to_numeric_binary(obj)
class NumericDumper(DecimalDumper):
def dump(self, obj: Union[Decimal, int]) -> bytes:
if isinstance(obj, int):
return str(obj).encode()
else:
return super().dump(obj)
class NumericBinaryDumper(Dumper):
format = Format.BINARY
oid = postgres.types["numeric"].oid
def dump(self, obj: Union[Decimal, int]) -> Buffer:
if isinstance(obj, int):
return dump_int_to_numeric_binary(obj)
else:
return dump_decimal_to_numeric_binary(obj)
def dump_decimal_to_numeric_binary(obj: Decimal) -> Union[bytearray, bytes]:
sign, digits, exp = obj.as_tuple()
if exp == "n" or exp == "N":
return NUMERIC_NAN_BIN
elif exp == "F":
return NUMERIC_NINF_BIN if sign else NUMERIC_PINF_BIN
# Weights of py digits into a pg digit according to their positions.
# Starting with an index wi != 0 is equivalent to prepending 0's to
# the digits tuple, but without really changing it.
weights = (1000, 100, 10, 1)
wi = 0
ndigits = nzdigits = len(digits)
# Find the last nonzero digit
while nzdigits > 0 and digits[nzdigits - 1] == 0:
nzdigits -= 1
if exp <= 0:
dscale = -exp
else:
dscale = 0
# align the py digits to the pg digits if there's some py exponent
ndigits += exp % DEC_DIGITS
if not nzdigits:
return _pack_numeric_head(0, 0, NUMERIC_POS, dscale)
# Equivalent of 0-padding left to align the py digits to the pg digits
# but without changing the digits tuple.
mod = (ndigits - dscale) % DEC_DIGITS
if mod:
wi = DEC_DIGITS - mod
ndigits += wi
tmp = nzdigits + wi
out = bytearray(
_pack_numeric_head(
tmp // DEC_DIGITS + (tmp % DEC_DIGITS and 1), # ndigits
(ndigits + exp) // DEC_DIGITS - 1, # weight
NUMERIC_NEG if sign else NUMERIC_POS, # sign
dscale,
)
)
pgdigit = 0
for i in range(nzdigits):
pgdigit += weights[wi] * digits[i]
wi += 1
if wi >= DEC_DIGITS:
out += pack_uint2(pgdigit)
pgdigit = wi = 0
if pgdigit:
out += pack_uint2(pgdigit)
return out
def dump_int_to_numeric_binary(obj: int) -> bytearray:
ndigits = int(obj.bit_length() * BIT_PER_PGDIGIT) + 1
out = bytearray(b"\x00\x00" * (ndigits + 4))
if obj < 0:
sign = NUMERIC_NEG
obj = -obj
else:
sign = NUMERIC_POS
out[:8] = _pack_numeric_head(ndigits, ndigits - 1, sign, 0)
i = 8 + (ndigits - 1) * 2
while obj:
rem = obj % 10_000
obj //= 10_000
out[i : i + 2] = pack_uint2(rem)
i -= 2
return out
def register_default_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper(int, IntDumper)
adapters.register_dumper(int, IntBinaryDumper)
adapters.register_dumper(float, FloatDumper)
adapters.register_dumper(float, FloatBinaryDumper)
adapters.register_dumper(Int2, Int2Dumper)
adapters.register_dumper(Int4, Int4Dumper)
adapters.register_dumper(Int8, Int8Dumper)
adapters.register_dumper(IntNumeric, IntNumericDumper)
adapters.register_dumper(Oid, OidDumper)
# The binary dumper is currently some 30% slower, so default to text
# (see tests/scripts/testdec.py for a rough benchmark)
# Also, must be after IntNumericDumper
adapters.register_dumper("decimal.Decimal", DecimalBinaryDumper)
adapters.register_dumper("decimal.Decimal", DecimalDumper)
# Used only by oid, can take both int and Decimal as input
adapters.register_dumper(None, NumericBinaryDumper)
adapters.register_dumper(None, NumericDumper)
adapters.register_dumper(Float4, Float4Dumper)
adapters.register_dumper(Float8, FloatDumper)
adapters.register_dumper(Int2, Int2BinaryDumper)
adapters.register_dumper(Int4, Int4BinaryDumper)
adapters.register_dumper(Int8, Int8BinaryDumper)
adapters.register_dumper(Oid, OidBinaryDumper)
adapters.register_dumper(Float4, Float4BinaryDumper)
adapters.register_dumper(Float8, FloatBinaryDumper)
adapters.register_loader("int2", IntLoader)
adapters.register_loader("int4", IntLoader)
adapters.register_loader("int8", IntLoader)
adapters.register_loader("oid", IntLoader)
adapters.register_loader("int2", Int2BinaryLoader)
adapters.register_loader("int4", Int4BinaryLoader)
adapters.register_loader("int8", Int8BinaryLoader)
adapters.register_loader("oid", OidBinaryLoader)
adapters.register_loader("float4", FloatLoader)
adapters.register_loader("float8", FloatLoader)
adapters.register_loader("float4", Float4BinaryLoader)
adapters.register_loader("float8", Float8BinaryLoader)
adapters.register_loader("numeric", NumericLoader)
adapters.register_loader("numeric", NumericBinaryLoader)