""" psycopg copy support """ # Copyright (C) 2020 The Psycopg Team import re import queue import struct import asyncio import threading from abc import ABC, abstractmethod from types import TracebackType from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING from . import pq from . import adapt from . import errors as e from .abc import Buffer, ConnectionType, PQGen, Transformer from ._compat import create_task from ._cmodule import _psycopg from ._encodings import pgconn_encoding from .generators import copy_from, copy_to, copy_end if TYPE_CHECKING: from .cursor import BaseCursor, Cursor from .cursor_async import AsyncCursor from .connection import Connection # noqa: F401 from .connection_async import AsyncConnection # noqa: F401 PY_TEXT = adapt.PyFormat.TEXT PY_BINARY = adapt.PyFormat.BINARY TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY COPY_IN = pq.ExecStatus.COPY_IN COPY_OUT = pq.ExecStatus.COPY_OUT ACTIVE = pq.TransactionStatus.ACTIVE # Size of data to accumulate before sending it down the network. We fill a # buffer this size field by field, and when it passes the threshold size # we ship it, so it may end up being bigger than this. BUFFER_SIZE = 32 * 1024 # Maximum data size we want to queue to send to the libpq copy. Sending a # buffer too big to be handled can cause an infinite loop in the libpq # (#255) so we want to split it in more digestable chunks. MAX_BUFFER_SIZE = 4 * BUFFER_SIZE # Note: making this buffer too large, e.g. # MAX_BUFFER_SIZE = 1024 * 1024 # makes operations *way* slower! Probably triggering some quadraticity # in the libpq memory management and data sending. # Max size of the write queue of buffers. More than that copy will block # Each buffer should be around BUFFER_SIZE size. QUEUE_SIZE = 1024 class BaseCopy(Generic[ConnectionType]): """ Base implementation for the copy user interface. Two subclasses expose real methods with the sync/async differences. The difference between the text and binary format is managed by two different `Formatter` subclasses. Writing (the I/O part) is implemented in the subclasses by a `Writer` or `AsyncWriter` instance. Normally writing implies sending copy data to a database, but a different writer might be chosen, e.g. to stream data into a file for later use. """ _Self = TypeVar("_Self", bound="BaseCopy[Any]") formatter: "Formatter" def __init__( self, cursor: "BaseCursor[ConnectionType, Any]", *, binary: Optional[bool] = None, ): self.cursor = cursor self.connection = cursor.connection self._pgconn = self.connection.pgconn result = cursor.pgresult if result: self._direction = result.status if self._direction != COPY_IN and self._direction != COPY_OUT: raise e.ProgrammingError( "the cursor should have performed a COPY operation;" f" its status is {pq.ExecStatus(self._direction).name} instead" ) else: self._direction = COPY_IN if binary is None: binary = bool(result and result.binary_tuples) tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor) if binary: self.formatter = BinaryFormatter(tx) else: self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn)) self._finished = False def __repr__(self) -> str: cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" info = pq.misc.connection_summary(self._pgconn) return f"<{cls} {info} at 0x{id(self):x}>" def _enter(self) -> None: if self._finished: raise TypeError("copy blocks can be used only once") def set_types(self, types: Sequence[Union[int, str]]) -> None: """ Set the types expected in a COPY operation. The types must be specified as a sequence of oid or PostgreSQL type names (e.g. ``int4``, ``timestamptz[]``). This operation overcomes the lack of metadata returned by PostgreSQL when a COPY operation begins: - On :sql:`COPY TO`, `!set_types()` allows to specify what types the operation returns. If `!set_types()` is not used, the data will be returned as unparsed strings or bytes instead of Python objects. - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the database expects. This is especially useful in binary copy, because PostgreSQL will apply no cast rule. """ registry = self.cursor.adapters.types oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types] if self._direction == COPY_IN: self.formatter.transformer.set_dumper_types(oids, self.formatter.format) else: self.formatter.transformer.set_loader_types(oids, self.formatter.format) # High level copy protocol generators (state change of the Copy object) def _read_gen(self) -> PQGen[Buffer]: if self._finished: return memoryview(b"") res = yield from copy_from(self._pgconn) if isinstance(res, memoryview): return res # res is the final PGresult self._finished = True # This result is a COMMAND_OK which has info about the number of rows # returned, but not about the columns, which is instead an information # that was received on the COPY_OUT result at the beginning of COPY. # So, don't replace the results in the cursor, just update the rowcount. nrows = res.command_tuples self.cursor._rowcount = nrows if nrows is not None else -1 return memoryview(b"") def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]: data = yield from self._read_gen() if not data: return None row = self.formatter.parse_row(data) if row is None: # Get the final result to finish the copy operation yield from self._read_gen() self._finished = True return None return row def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]: if not exc: return if self._pgconn.transaction_status != ACTIVE: # The server has already finished to send copy data. The connection # is already in a good state. return # Throw a cancel to the server, then consume the rest of the copy data # (which might or might not have been already transferred entirely to # the client, so we won't necessary see the exception associated with # canceling). self.connection.cancel() try: while (yield from self._read_gen()): pass except e.QueryCanceled: pass class Copy(BaseCopy["Connection[Any]"]): """Manage a :sql:`COPY` operation. :param cursor: the cursor where the operation is performed. :param binary: if `!True`, write binary format. :param writer: the object to write to destination. If not specified, write to the `!cursor` connection. Choosing `!binary` is not necessary if the cursor has executed a :sql:`COPY` operation, because the operation result describes the format too. The parameter is useful when a `!Copy` object is created manually and no operation is performed on the cursor, such as when using ``writer=``\\ `~psycopg.copy.FileWriter`. """ __module__ = "psycopg" writer: "Writer" def __init__( self, cursor: "Cursor[Any]", *, binary: Optional[bool] = None, writer: Optional["Writer"] = None, ): super().__init__(cursor, binary=binary) if not writer: writer = LibpqWriter(cursor) self.writer = writer self._write = writer.write def __enter__(self: BaseCopy._Self) -> BaseCopy._Self: self._enter() return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: self.finish(exc_val) # End user sync interface def __iter__(self) -> Iterator[Buffer]: """Implement block-by-block iteration on :sql:`COPY TO`.""" while True: data = self.read() if not data: break yield data def read(self) -> Buffer: """ Read an unparsed row after a :sql:`COPY TO` operation. Return an empty string when the data is finished. """ return self.connection.wait(self._read_gen()) def rows(self) -> Iterator[Tuple[Any, ...]]: """ Iterate on the result of a :sql:`COPY TO` operation record by record. Note that the records returned will be tuples of unparsed strings or bytes, unless data types are specified using `set_types()`. """ while True: record = self.read_row() if record is None: break yield record def read_row(self) -> Optional[Tuple[Any, ...]]: """ Read a parsed row of data from a table after a :sql:`COPY TO` operation. Return `!None` when the data is finished. Note that the records returned will be tuples of unparsed strings or bytes, unless data types are specified using `set_types()`. """ return self.connection.wait(self._read_row_gen()) def write(self, buffer: Union[Buffer, str]) -> None: """ Write a block of data to a table after a :sql:`COPY FROM` operation. If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In text mode it can be either `!bytes` or `!str`. """ data = self.formatter.write(buffer) if data: self._write(data) def write_row(self, row: Sequence[Any]) -> None: """Write a record to a table after a :sql:`COPY FROM` operation.""" data = self.formatter.write_row(row) if data: self._write(data) def finish(self, exc: Optional[BaseException]) -> None: """Terminate the copy operation and free the resources allocated. You shouldn't need to call this function yourself: it is usually called by exit. It is available if, despite what is documented, you end up using the `Copy` object outside a block. """ if self._direction == COPY_IN: data = self.formatter.end() if data: self._write(data) self.writer.finish(exc) self._finished = True else: self.connection.wait(self._end_copy_out_gen(exc)) class Writer(ABC): """ A class to write copy data somewhere. """ @abstractmethod def write(self, data: Buffer) -> None: """ Write some data to destination. """ ... def finish(self, exc: Optional[BaseException] = None) -> None: """ Called when write operations are finished. If operations finished with an error, it will be passed to ``exc``. """ pass class LibpqWriter(Writer): """ A `Writer` to write copy data to a Postgres database. """ def __init__(self, cursor: "Cursor[Any]"): self.cursor = cursor self.connection = cursor.connection self._pgconn = self.connection.pgconn def write(self, data: Buffer) -> None: if len(data) <= MAX_BUFFER_SIZE: # Most used path: we don't need to split the buffer in smaller # bits, so don't make a copy. self.connection.wait(copy_to(self._pgconn, data)) else: # Copy a buffer too large in chunks to avoid causing a memory # error in the libpq, which may cause an infinite loop (#255). for i in range(0, len(data), MAX_BUFFER_SIZE): self.connection.wait( copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) ) def finish(self, exc: Optional[BaseException] = None) -> None: bmsg: Optional[bytes] if exc: msg = f"error from Python: {type(exc).__qualname__} - {exc}" bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") else: bmsg = None res = self.connection.wait(copy_end(self._pgconn, bmsg)) self.cursor._results = [res] class QueuedLibpqDriver(LibpqWriter): """ A writer using a buffer to queue data to write to a Postgres database. `write()` returns immediately, so that the main thread can be CPU-bound formatting messages, while a worker thread can be IO-bound waiting to write on the connection. """ def __init__(self, cursor: "Cursor[Any]"): super().__init__(cursor) self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE) self._worker: Optional[threading.Thread] = None self._worker_error: Optional[BaseException] = None def worker(self) -> None: """Push data to the server when available from the copy queue. Terminate reading when the queue receives a false-y value, or in case of error. The function is designed to be run in a separate thread. """ try: while True: data = self._queue.get(block=True, timeout=24 * 60 * 60) if not data: break self.connection.wait(copy_to(self._pgconn, data)) except BaseException as ex: # Propagate the error to the main thread. self._worker_error = ex def write(self, data: Buffer) -> None: if not self._worker: # warning: reference loop, broken by _write_end self._worker = threading.Thread(target=self.worker) self._worker.daemon = True self._worker.start() # If the worker thread raies an exception, re-raise it to the caller. if self._worker_error: raise self._worker_error if len(data) <= MAX_BUFFER_SIZE: # Most used path: we don't need to split the buffer in smaller # bits, so don't make a copy. self._queue.put(data) else: # Copy a buffer too large in chunks to avoid causing a memory # error in the libpq, which may cause an infinite loop (#255). for i in range(0, len(data), MAX_BUFFER_SIZE): self._queue.put(data[i : i + MAX_BUFFER_SIZE]) def finish(self, exc: Optional[BaseException] = None) -> None: self._queue.put(b"") if self._worker: self._worker.join() self._worker = None # break the loop # Check if the worker thread raised any exception before terminating. if self._worker_error: raise self._worker_error super().finish(exc) class FileWriter(Writer): """ A `Writer` to write copy data to a file-like object. :param file: the file where to write copy data. It must be open for writing in binary mode. """ def __init__(self, file: IO[bytes]): self.file = file def write(self, data: Buffer) -> None: self.file.write(data) class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" __module__ = "psycopg" writer: "AsyncWriter" def __init__( self, cursor: "AsyncCursor[Any]", *, binary: Optional[bool] = None, writer: Optional["AsyncWriter"] = None, ): super().__init__(cursor, binary=binary) if not writer: writer = AsyncLibpqWriter(cursor) self.writer = writer self._write = writer.write async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self: self._enter() return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: await self.finish(exc_val) async def __aiter__(self) -> AsyncIterator[Buffer]: while True: data = await self.read() if not data: break yield data async def read(self) -> Buffer: return await self.connection.wait(self._read_gen()) async def rows(self) -> AsyncIterator[Tuple[Any, ...]]: while True: record = await self.read_row() if record is None: break yield record async def read_row(self) -> Optional[Tuple[Any, ...]]: return await self.connection.wait(self._read_row_gen()) async def write(self, buffer: Union[Buffer, str]) -> None: data = self.formatter.write(buffer) if data: await self._write(data) async def write_row(self, row: Sequence[Any]) -> None: data = self.formatter.write_row(row) if data: await self._write(data) async def finish(self, exc: Optional[BaseException]) -> None: if self._direction == COPY_IN: data = self.formatter.end() if data: await self._write(data) await self.writer.finish(exc) self._finished = True else: await self.connection.wait(self._end_copy_out_gen(exc)) class AsyncWriter(ABC): """ A class to write copy data somewhere (for async connections). """ @abstractmethod async def write(self, data: Buffer) -> None: ... async def finish(self, exc: Optional[BaseException] = None) -> None: pass class AsyncLibpqWriter(AsyncWriter): """ An `AsyncWriter` to write copy data to a Postgres database. """ def __init__(self, cursor: "AsyncCursor[Any]"): self.cursor = cursor self.connection = cursor.connection self._pgconn = self.connection.pgconn async def write(self, data: Buffer) -> None: if len(data) <= MAX_BUFFER_SIZE: # Most used path: we don't need to split the buffer in smaller # bits, so don't make a copy. await self.connection.wait(copy_to(self._pgconn, data)) else: # Copy a buffer too large in chunks to avoid causing a memory # error in the libpq, which may cause an infinite loop (#255). for i in range(0, len(data), MAX_BUFFER_SIZE): await self.connection.wait( copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) ) async def finish(self, exc: Optional[BaseException] = None) -> None: bmsg: Optional[bytes] if exc: msg = f"error from Python: {type(exc).__qualname__} - {exc}" bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") else: bmsg = None res = await self.connection.wait(copy_end(self._pgconn, bmsg)) self.cursor._results = [res] class AsyncQueuedLibpqWriter(AsyncLibpqWriter): """ An `AsyncWriter` using a buffer to queue data to write. `write()` returns immediately, so that the main thread can be CPU-bound formatting messages, while a worker thread can be IO-bound waiting to write on the connection. """ def __init__(self, cursor: "AsyncCursor[Any]"): super().__init__(cursor) self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE) self._worker: Optional[asyncio.Future[None]] = None async def worker(self) -> None: """Push data to the server when available from the copy queue. Terminate reading when the queue receives a false-y value. The function is designed to be run in a separate task. """ while True: data = await self._queue.get() if not data: break await self.connection.wait(copy_to(self._pgconn, data)) async def write(self, data: Buffer) -> None: if not self._worker: self._worker = create_task(self.worker()) if len(data) <= MAX_BUFFER_SIZE: # Most used path: we don't need to split the buffer in smaller # bits, so don't make a copy. await self._queue.put(data) else: # Copy a buffer too large in chunks to avoid causing a memory # error in the libpq, which may cause an infinite loop (#255). for i in range(0, len(data), MAX_BUFFER_SIZE): await self._queue.put(data[i : i + MAX_BUFFER_SIZE]) async def finish(self, exc: Optional[BaseException] = None) -> None: await self._queue.put(b"") if self._worker: await asyncio.gather(self._worker) self._worker = None # break reference loops if any await super().finish(exc) class Formatter(ABC): """ A class which understand a copy format (text, binary). """ format: pq.Format def __init__(self, transformer: Transformer): self.transformer = transformer self._write_buffer = bytearray() self._row_mode = False # true if the user is using write_row() @abstractmethod def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: ... @abstractmethod def write(self, buffer: Union[Buffer, str]) -> Buffer: ... @abstractmethod def write_row(self, row: Sequence[Any]) -> Buffer: ... @abstractmethod def end(self) -> Buffer: ... class TextFormatter(Formatter): format = TEXT def __init__(self, transformer: Transformer, encoding: str = "utf-8"): super().__init__(transformer) self._encoding = encoding def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: if data: return parse_row_text(data, self.transformer) else: return None def write(self, buffer: Union[Buffer, str]) -> Buffer: data = self._ensure_bytes(buffer) self._signature_sent = True return data def write_row(self, row: Sequence[Any]) -> Buffer: # Note down that we are writing in row mode: it means we will have # to take care of the end-of-copy marker too self._row_mode = True format_row_text(row, self.transformer, self._write_buffer) if len(self._write_buffer) > BUFFER_SIZE: buffer, self._write_buffer = self._write_buffer, bytearray() return buffer else: return b"" def end(self) -> Buffer: buffer, self._write_buffer = self._write_buffer, bytearray() return buffer def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: if isinstance(data, str): return data.encode(self._encoding) else: # Assume, for simplicity, that the user is not passing stupid # things to the write function. If that's the case, things # will fail downstream. return data class BinaryFormatter(Formatter): format = BINARY def __init__(self, transformer: Transformer): super().__init__(transformer) self._signature_sent = False def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: if not self._signature_sent: if data[: len(_binary_signature)] != _binary_signature: raise e.DataError( "binary copy doesn't start with the expected signature" ) self._signature_sent = True data = data[len(_binary_signature) :] elif data == _binary_trailer: return None return parse_row_binary(data, self.transformer) def write(self, buffer: Union[Buffer, str]) -> Buffer: data = self._ensure_bytes(buffer) self._signature_sent = True return data def write_row(self, row: Sequence[Any]) -> Buffer: # Note down that we are writing in row mode: it means we will have # to take care of the end-of-copy marker too self._row_mode = True if not self._signature_sent: self._write_buffer += _binary_signature self._signature_sent = True format_row_binary(row, self.transformer, self._write_buffer) if len(self._write_buffer) > BUFFER_SIZE: buffer, self._write_buffer = self._write_buffer, bytearray() return buffer else: return b"" def end(self) -> Buffer: # If we have sent no data we need to send the signature # and the trailer if not self._signature_sent: self._write_buffer += _binary_signature self._write_buffer += _binary_trailer elif self._row_mode: # if we have sent data already, we have sent the signature # too (either with the first row, or we assume that in # block mode the signature is included). # Write the trailer only if we are sending rows (with the # assumption that who is copying binary data is sending the # whole format). self._write_buffer += _binary_trailer buffer, self._write_buffer = self._write_buffer, bytearray() return buffer def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: if isinstance(data, str): raise TypeError("cannot copy str data in binary mode: use bytes instead") else: # Assume, for simplicity, that the user is not passing stupid # things to the write function. If that's the case, things # will fail downstream. return data def _format_row_text( row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None ) -> bytearray: """Convert a row of objects to the data to send for copy.""" if out is None: out = bytearray() if not row: out += b"\n" return out for item in row: if item is not None: dumper = tx.get_dumper(item, PY_TEXT) b = dumper.dump(item) out += _dump_re.sub(_dump_sub, b) else: out += rb"\N" out += b"\t" out[-1:] = b"\n" return out def _format_row_binary( row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None ) -> bytearray: """Convert a row of objects to the data to send for binary copy.""" if out is None: out = bytearray() out += _pack_int2(len(row)) adapted = tx.dump_sequence(row, [PY_BINARY] * len(row)) for b in adapted: if b is not None: out += _pack_int4(len(b)) out += b else: out += _binary_null return out def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: if not isinstance(data, bytes): data = bytes(data) fields = data.split(b"\t") fields[-1] = fields[-1][:-1] # drop \n row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields] return tx.load_sequence(row) def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: row: List[Optional[Buffer]] = [] nfields = _unpack_int2(data, 0)[0] pos = 2 for i in range(nfields): length = _unpack_int4(data, pos)[0] pos += 4 if length >= 0: row.append(data[pos : pos + length]) pos += length else: row.append(None) return tx.load_sequence(row) _pack_int2 = struct.Struct("!h").pack _pack_int4 = struct.Struct("!i").pack _unpack_int2 = struct.Struct("!h").unpack_from _unpack_int4 = struct.Struct("!i").unpack_from _binary_signature = ( b"PGCOPY\n\xff\r\n\0" # Signature b"\x00\x00\x00\x00" # flags b"\x00\x00\x00\x00" # extra length ) _binary_trailer = b"\xff\xff" _binary_null = b"\xff\xff\xff\xff" _dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]") _dump_repl = { b"\b": b"\\b", b"\t": b"\\t", b"\n": b"\\n", b"\v": b"\\v", b"\f": b"\\f", b"\r": b"\\r", b"\\": b"\\\\", } def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes: return __map[m.group(0)] _load_re = re.compile(b"\\\\[btnvfr\\\\]") _load_repl = {v: k for k, v in _dump_repl.items()} def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes: return __map[m.group(0)] # Override functions with fast versions if available if _psycopg: format_row_text = _psycopg.format_row_text format_row_binary = _psycopg.format_row_binary parse_row_text = _psycopg.parse_row_text parse_row_binary = _psycopg.parse_row_binary else: format_row_text = _format_row_text format_row_binary = _format_row_binary parse_row_text = _parse_row_text parse_row_binary = _parse_row_binary