""" Support for prepared statements """ # Copyright (C) 2020 The Psycopg Team from enum import IntEnum, auto from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING from collections import OrderedDict from typing_extensions import TypeAlias from . import pq from ._compat import Deque from ._queries import PostgresQuery if TYPE_CHECKING: from .pq.abc import PGresult Key: TypeAlias = Tuple[bytes, Tuple[int, ...]] COMMAND_OK = pq.ExecStatus.COMMAND_OK TUPLES_OK = pq.ExecStatus.TUPLES_OK class Prepare(IntEnum): NO = auto() YES = auto() SHOULD = auto() class PrepareManager: # Number of times a query is executed before it is prepared. prepare_threshold: Optional[int] = 5 # Maximum number of prepared statements on the connection. prepared_max: int = 100 def __init__(self) -> None: # Map (query, types) to the number of times the query was seen. self._counts: OrderedDict[Key, int] = OrderedDict() # Map (query, types) to the name of the statement if prepared. self._names: OrderedDict[Key, bytes] = OrderedDict() # Counter to generate prepared statements names self._prepared_idx = 0 self._maint_commands = Deque[bytes]() @staticmethod def key(query: PostgresQuery) -> Key: return (query.query, query.types) def get( self, query: PostgresQuery, prepare: Optional[bool] = None ) -> Tuple[Prepare, bytes]: """ Check if a query is prepared, tell back whether to prepare it. """ if prepare is False or self.prepare_threshold is None: # The user doesn't want this query to be prepared return Prepare.NO, b"" key = self.key(query) name = self._names.get(key) if name: # The query was already prepared in this session return Prepare.YES, name count = self._counts.get(key, 0) if count >= self.prepare_threshold or prepare: # The query has been executed enough times and needs to be prepared name = f"_pg3_{self._prepared_idx}".encode() self._prepared_idx += 1 return Prepare.SHOULD, name else: # The query is not to be prepared yet return Prepare.NO, b"" def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool: """Check if we need to discard our entire state: it should happen on rollback or on dropping objects, because the same object may get recreated and postgres would fail internal lookups. """ if self._names or prep == Prepare.SHOULD: for result in results: if result.status != COMMAND_OK: continue cmdstat = result.command_status if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"): return self.clear() return False @staticmethod def _check_results(results: Sequence["PGresult"]) -> bool: """Return False if 'results' are invalid for prepared statement cache.""" if len(results) != 1: # We cannot prepare a multiple statement return False status = results[0].status if COMMAND_OK != status != TUPLES_OK: # We don't prepare failed queries or other weird results return False return True def _rotate(self) -> None: """Evict an old value from the cache. If it was prepared, deallocate it. Do it only once: if the cache was resized, deallocate gradually. """ if len(self._counts) > self.prepared_max: self._counts.popitem(last=False) if len(self._names) > self.prepared_max: name = self._names.popitem(last=False)[1] self._maint_commands.append(b"DEALLOCATE " + name) def maybe_add_to_cache( self, query: PostgresQuery, prep: Prepare, name: bytes ) -> Optional[Key]: """Handle 'query' for possible addition to the cache. If a new entry has been added, return its key. Return None otherwise (meaning the query is already in cache or cache is not enabled). Note: This method is only called in pipeline mode. """ # don't do anything if prepared statements are disabled if self.prepare_threshold is None: return None key = self.key(query) if key in self._counts: if prep is Prepare.SHOULD: del self._counts[key] self._names[key] = name else: self._counts[key] += 1 self._counts.move_to_end(key) return None elif key in self._names: self._names.move_to_end(key) return None else: if prep is Prepare.SHOULD: self._names[key] = name else: self._counts[key] = 1 return key def validate( self, key: Key, prep: Prepare, name: bytes, results: Sequence["PGresult"], ) -> None: """Validate cached entry with 'key' by checking query 'results'. Possibly return a command to perform maintenance on database side. Note: this method is only called in pipeline mode. """ if self._should_discard(prep, results): return if not self._check_results(results): self._names.pop(key, None) self._counts.pop(key, None) else: self._rotate() def clear(self) -> bool: """Clear the cache of the maintenance commands. Clear the internal state and prepare a command to clear the state of the server. """ self._counts.clear() if self._names: self._names.clear() self._maint_commands.clear() self._maint_commands.append(b"DEALLOCATE ALL") return True else: return False def get_maintenance_commands(self) -> Iterator[bytes]: """ Iterate over the commands needed to align the server state to our state """ while self._maint_commands: yield self._maint_commands.popleft()