# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Select / poll helper""" import errno import select import sys from typing import Any, Optional # PYTHON-2320: Jython does not fully support poll on SSL sockets, # https://bugs.jython.org/issue2900 _HAVE_POLL = hasattr(select, "poll") and not sys.platform.startswith("java") _SelectError = getattr(select, "error", OSError) def _errno_from_exception(exc): if hasattr(exc, "errno"): return exc.errno if exc.args: return exc.args[0] return None class SocketChecker: def __init__(self) -> None: self._poller: Optional[select.poll] if _HAVE_POLL: self._poller = select.poll() else: self._poller = None def select( self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[float] = 0 ) -> bool: """Select for reads or writes with a timeout in seconds (or None). Returns True if the socket is readable/writable, False on timeout. """ res: Any while True: try: if self._poller: mask = select.POLLERR | select.POLLHUP if read: mask = mask | select.POLLIN | select.POLLPRI if write: mask = mask | select.POLLOUT self._poller.register(sock, mask) try: # poll() timeout is in milliseconds. select() # timeout is in seconds. timeout_ = None if timeout is None else timeout * 1000 res = self._poller.poll(timeout_) # poll returns a possibly-empty list containing # (fd, event) 2-tuples for the descriptors that have # events or errors to report. Return True if the list # is not empty. return bool(res) finally: self._poller.unregister(sock) else: rlist = [sock] if read else [] wlist = [sock] if write else [] res = select.select(rlist, wlist, [sock], timeout) # select returns a 3-tuple of lists of objects that are # ready: subsets of the first three arguments. Return # True if any of the lists are not empty. return any(res) except (_SelectError, OSError) as exc: # type: ignore if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN): continue raise def socket_closed(self, sock: Any) -> bool: """Return True if we know socket has been closed, False otherwise.""" try: return self.select(sock, read=True) except (RuntimeError, KeyError): # RuntimeError is raised during a concurrent poll. KeyError # is raised by unregister if the socket is not in the poller. # These errors should not be possible since we protect the # poller with a mutex. raise except ValueError: # ValueError is raised by register/unregister/select if the # socket file descriptor is negative or outside the range for # select (> 1023). return True except Exception: # Any other exceptions should be attributed to a closed # or invalid socket. return True