# Copyright 2023-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MONGODB-OIDC Authentication helpers.""" import os import threading from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import Callable, Dict, List, Optional import bson from bson.binary import Binary from bson.son import SON from pymongo.errors import ConfigurationError, OperationFailure from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE @dataclass class _OIDCProperties: request_token_callback: Optional[Callable[..., Dict]] refresh_token_callback: Optional[Callable[..., Dict]] provider_name: Optional[str] allowed_hosts: List[str] """Mechanism properties for MONGODB-OIDC authentication.""" TOKEN_BUFFER_MINUTES = 5 CALLBACK_TIMEOUT_SECONDS = 5 * 60 CACHE_TIMEOUT_MINUTES = 60 * 5 CALLBACK_VERSION = 0 _CACHE: Dict[str, "_OIDCAuthenticator"] = {} def _get_authenticator(credentials, address): # Clear out old items in the cache. now_utc = datetime.now(timezone.utc) to_remove = [] for key, value in _CACHE.items(): if value.cache_exp_utc is not None and value.cache_exp_utc < now_utc: to_remove.append(key) for key in to_remove: del _CACHE[key] # Extract values. principal_name = credentials.username properties = credentials.mechanism_properties request_cb = properties.request_token_callback refresh_cb = properties.refresh_token_callback # Validate that the address is allowed. if not properties.provider_name: found = False allowed_hosts = properties.allowed_hosts for patt in allowed_hosts: if patt == address[0]: found = True elif patt.startswith("*.") and address[0].endswith(patt[1:]): found = True if not found: raise ConfigurationError( f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}" ) # Get or create the cache item. cache_key = f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}" _CACHE.setdefault(cache_key, _OIDCAuthenticator(username=principal_name, properties=properties)) return _CACHE[cache_key] def _get_cache_exp(): return datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES) @dataclass class _OIDCAuthenticator: username: str properties: _OIDCProperties idp_info: Optional[Dict] = field(default=None) idp_resp: Optional[Dict] = field(default=None) reauth_gen_id: int = field(default=0) idp_info_gen_id: int = field(default=0) token_gen_id: int = field(default=0) token_exp_utc: Optional[datetime] = field(default=None) cache_exp_utc: datetime = field(default_factory=_get_cache_exp) lock: threading.Lock = field(default_factory=threading.Lock) def get_current_token(self, use_callbacks=True): properties = self.properties request_cb = properties.request_token_callback refresh_cb = properties.refresh_token_callback if not use_callbacks: request_cb = None refresh_cb = None current_valid_token = False if self.token_exp_utc is not None: now_utc = datetime.now(timezone.utc) exp_utc = self.token_exp_utc buffer_seconds = TOKEN_BUFFER_MINUTES * 60 if (exp_utc - now_utc).total_seconds() >= buffer_seconds: current_valid_token = True timeout = CALLBACK_TIMEOUT_SECONDS if not use_callbacks and not current_valid_token: return None if not current_valid_token and request_cb is not None: prev_token = self.idp_resp and self.idp_resp["access_token"] with self.lock: # See if the token was changed while we were waiting for the # lock. new_token = self.idp_resp and self.idp_resp["access_token"] if new_token != prev_token: return new_token refresh_token = self.idp_resp and self.idp_resp.get("refresh_token") refresh_token = refresh_token or "" context = { "timeout_seconds": timeout, "version": CALLBACK_VERSION, "refresh_token": refresh_token, } if self.idp_resp is None or refresh_cb is None: self.idp_resp = request_cb(self.idp_info, context) elif request_cb is not None: self.idp_resp = refresh_cb(self.idp_info, context) cache_exp_utc = datetime.now(timezone.utc) + timedelta( minutes=CACHE_TIMEOUT_MINUTES ) self.cache_exp_utc = cache_exp_utc self.token_gen_id += 1 token_result = self.idp_resp # Validate callback return value. if not isinstance(token_result, dict): raise ValueError("OIDC callback returned invalid result") if "access_token" not in token_result: raise ValueError("OIDC callback did not return an access_token") expected = ["access_token", "expires_in_seconds", "refesh_token"] for key in token_result: if key not in expected: raise ValueError(f'Unexpected field in callback result "{key}"') token = token_result["access_token"] if "expires_in_seconds" in token_result: expires_in = int(token_result["expires_in_seconds"]) buffer_seconds = TOKEN_BUFFER_MINUTES * 60 if expires_in >= buffer_seconds: now_utc = datetime.now(timezone.utc) exp_utc = now_utc + timedelta(seconds=expires_in) self.token_exp_utc = exp_utc return token def auth_start_cmd(self, use_callbacks=True): properties = self.properties # Handle aws provider credentials. if properties.provider_name == "aws": aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] with open(aws_identity_file) as fid: token = fid.read().strip() payload = {"jwt": token} cmd = SON( [ ("saslStart", 1), ("mechanism", "MONGODB-OIDC"), ("payload", Binary(bson.encode(payload))), ] ) return cmd principal_name = self.username if self.idp_info is not None: self.cache_exp_utc = datetime.now(timezone.utc) + timedelta( minutes=CACHE_TIMEOUT_MINUTES ) if self.idp_info is None: self.cache_exp_utc = _get_cache_exp() if self.idp_info is None: # Send the SASL start with the optional principal name. payload = {} if principal_name: payload["n"] = principal_name cmd = SON( [ ("saslStart", 1), ("mechanism", "MONGODB-OIDC"), ("payload", Binary(bson.encode(payload))), ("autoAuthorize", 1), ] ) return cmd token = self.get_current_token(use_callbacks) if not token: return None bin_payload = Binary(bson.encode({"jwt": token})) return SON( [ ("saslStart", 1), ("mechanism", "MONGODB-OIDC"), ("payload", bin_payload), ] ) def clear(self): self.idp_info = None self.idp_resp = None self.token_exp_utc = None def run_command(self, sock_info, cmd): try: return sock_info.command("$external", cmd, no_reauth=True) except OperationFailure as exc: self.clear() if exc.code == _REAUTHENTICATION_REQUIRED_CODE: if "jwt" in bson.decode(cmd["payload"]): # type:ignore[attr-defined] if self.idp_info_gen_id > self.reauth_gen_id: raise return self.authenticate(sock_info, reauthenticate=True) raise def authenticate(self, sock_info, reauthenticate=False): if reauthenticate: prev_id = getattr(sock_info, "oidc_token_gen_id", None) # Check if we've already changed tokens. if prev_id == self.token_gen_id: self.reauth_gen_id = self.idp_info_gen_id self.token_exp_utc = None if not self.properties.refresh_token_callback: self.clear() ctx = sock_info.auth_ctx cmd = None if ctx and ctx.speculate_succeeded(): resp = ctx.speculative_authenticate else: cmd = self.auth_start_cmd() resp = self.run_command(sock_info, cmd) if resp["done"]: sock_info.oidc_token_gen_id = self.token_gen_id return None server_resp: Dict = bson.decode(resp["payload"]) if "issuer" in server_resp: self.idp_info = server_resp self.idp_info_gen_id += 1 conversation_id = resp["conversationId"] token = self.get_current_token() sock_info.oidc_token_gen_id = self.token_gen_id bin_payload = Binary(bson.encode({"jwt": token})) cmd = SON( [ ("saslContinue", 1), ("conversationId", conversation_id), ("payload", bin_payload), ] ) resp = self.run_command(sock_info, cmd) if not resp["done"]: self.clear() raise OperationFailure("SASL conversation failed to complete.") return resp def _authenticate_oidc(credentials, sock_info, reauthenticate): """Authenticate using MONGODB-OIDC.""" authenticator = _get_authenticator(credentials, sock_info.address) return authenticator.authenticate(sock_info, reauthenticate=reauthenticate)