# Copyright 2022-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Internal helpers for CSOT.""" import functools import time from collections import deque from contextvars import ContextVar, Token from typing import Any, Callable, Deque, MutableMapping, Optional, Tuple, TypeVar, cast from pymongo.write_concern import WriteConcern TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None) RTT: ContextVar[float] = ContextVar("RTT", default=0.0) DEADLINE: ContextVar[float] = ContextVar("DEADLINE", default=float("inf")) def get_timeout() -> Optional[float]: return TIMEOUT.get(None) def get_rtt() -> float: return RTT.get() def get_deadline() -> float: return DEADLINE.get() def set_rtt(rtt: float) -> None: RTT.set(rtt) def remaining() -> Optional[float]: if not get_timeout(): return None return DEADLINE.get() - time.monotonic() def clamp_remaining(max_timeout: float) -> float: """Return the remaining timeout clamped to a max value.""" timeout = remaining() if timeout is None: return max_timeout return min(timeout, max_timeout) class _TimeoutContext: """Internal timeout context manager. Use :func:`pymongo.timeout` instead:: with pymongo.timeout(0.5): client.test.test.insert_one({}) """ __slots__ = ("_timeout", "_tokens") def __init__(self, timeout: Optional[float]): self._timeout = timeout self._tokens: Optional[Tuple[Token, Token, Token]] = None def __enter__(self): timeout_token = TIMEOUT.set(self._timeout) prev_deadline = DEADLINE.get() next_deadline = time.monotonic() + self._timeout if self._timeout else float("inf") deadline_token = DEADLINE.set(min(prev_deadline, next_deadline)) rtt_token = RTT.set(0.0) self._tokens = (timeout_token, deadline_token, rtt_token) return self def __exit__(self, exc_type, exc_val, exc_tb): if self._tokens: timeout_token, deadline_token, rtt_token = self._tokens TIMEOUT.reset(timeout_token) DEADLINE.reset(deadline_token) RTT.reset(rtt_token) # See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories F = TypeVar("F", bound=Callable[..., Any]) def apply(func: F) -> F: """Apply the client's timeoutMS to this operation.""" @functools.wraps(func) def csot_wrapper(self, *args, **kwargs): if get_timeout() is None: timeout = self._timeout if timeout is not None: with _TimeoutContext(timeout): return func(self, *args, **kwargs) return func(self, *args, **kwargs) return cast(F, csot_wrapper) def apply_write_concern(cmd: MutableMapping, write_concern: Optional[WriteConcern]) -> None: """Apply the given write concern to a command.""" if not write_concern or write_concern.is_server_default: return wc = write_concern.document if get_timeout() is not None: wc.pop("wtimeout", None) if wc: cmd["writeConcern"] = wc _MAX_RTT_SAMPLES: int = 10 _MIN_RTT_SAMPLES: int = 2 class MovingMinimum: """Tracks a minimum RTT within the last 10 RTT samples.""" samples: Deque[float] def __init__(self) -> None: self.samples = deque(maxlen=_MAX_RTT_SAMPLES) def add_sample(self, sample: float) -> None: if sample < 0: # Likely system time change while waiting for hello response # and not using time.monotonic. Ignore it, the next one will # probably be valid. return self.samples.append(sample) def get(self) -> float: """Get the min, or 0.0 if there aren't enough samples yet.""" if len(self.samples) >= _MIN_RTT_SAMPLES: return min(self.samples) return 0.0 def reset(self) -> None: self.samples.clear()