# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Internal network layer helper methods.""" import datetime import errno import socket import struct import time from bson import _decode_all_selective from pymongo import _csot, helpers, message, ssl_support from pymongo.common import MAX_MESSAGE_SIZE from pymongo.compression_support import _NO_COMPRESSION, decompress from pymongo.errors import ( NotPrimaryError, OperationFailure, ProtocolError, _OperationCancelled, ) from pymongo.message import _UNPACK_REPLY, _OpMsg from pymongo.monitoring import _is_speculative_authenticate from pymongo.socket_checker import _errno_from_exception _UNPACK_HEADER = struct.Struct(" max_bson_size: message._raise_document_too_large(name, size, max_bson_size) else: request_id, msg, size = message._query( 0, ns, 0, -1, spec, None, codec_options, compression_ctx ) if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) if publish: encoding_duration = datetime.datetime.now() - start listeners.publish_command_start( orig, dbname, request_id, address, service_id=sock_info.service_id ) start = datetime.datetime.now() try: sock_info.sock.sendall(msg) if use_op_msg and unacknowledged: # Unacknowledged, fake a successful command response. reply = None response_doc = {"ok": 1} else: reply = receive_message(sock_info, request_id) sock_info.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response( codec_options=codec_options, user_fields=user_fields ) response_doc = unpacked_docs[0] if client: client._process_response(response_doc, session) if check: helpers._check_command_response( response_doc, sock_info.max_wire_version, allowable_errors, parse_write_concern_error=parse_write_concern_error, ) except Exception as exc: if publish: duration = (datetime.datetime.now() - start) + encoding_duration if isinstance(exc, (NotPrimaryError, OperationFailure)): failure = exc.details else: failure = message._convert_exception(exc) listeners.publish_command_failure( duration, failure, name, request_id, address, service_id=sock_info.service_id ) raise if publish: duration = (datetime.datetime.now() - start) + encoding_duration listeners.publish_command_success( duration, response_doc, name, request_id, address, service_id=sock_info.service_id, speculative_hello=speculative_hello, ) if client and client._encrypter and reply: decrypted = client._encrypter.decrypt(reply.raw_command_response()) response_doc = _decode_all_selective(decrypted, codec_options, user_fields)[0] return response_doc _UNPACK_COMPRESSION_HEADER = struct.Struct(" max_message_size: raise ProtocolError( "Message length ({!r}) is larger than server max " "message size ({!r})".format(length, max_message_size) ) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( _receive_data_on_socket(sock_info, 9, deadline) ) data = decompress(_receive_data_on_socket(sock_info, length - 25, deadline), compressor_id) else: data = _receive_data_on_socket(sock_info, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] except KeyError: raise ProtocolError(f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}") return unpack_reply(data) _POLL_TIMEOUT = 0.5 def wait_for_read(sock_info, deadline): """Block until at least one byte is read, or a timeout, or a cancel.""" context = sock_info.cancel_context # Only Monitor connections can be cancelled. if context: sock = sock_info.sock timed_out = False while True: # SSLSocket can have buffered data which won't be caught by select. if hasattr(sock, "pending") and sock.pending() > 0: readable = True else: # Wait up to 500ms for the socket to become readable and then # check for cancellation. if deadline: remaining = deadline - time.monotonic() # When the timeout has expired perform one final check to # see if the socket is readable. This helps avoid spurious # timeouts on AWS Lambda and other FaaS environments. if remaining <= 0: timed_out = True timeout = max(min(remaining, _POLL_TIMEOUT), 0) else: timeout = _POLL_TIMEOUT readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout) if context.cancelled: raise _OperationCancelled("hello cancelled") if readable: return if timed_out: raise socket.timeout("timed out") # Errors raised by sockets (and TLS sockets) when in non-blocking mode. BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS) def _receive_data_on_socket(sock_info, length, deadline): buf = bytearray(length) mv = memoryview(buf) bytes_read = 0 while bytes_read < length: try: wait_for_read(sock_info, deadline) # CSOT: Update timeout. When the timeout has expired perform one # final non-blocking recv. This helps avoid spurious timeouts when # the response is actually already buffered on the client. if _csot.get_timeout(): sock_info.set_socket_timeout(max(deadline - time.monotonic(), 0)) chunk_length = sock_info.sock.recv_into(mv[bytes_read:]) except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") except OSError as exc: # noqa: B014 if _errno_from_exception(exc) == errno.EINTR: continue raise if chunk_length == 0: raise OSError("connection closed") bytes_read += chunk_length return mv