# Python implementation of the MySQL client-server protocol # http://dev.mysql.com/doc/internals/en/client-server-protocol.html # Error codes: # https://dev.mysql.com/doc/refman/5.5/en/error-handling.html import errno import os import socket import struct import sys import traceback import warnings from . import _auth from .charset import charset_by_name, charset_by_id from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS from . import converters from .cursors import Cursor from .optionfile import Parser from .protocol import ( dump_packet, MysqlPacket, FieldDescriptorPacket, OKPacketWrapper, EOFPacketWrapper, LoadLocalPacketWrapper, ) from . import err, VERSION_STRING try: import ssl SSL_ENABLED = True except ImportError: ssl = None SSL_ENABLED = False try: import getpass DEFAULT_USER = getpass.getuser() del getpass except (ImportError, KeyError): # KeyError occurs when there's no entry in OS database for a current user. DEFAULT_USER = None DEBUG = False TEXT_TYPES = { FIELD_TYPE.BIT, FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB, FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.STRING, FIELD_TYPE.TINY_BLOB, FIELD_TYPE.VAR_STRING, FIELD_TYPE.VARCHAR, FIELD_TYPE.GEOMETRY, } DEFAULT_CHARSET = "utf8mb4" MAX_PACKET_LEN = 2**24 - 1 def _pack_int24(n): return struct.pack("`_ in the specification. """ _sock = None _auth_plugin_name = "" _closed = False _secure = False def __init__( self, *, user=None, # The first four arguments is based on DB-API 2.0 recommendation. password="", host=None, database=None, unix_socket=None, port=0, charset="", sql_mode=None, read_default_file=None, conv=None, use_unicode=True, client_flag=0, cursorclass=Cursor, init_command=None, connect_timeout=10, read_default_group=None, autocommit=False, local_infile=False, max_allowed_packet=16 * 1024 * 1024, defer_connect=False, auth_plugin_map=None, read_timeout=None, write_timeout=None, bind_address=None, binary_prefix=False, program_name=None, server_public_key=None, ssl=None, ssl_ca=None, ssl_cert=None, ssl_disabled=None, ssl_key=None, ssl_verify_cert=None, ssl_verify_identity=None, compress=None, # not supported named_pipe=None, # not supported passwd=None, # deprecated db=None, # deprecated ): if db is not None and database is None: # We will raise warning in 2022 or later. # See https://github.com/PyMySQL/PyMySQL/issues/939 # warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3) database = db if passwd is not None and not password: # We will raise warning in 2022 or later. # See https://github.com/PyMySQL/PyMySQL/issues/939 # warnings.warn( # "'passwd' is deprecated, use 'password'", DeprecationWarning, 3 # ) password = passwd if compress or named_pipe: raise NotImplementedError( "compress and named_pipe arguments are not supported" ) self._local_infile = bool(local_infile) if self._local_infile: client_flag |= CLIENT.LOCAL_FILES if read_default_group and not read_default_file: if sys.platform.startswith("win"): read_default_file = "c:\\my.ini" else: read_default_file = "/etc/my.cnf" if read_default_file: if not read_default_group: read_default_group = "client" cfg = Parser() cfg.read(os.path.expanduser(read_default_file)) def _config(key, arg): if arg: return arg try: return cfg.get(read_default_group, key) except Exception: return arg user = _config("user", user) password = _config("password", password) host = _config("host", host) database = _config("database", database) unix_socket = _config("socket", unix_socket) port = int(_config("port", port)) bind_address = _config("bind-address", bind_address) charset = _config("default-character-set", charset) if not ssl: ssl = {} if isinstance(ssl, dict): for key in ["ca", "capath", "cert", "key", "cipher"]: value = _config("ssl-" + key, ssl.get(key)) if value: ssl[key] = value self.ssl = False if not ssl_disabled: if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity: ssl = { "ca": ssl_ca, "check_hostname": bool(ssl_verify_identity), "verify_mode": ssl_verify_cert if ssl_verify_cert is not None else False, } if ssl_cert is not None: ssl["cert"] = ssl_cert if ssl_key is not None: ssl["key"] = ssl_key if ssl: if not SSL_ENABLED: raise NotImplementedError("ssl module not found") self.ssl = True client_flag |= CLIENT.SSL self.ctx = self._create_ssl_ctx(ssl) self.host = host or "localhost" self.port = port or 3306 if type(self.port) is not int: raise ValueError("port should be of type int") self.user = user or DEFAULT_USER self.password = password or b"" if isinstance(self.password, str): self.password = self.password.encode("latin1") self.db = database self.unix_socket = unix_socket self.bind_address = bind_address if not (0 < connect_timeout <= 31536000): raise ValueError("connect_timeout should be >0 and <=31536000") self.connect_timeout = connect_timeout or None if read_timeout is not None and read_timeout <= 0: raise ValueError("read_timeout should be > 0") self._read_timeout = read_timeout if write_timeout is not None and write_timeout <= 0: raise ValueError("write_timeout should be > 0") self._write_timeout = write_timeout self.charset = charset or DEFAULT_CHARSET self.use_unicode = use_unicode self.encoding = charset_by_name(self.charset).encoding client_flag |= CLIENT.CAPABILITIES if self.db: client_flag |= CLIENT.CONNECT_WITH_DB self.client_flag = client_flag self.cursorclass = cursorclass self._result = None self._affected_rows = 0 self.host_info = "Not connected" # specified autocommit mode. None means use server default. self.autocommit_mode = autocommit if conv is None: conv = converters.conversions # Need for MySQLdb compatibility. self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int} self.decoders = {k: v for (k, v) in conv.items() if type(k) is int} self.sql_mode = sql_mode self.init_command = init_command self.max_allowed_packet = max_allowed_packet self._auth_plugin_map = auth_plugin_map or {} self._binary_prefix = binary_prefix self.server_public_key = server_public_key self._connect_attrs = { "_client_name": "pymysql", "_pid": str(os.getpid()), "_client_version": VERSION_STRING, } if program_name: self._connect_attrs["program_name"] = program_name if defer_connect: self._sock = None else: self.connect() def __enter__(self): return self def __exit__(self, *exc_info): del exc_info self.close() def _create_ssl_ctx(self, sslp): if isinstance(sslp, ssl.SSLContext): return sslp ca = sslp.get("ca") capath = sslp.get("capath") hasnoca = ca is None and capath is None ctx = ssl.create_default_context(cafile=ca, capath=capath) ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True) verify_mode_value = sslp.get("verify_mode") if verify_mode_value is None: ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED elif isinstance(verify_mode_value, bool): ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE else: if isinstance(verify_mode_value, str): verify_mode_value = verify_mode_value.lower() if verify_mode_value in ("none", "0", "false", "no"): ctx.verify_mode = ssl.CERT_NONE elif verify_mode_value == "optional": ctx.verify_mode = ssl.CERT_OPTIONAL elif verify_mode_value in ("required", "1", "true", "yes"): ctx.verify_mode = ssl.CERT_REQUIRED else: ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED if "cert" in sslp: ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key")) if "cipher" in sslp: ctx.set_ciphers(sslp["cipher"]) ctx.options |= ssl.OP_NO_SSLv2 ctx.options |= ssl.OP_NO_SSLv3 return ctx def close(self): """ Send the quit message and close the socket. See `Connection.close() `_ in the specification. :raise Error: If the connection is already closed. """ if self._closed: raise err.Error("Already closed") self._closed = True if self._sock is None: return send_data = struct.pack("`_ in the specification. """ self._execute_command(COMMAND.COM_QUERY, "COMMIT") self._read_ok_packet() def rollback(self): """ Roll back the current transaction. See `Connection.rollback() `_ in the specification. """ self._execute_command(COMMAND.COM_QUERY, "ROLLBACK") self._read_ok_packet() def show_warnings(self): """Send the "SHOW WARNINGS" SQL command.""" self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS") result = MySQLResult(self) result.read() return result.rows def select_db(self, db): """ Set current db. :param db: The name of the db. """ self._execute_command(COMMAND.COM_INIT_DB, db) self._read_ok_packet() def escape(self, obj, mapping=None): """Escape whatever value is passed. Non-standard, for internal use; do not use this in your applications. """ if isinstance(obj, str): return "'" + self.escape_string(obj) + "'" if isinstance(obj, (bytes, bytearray)): ret = self._quote_bytes(obj) if self._binary_prefix: ret = "_binary" + ret return ret return converters.escape_item(obj, self.charset, mapping=mapping) def literal(self, obj): """Alias for escape(). Non-standard, for internal use; do not use this in your applications. """ return self.escape(obj, self.encoders) def escape_string(self, s): if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES: return s.replace("'", "''") return converters.escape_string(s) def _quote_bytes(self, s): if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES: return "'%s'" % (s.replace(b"'", b"''").decode("ascii", "surrogateescape"),) return converters.escape_bytes(s) def cursor(self, cursor=None): """ Create a new cursor to execute queries with. :param cursor: The type of cursor to create. None means use Cursor. :type cursor: :py:class:`Cursor`, :py:class:`SSCursor`, :py:class:`DictCursor`, or :py:class:`SSDictCursor`. """ if cursor: return cursor(self) return self.cursorclass(self) # The following methods are INTERNAL USE ONLY (called from Cursor) def query(self, sql, unbuffered=False): # if DEBUG: # print("DEBUG: sending query:", sql) if isinstance(sql, str): sql = sql.encode(self.encoding, "surrogateescape") self._execute_command(COMMAND.COM_QUERY, sql) self._affected_rows = self._read_query_result(unbuffered=unbuffered) return self._affected_rows def next_result(self, unbuffered=False): self._affected_rows = self._read_query_result(unbuffered=unbuffered) return self._affected_rows def affected_rows(self): return self._affected_rows def kill(self, thread_id): arg = struct.pack("= 5: self.client_flag |= CLIENT.MULTI_RESULTS if self.user is None: raise ValueError("Did not specify a username") charset_id = charset_by_name(self.charset).id if isinstance(self.user, str): self.user = self.user.encode(self.encoding) data_init = struct.pack( "=5.0) data += authresp + b"\0" if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB: if isinstance(self.db, str): self.db = self.db.encode(self.encoding) data += self.db + b"\0" if self.server_capabilities & CLIENT.PLUGIN_AUTH: data += (plugin_name or b"") + b"\0" if self.server_capabilities & CLIENT.CONNECT_ATTRS: connect_attrs = b"" for k, v in self._connect_attrs.items(): k = k.encode("utf-8") connect_attrs += _lenenc_int(len(k)) + k v = v.encode("utf-8") connect_attrs += _lenenc_int(len(v)) + v data += _lenenc_int(len(connect_attrs)) + connect_attrs self.write_packet(data) auth_packet = self._read_packet() # if authentication method isn't accepted the first byte # will have the octet 254 if auth_packet.is_auth_switch_request(): if DEBUG: print("received auth switch") # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest auth_packet.read_uint8() # 0xfe packet identifier plugin_name = auth_packet.read_string() if ( self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None ): auth_packet = self._process_auth(plugin_name, auth_packet) else: raise err.OperationalError("received unknown auth switch request") elif auth_packet.is_extra_auth_data(): if DEBUG: print("received extra data") # https://dev.mysql.com/doc/internals/en/successful-authentication.html if self._auth_plugin_name == "caching_sha2_password": auth_packet = _auth.caching_sha2_password_auth(self, auth_packet) elif self._auth_plugin_name == "sha256_password": auth_packet = _auth.sha256_password_auth(self, auth_packet) else: raise err.OperationalError( "Received extra packet for auth method %r", self._auth_plugin_name ) if DEBUG: print("Succeed to auth") def _process_auth(self, plugin_name, auth_packet): handler = self._get_auth_plugin_handler(plugin_name) if handler: try: return handler.authenticate(auth_packet) except AttributeError: if plugin_name != b"dialog": raise err.OperationalError( CR.CR_AUTH_PLUGIN_CANNOT_LOAD, "Authentication plugin '%s'" " not loaded: - %r missing authenticate method" % (plugin_name, type(handler)), ) if plugin_name == b"caching_sha2_password": return _auth.caching_sha2_password_auth(self, auth_packet) elif plugin_name == b"sha256_password": return _auth.sha256_password_auth(self, auth_packet) elif plugin_name == b"mysql_native_password": data = _auth.scramble_native_password(self.password, auth_packet.read_all()) elif plugin_name == b"client_ed25519": data = _auth.ed25519_password(self.password, auth_packet.read_all()) elif plugin_name == b"mysql_old_password": data = ( _auth.scramble_old_password(self.password, auth_packet.read_all()) + b"\0" ) elif plugin_name == b"mysql_clear_password": # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html data = self.password + b"\0" elif plugin_name == b"dialog": pkt = auth_packet while True: flag = pkt.read_uint8() echo = (flag & 0x06) == 0x02 last = (flag & 0x01) == 0x01 prompt = pkt.read_all() if prompt == b"Password: ": self.write_packet(self.password + b"\0") elif handler: resp = "no response - TypeError within plugin.prompt method" try: resp = handler.prompt(echo, prompt) self.write_packet(resp + b"\0") except AttributeError: raise err.OperationalError( CR.CR_AUTH_PLUGIN_CANNOT_LOAD, "Authentication plugin '%s'" " not loaded: - %r missing prompt method" % (plugin_name, handler), ) except TypeError: raise err.OperationalError( CR.CR_AUTH_PLUGIN_ERR, "Authentication plugin '%s'" " %r didn't respond with string. Returned '%r' to prompt %r" % (plugin_name, handler, resp, prompt), ) else: raise err.OperationalError( CR.CR_AUTH_PLUGIN_CANNOT_LOAD, "Authentication plugin '%s' not configured" % (plugin_name,), ) pkt = self._read_packet() pkt.check_error() if pkt.is_ok_packet() or last: break return pkt else: raise err.OperationalError( CR.CR_AUTH_PLUGIN_CANNOT_LOAD, "Authentication plugin '%s' not configured" % plugin_name, ) self.write_packet(data) pkt = self._read_packet() pkt.check_error() return pkt def _get_auth_plugin_handler(self, plugin_name): plugin_class = self._auth_plugin_map.get(plugin_name) if not plugin_class and isinstance(plugin_name, bytes): plugin_class = self._auth_plugin_map.get(plugin_name.decode("ascii")) if plugin_class: try: handler = plugin_class(self) except TypeError: raise err.OperationalError( CR.CR_AUTH_PLUGIN_CANNOT_LOAD, "Authentication plugin '%s'" " not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class), ) else: handler = None return handler # _mysql support def thread_id(self): return self.server_thread_id[0] def character_set_name(self): return self.charset def get_host_info(self): return self.host_info def get_proto_info(self): return self.protocol_version def _get_server_information(self): i = 0 packet = self._read_packet() data = packet.get_all_data() self.protocol_version = data[i] i += 1 server_end = data.find(b"\0", i) self.server_version = data[i:server_end].decode("latin1") i = server_end + 1 self.server_thread_id = struct.unpack("= i + 6: lang, stat, cap_h, salt_len = struct.unpack("= i + salt_len: # salt_len includes auth_plugin_data_part_1 and filler self.salt += data[i : i + salt_len] i += salt_len i += 1 # AUTH PLUGIN NAME may appear here. if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i: # Due to Bug#59453 the auth-plugin-name is missing the terminating # NUL-char in versions prior to 5.5.10 and 5.6.2. # ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake # didn't use version checks as mariadb is corrected and reports # earlier than those two. server_end = data.find(b"\0", i) if server_end < 0: # pragma: no cover - very specific upstream bug # not found \0 and last field so take it all self._auth_plugin_name = data[i:].decode("utf-8") else: self._auth_plugin_name = data[i:server_end].decode("utf-8") def get_server_info(self): return self.server_version Warning = err.Warning Error = err.Error InterfaceError = err.InterfaceError DatabaseError = err.DatabaseError DataError = err.DataError OperationalError = err.OperationalError IntegrityError = err.IntegrityError InternalError = err.InternalError ProgrammingError = err.ProgrammingError NotSupportedError = err.NotSupportedError class MySQLResult: def __init__(self, connection): """ :type connection: Connection """ self.connection = connection self.affected_rows = None self.insert_id = None self.server_status = None self.warning_count = 0 self.message = None self.field_count = 0 self.description = None self.rows = None self.has_next = None self.unbuffered_active = False def __del__(self): if self.unbuffered_active: self._finish_unbuffered_query() def read(self): try: first_packet = self.connection._read_packet() if first_packet.is_ok_packet(): self._read_ok_packet(first_packet) elif first_packet.is_load_local_packet(): self._read_load_local_packet(first_packet) else: self._read_result_packet(first_packet) finally: self.connection = None def init_unbuffered_query(self): """ :raise OperationalError: If the connection to the MySQL server is lost. :raise InternalError: """ self.unbuffered_active = True first_packet = self.connection._read_packet() if first_packet.is_ok_packet(): self._read_ok_packet(first_packet) self.unbuffered_active = False self.connection = None elif first_packet.is_load_local_packet(): self._read_load_local_packet(first_packet) self.unbuffered_active = False self.connection = None else: self.field_count = first_packet.read_length_encoded_integer() self._get_descriptions() # Apparently, MySQLdb picks this number because it's the maximum # value of a 64bit unsigned integer. Since we're emulating MySQLdb, # we set it to this instead of None, which would be preferred. self.affected_rows = 18446744073709551615 def _read_ok_packet(self, first_packet): ok_packet = OKPacketWrapper(first_packet) self.affected_rows = ok_packet.affected_rows self.insert_id = ok_packet.insert_id self.server_status = ok_packet.server_status self.warning_count = ok_packet.warning_count self.message = ok_packet.message self.has_next = ok_packet.has_next def _read_load_local_packet(self, first_packet): if not self.connection._local_infile: raise RuntimeError( "**WARN**: Received LOAD_LOCAL packet but local_infile option is false." ) load_packet = LoadLocalPacketWrapper(first_packet) sender = LoadLocalFile(load_packet.filename, self.connection) try: sender.send_data() except: self.connection._read_packet() # skip ok packet raise ok_packet = self.connection._read_packet() if ( not ok_packet.is_ok_packet() ): # pragma: no cover - upstream induced protocol error raise err.OperationalError( CR.CR_COMMANDS_OUT_OF_SYNC, "Commands Out of Sync", ) self._read_ok_packet(ok_packet) def _check_packet_is_eof(self, packet): if not packet.is_eof_packet(): return False # TODO: Support CLIENT.DEPRECATE_EOF # 1) Add DEPRECATE_EOF to CAPABILITIES # 2) Mask CAPABILITIES with server_capabilities # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper instead of EOFPacketWrapper wp = EOFPacketWrapper(packet) self.warning_count = wp.warning_count self.has_next = wp.has_next return True def _read_result_packet(self, first_packet): self.field_count = first_packet.read_length_encoded_integer() self._get_descriptions() self._read_rowdata_packet() def _read_rowdata_packet_unbuffered(self): # Check if in an active query if not self.unbuffered_active: return # EOF packet = self.connection._read_packet() if self._check_packet_is_eof(packet): self.unbuffered_active = False self.connection = None self.rows = None return row = self._read_row_from_packet(packet) self.affected_rows = 1 self.rows = (row,) # rows should tuple of row for MySQL-python compatibility. return row def _finish_unbuffered_query(self): # After much reading on the MySQL protocol, it appears that there is, # in fact, no way to stop MySQL from sending all the data after # executing a query, so we just spin, and wait for an EOF packet. while self.unbuffered_active: packet = self.connection._read_packet() if self._check_packet_is_eof(packet): self.unbuffered_active = False self.connection = None # release reference to kill cyclic reference. def _read_rowdata_packet(self): """Read a rowdata packet for each data row in the result set.""" rows = [] while True: packet = self.connection._read_packet() if self._check_packet_is_eof(packet): self.connection = None # release reference to kill cyclic reference. break rows.append(self._read_row_from_packet(packet)) self.affected_rows = len(rows) self.rows = tuple(rows) def _read_row_from_packet(self, packet): row = [] for encoding, converter in self.converters: try: data = packet.read_length_coded_string() except IndexError: # No more columns in this row # See https://github.com/PyMySQL/PyMySQL/pull/434 break if data is not None: if encoding is not None: data = data.decode(encoding) if DEBUG: print("DEBUG: DATA = ", data) if converter is not None: data = converter(data) row.append(data) return tuple(row) def _get_descriptions(self): """Read a column descriptor packet for each column in the result.""" self.fields = [] self.converters = [] use_unicode = self.connection.use_unicode conn_encoding = self.connection.encoding description = [] for i in range(self.field_count): field = self.connection._read_packet(FieldDescriptorPacket) self.fields.append(field) description.append(field.description()) field_type = field.type_code if use_unicode: if field_type == FIELD_TYPE.JSON: # When SELECT from JSON column: charset = binary # When SELECT CAST(... AS JSON): charset = connection encoding # This behavior is different from TEXT / BLOB. # We should decode result by connection encoding regardless charsetnr. # See https://github.com/PyMySQL/PyMySQL/issues/488 encoding = conn_encoding # SELECT CAST(... AS JSON) elif field_type in TEXT_TYPES: if field.charsetnr == 63: # binary # TEXTs with charset=binary means BINARY types. encoding = None else: encoding = conn_encoding else: # Integers, Dates and Times, and other basic data is encoded in ascii encoding = "ascii" else: encoding = None converter = self.connection.decoders.get(field_type) if converter is converters.through: converter = None if DEBUG: print(f"DEBUG: field={field}, converter={converter}") self.converters.append((encoding, converter)) eof_packet = self.connection._read_packet() assert eof_packet.is_eof_packet(), "Protocol error, expecting EOF" self.description = tuple(description) class LoadLocalFile: def __init__(self, filename, connection): self.filename = filename self.connection = connection def send_data(self): """Send data packets from the local file to the server""" if not self.connection._sock: raise err.InterfaceError(0, "") conn = self.connection try: with open(self.filename, "rb") as open_file: packet_size = min( conn.max_allowed_packet, 16 * 1024 ) # 16KB is efficient enough while True: chunk = open_file.read(packet_size) if not chunk: break conn.write_packet(chunk) except IOError: raise err.OperationalError( ER.FILE_NOT_FOUND, f"Can't find file '{self.filename}'", ) finally: # send the empty packet to signify we are done sending data conn.write_packet(b"")