382 lines
15 KiB
Python
382 lines
15 KiB
Python
|
import os
|
||
|
import sys
|
||
|
from io import StringIO
|
||
|
|
||
|
from django.apps import apps
|
||
|
from django.conf import settings
|
||
|
from django.core import serializers
|
||
|
from django.db import router
|
||
|
from django.db.transaction import atomic
|
||
|
from django.utils.module_loading import import_string
|
||
|
|
||
|
# The prefix to put on the default database name when creating
|
||
|
# the test database.
|
||
|
TEST_DATABASE_PREFIX = "test_"
|
||
|
|
||
|
|
||
|
class BaseDatabaseCreation:
|
||
|
"""
|
||
|
Encapsulate backend-specific differences pertaining to creation and
|
||
|
destruction of the test database.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, connection):
|
||
|
self.connection = connection
|
||
|
|
||
|
def _nodb_cursor(self):
|
||
|
return self.connection._nodb_cursor()
|
||
|
|
||
|
def log(self, msg):
|
||
|
sys.stderr.write(msg + os.linesep)
|
||
|
|
||
|
def create_test_db(
|
||
|
self, verbosity=1, autoclobber=False, serialize=True, keepdb=False
|
||
|
):
|
||
|
"""
|
||
|
Create a test database, prompting the user for confirmation if the
|
||
|
database already exists. Return the name of the test database created.
|
||
|
"""
|
||
|
# Don't import django.core.management if it isn't needed.
|
||
|
from django.core.management import call_command
|
||
|
|
||
|
test_database_name = self._get_test_db_name()
|
||
|
|
||
|
if verbosity >= 1:
|
||
|
action = "Creating"
|
||
|
if keepdb:
|
||
|
action = "Using existing"
|
||
|
|
||
|
self.log(
|
||
|
"%s test database for alias %s..."
|
||
|
% (
|
||
|
action,
|
||
|
self._get_database_display_str(verbosity, test_database_name),
|
||
|
)
|
||
|
)
|
||
|
|
||
|
# We could skip this call if keepdb is True, but we instead
|
||
|
# give it the keepdb param. This is to handle the case
|
||
|
# where the test DB doesn't exist, in which case we need to
|
||
|
# create it, then just not destroy it. If we instead skip
|
||
|
# this, we will get an exception.
|
||
|
self._create_test_db(verbosity, autoclobber, keepdb)
|
||
|
|
||
|
self.connection.close()
|
||
|
settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
|
||
|
self.connection.settings_dict["NAME"] = test_database_name
|
||
|
|
||
|
try:
|
||
|
if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
|
||
|
# Disable migrations for all apps.
|
||
|
old_migration_modules = settings.MIGRATION_MODULES
|
||
|
settings.MIGRATION_MODULES = {
|
||
|
app.label: None for app in apps.get_app_configs()
|
||
|
}
|
||
|
# We report migrate messages at one level lower than that
|
||
|
# requested. This ensures we don't get flooded with messages during
|
||
|
# testing (unless you really ask to be flooded).
|
||
|
call_command(
|
||
|
"migrate",
|
||
|
verbosity=max(verbosity - 1, 0),
|
||
|
interactive=False,
|
||
|
database=self.connection.alias,
|
||
|
run_syncdb=True,
|
||
|
)
|
||
|
finally:
|
||
|
if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
|
||
|
settings.MIGRATION_MODULES = old_migration_modules
|
||
|
|
||
|
# We then serialize the current state of the database into a string
|
||
|
# and store it on the connection. This slightly horrific process is so people
|
||
|
# who are testing on databases without transactions or who are using
|
||
|
# a TransactionTestCase still get a clean database on every test run.
|
||
|
if serialize:
|
||
|
self.connection._test_serialized_contents = self.serialize_db_to_string()
|
||
|
|
||
|
call_command("createcachetable", database=self.connection.alias)
|
||
|
|
||
|
# Ensure a connection for the side effect of initializing the test database.
|
||
|
self.connection.ensure_connection()
|
||
|
|
||
|
if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true":
|
||
|
self.mark_expected_failures_and_skips()
|
||
|
|
||
|
return test_database_name
|
||
|
|
||
|
def set_as_test_mirror(self, primary_settings_dict):
|
||
|
"""
|
||
|
Set this database up to be used in testing as a mirror of a primary
|
||
|
database whose settings are given.
|
||
|
"""
|
||
|
self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]
|
||
|
|
||
|
def serialize_db_to_string(self):
|
||
|
"""
|
||
|
Serialize all data in the database into a JSON string.
|
||
|
Designed only for test runner usage; will not handle large
|
||
|
amounts of data.
|
||
|
"""
|
||
|
|
||
|
# Iteratively return every object for all models to serialize.
|
||
|
def get_objects():
|
||
|
from django.db.migrations.loader import MigrationLoader
|
||
|
|
||
|
loader = MigrationLoader(self.connection)
|
||
|
for app_config in apps.get_app_configs():
|
||
|
if (
|
||
|
app_config.models_module is not None
|
||
|
and app_config.label in loader.migrated_apps
|
||
|
and app_config.name not in settings.TEST_NON_SERIALIZED_APPS
|
||
|
):
|
||
|
for model in app_config.get_models():
|
||
|
if model._meta.can_migrate(
|
||
|
self.connection
|
||
|
) and router.allow_migrate_model(self.connection.alias, model):
|
||
|
queryset = model._base_manager.using(
|
||
|
self.connection.alias,
|
||
|
).order_by(model._meta.pk.name)
|
||
|
yield from queryset.iterator()
|
||
|
|
||
|
# Serialize to a string
|
||
|
out = StringIO()
|
||
|
serializers.serialize("json", get_objects(), indent=None, stream=out)
|
||
|
return out.getvalue()
|
||
|
|
||
|
def deserialize_db_from_string(self, data):
|
||
|
"""
|
||
|
Reload the database with data from a string generated by
|
||
|
the serialize_db_to_string() method.
|
||
|
"""
|
||
|
data = StringIO(data)
|
||
|
table_names = set()
|
||
|
# Load data in a transaction to handle forward references and cycles.
|
||
|
with atomic(using=self.connection.alias):
|
||
|
# Disable constraint checks, because some databases (MySQL) doesn't
|
||
|
# support deferred checks.
|
||
|
with self.connection.constraint_checks_disabled():
|
||
|
for obj in serializers.deserialize(
|
||
|
"json", data, using=self.connection.alias
|
||
|
):
|
||
|
obj.save()
|
||
|
table_names.add(obj.object.__class__._meta.db_table)
|
||
|
# Manually check for any invalid keys that might have been added,
|
||
|
# because constraint checks were disabled.
|
||
|
self.connection.check_constraints(table_names=table_names)
|
||
|
|
||
|
def _get_database_display_str(self, verbosity, database_name):
|
||
|
"""
|
||
|
Return display string for a database for use in various actions.
|
||
|
"""
|
||
|
return "'%s'%s" % (
|
||
|
self.connection.alias,
|
||
|
(" ('%s')" % database_name) if verbosity >= 2 else "",
|
||
|
)
|
||
|
|
||
|
def _get_test_db_name(self):
|
||
|
"""
|
||
|
Internal implementation - return the name of the test DB that will be
|
||
|
created. Only useful when called from create_test_db() and
|
||
|
_create_test_db() and when no external munging is done with the 'NAME'
|
||
|
settings.
|
||
|
"""
|
||
|
if self.connection.settings_dict["TEST"]["NAME"]:
|
||
|
return self.connection.settings_dict["TEST"]["NAME"]
|
||
|
return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"]
|
||
|
|
||
|
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
|
||
|
cursor.execute("CREATE DATABASE %(dbname)s %(suffix)s" % parameters)
|
||
|
|
||
|
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
|
||
|
"""
|
||
|
Internal implementation - create the test db tables.
|
||
|
"""
|
||
|
test_database_name = self._get_test_db_name()
|
||
|
test_db_params = {
|
||
|
"dbname": self.connection.ops.quote_name(test_database_name),
|
||
|
"suffix": self.sql_table_creation_suffix(),
|
||
|
}
|
||
|
# Create the test database and connect to it.
|
||
|
with self._nodb_cursor() as cursor:
|
||
|
try:
|
||
|
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||
|
except Exception as e:
|
||
|
# if we want to keep the db, then no need to do any of the below,
|
||
|
# just return and skip it all.
|
||
|
if keepdb:
|
||
|
return test_database_name
|
||
|
|
||
|
self.log("Got an error creating the test database: %s" % e)
|
||
|
if not autoclobber:
|
||
|
confirm = input(
|
||
|
"Type 'yes' if you would like to try deleting the test "
|
||
|
"database '%s', or 'no' to cancel: " % test_database_name
|
||
|
)
|
||
|
if autoclobber or confirm == "yes":
|
||
|
try:
|
||
|
if verbosity >= 1:
|
||
|
self.log(
|
||
|
"Destroying old test database for alias %s..."
|
||
|
% (
|
||
|
self._get_database_display_str(
|
||
|
verbosity, test_database_name
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
|
||
|
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||
|
except Exception as e:
|
||
|
self.log("Got an error recreating the test database: %s" % e)
|
||
|
sys.exit(2)
|
||
|
else:
|
||
|
self.log("Tests cancelled.")
|
||
|
sys.exit(1)
|
||
|
|
||
|
return test_database_name
|
||
|
|
||
|
def clone_test_db(self, suffix, verbosity=1, autoclobber=False, keepdb=False):
|
||
|
"""
|
||
|
Clone a test database.
|
||
|
"""
|
||
|
source_database_name = self.connection.settings_dict["NAME"]
|
||
|
|
||
|
if verbosity >= 1:
|
||
|
action = "Cloning test database"
|
||
|
if keepdb:
|
||
|
action = "Using existing clone"
|
||
|
self.log(
|
||
|
"%s for alias %s..."
|
||
|
% (
|
||
|
action,
|
||
|
self._get_database_display_str(verbosity, source_database_name),
|
||
|
)
|
||
|
)
|
||
|
|
||
|
# We could skip this call if keepdb is True, but we instead
|
||
|
# give it the keepdb param. See create_test_db for details.
|
||
|
self._clone_test_db(suffix, verbosity, keepdb)
|
||
|
|
||
|
def get_test_db_clone_settings(self, suffix):
|
||
|
"""
|
||
|
Return a modified connection settings dict for the n-th clone of a DB.
|
||
|
"""
|
||
|
# When this function is called, the test database has been created
|
||
|
# already and its name has been copied to settings_dict['NAME'] so
|
||
|
# we don't need to call _get_test_db_name.
|
||
|
orig_settings_dict = self.connection.settings_dict
|
||
|
return {
|
||
|
**orig_settings_dict,
|
||
|
"NAME": "{}_{}".format(orig_settings_dict["NAME"], suffix),
|
||
|
}
|
||
|
|
||
|
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||
|
"""
|
||
|
Internal implementation - duplicate the test db tables.
|
||
|
"""
|
||
|
raise NotImplementedError(
|
||
|
"The database backend doesn't support cloning databases. "
|
||
|
"Disable the option to run tests in parallel processes."
|
||
|
)
|
||
|
|
||
|
def destroy_test_db(
|
||
|
self, old_database_name=None, verbosity=1, keepdb=False, suffix=None
|
||
|
):
|
||
|
"""
|
||
|
Destroy a test database, prompting the user for confirmation if the
|
||
|
database already exists.
|
||
|
"""
|
||
|
self.connection.close()
|
||
|
if suffix is None:
|
||
|
test_database_name = self.connection.settings_dict["NAME"]
|
||
|
else:
|
||
|
test_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
|
||
|
|
||
|
if verbosity >= 1:
|
||
|
action = "Destroying"
|
||
|
if keepdb:
|
||
|
action = "Preserving"
|
||
|
self.log(
|
||
|
"%s test database for alias %s..."
|
||
|
% (
|
||
|
action,
|
||
|
self._get_database_display_str(verbosity, test_database_name),
|
||
|
)
|
||
|
)
|
||
|
|
||
|
# if we want to preserve the database
|
||
|
# skip the actual destroying piece.
|
||
|
if not keepdb:
|
||
|
self._destroy_test_db(test_database_name, verbosity)
|
||
|
|
||
|
# Restore the original database name
|
||
|
if old_database_name is not None:
|
||
|
settings.DATABASES[self.connection.alias]["NAME"] = old_database_name
|
||
|
self.connection.settings_dict["NAME"] = old_database_name
|
||
|
|
||
|
def _destroy_test_db(self, test_database_name, verbosity):
|
||
|
"""
|
||
|
Internal implementation - remove the test db tables.
|
||
|
"""
|
||
|
# Remove the test database to clean up after
|
||
|
# ourselves. Connect to the previous database (not the test database)
|
||
|
# to do so, because it's not allowed to delete a database while being
|
||
|
# connected to it.
|
||
|
with self._nodb_cursor() as cursor:
|
||
|
cursor.execute(
|
||
|
"DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name)
|
||
|
)
|
||
|
|
||
|
def mark_expected_failures_and_skips(self):
|
||
|
"""
|
||
|
Mark tests in Django's test suite which are expected failures on this
|
||
|
database and test which should be skipped on this database.
|
||
|
"""
|
||
|
# Only load unittest if we're actually testing.
|
||
|
from unittest import expectedFailure, skip
|
||
|
|
||
|
for test_name in self.connection.features.django_test_expected_failures:
|
||
|
test_case_name, _, test_method_name = test_name.rpartition(".")
|
||
|
test_app = test_name.split(".")[0]
|
||
|
# Importing a test app that isn't installed raises RuntimeError.
|
||
|
if test_app in settings.INSTALLED_APPS:
|
||
|
test_case = import_string(test_case_name)
|
||
|
test_method = getattr(test_case, test_method_name)
|
||
|
setattr(test_case, test_method_name, expectedFailure(test_method))
|
||
|
for reason, tests in self.connection.features.django_test_skips.items():
|
||
|
for test_name in tests:
|
||
|
test_case_name, _, test_method_name = test_name.rpartition(".")
|
||
|
test_app = test_name.split(".")[0]
|
||
|
# Importing a test app that isn't installed raises RuntimeError.
|
||
|
if test_app in settings.INSTALLED_APPS:
|
||
|
test_case = import_string(test_case_name)
|
||
|
test_method = getattr(test_case, test_method_name)
|
||
|
setattr(test_case, test_method_name, skip(reason)(test_method))
|
||
|
|
||
|
def sql_table_creation_suffix(self):
|
||
|
"""
|
||
|
SQL to append to the end of the test table creation statements.
|
||
|
"""
|
||
|
return ""
|
||
|
|
||
|
def test_db_signature(self):
|
||
|
"""
|
||
|
Return a tuple with elements of self.connection.settings_dict (a
|
||
|
DATABASES setting value) that uniquely identify a database
|
||
|
accordingly to the RDBMS particularities.
|
||
|
"""
|
||
|
settings_dict = self.connection.settings_dict
|
||
|
return (
|
||
|
settings_dict["HOST"],
|
||
|
settings_dict["PORT"],
|
||
|
settings_dict["ENGINE"],
|
||
|
self._get_test_db_name(),
|
||
|
)
|
||
|
|
||
|
def setup_worker_connection(self, _worker_id):
|
||
|
settings_dict = self.get_test_db_clone_settings(str(_worker_id))
|
||
|
# connection.settings_dict must be updated in place for changes to be
|
||
|
# reflected in django.db.connections. If the following line assigned
|
||
|
# connection.settings_dict = settings_dict, new threads would connect
|
||
|
# to the default database instead of the appropriate clone.
|
||
|
self.connection.settings_dict.update(settings_dict)
|
||
|
self.connection.close()
|