303 lines
9.9 KiB
Python
303 lines
9.9 KiB
Python
# Copyright 2009-present MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Bits and pieces used by the driver that don't really fit elsewhere."""
|
|
|
|
import sys
|
|
import traceback
|
|
from collections import abc
|
|
from typing import Any, List, NoReturn
|
|
|
|
from bson.son import SON
|
|
from pymongo import ASCENDING
|
|
from pymongo.errors import (
|
|
CursorNotFound,
|
|
DuplicateKeyError,
|
|
ExecutionTimeout,
|
|
NotPrimaryError,
|
|
OperationFailure,
|
|
WriteConcernError,
|
|
WriteError,
|
|
WTimeoutError,
|
|
_wtimeout_error,
|
|
)
|
|
from pymongo.hello import HelloCompat
|
|
|
|
# From the SDAM spec, the "node is shutting down" codes.
|
|
_SHUTDOWN_CODES = frozenset(
|
|
[
|
|
11600, # InterruptedAtShutdown
|
|
91, # ShutdownInProgress
|
|
]
|
|
)
|
|
# From the SDAM spec, the "not primary" error codes are combined with the
|
|
# "node is recovering" error codes (of which the "node is shutting down"
|
|
# errors are a subset).
|
|
_NOT_PRIMARY_CODES: frozenset = (
|
|
frozenset(
|
|
[
|
|
10058, # LegacyNotPrimary <=3.2 "not primary" error code
|
|
10107, # NotWritablePrimary
|
|
13435, # NotPrimaryNoSecondaryOk
|
|
11602, # InterruptedDueToReplStateChange
|
|
13436, # NotPrimaryOrSecondary
|
|
189, # PrimarySteppedDown
|
|
]
|
|
)
|
|
| _SHUTDOWN_CODES
|
|
)
|
|
# From the retryable writes spec.
|
|
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
|
|
[
|
|
7, # HostNotFound
|
|
6, # HostUnreachable
|
|
89, # NetworkTimeout
|
|
9001, # SocketException
|
|
262, # ExceededTimeLimit
|
|
]
|
|
)
|
|
|
|
# Server code raised when re-authentication is required
|
|
_REAUTHENTICATION_REQUIRED_CODE = 391
|
|
|
|
|
|
def _gen_index_name(keys):
|
|
"""Generate an index name from the set of fields it is over."""
|
|
return "_".join(["{}_{}".format(*item) for item in keys])
|
|
|
|
|
|
def _index_list(key_or_list, direction=None):
|
|
"""Helper to generate a list of (key, direction) pairs.
|
|
|
|
Takes such a list, or a single key, or a single key and direction.
|
|
"""
|
|
if direction is not None:
|
|
if not isinstance(key_or_list, str):
|
|
raise TypeError("Expected a string and a direction")
|
|
return [(key_or_list, direction)]
|
|
else:
|
|
if isinstance(key_or_list, str):
|
|
return [(key_or_list, ASCENDING)]
|
|
if isinstance(key_or_list, abc.ItemsView):
|
|
return list(key_or_list)
|
|
elif not isinstance(key_or_list, (list, tuple)):
|
|
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
|
|
values = []
|
|
for item in key_or_list:
|
|
if isinstance(item, str):
|
|
item = (item, ASCENDING)
|
|
values.append(item)
|
|
return values
|
|
|
|
|
|
def _index_document(index_list):
|
|
"""Helper to generate an index specifying document.
|
|
|
|
Takes a list of (key, direction) pairs.
|
|
"""
|
|
if isinstance(index_list, abc.Mapping):
|
|
raise TypeError(
|
|
"passing a dict to sort/create_index/hint is not "
|
|
"allowed - use a list of tuples instead. did you "
|
|
"mean %r?" % list(index_list.items())
|
|
)
|
|
elif not isinstance(index_list, (list, tuple)):
|
|
raise TypeError("must use a list of (key, direction) pairs, not: " + repr(index_list))
|
|
if not len(index_list):
|
|
raise ValueError("key_or_list must not be the empty list")
|
|
|
|
index: SON[str, Any] = SON()
|
|
for item in index_list:
|
|
if isinstance(item, str):
|
|
item = (item, ASCENDING)
|
|
key, value = item
|
|
if not isinstance(key, str):
|
|
raise TypeError("first item in each key pair must be an instance of str")
|
|
if not isinstance(value, (str, int, abc.Mapping)):
|
|
raise TypeError(
|
|
"second item in each key pair must be 1, -1, "
|
|
"'2d', or another valid MongoDB index specifier."
|
|
)
|
|
index[key] = value
|
|
return index
|
|
|
|
|
|
def _check_command_response(
|
|
response, max_wire_version, allowable_errors=None, parse_write_concern_error=False
|
|
):
|
|
"""Check the response to a command for errors."""
|
|
if "ok" not in response:
|
|
# Server didn't recognize our message as a command.
|
|
raise OperationFailure(
|
|
response.get("$err"), response.get("code"), response, max_wire_version
|
|
)
|
|
|
|
if parse_write_concern_error and "writeConcernError" in response:
|
|
_error = response["writeConcernError"]
|
|
_labels = response.get("errorLabels")
|
|
if _labels:
|
|
_error.update({"errorLabels": _labels})
|
|
_raise_write_concern_error(_error)
|
|
|
|
if response["ok"]:
|
|
return
|
|
|
|
details = response
|
|
# Mongos returns the error details in a 'raw' object
|
|
# for some errors.
|
|
if "raw" in response:
|
|
for shard in response["raw"].values():
|
|
# Grab the first non-empty raw error from a shard.
|
|
if shard.get("errmsg") and not shard.get("ok"):
|
|
details = shard
|
|
break
|
|
|
|
errmsg = details["errmsg"]
|
|
code = details.get("code")
|
|
|
|
# For allowable errors, only check for error messages when the code is not
|
|
# included.
|
|
if allowable_errors:
|
|
if code is not None:
|
|
if code in allowable_errors:
|
|
return
|
|
elif errmsg in allowable_errors:
|
|
return
|
|
|
|
# Server is "not primary" or "recovering"
|
|
if code is not None:
|
|
if code in _NOT_PRIMARY_CODES:
|
|
raise NotPrimaryError(errmsg, response)
|
|
elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg:
|
|
raise NotPrimaryError(errmsg, response)
|
|
|
|
# Other errors
|
|
# findAndModify with upsert can raise duplicate key error
|
|
if code in (11000, 11001, 12582):
|
|
raise DuplicateKeyError(errmsg, code, response, max_wire_version)
|
|
elif code == 50:
|
|
raise ExecutionTimeout(errmsg, code, response, max_wire_version)
|
|
elif code == 43:
|
|
raise CursorNotFound(errmsg, code, response, max_wire_version)
|
|
|
|
raise OperationFailure(errmsg, code, response, max_wire_version)
|
|
|
|
|
|
def _raise_last_write_error(write_errors: List[Any]) -> NoReturn:
|
|
# If the last batch had multiple errors only report
|
|
# the last error to emulate continue_on_error.
|
|
error = write_errors[-1]
|
|
if error.get("code") == 11000:
|
|
raise DuplicateKeyError(error.get("errmsg"), 11000, error)
|
|
raise WriteError(error.get("errmsg"), error.get("code"), error)
|
|
|
|
|
|
def _raise_write_concern_error(error: Any) -> NoReturn:
|
|
if _wtimeout_error(error):
|
|
# Make sure we raise WTimeoutError
|
|
raise WTimeoutError(error.get("errmsg"), error.get("code"), error)
|
|
raise WriteConcernError(error.get("errmsg"), error.get("code"), error)
|
|
|
|
|
|
def _get_wce_doc(result):
|
|
"""Return the writeConcernError or None."""
|
|
wce = result.get("writeConcernError")
|
|
if wce:
|
|
# The server reports errorLabels at the top level but it's more
|
|
# convenient to attach it to the writeConcernError doc itself.
|
|
error_labels = result.get("errorLabels")
|
|
if error_labels:
|
|
wce["errorLabels"] = error_labels
|
|
return wce
|
|
|
|
|
|
def _check_write_command_response(result):
|
|
"""Backward compatibility helper for write command error handling."""
|
|
# Prefer write errors over write concern errors
|
|
write_errors = result.get("writeErrors")
|
|
if write_errors:
|
|
_raise_last_write_error(write_errors)
|
|
|
|
wce = _get_wce_doc(result)
|
|
if wce:
|
|
_raise_write_concern_error(wce)
|
|
|
|
|
|
def _fields_list_to_dict(fields, option_name):
|
|
"""Takes a sequence of field names and returns a matching dictionary.
|
|
|
|
["a", "b"] becomes {"a": 1, "b": 1}
|
|
|
|
and
|
|
|
|
["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1}
|
|
"""
|
|
if isinstance(fields, abc.Mapping):
|
|
return fields
|
|
|
|
if isinstance(fields, (abc.Sequence, abc.Set)):
|
|
if not all(isinstance(field, str) for field in fields):
|
|
raise TypeError(f"{option_name} must be a list of key names, each an instance of str")
|
|
return dict.fromkeys(fields, 1)
|
|
|
|
raise TypeError(f"{option_name} must be a mapping or list of key names")
|
|
|
|
|
|
def _handle_exception():
|
|
"""Print exceptions raised by subscribers to stderr."""
|
|
# Heavily influenced by logging.Handler.handleError.
|
|
|
|
# See note here:
|
|
# https://docs.python.org/3.4/library/sys.html#sys.__stderr__
|
|
if sys.stderr:
|
|
einfo = sys.exc_info()
|
|
try:
|
|
traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr)
|
|
except OSError:
|
|
pass
|
|
finally:
|
|
del einfo
|
|
|
|
|
|
def _handle_reauth(func):
|
|
def inner(*args, **kwargs):
|
|
no_reauth = kwargs.pop("no_reauth", False)
|
|
from pymongo.pool import SocketInfo
|
|
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except OperationFailure as exc:
|
|
if no_reauth:
|
|
raise
|
|
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
|
|
# Look for an argument that either is a SocketInfo
|
|
# or has a socket_info attribute, so we can trigger
|
|
# a reauth.
|
|
sock_info = None
|
|
for arg in args:
|
|
if isinstance(arg, SocketInfo):
|
|
sock_info = arg
|
|
break
|
|
if hasattr(arg, "sock_info"):
|
|
sock_info = arg.sock_info
|
|
break
|
|
if sock_info:
|
|
sock_info.authenticate(reauthenticate=True)
|
|
else:
|
|
raise
|
|
return func(*args, **kwargs)
|
|
raise
|
|
|
|
return inner
|