from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.backends.ddl_references import IndexColumns from django.db.backends.postgresql.psycopg_any import sql from django.db.backends.utils import strip_quotes class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # Setting all constraints to IMMEDIATE to allow changing data in the same # transaction. sql_update_with_default = ( "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL" "; SET CONSTRAINTS ALL IMMEDIATE" ) sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s" sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE" sql_create_index = ( "CREATE INDEX %(name)s ON %(table)s%(using)s " "(%(columns)s)%(include)s%(extra)s%(condition)s" ) sql_create_index_concurrently = ( "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s " "(%(columns)s)%(include)s%(extra)s%(condition)s" ) sql_delete_index = "DROP INDEX IF EXISTS %(name)s" sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s" # Setting the constraint to IMMEDIATE to allow changing data in the same # transaction. sql_create_column_inline_fk = ( "CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s" "; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE" ) # Setting the constraint to IMMEDIATE runs any deferred checks to allow # dropping it in the same transaction. sql_delete_fk = ( "SET CONSTRAINTS %(name)s IMMEDIATE; " "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" ) sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)" def execute(self, sql, params=()): # Merge the query client-side, as PostgreSQL won't do it server-side. if params is None: return super().execute(sql, params) sql = self.connection.ops.compose_sql(str(sql), params) # Don't let the superclass touch anything. return super().execute(sql, None) sql_add_identity = ( "ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD " "GENERATED BY DEFAULT AS IDENTITY" ) sql_drop_indentity = ( "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS" ) def quote_value(self, value): if isinstance(value, str): value = value.replace("%", "%%") return sql.quote(value, self.connection.connection) def _field_indexes_sql(self, model, field): output = super()._field_indexes_sql(model, field) like_index_statement = self._create_like_index_sql(model, field) if like_index_statement is not None: output.append(like_index_statement) return output def _field_data_type(self, field): if field.is_relation: return field.rel_db_type(self.connection) return self.connection.data_types.get( field.get_internal_type(), field.db_type(self.connection), ) def _field_base_data_types(self, field): # Yield base data types for array fields. if field.base_field.get_internal_type() == "ArrayField": yield from self._field_base_data_types(field.base_field) else: yield self._field_data_type(field.base_field) def _create_like_index_sql(self, model, field): """ Return the statement to create an index with varchar operator pattern when the column type is 'varchar' or 'text', otherwise return None. """ db_type = field.db_type(connection=self.connection) if db_type is not None and (field.db_index or field.unique): # Fields with database column types of `varchar` and `text` need # a second index that specifies their operator class, which is # needed when performing correct LIKE queries outside the # C locale. See #12234. # # The same doesn't apply to array fields such as varchar[size] # and text[size], so skip them. if "[" in db_type: return None # Non-deterministic collations on Postgresql don't support indexes # for operator classes varchar_pattern_ops/text_pattern_ops. if getattr(field, "db_collation", None): return None if db_type.startswith("varchar"): return self._create_index_sql( model, fields=[field], suffix="_like", opclasses=["varchar_pattern_ops"], ) elif db_type.startswith("text"): return self._create_index_sql( model, fields=[field], suffix="_like", opclasses=["text_pattern_ops"], ) return None def _using_sql(self, new_field, old_field): using_sql = " USING %(column)s::%(type)s" new_internal_type = new_field.get_internal_type() old_internal_type = old_field.get_internal_type() if new_internal_type == "ArrayField" and new_internal_type == old_internal_type: # Compare base data types for array fields. if list(self._field_base_data_types(old_field)) != list( self._field_base_data_types(new_field) ): return using_sql elif self._field_data_type(old_field) != self._field_data_type(new_field): return using_sql return "" def _get_sequence_name(self, table, column): with self.connection.cursor() as cursor: for sequence in self.connection.introspection.get_sequences(cursor, table): if sequence["column"] == column: return sequence["name"] return None def _alter_column_type_sql( self, model, old_field, new_field, new_type, old_collation, new_collation ): # Drop indexes on varchar/text/citext columns that are changing to a # different type. old_db_params = old_field.db_parameters(connection=self.connection) old_type = old_db_params["type"] if (old_field.db_index or old_field.unique) and ( (old_type.startswith("varchar") and not new_type.startswith("varchar")) or (old_type.startswith("text") and not new_type.startswith("text")) or (old_type.startswith("citext") and not new_type.startswith("citext")) ): index_name = self._create_index_name( model._meta.db_table, [old_field.column], suffix="_like" ) self.execute(self._delete_index_sql(model, index_name)) self.sql_alter_column_type = ( "ALTER COLUMN %(column)s TYPE %(type)s%(collation)s" ) # Cast when data type changed. if using_sql := self._using_sql(new_field, old_field): self.sql_alter_column_type += using_sql new_internal_type = new_field.get_internal_type() old_internal_type = old_field.get_internal_type() # Make ALTER TYPE with IDENTITY make sense. table = strip_quotes(model._meta.db_table) auto_field_types = { "AutoField", "BigAutoField", "SmallAutoField", } old_is_auto = old_internal_type in auto_field_types new_is_auto = new_internal_type in auto_field_types if new_is_auto and not old_is_auto: column = strip_quotes(new_field.column) return ( ( self.sql_alter_column_type % { "column": self.quote_name(column), "type": new_type, "collation": "", }, [], ), [ ( self.sql_add_identity % { "table": self.quote_name(table), "column": self.quote_name(column), }, [], ), ], ) elif old_is_auto and not new_is_auto: # Drop IDENTITY if exists (pre-Django 4.1 serial columns don't have # it). self.execute( self.sql_drop_indentity % { "table": self.quote_name(table), "column": self.quote_name(strip_quotes(new_field.column)), } ) column = strip_quotes(new_field.column) fragment, _ = super()._alter_column_type_sql( model, old_field, new_field, new_type, old_collation, new_collation ) # Drop the sequence if exists (Django 4.1+ identity columns don't # have it). other_actions = [] if sequence_name := self._get_sequence_name(table, column): other_actions = [ ( self.sql_delete_sequence % { "sequence": self.quote_name(sequence_name), }, [], ) ] return fragment, other_actions elif new_is_auto and old_is_auto and old_internal_type != new_internal_type: fragment, _ = super()._alter_column_type_sql( model, old_field, new_field, new_type, old_collation, new_collation ) column = strip_quotes(new_field.column) db_types = { "AutoField": "integer", "BigAutoField": "bigint", "SmallAutoField": "smallint", } # Alter the sequence type if exists (Django 4.1+ identity columns # don't have it). other_actions = [] if sequence_name := self._get_sequence_name(table, column): other_actions = [ ( self.sql_alter_sequence_type % { "sequence": self.quote_name(sequence_name), "type": db_types[new_internal_type], }, [], ), ] return fragment, other_actions else: return super()._alter_column_type_sql( model, old_field, new_field, new_type, old_collation, new_collation ) def _alter_column_collation_sql( self, model, new_field, new_type, new_collation, old_field ): sql = self.sql_alter_column_collate # Cast when data type changed. if using_sql := self._using_sql(new_field, old_field): sql += using_sql return ( sql % { "column": self.quote_name(new_field.column), "type": new_type, "collation": " " + self._collate_sql(new_collation) if new_collation else "", }, [], ) def _alter_field( self, model, old_field, new_field, old_type, new_type, old_db_params, new_db_params, strict=False, ): super()._alter_field( model, old_field, new_field, old_type, new_type, old_db_params, new_db_params, strict, ) # Added an index? Create any PostgreSQL-specific indexes. if (not (old_field.db_index or old_field.unique) and new_field.db_index) or ( not old_field.unique and new_field.unique ): like_index_statement = self._create_like_index_sql(model, new_field) if like_index_statement is not None: self.execute(like_index_statement) # Removed an index? Drop any PostgreSQL-specific indexes. if old_field.unique and not (new_field.db_index or new_field.unique): index_to_remove = self._create_index_name( model._meta.db_table, [old_field.column], suffix="_like" ) self.execute(self._delete_index_sql(model, index_to_remove)) def _index_columns(self, table, columns, col_suffixes, opclasses): if opclasses: return IndexColumns( table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses, ) return super()._index_columns(table, columns, col_suffixes, opclasses) def add_index(self, model, index, concurrently=False): self.execute( index.create_sql(model, self, concurrently=concurrently), params=None ) def remove_index(self, model, index, concurrently=False): self.execute(index.remove_sql(model, self, concurrently=concurrently)) def _delete_index_sql(self, model, name, sql=None, concurrently=False): sql = ( self.sql_delete_index_concurrently if concurrently else self.sql_delete_index ) return super()._delete_index_sql(model, name, sql) def _create_index_sql( self, model, *, fields=None, name=None, suffix="", using="", db_tablespace=None, col_suffixes=(), sql=None, opclasses=(), condition=None, concurrently=False, include=None, expressions=None, ): sql = sql or ( self.sql_create_index if not concurrently else self.sql_create_index_concurrently ) return super()._create_index_sql( model, fields=fields, name=name, suffix=suffix, using=using, db_tablespace=db_tablespace, col_suffixes=col_suffixes, sql=sql, opclasses=opclasses, condition=condition, include=include, expressions=expressions, )