import os import sys from io import StringIO from django.apps import apps from django.conf import settings from django.core import serializers from django.db import router from django.db.transaction import atomic from django.utils.module_loading import import_string # The prefix to put on the default database name when creating # the test database. TEST_DATABASE_PREFIX = "test_" class BaseDatabaseCreation: """ Encapsulate backend-specific differences pertaining to creation and destruction of the test database. """ def __init__(self, connection): self.connection = connection def _nodb_cursor(self): return self.connection._nodb_cursor() def log(self, msg): sys.stderr.write(msg + os.linesep) def create_test_db( self, verbosity=1, autoclobber=False, serialize=True, keepdb=False ): """ Create a test database, prompting the user for confirmation if the database already exists. Return the name of the test database created. """ # Don't import django.core.management if it isn't needed. from django.core.management import call_command test_database_name = self._get_test_db_name() if verbosity >= 1: action = "Creating" if keepdb: action = "Using existing" self.log( "%s test database for alias %s..." % ( action, self._get_database_display_str(verbosity, test_database_name), ) ) # We could skip this call if keepdb is True, but we instead # give it the keepdb param. This is to handle the case # where the test DB doesn't exist, in which case we need to # create it, then just not destroy it. If we instead skip # this, we will get an exception. self._create_test_db(verbosity, autoclobber, keepdb) self.connection.close() settings.DATABASES[self.connection.alias]["NAME"] = test_database_name self.connection.settings_dict["NAME"] = test_database_name try: if self.connection.settings_dict["TEST"]["MIGRATE"] is False: # Disable migrations for all apps. old_migration_modules = settings.MIGRATION_MODULES settings.MIGRATION_MODULES = { app.label: None for app in apps.get_app_configs() } # We report migrate messages at one level lower than that # requested. This ensures we don't get flooded with messages during # testing (unless you really ask to be flooded). call_command( "migrate", verbosity=max(verbosity - 1, 0), interactive=False, database=self.connection.alias, run_syncdb=True, ) finally: if self.connection.settings_dict["TEST"]["MIGRATE"] is False: settings.MIGRATION_MODULES = old_migration_modules # We then serialize the current state of the database into a string # and store it on the connection. This slightly horrific process is so people # who are testing on databases without transactions or who are using # a TransactionTestCase still get a clean database on every test run. if serialize: self.connection._test_serialized_contents = self.serialize_db_to_string() call_command("createcachetable", database=self.connection.alias) # Ensure a connection for the side effect of initializing the test database. self.connection.ensure_connection() if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true": self.mark_expected_failures_and_skips() return test_database_name def set_as_test_mirror(self, primary_settings_dict): """ Set this database up to be used in testing as a mirror of a primary database whose settings are given. """ self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"] def serialize_db_to_string(self): """ Serialize all data in the database into a JSON string. Designed only for test runner usage; will not handle large amounts of data. """ # Iteratively return every object for all models to serialize. def get_objects(): from django.db.migrations.loader import MigrationLoader loader = MigrationLoader(self.connection) for app_config in apps.get_app_configs(): if ( app_config.models_module is not None and app_config.label in loader.migrated_apps and app_config.name not in settings.TEST_NON_SERIALIZED_APPS ): for model in app_config.get_models(): if model._meta.can_migrate( self.connection ) and router.allow_migrate_model(self.connection.alias, model): queryset = model._base_manager.using( self.connection.alias, ).order_by(model._meta.pk.name) yield from queryset.iterator() # Serialize to a string out = StringIO() serializers.serialize("json", get_objects(), indent=None, stream=out) return out.getvalue() def deserialize_db_from_string(self, data): """ Reload the database with data from a string generated by the serialize_db_to_string() method. """ data = StringIO(data) table_names = set() # Load data in a transaction to handle forward references and cycles. with atomic(using=self.connection.alias): # Disable constraint checks, because some databases (MySQL) doesn't # support deferred checks. with self.connection.constraint_checks_disabled(): for obj in serializers.deserialize( "json", data, using=self.connection.alias ): obj.save() table_names.add(obj.object.__class__._meta.db_table) # Manually check for any invalid keys that might have been added, # because constraint checks were disabled. self.connection.check_constraints(table_names=table_names) def _get_database_display_str(self, verbosity, database_name): """ Return display string for a database for use in various actions. """ return "'%s'%s" % ( self.connection.alias, (" ('%s')" % database_name) if verbosity >= 2 else "", ) def _get_test_db_name(self): """ Internal implementation - return the name of the test DB that will be created. Only useful when called from create_test_db() and _create_test_db() and when no external munging is done with the 'NAME' settings. """ if self.connection.settings_dict["TEST"]["NAME"]: return self.connection.settings_dict["TEST"]["NAME"] return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"] def _execute_create_test_db(self, cursor, parameters, keepdb=False): cursor.execute("CREATE DATABASE %(dbname)s %(suffix)s" % parameters) def _create_test_db(self, verbosity, autoclobber, keepdb=False): """ Internal implementation - create the test db tables. """ test_database_name = self._get_test_db_name() test_db_params = { "dbname": self.connection.ops.quote_name(test_database_name), "suffix": self.sql_table_creation_suffix(), } # Create the test database and connect to it. with self._nodb_cursor() as cursor: try: self._execute_create_test_db(cursor, test_db_params, keepdb) except Exception as e: # if we want to keep the db, then no need to do any of the below, # just return and skip it all. if keepdb: return test_database_name self.log("Got an error creating the test database: %s" % e) if not autoclobber: confirm = input( "Type 'yes' if you would like to try deleting the test " "database '%s', or 'no' to cancel: " % test_database_name ) if autoclobber or confirm == "yes": try: if verbosity >= 1: self.log( "Destroying old test database for alias %s..." % ( self._get_database_display_str( verbosity, test_database_name ), ) ) cursor.execute("DROP DATABASE %(dbname)s" % test_db_params) self._execute_create_test_db(cursor, test_db_params, keepdb) except Exception as e: self.log("Got an error recreating the test database: %s" % e) sys.exit(2) else: self.log("Tests cancelled.") sys.exit(1) return test_database_name def clone_test_db(self, suffix, verbosity=1, autoclobber=False, keepdb=False): """ Clone a test database. """ source_database_name = self.connection.settings_dict["NAME"] if verbosity >= 1: action = "Cloning test database" if keepdb: action = "Using existing clone" self.log( "%s for alias %s..." % ( action, self._get_database_display_str(verbosity, source_database_name), ) ) # We could skip this call if keepdb is True, but we instead # give it the keepdb param. See create_test_db for details. self._clone_test_db(suffix, verbosity, keepdb) def get_test_db_clone_settings(self, suffix): """ Return a modified connection settings dict for the n-th clone of a DB. """ # When this function is called, the test database has been created # already and its name has been copied to settings_dict['NAME'] so # we don't need to call _get_test_db_name. orig_settings_dict = self.connection.settings_dict return { **orig_settings_dict, "NAME": "{}_{}".format(orig_settings_dict["NAME"], suffix), } def _clone_test_db(self, suffix, verbosity, keepdb=False): """ Internal implementation - duplicate the test db tables. """ raise NotImplementedError( "The database backend doesn't support cloning databases. " "Disable the option to run tests in parallel processes." ) def destroy_test_db( self, old_database_name=None, verbosity=1, keepdb=False, suffix=None ): """ Destroy a test database, prompting the user for confirmation if the database already exists. """ self.connection.close() if suffix is None: test_database_name = self.connection.settings_dict["NAME"] else: test_database_name = self.get_test_db_clone_settings(suffix)["NAME"] if verbosity >= 1: action = "Destroying" if keepdb: action = "Preserving" self.log( "%s test database for alias %s..." % ( action, self._get_database_display_str(verbosity, test_database_name), ) ) # if we want to preserve the database # skip the actual destroying piece. if not keepdb: self._destroy_test_db(test_database_name, verbosity) # Restore the original database name if old_database_name is not None: settings.DATABASES[self.connection.alias]["NAME"] = old_database_name self.connection.settings_dict["NAME"] = old_database_name def _destroy_test_db(self, test_database_name, verbosity): """ Internal implementation - remove the test db tables. """ # Remove the test database to clean up after # ourselves. Connect to the previous database (not the test database) # to do so, because it's not allowed to delete a database while being # connected to it. with self._nodb_cursor() as cursor: cursor.execute( "DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name) ) def mark_expected_failures_and_skips(self): """ Mark tests in Django's test suite which are expected failures on this database and test which should be skipped on this database. """ # Only load unittest if we're actually testing. from unittest import expectedFailure, skip for test_name in self.connection.features.django_test_expected_failures: test_case_name, _, test_method_name = test_name.rpartition(".") test_app = test_name.split(".")[0] # Importing a test app that isn't installed raises RuntimeError. if test_app in settings.INSTALLED_APPS: test_case = import_string(test_case_name) test_method = getattr(test_case, test_method_name) setattr(test_case, test_method_name, expectedFailure(test_method)) for reason, tests in self.connection.features.django_test_skips.items(): for test_name in tests: test_case_name, _, test_method_name = test_name.rpartition(".") test_app = test_name.split(".")[0] # Importing a test app that isn't installed raises RuntimeError. if test_app in settings.INSTALLED_APPS: test_case = import_string(test_case_name) test_method = getattr(test_case, test_method_name) setattr(test_case, test_method_name, skip(reason)(test_method)) def sql_table_creation_suffix(self): """ SQL to append to the end of the test table creation statements. """ return "" def test_db_signature(self): """ Return a tuple with elements of self.connection.settings_dict (a DATABASES setting value) that uniquely identify a database accordingly to the RDBMS particularities. """ settings_dict = self.connection.settings_dict return ( settings_dict["HOST"], settings_dict["PORT"], settings_dict["ENGINE"], self._get_test_db_name(), ) def setup_worker_connection(self, _worker_id): settings_dict = self.get_test_db_clone_settings(str(_worker_id)) # connection.settings_dict must be updated in place for changes to be # reflected in django.db.connections. If the following line assigned # connection.settings_dict = settings_dict, new threads would connect # to the default database instead of the appropriate clone. self.connection.settings_dict.update(settings_dict) self.connection.close()