386 lines
15 KiB
Python
386 lines
15 KiB
Python
|
# Copyright 2019-present MongoDB, Inc.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||
|
# may not use this file except in compliance with the License. You
|
||
|
# may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||
|
# implied. See the License for the specific language governing
|
||
|
# permissions and limitations under the License.
|
||
|
|
||
|
"""A CPython compatible SSLContext implementation wrapping PyOpenSSL's
|
||
|
context.
|
||
|
"""
|
||
|
|
||
|
import socket as _socket
|
||
|
import ssl as _stdlibssl
|
||
|
import sys as _sys
|
||
|
import time as _time
|
||
|
from errno import EINTR as _EINTR
|
||
|
from ipaddress import ip_address as _ip_address
|
||
|
|
||
|
from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate
|
||
|
from OpenSSL import SSL as _SSL
|
||
|
from OpenSSL import crypto as _crypto
|
||
|
from service_identity import CertificateError as _SICertificateError
|
||
|
from service_identity import VerificationError as _SIVerificationError
|
||
|
from service_identity.pyopenssl import verify_hostname as _verify_hostname
|
||
|
from service_identity.pyopenssl import verify_ip_address as _verify_ip_address
|
||
|
|
||
|
from pymongo.errors import ConfigurationError as _ConfigurationError
|
||
|
from pymongo.errors import _CertificateError
|
||
|
from pymongo.ocsp_cache import _OCSPCache
|
||
|
from pymongo.ocsp_support import _load_trusted_ca_certs, _ocsp_callback
|
||
|
from pymongo.socket_checker import SocketChecker as _SocketChecker
|
||
|
from pymongo.socket_checker import _errno_from_exception
|
||
|
from pymongo.write_concern import validate_boolean
|
||
|
|
||
|
try:
|
||
|
import certifi
|
||
|
|
||
|
_HAVE_CERTIFI = True
|
||
|
except ImportError:
|
||
|
_HAVE_CERTIFI = False
|
||
|
|
||
|
PROTOCOL_SSLv23 = _SSL.SSLv23_METHOD
|
||
|
# Always available
|
||
|
OP_NO_SSLv2 = _SSL.OP_NO_SSLv2
|
||
|
OP_NO_SSLv3 = _SSL.OP_NO_SSLv3
|
||
|
OP_NO_COMPRESSION = _SSL.OP_NO_COMPRESSION
|
||
|
# This isn't currently documented for PyOpenSSL
|
||
|
OP_NO_RENEGOTIATION = getattr(_SSL, "OP_NO_RENEGOTIATION", 0)
|
||
|
|
||
|
# Always available
|
||
|
HAS_SNI = True
|
||
|
IS_PYOPENSSL = True
|
||
|
|
||
|
# Base Exception class
|
||
|
SSLError = _SSL.Error
|
||
|
|
||
|
# https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L2995-L3002
|
||
|
_VERIFY_MAP = {
|
||
|
_stdlibssl.CERT_NONE: _SSL.VERIFY_NONE,
|
||
|
_stdlibssl.CERT_OPTIONAL: _SSL.VERIFY_PEER,
|
||
|
_stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
|
||
|
}
|
||
|
|
||
|
_REVERSE_VERIFY_MAP = {value: key for key, value in _VERIFY_MAP.items()}
|
||
|
|
||
|
|
||
|
# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
|
||
|
# not permitted for SNI hostname.
|
||
|
def _is_ip_address(address):
|
||
|
try:
|
||
|
_ip_address(address)
|
||
|
return True
|
||
|
except (ValueError, UnicodeError): # noqa: B014
|
||
|
return False
|
||
|
|
||
|
|
||
|
# According to the docs for Connection.send it can raise
|
||
|
# WantX509LookupError and should be retried.
|
||
|
BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError)
|
||
|
|
||
|
|
||
|
def _ragged_eof(exc):
|
||
|
"""Return True if the OpenSSL.SSL.SysCallError is a ragged EOF."""
|
||
|
return exc.args == (-1, "Unexpected EOF")
|
||
|
|
||
|
|
||
|
# https://github.com/pyca/pyopenssl/issues/168
|
||
|
# https://github.com/pyca/pyopenssl/issues/176
|
||
|
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
|
||
|
class _sslConn(_SSL.Connection):
|
||
|
def __init__(self, ctx, sock, suppress_ragged_eofs):
|
||
|
self.socket_checker = _SocketChecker()
|
||
|
self.suppress_ragged_eofs = suppress_ragged_eofs
|
||
|
super().__init__(ctx, sock)
|
||
|
|
||
|
def _call(self, call, *args, **kwargs):
|
||
|
timeout = self.gettimeout()
|
||
|
if timeout:
|
||
|
start = _time.monotonic()
|
||
|
while True:
|
||
|
try:
|
||
|
return call(*args, **kwargs)
|
||
|
except BLOCKING_IO_ERRORS as exc:
|
||
|
# Check for closed socket.
|
||
|
if self.fileno() == -1:
|
||
|
if timeout and _time.monotonic() - start > timeout:
|
||
|
raise _socket.timeout("timed out")
|
||
|
raise SSLError("Underlying socket has been closed")
|
||
|
if isinstance(exc, _SSL.WantReadError):
|
||
|
want_read = True
|
||
|
want_write = False
|
||
|
elif isinstance(exc, _SSL.WantWriteError):
|
||
|
want_read = False
|
||
|
want_write = True
|
||
|
else:
|
||
|
want_read = True
|
||
|
want_write = True
|
||
|
self.socket_checker.select(self, want_read, want_write, timeout)
|
||
|
if timeout and _time.monotonic() - start > timeout:
|
||
|
raise _socket.timeout("timed out")
|
||
|
continue
|
||
|
|
||
|
def do_handshake(self, *args, **kwargs):
|
||
|
return self._call(super().do_handshake, *args, **kwargs)
|
||
|
|
||
|
def recv(self, *args, **kwargs):
|
||
|
try:
|
||
|
return self._call(super().recv, *args, **kwargs)
|
||
|
except _SSL.SysCallError as exc:
|
||
|
# Suppress ragged EOFs to match the stdlib.
|
||
|
if self.suppress_ragged_eofs and _ragged_eof(exc):
|
||
|
return b""
|
||
|
raise
|
||
|
|
||
|
def recv_into(self, *args, **kwargs):
|
||
|
try:
|
||
|
return self._call(super().recv_into, *args, **kwargs)
|
||
|
except _SSL.SysCallError as exc:
|
||
|
# Suppress ragged EOFs to match the stdlib.
|
||
|
if self.suppress_ragged_eofs and _ragged_eof(exc):
|
||
|
return 0
|
||
|
raise
|
||
|
|
||
|
def sendall(self, buf, flags=0):
|
||
|
view = memoryview(buf)
|
||
|
total_length = len(buf)
|
||
|
total_sent = 0
|
||
|
while total_sent < total_length:
|
||
|
try:
|
||
|
sent = self._call(super().send, view[total_sent:], flags)
|
||
|
# XXX: It's not clear if this can actually happen. PyOpenSSL
|
||
|
# doesn't appear to have any interrupt handling, nor any interrupt
|
||
|
# errors for OpenSSL connections.
|
||
|
except OSError as exc: # noqa: B014
|
||
|
if _errno_from_exception(exc) == _EINTR:
|
||
|
continue
|
||
|
raise
|
||
|
# https://github.com/pyca/pyopenssl/blob/19.1.0/src/OpenSSL/SSL.py#L1756
|
||
|
# https://www.openssl.org/docs/man1.0.2/man3/SSL_write.html
|
||
|
if sent <= 0:
|
||
|
raise OSError("connection closed")
|
||
|
total_sent += sent
|
||
|
|
||
|
|
||
|
class _CallbackData:
|
||
|
"""Data class which is passed to the OCSP callback."""
|
||
|
|
||
|
def __init__(self):
|
||
|
self.trusted_ca_certs = None
|
||
|
self.check_ocsp_endpoint = None
|
||
|
self.ocsp_response_cache = _OCSPCache()
|
||
|
|
||
|
|
||
|
class SSLContext:
|
||
|
"""A CPython compatible SSLContext implementation wrapping PyOpenSSL's
|
||
|
context.
|
||
|
"""
|
||
|
|
||
|
__slots__ = ("_protocol", "_ctx", "_callback_data", "_check_hostname")
|
||
|
|
||
|
def __init__(self, protocol):
|
||
|
self._protocol = protocol
|
||
|
self._ctx = _SSL.Context(self._protocol)
|
||
|
self._callback_data = _CallbackData()
|
||
|
self._check_hostname = True
|
||
|
# OCSP
|
||
|
# XXX: Find a better place to do this someday, since this is client
|
||
|
# side configuration and wrap_socket tries to support both client and
|
||
|
# server side sockets.
|
||
|
self._callback_data.check_ocsp_endpoint = True
|
||
|
self._ctx.set_ocsp_client_callback(callback=_ocsp_callback, data=self._callback_data)
|
||
|
|
||
|
@property
|
||
|
def protocol(self):
|
||
|
"""The protocol version chosen when constructing the context.
|
||
|
This attribute is read-only.
|
||
|
"""
|
||
|
return self._protocol
|
||
|
|
||
|
def __get_verify_mode(self):
|
||
|
"""Whether to try to verify other peers' certificates and how to
|
||
|
behave if verification fails. This attribute must be one of
|
||
|
ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED.
|
||
|
"""
|
||
|
return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()]
|
||
|
|
||
|
def __set_verify_mode(self, value):
|
||
|
"""Setter for verify_mode."""
|
||
|
|
||
|
def _cb(connobj, x509obj, errnum, errdepth, retcode):
|
||
|
# It seems we don't need to do anything here. Twisted doesn't,
|
||
|
# and OpenSSL's SSL_CTX_set_verify let's you pass NULL
|
||
|
# for the callback option. It's weird that PyOpenSSL requires
|
||
|
# this.
|
||
|
return retcode
|
||
|
|
||
|
self._ctx.set_verify(_VERIFY_MAP[value], _cb)
|
||
|
|
||
|
verify_mode = property(__get_verify_mode, __set_verify_mode)
|
||
|
|
||
|
def __get_check_hostname(self):
|
||
|
return self._check_hostname
|
||
|
|
||
|
def __set_check_hostname(self, value):
|
||
|
validate_boolean("check_hostname", value)
|
||
|
self._check_hostname = value
|
||
|
|
||
|
check_hostname = property(__get_check_hostname, __set_check_hostname)
|
||
|
|
||
|
def __get_check_ocsp_endpoint(self):
|
||
|
return self._callback_data.check_ocsp_endpoint
|
||
|
|
||
|
def __set_check_ocsp_endpoint(self, value):
|
||
|
validate_boolean("check_ocsp", value)
|
||
|
self._callback_data.check_ocsp_endpoint = value
|
||
|
|
||
|
check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint)
|
||
|
|
||
|
def __get_options(self):
|
||
|
# Calling set_options adds the option to the existing bitmask and
|
||
|
# returns the new bitmask.
|
||
|
# https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options
|
||
|
return self._ctx.set_options(0)
|
||
|
|
||
|
def __set_options(self, value):
|
||
|
# Explcitly convert to int, since newer CPython versions
|
||
|
# use enum.IntFlag for options. The values are the same
|
||
|
# regardless of implementation.
|
||
|
self._ctx.set_options(int(value))
|
||
|
|
||
|
options = property(__get_options, __set_options)
|
||
|
|
||
|
def load_cert_chain(self, certfile, keyfile=None, password=None):
|
||
|
"""Load a private key and the corresponding certificate. The certfile
|
||
|
string must be the path to a single file in PEM format containing the
|
||
|
certificate as well as any number of CA certificates needed to
|
||
|
establish the certificate's authenticity. The keyfile string, if
|
||
|
present, must point to a file containing the private key. Otherwise
|
||
|
the private key will be taken from certfile as well.
|
||
|
"""
|
||
|
# Match CPython behavior
|
||
|
# https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L3930-L3971
|
||
|
# Password callback MUST be set first or it will be ignored.
|
||
|
if password:
|
||
|
|
||
|
def _pwcb(max_length, prompt_twice, user_data):
|
||
|
# XXX:We could check the password length against what OpenSSL
|
||
|
# tells us is the max, but we can't raise an exception, so...
|
||
|
# warn?
|
||
|
return password.encode("utf-8")
|
||
|
|
||
|
self._ctx.set_passwd_cb(_pwcb)
|
||
|
self._ctx.use_certificate_chain_file(certfile)
|
||
|
self._ctx.use_privatekey_file(keyfile or certfile)
|
||
|
self._ctx.check_privatekey()
|
||
|
|
||
|
def load_verify_locations(self, cafile=None, capath=None):
|
||
|
"""Load a set of "certification authority"(CA) certificates used to
|
||
|
validate other peers' certificates when `~verify_mode` is other than
|
||
|
ssl.CERT_NONE.
|
||
|
"""
|
||
|
self._ctx.load_verify_locations(cafile, capath)
|
||
|
# Manually load the CA certs when get_verified_chain is not available (pyopenssl<20).
|
||
|
if not hasattr(_SSL.Connection, "get_verified_chain"):
|
||
|
self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile)
|
||
|
|
||
|
def _load_certifi(self):
|
||
|
"""Attempt to load CA certs from certifi."""
|
||
|
if _HAVE_CERTIFI:
|
||
|
self.load_verify_locations(certifi.where())
|
||
|
else:
|
||
|
raise _ConfigurationError(
|
||
|
"tlsAllowInvalidCertificates is False but no system "
|
||
|
"CA certificates could be loaded. Please install the "
|
||
|
"certifi package, or provide a path to a CA file using "
|
||
|
"the tlsCAFile option"
|
||
|
)
|
||
|
|
||
|
def _load_wincerts(self, store):
|
||
|
"""Attempt to load CA certs from Windows trust store."""
|
||
|
cert_store = self._ctx.get_cert_store()
|
||
|
oid = _stdlibssl.Purpose.SERVER_AUTH.oid
|
||
|
for cert, encoding, trust in _stdlibssl.enum_certificates(store): # type: ignore
|
||
|
if encoding == "x509_asn":
|
||
|
if trust is True or oid in trust:
|
||
|
cert_store.add_cert(
|
||
|
_crypto.X509.from_cryptography(_load_der_x509_certificate(cert))
|
||
|
)
|
||
|
|
||
|
def load_default_certs(self):
|
||
|
"""A PyOpenSSL version of load_default_certs from CPython."""
|
||
|
# PyOpenSSL is incapable of loading CA certs from Windows, and mostly
|
||
|
# incapable on macOS.
|
||
|
# https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_default_verify_paths
|
||
|
if _sys.platform == "win32":
|
||
|
try:
|
||
|
for storename in ("CA", "ROOT"):
|
||
|
self._load_wincerts(storename)
|
||
|
except PermissionError:
|
||
|
# Fall back to certifi
|
||
|
self._load_certifi()
|
||
|
elif _sys.platform == "darwin":
|
||
|
self._load_certifi()
|
||
|
self._ctx.set_default_verify_paths()
|
||
|
|
||
|
def set_default_verify_paths(self):
|
||
|
"""Specify that the platform provided CA certificates are to be used
|
||
|
for verification purposes.
|
||
|
"""
|
||
|
# Note: See PyOpenSSL's docs for limitations, which are similar
|
||
|
# but not that same as CPython's.
|
||
|
self._ctx.set_default_verify_paths()
|
||
|
|
||
|
def wrap_socket(
|
||
|
self,
|
||
|
sock,
|
||
|
server_side=False,
|
||
|
do_handshake_on_connect=True,
|
||
|
suppress_ragged_eofs=True,
|
||
|
server_hostname=None,
|
||
|
session=None,
|
||
|
):
|
||
|
"""Wrap an existing Python socket sock and return a TLS socket
|
||
|
object.
|
||
|
"""
|
||
|
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
|
||
|
if session:
|
||
|
ssl_conn.set_session(session)
|
||
|
if server_side is True:
|
||
|
ssl_conn.set_accept_state()
|
||
|
else:
|
||
|
# SNI
|
||
|
if server_hostname and not _is_ip_address(server_hostname):
|
||
|
# XXX: Do this in a callback registered with
|
||
|
# SSLContext.set_info_callback? See Twisted for an example.
|
||
|
ssl_conn.set_tlsext_host_name(server_hostname.encode("idna"))
|
||
|
if self.verify_mode != _stdlibssl.CERT_NONE:
|
||
|
# Request a stapled OCSP response.
|
||
|
ssl_conn.request_ocsp()
|
||
|
ssl_conn.set_connect_state()
|
||
|
# If this wasn't true the caller of wrap_socket would call
|
||
|
# do_handshake()
|
||
|
if do_handshake_on_connect:
|
||
|
# XXX: If we do hostname checking in a callback we can get rid
|
||
|
# of this call to do_handshake() since the handshake
|
||
|
# will happen automatically later.
|
||
|
ssl_conn.do_handshake()
|
||
|
# XXX: Do this in a callback registered with
|
||
|
# SSLContext.set_info_callback? See Twisted for an example.
|
||
|
if self.check_hostname and server_hostname is not None:
|
||
|
try:
|
||
|
if _is_ip_address(server_hostname):
|
||
|
_verify_ip_address(ssl_conn, server_hostname)
|
||
|
else:
|
||
|
_verify_hostname(ssl_conn, server_hostname)
|
||
|
except (_SICertificateError, _SIVerificationError) as exc:
|
||
|
raise _CertificateError(str(exc))
|
||
|
return ssl_conn
|