""" Transaction context managers returned by Connection.transaction() """ # Copyright (C) 2020 The Psycopg Team import logging from types import TracebackType from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING from . import pq from . import sql from . import errors as e from .abc import ConnectionType, PQGen if TYPE_CHECKING: from typing import Any from .connection import Connection from .connection_async import AsyncConnection IDLE = pq.TransactionStatus.IDLE OK = pq.ConnStatus.OK logger = logging.getLogger(__name__) class Rollback(Exception): """ Exit the current `Transaction` context immediately and rollback any changes made within this context. If a transaction context is specified in the constructor, rollback enclosing transactions contexts up to and including the one specified. """ __module__ = "psycopg" def __init__( self, transaction: Union["Transaction", "AsyncTransaction", None] = None, ): self.transaction = transaction def __repr__(self) -> str: return f"{self.__class__.__qualname__}({self.transaction!r})" class OutOfOrderTransactionNesting(e.ProgrammingError): """Out-of-order transaction nesting detected""" class BaseTransaction(Generic[ConnectionType]): def __init__( self, connection: ConnectionType, savepoint_name: Optional[str] = None, force_rollback: bool = False, ): self._conn = connection self.pgconn = self._conn.pgconn self._savepoint_name = savepoint_name or "" self.force_rollback = force_rollback self._entered = self._exited = False self._outer_transaction = False self._stack_index = -1 @property def savepoint_name(self) -> Optional[str]: """ The name of the savepoint; `!None` if handling the main transaction. """ # Yes, it may change on __enter__. No, I don't care, because the # un-entered state is outside the public interface. return self._savepoint_name def __repr__(self) -> str: cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" info = pq.misc.connection_summary(self.pgconn) if not self._entered: status = "inactive" elif not self._exited: status = "active" else: status = "terminated" sp = f"{self.savepoint_name!r} " if self.savepoint_name else "" return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>" def _enter_gen(self) -> PQGen[None]: if self._entered: raise TypeError("transaction blocks can be used only once") self._entered = True self._push_savepoint() for command in self._get_enter_commands(): yield from self._conn._exec_command(command) def _exit_gen( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> PQGen[bool]: if not exc_val and not self.force_rollback: yield from self._commit_gen() return False else: # try to rollback, but if there are problems (connection in a bad # state) just warn without clobbering the exception bubbling up. try: return (yield from self._rollback_gen(exc_val)) except OutOfOrderTransactionNesting: # Clobber an exception happened in the block with the exception # caused by out-of-order transaction detected, so make the # behaviour consistent with _commit_gen and to make sure the # user fixes this condition, which is unrelated from # operational error that might arise in the block. raise except Exception as exc2: logger.warning("error ignored in rollback of %s: %s", self, exc2) return False def _commit_gen(self) -> PQGen[None]: ex = self._pop_savepoint("commit") self._exited = True if ex: raise ex for command in self._get_commit_commands(): yield from self._conn._exec_command(command) def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]: if isinstance(exc_val, Rollback): logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True) ex = self._pop_savepoint("rollback") self._exited = True if ex: raise ex for command in self._get_rollback_commands(): yield from self._conn._exec_command(command) if isinstance(exc_val, Rollback): if not exc_val.transaction or exc_val.transaction is self: return True # Swallow the exception return False def _get_enter_commands(self) -> Iterator[bytes]: if self._outer_transaction: yield self._conn._get_tx_start_command() if self._savepoint_name: yield ( sql.SQL("SAVEPOINT {}") .format(sql.Identifier(self._savepoint_name)) .as_bytes(self._conn) ) def _get_commit_commands(self) -> Iterator[bytes]: if self._savepoint_name and not self._outer_transaction: yield ( sql.SQL("RELEASE {}") .format(sql.Identifier(self._savepoint_name)) .as_bytes(self._conn) ) if self._outer_transaction: assert not self._conn._num_transactions yield b"COMMIT" def _get_rollback_commands(self) -> Iterator[bytes]: if self._savepoint_name and not self._outer_transaction: yield ( sql.SQL("ROLLBACK TO {n}") .format(n=sql.Identifier(self._savepoint_name)) .as_bytes(self._conn) ) yield ( sql.SQL("RELEASE {n}") .format(n=sql.Identifier(self._savepoint_name)) .as_bytes(self._conn) ) if self._outer_transaction: assert not self._conn._num_transactions yield b"ROLLBACK" # Also clear the prepared statements cache. if self._conn._prepared.clear(): yield from self._conn._prepared.get_maintenance_commands() def _push_savepoint(self) -> None: """ Push the transaction on the connection transactions stack. Also set the internal state of the object and verify consistency. """ self._outer_transaction = self.pgconn.transaction_status == IDLE if self._outer_transaction: # outer transaction: if no name it's only a begin, else # there will be an additional savepoint assert not self._conn._num_transactions else: # inner transaction: it always has a name if not self._savepoint_name: self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}" self._stack_index = self._conn._num_transactions self._conn._num_transactions += 1 def _pop_savepoint(self, action: str) -> Optional[Exception]: """ Pop the transaction from the connection transactions stack. Also verify the state consistency. """ self._conn._num_transactions -= 1 if self._conn._num_transactions == self._stack_index: return None return OutOfOrderTransactionNesting( f"transaction {action} at the wrong nesting level: {self}" ) class Transaction(BaseTransaction["Connection[Any]"]): """ Returned by `Connection.transaction()` to handle a transaction block. """ __module__ = "psycopg" _Self = TypeVar("_Self", bound="Transaction") @property def connection(self) -> "Connection[Any]": """The connection the object is managing.""" return self._conn def __enter__(self: _Self) -> _Self: with self._conn.lock: self._conn.wait(self._enter_gen()) return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> bool: if self.pgconn.status == OK: with self._conn.lock: return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) else: return False class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]): """ Returned by `AsyncConnection.transaction()` to handle a transaction block. """ __module__ = "psycopg" _Self = TypeVar("_Self", bound="AsyncTransaction") @property def connection(self) -> "AsyncConnection[Any]": return self._conn async def __aenter__(self: _Self) -> _Self: async with self._conn.lock: await self._conn.wait(self._enter_gen()) return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> bool: if self.pgconn.status == OK: async with self._conn.lock: return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) else: return False