Ajoutez des fichiers projet.

This commit is contained in:
Ambulance Clerc
2021-12-18 18:43:17 +01:00
parent 3c4d48ed26
commit 46254605fc
4842 changed files with 732322 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
from django.core import signals
from django.db.utils import (
DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY, ConnectionHandler,
ConnectionRouter, DatabaseError, DataError, Error, IntegrityError,
InterfaceError, InternalError, NotSupportedError, OperationalError,
ProgrammingError,
)
from django.utils.connection import ConnectionProxy
__all__ = [
'connection', 'connections', 'router', 'DatabaseError', 'IntegrityError',
'InternalError', 'ProgrammingError', 'DataError', 'NotSupportedError',
'Error', 'InterfaceError', 'OperationalError', 'DEFAULT_DB_ALIAS',
'DJANGO_VERSION_PICKLE_KEY',
]
connections = ConnectionHandler()
router = ConnectionRouter()
# For backwards compatibility. Prefer connections['default'] instead.
connection = ConnectionProxy(connections, DEFAULT_DB_ALIAS)
# Register an event to reset saved queries when a Django request is started.
def reset_queries(**kwargs):
for conn in connections.all():
conn.queries_log.clear()
signals.request_started.connect(reset_queries)
# Register an event to reset transaction state and close connections past
# their lifetime.
def close_old_connections(**kwargs):
for conn in connections.all():
conn.close_if_unusable_or_obsolete()
signals.request_started.connect(close_old_connections)
signals.request_finished.connect(close_old_connections)

View File

@@ -0,0 +1,687 @@
import _thread
import copy
import threading
import time
import warnings
from collections import deque
from contextlib import contextmanager
try:
import zoneinfo
except ImportError:
from backports import zoneinfo
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS, DatabaseError
from django.db.backends import utils
from django.db.backends.base.validation import BaseDatabaseValidation
from django.db.backends.signals import connection_created
from django.db.transaction import TransactionManagementError
from django.db.utils import DatabaseErrorWrapper
from django.utils import timezone
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
NO_DB_ALIAS = '__no_db__'
# RemovedInDjango50Warning
def timezone_constructor(tzname):
if settings.USE_DEPRECATED_PYTZ:
import pytz
return pytz.timezone(tzname)
return zoneinfo.ZoneInfo(tzname)
class BaseDatabaseWrapper:
"""Represent a database connection."""
# Mapping of Field objects to their column types.
data_types = {}
# Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
data_types_suffix = {}
# Mapping of Field objects to their SQL for CHECK constraints.
data_type_check_constraints = {}
ops = None
vendor = 'unknown'
display_name = 'unknown'
SchemaEditorClass = None
# Classes instantiated in __init__().
client_class = None
creation_class = None
features_class = None
introspection_class = None
ops_class = None
validation_class = BaseDatabaseValidation
queries_limit = 9000
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
# Connection related attributes.
# The underlying database connection.
self.connection = None
# `settings_dict` should be a dictionary containing keys such as
# NAME, USER, etc. It's called `settings_dict` instead of `settings`
# to disambiguate it from Django settings modules.
self.settings_dict = settings_dict
self.alias = alias
# Query logging in debug mode or when explicitly enabled.
self.queries_log = deque(maxlen=self.queries_limit)
self.force_debug_cursor = False
# Transaction related attributes.
# Tracks if the connection is in autocommit mode. Per PEP 249, by
# default, it isn't.
self.autocommit = False
# Tracks if the connection is in a transaction managed by 'atomic'.
self.in_atomic_block = False
# Increment to generate unique savepoint ids.
self.savepoint_state = 0
# List of savepoints created by 'atomic'.
self.savepoint_ids = []
# Tracks if the outermost 'atomic' block should commit on exit,
# ie. if autocommit was active on entry.
self.commit_on_exit = True
# Tracks if the transaction should be rolled back to the next
# available savepoint because of an exception in an inner block.
self.needs_rollback = False
# Connection termination related attributes.
self.close_at = None
self.closed_in_transaction = False
self.errors_occurred = False
# Thread-safety related attributes.
self._thread_sharing_lock = threading.Lock()
self._thread_sharing_count = 0
self._thread_ident = _thread.get_ident()
# A list of no-argument functions to run when the transaction commits.
# Each entry is an (sids, func) tuple, where sids is a set of the
# active savepoint IDs when this function was registered.
self.run_on_commit = []
# Should we run the on-commit hooks the next time set_autocommit(True)
# is called?
self.run_commit_hooks_on_set_autocommit_on = False
# A stack of wrappers to be invoked around execute()/executemany()
# calls. Each entry is a function taking five arguments: execute, sql,
# params, many, and context. It's the function's responsibility to
# call execute(sql, params, many, context).
self.execute_wrappers = []
self.client = self.client_class(self)
self.creation = self.creation_class(self)
self.features = self.features_class(self)
self.introspection = self.introspection_class(self)
self.ops = self.ops_class(self)
self.validation = self.validation_class(self)
def ensure_timezone(self):
"""
Ensure the connection's timezone is set to `self.timezone_name` and
return whether it changed or not.
"""
return False
@cached_property
def timezone(self):
"""
Return a tzinfo of the database connection time zone.
This is only used when time zone support is enabled. When a datetime is
read from the database, it is always returned in this time zone.
When the database backend supports time zones, it doesn't matter which
time zone Django uses, as long as aware datetimes are used everywhere.
Other users connecting to the database can choose their own time zone.
When the database backend doesn't support time zones, the time zone
Django uses may be constrained by the requirements of other users of
the database.
"""
if not settings.USE_TZ:
return None
elif self.settings_dict['TIME_ZONE'] is None:
return timezone.utc
else:
return timezone_constructor(self.settings_dict['TIME_ZONE'])
@cached_property
def timezone_name(self):
"""
Name of the time zone of the database connection.
"""
if not settings.USE_TZ:
return settings.TIME_ZONE
elif self.settings_dict['TIME_ZONE'] is None:
return 'UTC'
else:
return self.settings_dict['TIME_ZONE']
@property
def queries_logged(self):
return self.force_debug_cursor or settings.DEBUG
@property
def queries(self):
if len(self.queries_log) == self.queries_log.maxlen:
warnings.warn(
"Limit for query logging exceeded, only the last {} queries "
"will be returned.".format(self.queries_log.maxlen))
return list(self.queries_log)
# ##### Backend-specific methods for creating connections and cursors #####
def get_connection_params(self):
"""Return a dict of parameters suitable for get_new_connection."""
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_connection_params() method')
def get_new_connection(self, conn_params):
"""Open a connection to the database."""
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_new_connection() method')
def init_connection_state(self):
"""Initialize the database connection settings."""
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require an init_connection_state() method')
def create_cursor(self, name=None):
"""Create a cursor. Assume that a connection is established."""
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a create_cursor() method')
# ##### Backend-specific methods for creating connections #####
@async_unsafe
def connect(self):
"""Connect to the database. Assume that the connection is closed."""
# Check for invalid configurations.
self.check_settings()
# In case the previous connection was closed while in an atomic block
self.in_atomic_block = False
self.savepoint_ids = []
self.needs_rollback = False
# Reset parameters defining when to close the connection
max_age = self.settings_dict['CONN_MAX_AGE']
self.close_at = None if max_age is None else time.monotonic() + max_age
self.closed_in_transaction = False
self.errors_occurred = False
# Establish the connection
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.set_autocommit(self.settings_dict['AUTOCOMMIT'])
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
self.run_on_commit = []
def check_settings(self):
if self.settings_dict['TIME_ZONE'] is not None and not settings.USE_TZ:
raise ImproperlyConfigured(
"Connection '%s' cannot set TIME_ZONE because USE_TZ is False."
% self.alias
)
@async_unsafe
def ensure_connection(self):
"""Guarantee that a connection to the database is established."""
if self.connection is None:
with self.wrap_database_errors:
self.connect()
# ##### Backend-specific wrappers for PEP-249 connection methods #####
def _prepare_cursor(self, cursor):
"""
Validate the connection is usable and perform database cursor wrapping.
"""
self.validate_thread_sharing()
if self.queries_logged:
wrapped_cursor = self.make_debug_cursor(cursor)
else:
wrapped_cursor = self.make_cursor(cursor)
return wrapped_cursor
def _cursor(self, name=None):
self.ensure_connection()
with self.wrap_database_errors:
return self._prepare_cursor(self.create_cursor(name))
def _commit(self):
if self.connection is not None:
with self.wrap_database_errors:
return self.connection.commit()
def _rollback(self):
if self.connection is not None:
with self.wrap_database_errors:
return self.connection.rollback()
def _close(self):
if self.connection is not None:
with self.wrap_database_errors:
return self.connection.close()
# ##### Generic wrappers for PEP-249 connection methods #####
@async_unsafe
def cursor(self):
"""Create a cursor, opening a connection if necessary."""
return self._cursor()
@async_unsafe
def commit(self):
"""Commit a transaction and reset the dirty flag."""
self.validate_thread_sharing()
self.validate_no_atomic_block()
self._commit()
# A successful commit means that the database connection works.
self.errors_occurred = False
self.run_commit_hooks_on_set_autocommit_on = True
@async_unsafe
def rollback(self):
"""Roll back a transaction and reset the dirty flag."""
self.validate_thread_sharing()
self.validate_no_atomic_block()
self._rollback()
# A successful rollback means that the database connection works.
self.errors_occurred = False
self.needs_rollback = False
self.run_on_commit = []
@async_unsafe
def close(self):
"""Close the connection to the database."""
self.validate_thread_sharing()
self.run_on_commit = []
# Don't call validate_no_atomic_block() to avoid making it difficult
# to get rid of a connection in an invalid state. The next connect()
# will reset the transaction state anyway.
if self.closed_in_transaction or self.connection is None:
return
try:
self._close()
finally:
if self.in_atomic_block:
self.closed_in_transaction = True
self.needs_rollback = True
else:
self.connection = None
# ##### Backend-specific savepoint management methods #####
def _savepoint(self, sid):
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_create_sql(sid))
def _savepoint_rollback(self, sid):
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_rollback_sql(sid))
def _savepoint_commit(self, sid):
with self.cursor() as cursor:
cursor.execute(self.ops.savepoint_commit_sql(sid))
def _savepoint_allowed(self):
# Savepoints cannot be created outside a transaction
return self.features.uses_savepoints and not self.get_autocommit()
# ##### Generic savepoint management methods #####
@async_unsafe
def savepoint(self):
"""
Create a savepoint inside the current transaction. Return an
identifier for the savepoint that will be used for the subsequent
rollback or commit. Do nothing if savepoints are not supported.
"""
if not self._savepoint_allowed():
return
thread_ident = _thread.get_ident()
tid = str(thread_ident).replace('-', '')
self.savepoint_state += 1
sid = "s%s_x%d" % (tid, self.savepoint_state)
self.validate_thread_sharing()
self._savepoint(sid)
return sid
@async_unsafe
def savepoint_rollback(self, sid):
"""
Roll back to a savepoint. Do nothing if savepoints are not supported.
"""
if not self._savepoint_allowed():
return
self.validate_thread_sharing()
self._savepoint_rollback(sid)
# Remove any callbacks registered while this savepoint was active.
self.run_on_commit = [
(sids, func) for (sids, func) in self.run_on_commit if sid not in sids
]
@async_unsafe
def savepoint_commit(self, sid):
"""
Release a savepoint. Do nothing if savepoints are not supported.
"""
if not self._savepoint_allowed():
return
self.validate_thread_sharing()
self._savepoint_commit(sid)
@async_unsafe
def clean_savepoints(self):
"""
Reset the counter used to generate unique savepoint ids in this thread.
"""
self.savepoint_state = 0
# ##### Backend-specific transaction management methods #####
def _set_autocommit(self, autocommit):
"""
Backend-specific implementation to enable or disable autocommit.
"""
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a _set_autocommit() method')
# ##### Generic transaction management methods #####
def get_autocommit(self):
"""Get the autocommit state."""
self.ensure_connection()
return self.autocommit
def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
"""
Enable or disable autocommit.
The usual way to start a transaction is to turn autocommit off.
SQLite does not properly start a transaction when disabling
autocommit. To avoid this buggy behavior and to actually enter a new
transaction, an explicit BEGIN is required. Using
force_begin_transaction_with_broken_autocommit=True will issue an
explicit BEGIN with SQLite. This option will be ignored for other
backends.
"""
self.validate_no_atomic_block()
self.ensure_connection()
start_transaction_under_autocommit = (
force_begin_transaction_with_broken_autocommit and not autocommit and
hasattr(self, '_start_transaction_under_autocommit')
)
if start_transaction_under_autocommit:
self._start_transaction_under_autocommit()
else:
self._set_autocommit(autocommit)
self.autocommit = autocommit
if autocommit and self.run_commit_hooks_on_set_autocommit_on:
self.run_and_clear_commit_hooks()
self.run_commit_hooks_on_set_autocommit_on = False
def get_rollback(self):
"""Get the "needs rollback" flag -- for *advanced use* only."""
if not self.in_atomic_block:
raise TransactionManagementError(
"The rollback flag doesn't work outside of an 'atomic' block.")
return self.needs_rollback
def set_rollback(self, rollback):
"""
Set or unset the "needs rollback" flag -- for *advanced use* only.
"""
if not self.in_atomic_block:
raise TransactionManagementError(
"The rollback flag doesn't work outside of an 'atomic' block.")
self.needs_rollback = rollback
def validate_no_atomic_block(self):
"""Raise an error if an atomic block is active."""
if self.in_atomic_block:
raise TransactionManagementError(
"This is forbidden when an 'atomic' block is active.")
def validate_no_broken_transaction(self):
if self.needs_rollback:
raise TransactionManagementError(
"An error occurred in the current transaction. You can't "
"execute queries until the end of the 'atomic' block.")
# ##### Foreign key constraints checks handling #####
@contextmanager
def constraint_checks_disabled(self):
"""
Disable foreign key constraint checking.
"""
disabled = self.disable_constraint_checking()
try:
yield
finally:
if disabled:
self.enable_constraint_checking()
def disable_constraint_checking(self):
"""
Backends can implement as needed to temporarily disable foreign key
constraint checking. Should return True if the constraints were
disabled and will need to be reenabled.
"""
return False
def enable_constraint_checking(self):
"""
Backends can implement as needed to re-enable foreign key constraint
checking.
"""
pass
def check_constraints(self, table_names=None):
"""
Backends can override this method if they can apply constraint
checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
IntegrityError if any invalid foreign key references are encountered.
"""
pass
# ##### Connection termination handling #####
def is_usable(self):
"""
Test if the database connection is usable.
This method may assume that self.connection is not None.
Actual implementations should take care not to raise exceptions
as that may prevent Django from recycling unusable connections.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require an is_usable() method")
def close_if_unusable_or_obsolete(self):
"""
Close the current connection if unrecoverable errors have occurred
or if it outlived its maximum age.
"""
if self.connection is not None:
# If the application didn't restore the original autocommit setting,
# don't take chances, drop the connection.
if self.get_autocommit() != self.settings_dict['AUTOCOMMIT']:
self.close()
return
# If an exception other than DataError or IntegrityError occurred
# since the last commit / rollback, check if the connection works.
if self.errors_occurred:
if self.is_usable():
self.errors_occurred = False
else:
self.close()
return
if self.close_at is not None and time.monotonic() >= self.close_at:
self.close()
return
# ##### Thread safety handling #####
@property
def allow_thread_sharing(self):
with self._thread_sharing_lock:
return self._thread_sharing_count > 0
def inc_thread_sharing(self):
with self._thread_sharing_lock:
self._thread_sharing_count += 1
def dec_thread_sharing(self):
with self._thread_sharing_lock:
if self._thread_sharing_count <= 0:
raise RuntimeError('Cannot decrement the thread sharing count below zero.')
self._thread_sharing_count -= 1
def validate_thread_sharing(self):
"""
Validate that the connection isn't accessed by another thread than the
one which originally created it, unless the connection was explicitly
authorized to be shared between threads (via the `inc_thread_sharing()`
method). Raise an exception if the validation fails.
"""
if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
raise DatabaseError(
"DatabaseWrapper objects created in a "
"thread can only be used in that same thread. The object "
"with alias '%s' was created in thread id %s and this is "
"thread id %s."
% (self.alias, self._thread_ident, _thread.get_ident())
)
# ##### Miscellaneous #####
def prepare_database(self):
"""
Hook to do any database check or preparation, generally called before
migrating a project or an app.
"""
pass
@cached_property
def wrap_database_errors(self):
"""
Context manager and decorator that re-throws backend-specific database
exceptions using Django's common wrappers.
"""
return DatabaseErrorWrapper(self)
def chunked_cursor(self):
"""
Return a cursor that tries to avoid caching in the database (if
supported by the database), otherwise return a regular cursor.
"""
return self.cursor()
def make_debug_cursor(self, cursor):
"""Create a cursor that logs all queries in self.queries_log."""
return utils.CursorDebugWrapper(cursor, self)
def make_cursor(self, cursor):
"""Create a cursor without debug logging."""
return utils.CursorWrapper(cursor, self)
@contextmanager
def temporary_connection(self):
"""
Context manager that ensures that a connection is established, and
if it opened one, closes it to avoid leaving a dangling connection.
This is useful for operations outside of the request-response cycle.
Provide a cursor: with self.temporary_connection() as cursor: ...
"""
must_close = self.connection is None
try:
with self.cursor() as cursor:
yield cursor
finally:
if must_close:
self.close()
@contextmanager
def _nodb_cursor(self):
"""
Return a cursor from an alternative connection to be used when there is
no need to access the main database, specifically for test db
creation/deletion. This also prevents the production database from
being exposed to potential child threads while (or after) the test
database is destroyed. Refs #10868, #17786, #16969.
"""
conn = self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS)
try:
with conn.cursor() as cursor:
yield cursor
finally:
conn.close()
def schema_editor(self, *args, **kwargs):
"""
Return a new instance of this backend's SchemaEditor.
"""
if self.SchemaEditorClass is None:
raise NotImplementedError(
'The SchemaEditorClass attribute of this database wrapper is still None')
return self.SchemaEditorClass(self, *args, **kwargs)
def on_commit(self, func):
if not callable(func):
raise TypeError("on_commit()'s callback must be a callable.")
if self.in_atomic_block:
# Transaction in progress; save for execution on commit.
self.run_on_commit.append((set(self.savepoint_ids), func))
elif not self.get_autocommit():
raise TransactionManagementError('on_commit() cannot be used in manual transaction management')
else:
# No transaction in progress and in autocommit mode; execute
# immediately.
func()
def run_and_clear_commit_hooks(self):
self.validate_no_atomic_block()
current_run_on_commit = self.run_on_commit
self.run_on_commit = []
while current_run_on_commit:
sids, func = current_run_on_commit.pop(0)
func()
@contextmanager
def execute_wrapper(self, wrapper):
"""
Return a context manager under which the wrapper is applied to suitable
database query executions.
"""
self.execute_wrappers.append(wrapper)
try:
yield
finally:
self.execute_wrappers.pop()
def copy(self, alias=None):
"""
Return a copy of this connection.
For tests that require two connections to the same database.
"""
settings_dict = copy.deepcopy(self.settings_dict)
if alias is None:
alias = self.alias
return type(self)(settings_dict, alias)

View File

@@ -0,0 +1,25 @@
import os
import subprocess
class BaseDatabaseClient:
"""Encapsulate backend-specific methods for opening a client shell."""
# This should be a string representing the name of the executable
# (e.g., "psql"). Subclasses must override this.
executable_name = None
def __init__(self, connection):
# connection is an instance of BaseDatabaseWrapper.
self.connection = connection
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
raise NotImplementedError(
'subclasses of BaseDatabaseClient must provide a '
'settings_to_cmd_args_env() method or override a runshell().'
)
def runshell(self, parameters):
args, env = self.settings_to_cmd_args_env(self.connection.settings_dict, parameters)
env = {**os.environ, **env} if env else None
subprocess.run(args, env=env, check=True)

View File

@@ -0,0 +1,342 @@
import os
import sys
from io import StringIO
from unittest import expectedFailure, skip
from django.apps import apps
from django.conf import settings
from django.core import serializers
from django.db import router
from django.db.transaction import atomic
from django.utils.module_loading import import_string
# The prefix to put on the default database name when creating
# the test database.
TEST_DATABASE_PREFIX = 'test_'
class BaseDatabaseCreation:
"""
Encapsulate backend-specific differences pertaining to creation and
destruction of the test database.
"""
def __init__(self, connection):
self.connection = connection
def _nodb_cursor(self):
return self.connection._nodb_cursor()
def log(self, msg):
sys.stderr.write(msg + os.linesep)
def create_test_db(self, verbosity=1, autoclobber=False, serialize=True, keepdb=False):
"""
Create a test database, prompting the user for confirmation if the
database already exists. Return the name of the test database created.
"""
# Don't import django.core.management if it isn't needed.
from django.core.management import call_command
test_database_name = self._get_test_db_name()
if verbosity >= 1:
action = 'Creating'
if keepdb:
action = "Using existing"
self.log('%s test database for alias %s...' % (
action,
self._get_database_display_str(verbosity, test_database_name),
))
# We could skip this call if keepdb is True, but we instead
# give it the keepdb param. This is to handle the case
# where the test DB doesn't exist, in which case we need to
# create it, then just not destroy it. If we instead skip
# this, we will get an exception.
self._create_test_db(verbosity, autoclobber, keepdb)
self.connection.close()
settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
self.connection.settings_dict["NAME"] = test_database_name
try:
if self.connection.settings_dict['TEST']['MIGRATE'] is False:
# Disable migrations for all apps.
old_migration_modules = settings.MIGRATION_MODULES
settings.MIGRATION_MODULES = {
app.label: None
for app in apps.get_app_configs()
}
# We report migrate messages at one level lower than that
# requested. This ensures we don't get flooded with messages during
# testing (unless you really ask to be flooded).
call_command(
'migrate',
verbosity=max(verbosity - 1, 0),
interactive=False,
database=self.connection.alias,
run_syncdb=True,
)
finally:
if self.connection.settings_dict['TEST']['MIGRATE'] is False:
settings.MIGRATION_MODULES = old_migration_modules
# We then serialize the current state of the database into a string
# and store it on the connection. This slightly horrific process is so people
# who are testing on databases without transactions or who are using
# a TransactionTestCase still get a clean database on every test run.
if serialize:
self.connection._test_serialized_contents = self.serialize_db_to_string()
call_command('createcachetable', database=self.connection.alias)
# Ensure a connection for the side effect of initializing the test database.
self.connection.ensure_connection()
if os.environ.get('RUNNING_DJANGOS_TEST_SUITE') == 'true':
self.mark_expected_failures_and_skips()
return test_database_name
def set_as_test_mirror(self, primary_settings_dict):
"""
Set this database up to be used in testing as a mirror of a primary
database whose settings are given.
"""
self.connection.settings_dict['NAME'] = primary_settings_dict['NAME']
def serialize_db_to_string(self):
"""
Serialize all data in the database into a JSON string.
Designed only for test runner usage; will not handle large
amounts of data.
"""
# Iteratively return every object for all models to serialize.
def get_objects():
from django.db.migrations.loader import MigrationLoader
loader = MigrationLoader(self.connection)
for app_config in apps.get_app_configs():
if (
app_config.models_module is not None and
app_config.label in loader.migrated_apps and
app_config.name not in settings.TEST_NON_SERIALIZED_APPS
):
for model in app_config.get_models():
if (
model._meta.can_migrate(self.connection) and
router.allow_migrate_model(self.connection.alias, model)
):
queryset = model._base_manager.using(
self.connection.alias,
).order_by(model._meta.pk.name)
yield from queryset.iterator()
# Serialize to a string
out = StringIO()
serializers.serialize("json", get_objects(), indent=None, stream=out)
return out.getvalue()
def deserialize_db_from_string(self, data):
"""
Reload the database with data from a string generated by
the serialize_db_to_string() method.
"""
data = StringIO(data)
table_names = set()
# Load data in a transaction to handle forward references and cycles.
with atomic(using=self.connection.alias):
# Disable constraint checks, because some databases (MySQL) doesn't
# support deferred checks.
with self.connection.constraint_checks_disabled():
for obj in serializers.deserialize('json', data, using=self.connection.alias):
obj.save()
table_names.add(obj.object.__class__._meta.db_table)
# Manually check for any invalid keys that might have been added,
# because constraint checks were disabled.
self.connection.check_constraints(table_names=table_names)
def _get_database_display_str(self, verbosity, database_name):
"""
Return display string for a database for use in various actions.
"""
return "'%s'%s" % (
self.connection.alias,
(" ('%s')" % database_name) if verbosity >= 2 else '',
)
def _get_test_db_name(self):
"""
Internal implementation - return the name of the test DB that will be
created. Only useful when called from create_test_db() and
_create_test_db() and when no external munging is done with the 'NAME'
settings.
"""
if self.connection.settings_dict['TEST']['NAME']:
return self.connection.settings_dict['TEST']['NAME']
return TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME']
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
cursor.execute('CREATE DATABASE %(dbname)s %(suffix)s' % parameters)
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
"""
Internal implementation - create the test db tables.
"""
test_database_name = self._get_test_db_name()
test_db_params = {
'dbname': self.connection.ops.quote_name(test_database_name),
'suffix': self.sql_table_creation_suffix(),
}
# Create the test database and connect to it.
with self._nodb_cursor() as cursor:
try:
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
# if we want to keep the db, then no need to do any of the below,
# just return and skip it all.
if keepdb:
return test_database_name
self.log('Got an error creating the test database: %s' % e)
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, test_database_name),
))
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log('Got an error recreating the test database: %s' % e)
sys.exit(2)
else:
self.log('Tests cancelled.')
sys.exit(1)
return test_database_name
def clone_test_db(self, suffix, verbosity=1, autoclobber=False, keepdb=False):
"""
Clone a test database.
"""
source_database_name = self.connection.settings_dict['NAME']
if verbosity >= 1:
action = 'Cloning test database'
if keepdb:
action = 'Using existing clone'
self.log('%s for alias %s...' % (
action,
self._get_database_display_str(verbosity, source_database_name),
))
# We could skip this call if keepdb is True, but we instead
# give it the keepdb param. See create_test_db for details.
self._clone_test_db(suffix, verbosity, keepdb)
def get_test_db_clone_settings(self, suffix):
"""
Return a modified connection settings dict for the n-th clone of a DB.
"""
# When this function is called, the test database has been created
# already and its name has been copied to settings_dict['NAME'] so
# we don't need to call _get_test_db_name.
orig_settings_dict = self.connection.settings_dict
return {**orig_settings_dict, 'NAME': '{}_{}'.format(orig_settings_dict['NAME'], suffix)}
def _clone_test_db(self, suffix, verbosity, keepdb=False):
"""
Internal implementation - duplicate the test db tables.
"""
raise NotImplementedError(
"The database backend doesn't support cloning databases. "
"Disable the option to run tests in parallel processes.")
def destroy_test_db(self, old_database_name=None, verbosity=1, keepdb=False, suffix=None):
"""
Destroy a test database, prompting the user for confirmation if the
database already exists.
"""
self.connection.close()
if suffix is None:
test_database_name = self.connection.settings_dict['NAME']
else:
test_database_name = self.get_test_db_clone_settings(suffix)['NAME']
if verbosity >= 1:
action = 'Destroying'
if keepdb:
action = 'Preserving'
self.log('%s test database for alias %s...' % (
action,
self._get_database_display_str(verbosity, test_database_name),
))
# if we want to preserve the database
# skip the actual destroying piece.
if not keepdb:
self._destroy_test_db(test_database_name, verbosity)
# Restore the original database name
if old_database_name is not None:
settings.DATABASES[self.connection.alias]["NAME"] = old_database_name
self.connection.settings_dict["NAME"] = old_database_name
def _destroy_test_db(self, test_database_name, verbosity):
"""
Internal implementation - remove the test db tables.
"""
# Remove the test database to clean up after
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
with self._nodb_cursor() as cursor:
cursor.execute("DROP DATABASE %s"
% self.connection.ops.quote_name(test_database_name))
def mark_expected_failures_and_skips(self):
"""
Mark tests in Django's test suite which are expected failures on this
database and test which should be skipped on this database.
"""
for test_name in self.connection.features.django_test_expected_failures:
test_case_name, _, test_method_name = test_name.rpartition('.')
test_app = test_name.split('.')[0]
# Importing a test app that isn't installed raises RuntimeError.
if test_app in settings.INSTALLED_APPS:
test_case = import_string(test_case_name)
test_method = getattr(test_case, test_method_name)
setattr(test_case, test_method_name, expectedFailure(test_method))
for reason, tests in self.connection.features.django_test_skips.items():
for test_name in tests:
test_case_name, _, test_method_name = test_name.rpartition('.')
test_app = test_name.split('.')[0]
# Importing a test app that isn't installed raises RuntimeError.
if test_app in settings.INSTALLED_APPS:
test_case = import_string(test_case_name)
test_method = getattr(test_case, test_method_name)
setattr(test_case, test_method_name, skip(reason)(test_method))
def sql_table_creation_suffix(self):
"""
SQL to append to the end of the test table creation statements.
"""
return ''
def test_db_signature(self):
"""
Return a tuple with elements of self.connection.settings_dict (a
DATABASES setting value) that uniquely identify a database
accordingly to the RDBMS particularities.
"""
settings_dict = self.connection.settings_dict
return (
settings_dict['HOST'],
settings_dict['PORT'],
settings_dict['ENGINE'],
self._get_test_db_name(),
)

View File

@@ -0,0 +1,367 @@
from django.db import ProgrammingError
from django.utils.functional import cached_property
class BaseDatabaseFeatures:
gis_enabled = False
# Oracle can't group by LOB (large object) data types.
allows_group_by_lob = True
allows_group_by_pk = False
allows_group_by_selected_pks = False
empty_fetchmany_value = []
update_can_self_select = True
# Does the backend distinguish between '' and None?
interprets_empty_strings_as_nulls = False
# Does the backend allow inserting duplicate NULL rows in a nullable
# unique field? All core backends implement this correctly, but other
# databases such as SQL Server do not.
supports_nullable_unique_constraints = True
# Does the backend allow inserting duplicate rows when a unique_together
# constraint exists and some fields are nullable but not all of them?
supports_partially_nullable_unique_constraints = True
# Does the backend support initially deferrable unique constraints?
supports_deferrable_unique_constraints = False
can_use_chunked_reads = True
can_return_columns_from_insert = False
can_return_rows_from_bulk_insert = False
has_bulk_insert = True
uses_savepoints = True
can_release_savepoints = False
# If True, don't use integer foreign keys referring to, e.g., positive
# integer primary keys.
related_fields_match_type = False
allow_sliced_subqueries_with_in = True
has_select_for_update = False
has_select_for_update_nowait = False
has_select_for_update_skip_locked = False
has_select_for_update_of = False
has_select_for_no_key_update = False
# Does the database's SELECT FOR UPDATE OF syntax require a column rather
# than a table?
select_for_update_of_column = False
# Does the default test database allow multiple connections?
# Usually an indication that the test database is in-memory
test_db_allows_multiple_connections = True
# Can an object be saved without an explicit primary key?
supports_unspecified_pk = False
# Can a fixture contain forward references? i.e., are
# FK constraints checked at the end of transaction, or
# at the end of each save operation?
supports_forward_references = True
# Does the backend truncate names properly when they are too long?
truncates_names = False
# Is there a REAL datatype in addition to floats/doubles?
has_real_datatype = False
supports_subqueries_in_group_by = True
# Does the backend ignore unnecessary ORDER BY clauses in subqueries?
ignores_unnecessary_order_by_in_subqueries = True
# Is there a true datatype for uuid?
has_native_uuid_field = False
# Is there a true datatype for timedeltas?
has_native_duration_field = False
# Does the database driver supports same type temporal data subtraction
# by returning the type used to store duration field?
supports_temporal_subtraction = False
# Does the __regex lookup support backreferencing and grouping?
supports_regex_backreferencing = True
# Can date/datetime lookups be performed using a string?
supports_date_lookup_using_string = True
# Can datetimes with timezones be used?
supports_timezones = True
# Does the database have a copy of the zoneinfo database?
has_zoneinfo_database = True
# When performing a GROUP BY, is an ORDER BY NULL required
# to remove any ordering?
requires_explicit_null_ordering_when_grouping = False
# Does the backend order NULL values as largest or smallest?
nulls_order_largest = False
# Does the backend support NULLS FIRST and NULLS LAST in ORDER BY?
supports_order_by_nulls_modifier = True
# Does the backend orders NULLS FIRST by default?
order_by_nulls_first = False
# The database's limit on the number of query parameters.
max_query_params = None
# Can an object have an autoincrement primary key of 0?
allows_auto_pk_0 = True
# Do we need to NULL a ForeignKey out, or can the constraint check be
# deferred
can_defer_constraint_checks = False
# date_interval_sql can properly handle mixed Date/DateTime fields and timedeltas
supports_mixed_date_datetime_comparisons = True
# Does the backend support tablespaces? Default to False because it isn't
# in the SQL standard.
supports_tablespaces = False
# Does the backend reset sequences between tests?
supports_sequence_reset = True
# Can the backend introspect the default value of a column?
can_introspect_default = True
# Confirm support for introspected foreign keys
# Every database can do this reliably, except MySQL,
# which can't do it for MyISAM tables
can_introspect_foreign_keys = True
# Map fields which some backends may not be able to differentiate to the
# field it's introspected as.
introspected_field_types = {
'AutoField': 'AutoField',
'BigAutoField': 'BigAutoField',
'BigIntegerField': 'BigIntegerField',
'BinaryField': 'BinaryField',
'BooleanField': 'BooleanField',
'CharField': 'CharField',
'DurationField': 'DurationField',
'GenericIPAddressField': 'GenericIPAddressField',
'IntegerField': 'IntegerField',
'PositiveBigIntegerField': 'PositiveBigIntegerField',
'PositiveIntegerField': 'PositiveIntegerField',
'PositiveSmallIntegerField': 'PositiveSmallIntegerField',
'SmallAutoField': 'SmallAutoField',
'SmallIntegerField': 'SmallIntegerField',
'TimeField': 'TimeField',
}
# Can the backend introspect the column order (ASC/DESC) for indexes?
supports_index_column_ordering = True
# Does the backend support introspection of materialized views?
can_introspect_materialized_views = False
# Support for the DISTINCT ON clause
can_distinct_on_fields = False
# Does the backend prevent running SQL queries in broken transactions?
atomic_transactions = True
# Can we roll back DDL in a transaction?
can_rollback_ddl = False
# Does it support operations requiring references rename in a transaction?
supports_atomic_references_rename = True
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
supports_combined_alters = False
# Does it support foreign keys?
supports_foreign_keys = True
# Can it create foreign key constraints inline when adding columns?
can_create_inline_fk = True
# Does it automatically index foreign keys?
indexes_foreign_keys = True
# Does it support CHECK constraints?
supports_column_check_constraints = True
supports_table_check_constraints = True
# Does the backend support introspection of CHECK constraints?
can_introspect_check_constraints = True
# Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value})
# parameter passing? Note this can be provided by the backend even if not
# supported by the Python driver
supports_paramstyle_pyformat = True
# Does the backend require literal defaults, rather than parameterized ones?
requires_literal_defaults = False
# Does the backend require a connection reset after each material schema change?
connection_persists_old_columns = False
# What kind of error does the backend throw when accessing closed cursor?
closed_cursor_error_class = ProgrammingError
# Does 'a' LIKE 'A' match?
has_case_insensitive_like = True
# Suffix for backends that don't support "SELECT xxx;" queries.
bare_select_suffix = ''
# If NULL is implied on columns without needing to be explicitly specified
implied_column_null = False
# Does the backend support "select for update" queries with limit (and offset)?
supports_select_for_update_with_limit = True
# Does the backend ignore null expressions in GREATEST and LEAST queries unless
# every expression is null?
greatest_least_ignores_nulls = False
# Can the backend clone databases for parallel test execution?
# Defaults to False to allow third-party backends to opt-in.
can_clone_databases = False
# Does the backend consider table names with different casing to
# be equal?
ignores_table_name_case = False
# Place FOR UPDATE right after FROM clause. Used on MSSQL.
for_update_after_from = False
# Combinatorial flags
supports_select_union = True
supports_select_intersection = True
supports_select_difference = True
supports_slicing_ordering_in_compound = False
supports_parentheses_in_compound = True
# Does the database support SQL 2003 FILTER (WHERE ...) in aggregate
# expressions?
supports_aggregate_filter_clause = False
# Does the backend support indexing a TextField?
supports_index_on_text_field = True
# Does the backend support window expressions (expression OVER (...))?
supports_over_clause = False
supports_frame_range_fixed_distance = False
only_supports_unbounded_with_preceding_and_following = False
# Does the backend support CAST with precision?
supports_cast_with_precision = True
# How many second decimals does the database return when casting a value to
# a type with time?
time_cast_precision = 6
# SQL to create a procedure for use by the Django test suite. The
# functionality of the procedure isn't important.
create_test_procedure_without_params_sql = None
create_test_procedure_with_int_param_sql = None
# Does the backend support keyword parameters for cursor.callproc()?
supports_callproc_kwargs = False
# What formats does the backend EXPLAIN syntax support?
supported_explain_formats = set()
# Does DatabaseOperations.explain_query_prefix() raise ValueError if
# unknown kwargs are passed to QuerySet.explain()?
validates_explain_options = True
# Does the backend support the default parameter in lead() and lag()?
supports_default_in_lead_lag = True
# Does the backend support ignoring constraint or uniqueness errors during
# INSERT?
supports_ignore_conflicts = True
# Does this backend require casting the results of CASE expressions used
# in UPDATE statements to ensure the expression has the correct type?
requires_casted_case_in_updates = False
# Does the backend support partial indexes (CREATE INDEX ... WHERE ...)?
supports_partial_indexes = True
supports_functions_in_partial_indexes = True
# Does the backend support covering indexes (CREATE INDEX ... INCLUDE ...)?
supports_covering_indexes = False
# Does the backend support indexes on expressions?
supports_expression_indexes = True
# Does the backend treat COLLATE as an indexed expression?
collate_as_index_expression = False
# Does the database allow more than one constraint or index on the same
# field(s)?
allows_multiple_constraints_on_same_fields = True
# Does the backend support boolean expressions in SELECT and GROUP BY
# clauses?
supports_boolean_expr_in_select_clause = True
# Does the backend support JSONField?
supports_json_field = True
# Can the backend introspect a JSONField?
can_introspect_json_field = True
# Does the backend support primitives in JSONField?
supports_primitives_in_json_field = True
# Is there a true datatype for JSON?
has_native_json_field = False
# Does the backend use PostgreSQL-style JSON operators like '->'?
has_json_operators = False
# Does the backend support __contains and __contained_by lookups for
# a JSONField?
supports_json_field_contains = True
# Does value__d__contains={'f': 'g'} (without a list around the dict) match
# {'d': [{'f': 'g'}]}?
json_key_contains_list_matching_requires_list = False
# Does the backend support JSONObject() database function?
has_json_object_function = True
# Does the backend support column collations?
supports_collation_on_charfield = True
supports_collation_on_textfield = True
# Does the backend support non-deterministic collations?
supports_non_deterministic_collations = True
# Collation names for use by the Django test suite.
test_collations = {
'ci': None, # Case-insensitive.
'cs': None, # Case-sensitive.
'non_default': None, # Non-default.
'swedish_ci': None # Swedish case-insensitive.
}
# SQL template override for tests.aggregation.tests.NowUTC
test_now_utc_template = None
# A set of dotted paths to tests in Django's test suite that are expected
# to fail on this database.
django_test_expected_failures = set()
# A map of reasons to sets of dotted paths to tests in Django's test suite
# that should be skipped for this database.
django_test_skips = {}
def __init__(self, connection):
self.connection = connection
@cached_property
def supports_explaining_query_execution(self):
"""Does this backend support explaining query execution?"""
return self.connection.ops.explain_prefix is not None
@cached_property
def supports_transactions(self):
"""Confirm support for transactions."""
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
self.connection.set_autocommit(False)
cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
self.connection.rollback()
self.connection.set_autocommit(True)
cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
count, = cursor.fetchone()
cursor.execute('DROP TABLE ROLLBACK_TEST')
return count == 0
def allows_group_by_selected_pks_on_model(self, model):
if not self.allows_group_by_selected_pks:
return False
return model._meta.managed

View File

@@ -0,0 +1,194 @@
from collections import namedtuple
# Structure returned by DatabaseIntrospection.get_table_list()
TableInfo = namedtuple('TableInfo', ['name', 'type'])
# Structure returned by the DB-API cursor.description interface (PEP 249)
FieldInfo = namedtuple(
'FieldInfo',
'name type_code display_size internal_size precision scale null_ok '
'default collation'
)
class BaseDatabaseIntrospection:
"""Encapsulate backend-specific introspection utilities."""
data_types_reverse = {}
def __init__(self, connection):
self.connection = connection
def get_field_type(self, data_type, description):
"""
Hook for a database backend to use the cursor description to
match a Django field type to a database column.
For Oracle, the column data_type on its own is insufficient to
distinguish between a FloatField and IntegerField, for example.
"""
return self.data_types_reverse[data_type]
def identifier_converter(self, name):
"""
Apply a conversion to the identifier for the purposes of comparison.
The default identifier converter is for case sensitive comparison.
"""
return name
def table_names(self, cursor=None, include_views=False):
"""
Return a list of names of all tables that exist in the database.
Sort the returned table list by Python's default sorting. Do NOT use
the database's ORDER BY here to avoid subtle differences in sorting
order between databases.
"""
def get_names(cursor):
return sorted(ti.name for ti in self.get_table_list(cursor)
if include_views or ti.type == 't')
if cursor is None:
with self.connection.cursor() as cursor:
return get_names(cursor)
return get_names(cursor)
def get_table_list(self, cursor):
"""
Return an unsorted list of TableInfo named tuples of all tables and
views that exist in the database.
"""
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_table_list() method')
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface.
"""
raise NotImplementedError(
'subclasses of BaseDatabaseIntrospection may require a '
'get_table_description() method.'
)
def get_migratable_models(self):
from django.apps import apps
from django.db import router
return (
model
for app_config in apps.get_app_configs()
for model in router.get_migratable_models(app_config, self.connection.alias)
if model._meta.can_migrate(self.connection)
)
def django_table_names(self, only_existing=False, include_views=True):
"""
Return a list of all table names that have associated Django models and
are in INSTALLED_APPS.
If only_existing is True, include only the tables in the database.
"""
tables = set()
for model in self.get_migratable_models():
if not model._meta.managed:
continue
tables.add(model._meta.db_table)
tables.update(
f.m2m_db_table() for f in model._meta.local_many_to_many
if f.remote_field.through._meta.managed
)
tables = list(tables)
if only_existing:
existing_tables = set(self.table_names(include_views=include_views))
tables = [
t
for t in tables
if self.identifier_converter(t) in existing_tables
]
return tables
def installed_models(self, tables):
"""
Return a set of all models represented by the provided list of table
names.
"""
tables = set(map(self.identifier_converter, tables))
return {
m for m in self.get_migratable_models()
if self.identifier_converter(m._meta.db_table) in tables
}
def sequence_list(self):
"""
Return a list of information about all DB sequences for all models in
all apps.
"""
sequence_list = []
with self.connection.cursor() as cursor:
for model in self.get_migratable_models():
if not model._meta.managed:
continue
if model._meta.swapped:
continue
sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
for f in model._meta.local_many_to_many:
# If this is an m2m using an intermediate table,
# we don't need to reset the sequence.
if f.remote_field.through._meta.auto_created:
sequence = self.get_sequences(cursor, f.m2m_db_table())
sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
return sequence_list
def get_sequences(self, cursor, table_name, table_fields=()):
"""
Return a list of introspected sequences for table_name. Each sequence
is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
'name' key can be added if the backend supports named sequences.
"""
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_sequences() method')
def get_relations(self, cursor, table_name):
"""
Return a dictionary of
{field_name: (field_name_other_table, other_table)} representing all
relationships to the given table.
"""
raise NotImplementedError(
'subclasses of BaseDatabaseIntrospection may require a '
'get_relations() method.'
)
def get_key_columns(self, cursor, table_name):
"""
Backends can override this to return a list of:
(column_name, referenced_table_name, referenced_column_name)
for all key columns in given table.
"""
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_key_columns() method')
def get_primary_key_column(self, cursor, table_name):
"""
Return the name of the primary key column for the given table.
"""
for constraint in self.get_constraints(cursor, table_name).values():
if constraint['primary_key']:
return constraint['columns'][0]
return None
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index)
across one or more columns.
Return a dict mapping constraint names to their attributes,
where attributes is a dict with keys:
* columns: List of columns this covers
* primary_key: True if primary key, False otherwise
* unique: True if this is a unique constraint, False otherwise
* foreign_key: (table, column) of target, or None
* check: True if check constraint, False otherwise
* index: True if index, False otherwise.
* orders: The order (ASC/DESC) defined for the columns of indexes
* type: The type of the index (btree, hash, etc.)
Some backends may return special constraint names that don't exist
if they don't name constraints of a certain type (e.g. SQLite)
"""
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_constraints() method')

View File

@@ -0,0 +1,709 @@
import datetime
import decimal
from importlib import import_module
import sqlparse
from django.conf import settings
from django.db import NotSupportedError, transaction
from django.db.backends import utils
from django.utils import timezone
from django.utils.encoding import force_str
class BaseDatabaseOperations:
"""
Encapsulate backend-specific differences, such as the way a backend
performs ordering or calculates the ID of a recently-inserted row.
"""
compiler_module = "django.db.models.sql.compiler"
# Integer field safe ranges by `internal_type` as documented
# in docs/ref/models/fields.txt.
integer_field_ranges = {
'SmallIntegerField': (-32768, 32767),
'IntegerField': (-2147483648, 2147483647),
'BigIntegerField': (-9223372036854775808, 9223372036854775807),
'PositiveBigIntegerField': (0, 9223372036854775807),
'PositiveSmallIntegerField': (0, 32767),
'PositiveIntegerField': (0, 2147483647),
'SmallAutoField': (-32768, 32767),
'AutoField': (-2147483648, 2147483647),
'BigAutoField': (-9223372036854775808, 9223372036854775807),
}
set_operators = {
'union': 'UNION',
'intersection': 'INTERSECT',
'difference': 'EXCEPT',
}
# Mapping of Field.get_internal_type() (typically the model field's class
# name) to the data type to use for the Cast() function, if different from
# DatabaseWrapper.data_types.
cast_data_types = {}
# CharField data type if the max_length argument isn't provided.
cast_char_field_without_max_length = None
# Start and end points for window expressions.
PRECEDING = 'PRECEDING'
FOLLOWING = 'FOLLOWING'
UNBOUNDED_PRECEDING = 'UNBOUNDED ' + PRECEDING
UNBOUNDED_FOLLOWING = 'UNBOUNDED ' + FOLLOWING
CURRENT_ROW = 'CURRENT ROW'
# Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
explain_prefix = None
def __init__(self, connection):
self.connection = connection
self._cache = None
def autoinc_sql(self, table, column):
"""
Return any SQL needed to support auto-incrementing primary keys, or
None if no SQL is necessary.
This SQL is executed when a table is created.
"""
return None
def bulk_batch_size(self, fields, objs):
"""
Return the maximum allowed batch size for the backend. The fields
are the fields going to be inserted in the batch, the objs contains
all the objects to be inserted.
"""
return len(objs)
def cache_key_culling_sql(self):
"""
Return an SQL query that retrieves the first cache key greater than the
n smallest.
This is used by the 'db' cache backend to determine where to start
culling.
"""
return "SELECT cache_key FROM %s ORDER BY cache_key LIMIT 1 OFFSET %%s"
def unification_cast_sql(self, output_field):
"""
Given a field instance, return the SQL that casts the result of a union
to that type. The resulting string should contain a '%s' placeholder
for the expression being cast.
"""
return '%s'
def date_extract_sql(self, lookup_type, field_name):
"""
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
extracts a value from the given date field field_name.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_extract_sql() method')
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
"""
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
truncates the given date or datetime field field_name to a date object
with only the given specificity.
If `tzname` is provided, the given value is truncated in a specific
timezone.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_trunc_sql() method.')
def datetime_cast_date_sql(self, field_name, tzname):
"""
Return the SQL to cast a datetime value to date value.
"""
raise NotImplementedError(
'subclasses of BaseDatabaseOperations may require a '
'datetime_cast_date_sql() method.'
)
def datetime_cast_time_sql(self, field_name, tzname):
"""
Return the SQL to cast a datetime value to time value.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_cast_time_sql() method')
def datetime_extract_sql(self, lookup_type, field_name, tzname):
"""
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
'second', return the SQL that extracts a value from the given
datetime field field_name.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_extract_sql() method')
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
"""
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
'second', return the SQL that truncates the given datetime field
field_name to a datetime object with only the given specificity.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method')
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
"""
Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
that truncates the given time or datetime field field_name to a time
object with only the given specificity.
If `tzname` is provided, the given value is truncated in a specific
timezone.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a time_trunc_sql() method')
def time_extract_sql(self, lookup_type, field_name):
"""
Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
that extracts a value from the given time field field_name.
"""
return self.date_extract_sql(lookup_type, field_name)
def deferrable_sql(self):
"""
Return the SQL to make a constraint "initially deferred" during a
CREATE TABLE statement.
"""
return ''
def distinct_sql(self, fields, params):
"""
Return an SQL DISTINCT clause which removes duplicate rows from the
result set. If any fields are given, only check the given fields for
duplicates.
"""
if fields:
raise NotSupportedError('DISTINCT ON fields is not supported by this database backend')
else:
return ['DISTINCT'], []
def fetch_returned_insert_columns(self, cursor, returning_params):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the newly created data.
"""
return cursor.fetchone()
def field_cast_sql(self, db_type, internal_type):
"""
Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type
(e.g. 'GenericIPAddressField'), return the SQL to cast it before using
it in a WHERE statement. The resulting string should contain a '%s'
placeholder for the column being searched against.
"""
return '%s'
def force_no_ordering(self):
"""
Return a list used in the "ORDER BY" clause to force no ordering at
all. Return an empty list to include nothing in the ordering.
"""
return []
def for_update_sql(self, nowait=False, skip_locked=False, of=(), no_key=False):
"""
Return the FOR UPDATE SQL clause to lock rows for an update operation.
"""
return 'FOR%s UPDATE%s%s%s' % (
' NO KEY' if no_key else '',
' OF %s' % ', '.join(of) if of else '',
' NOWAIT' if nowait else '',
' SKIP LOCKED' if skip_locked else '',
)
def _get_limit_offset_params(self, low_mark, high_mark):
offset = low_mark or 0
if high_mark is not None:
return (high_mark - offset), offset
elif offset:
return self.connection.ops.no_limit_value(), offset
return None, offset
def limit_offset_sql(self, low_mark, high_mark):
"""Return LIMIT/OFFSET SQL clause."""
limit, offset = self._get_limit_offset_params(low_mark, high_mark)
return ' '.join(sql for sql in (
('LIMIT %d' % limit) if limit else None,
('OFFSET %d' % offset) if offset else None,
) if sql)
def last_executed_query(self, cursor, sql, params):
"""
Return a string of the query last executed by the given cursor, with
placeholders replaced with actual values.
`sql` is the raw query containing placeholders and `params` is the
sequence of parameters. These are used by default, but this method
exists for database backends to provide a better implementation
according to their own quoting schemes.
"""
# Convert params to contain string values.
def to_string(s):
return force_str(s, strings_only=True, errors='replace')
if isinstance(params, (list, tuple)):
u_params = tuple(to_string(val) for val in params)
elif params is None:
u_params = ()
else:
u_params = {to_string(k): to_string(v) for k, v in params.items()}
return "QUERY = %r - PARAMS = %r" % (sql, u_params)
def last_insert_id(self, cursor, table_name, pk_name):
"""
Given a cursor object that has just performed an INSERT statement into
a table that has an auto-incrementing ID, return the newly created ID.
`pk_name` is the name of the primary-key column.
"""
return cursor.lastrowid
def lookup_cast(self, lookup_type, internal_type=None):
"""
Return the string to use in a query when performing lookups
("contains", "like", etc.). It should contain a '%s' placeholder for
the column being searched against.
"""
return "%s"
def max_in_list_size(self):
"""
Return the maximum number of items that can be passed in a single 'IN'
list condition, or None if the backend does not impose a limit.
"""
return None
def max_name_length(self):
"""
Return the maximum length of table and column names, or None if there
is no limit.
"""
return None
def no_limit_value(self):
"""
Return the value to use for the LIMIT when we are wanting "LIMIT
infinity". Return None if the limit clause can be omitted in this case.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a no_limit_value() method')
def pk_default_value(self):
"""
Return the value to use during an INSERT statement to specify that
the field should use its default value.
"""
return 'DEFAULT'
def prepare_sql_script(self, sql):
"""
Take an SQL script that may contain multiple lines and return a list
of statements to feed to successive cursor.execute() calls.
Since few databases are able to process raw SQL scripts in a single
cursor.execute() call and PEP 249 doesn't talk about this use case,
the default implementation is conservative.
"""
return [
sqlparse.format(statement, strip_comments=True)
for statement in sqlparse.split(sql) if statement
]
def process_clob(self, value):
"""
Return the value of a CLOB column, for backends that return a locator
object that requires additional processing.
"""
return value
def return_insert_columns(self, fields):
"""
For backends that support returning columns as part of an insert query,
return the SQL and params to append to the INSERT query. The returned
fragment should contain a format string to hold the appropriate column.
"""
pass
def compiler(self, compiler_name):
"""
Return the SQLCompiler class corresponding to the given name,
in the namespace corresponding to the `compiler_module` attribute
on this backend.
"""
if self._cache is None:
self._cache = import_module(self.compiler_module)
return getattr(self._cache, compiler_name)
def quote_name(self, name):
"""
Return a quoted version of the given table, index, or column name. Do
not quote the given name if it's already been quoted.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a quote_name() method')
def regex_lookup(self, lookup_type):
"""
Return the string to use in a query when performing regular expression
lookups (using "regex" or "iregex"). It should contain a '%s'
placeholder for the column being searched against.
If the feature is not supported (or part of it is not supported), raise
NotImplementedError.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a regex_lookup() method')
def savepoint_create_sql(self, sid):
"""
Return the SQL for starting a new savepoint. Only required if the
"uses_savepoints" feature is True. The "sid" parameter is a string
for the savepoint id.
"""
return "SAVEPOINT %s" % self.quote_name(sid)
def savepoint_commit_sql(self, sid):
"""
Return the SQL for committing the given savepoint.
"""
return "RELEASE SAVEPOINT %s" % self.quote_name(sid)
def savepoint_rollback_sql(self, sid):
"""
Return the SQL for rolling back the given savepoint.
"""
return "ROLLBACK TO SAVEPOINT %s" % self.quote_name(sid)
def set_time_zone_sql(self):
"""
Return the SQL that will set the connection's time zone.
Return '' if the backend doesn't support time zones.
"""
return ''
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
"""
Return a list of SQL statements required to remove all data from
the given database tables (without actually removing the tables
themselves).
The `style` argument is a Style object as returned by either
color_style() or no_style() in django.core.management.color.
If `reset_sequences` is True, the list includes SQL statements required
to reset the sequences.
The `allow_cascade` argument determines whether truncation may cascade
to tables with foreign keys pointing the tables being truncated.
PostgreSQL requires a cascade even if these tables are empty.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations must provide an sql_flush() method')
def execute_sql_flush(self, sql_list):
"""Execute a list of SQL statements to flush the database."""
with transaction.atomic(
using=self.connection.alias,
savepoint=self.connection.features.can_rollback_ddl,
):
with self.connection.cursor() as cursor:
for sql in sql_list:
cursor.execute(sql)
def sequence_reset_by_name_sql(self, style, sequences):
"""
Return a list of the SQL statements required to reset sequences
passed in `sequences`.
The `style` argument is a Style object as returned by either
color_style() or no_style() in django.core.management.color.
"""
return []
def sequence_reset_sql(self, style, model_list):
"""
Return a list of the SQL statements required to reset sequences for
the given models.
The `style` argument is a Style object as returned by either
color_style() or no_style() in django.core.management.color.
"""
return [] # No sequence reset required by default.
def start_transaction_sql(self):
"""Return the SQL statement required to start a transaction."""
return "BEGIN;"
def end_transaction_sql(self, success=True):
"""Return the SQL statement required to end a transaction."""
if not success:
return "ROLLBACK;"
return "COMMIT;"
def tablespace_sql(self, tablespace, inline=False):
"""
Return the SQL that will be used in a query to define the tablespace.
Return '' if the backend doesn't support tablespaces.
If `inline` is True, append the SQL to a row; otherwise append it to
the entire CREATE TABLE or CREATE INDEX statement.
"""
return ''
def prep_for_like_query(self, x):
"""Prepare a value for use in a LIKE query."""
return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
# Same as prep_for_like_query(), but called for "iexact" matches, which
# need not necessarily be implemented using "LIKE" in the backend.
prep_for_iexact_query = prep_for_like_query
def validate_autopk_value(self, value):
"""
Certain backends do not accept some values for "serial" fields
(for example zero in MySQL). Raise a ValueError if the value is
invalid, otherwise return the validated value.
"""
return value
def adapt_unknown_value(self, value):
"""
Transform a value to something compatible with the backend driver.
This method only depends on the type of the value. It's designed for
cases where the target type isn't known, such as .raw() SQL queries.
As a consequence it may not work perfectly in all circumstances.
"""
if isinstance(value, datetime.datetime): # must be before date
return self.adapt_datetimefield_value(value)
elif isinstance(value, datetime.date):
return self.adapt_datefield_value(value)
elif isinstance(value, datetime.time):
return self.adapt_timefield_value(value)
elif isinstance(value, decimal.Decimal):
return self.adapt_decimalfield_value(value)
else:
return value
def adapt_datefield_value(self, value):
"""
Transform a date value to an object compatible with what is expected
by the backend driver for date columns.
"""
if value is None:
return None
return str(value)
def adapt_datetimefield_value(self, value):
"""
Transform a datetime value to an object compatible with what is expected
by the backend driver for datetime columns.
"""
if value is None:
return None
return str(value)
def adapt_timefield_value(self, value):
"""
Transform a time value to an object compatible with what is expected
by the backend driver for time columns.
"""
if value is None:
return None
if timezone.is_aware(value):
raise ValueError("Django does not support timezone-aware times.")
return str(value)
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
"""
Transform a decimal.Decimal value to an object compatible with what is
expected by the backend driver for decimal (numeric) columns.
"""
return utils.format_number(value, max_digits, decimal_places)
def adapt_ipaddressfield_value(self, value):
"""
Transform a string representation of an IP address into the expected
type for the backend driver.
"""
return value or None
def year_lookup_bounds_for_date_field(self, value, iso_year=False):
"""
Return a two-elements list with the lower and upper bound to be used
with a BETWEEN operator to query a DateField value using a year
lookup.
`value` is an int, containing the looked-up year.
If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
"""
if iso_year:
first = datetime.date.fromisocalendar(value, 1, 1)
second = (
datetime.date.fromisocalendar(value + 1, 1, 1) -
datetime.timedelta(days=1)
)
else:
first = datetime.date(value, 1, 1)
second = datetime.date(value, 12, 31)
first = self.adapt_datefield_value(first)
second = self.adapt_datefield_value(second)
return [first, second]
def year_lookup_bounds_for_datetime_field(self, value, iso_year=False):
"""
Return a two-elements list with the lower and upper bound to be used
with a BETWEEN operator to query a DateTimeField value using a year
lookup.
`value` is an int, containing the looked-up year.
If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
"""
if iso_year:
first = datetime.datetime.fromisocalendar(value, 1, 1)
second = (
datetime.datetime.fromisocalendar(value + 1, 1, 1) -
datetime.timedelta(microseconds=1)
)
else:
first = datetime.datetime(value, 1, 1)
second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
if settings.USE_TZ:
tz = timezone.get_current_timezone()
first = timezone.make_aware(first, tz)
second = timezone.make_aware(second, tz)
first = self.adapt_datetimefield_value(first)
second = self.adapt_datetimefield_value(second)
return [first, second]
def get_db_converters(self, expression):
"""
Return a list of functions needed to convert field data.
Some field types on some backends do not provide data in the correct
format, this is the hook for converter functions.
"""
return []
def convert_durationfield_value(self, value, expression, connection):
if value is not None:
return datetime.timedelta(0, 0, value)
def check_expression_support(self, expression):
"""
Check that the backend supports the provided expression.
This is used on specific backends to rule out known expressions
that have problematic or nonexistent implementations. If the
expression has a known problem, the backend should raise
NotSupportedError.
"""
pass
def conditional_expression_supported_in_where_clause(self, expression):
"""
Return True, if the conditional expression is supported in the WHERE
clause.
"""
return True
def combine_expression(self, connector, sub_expressions):
"""
Combine a list of subexpressions into a single expression, using
the provided connecting operator. This is required because operators
can vary between backends (e.g., Oracle with %% and &) and between
subexpression types (e.g., date expressions).
"""
conn = ' %s ' % connector
return conn.join(sub_expressions)
def combine_duration_expression(self, connector, sub_expressions):
return self.combine_expression(connector, sub_expressions)
def binary_placeholder_sql(self, value):
"""
Some backends require special syntax to insert binary content (MySQL
for example uses '_binary %s').
"""
return '%s'
def modify_insert_params(self, placeholder, params):
"""
Allow modification of insert parameters. Needed for Oracle Spatial
backend due to #10888.
"""
return params
def integer_field_range(self, internal_type):
"""
Given an integer field internal type (e.g. 'PositiveIntegerField'),
return a tuple of the (min_value, max_value) form representing the
range of the column type bound to the field.
"""
return self.integer_field_ranges[internal_type]
def subtract_temporals(self, internal_type, lhs, rhs):
if self.connection.features.supports_temporal_subtraction:
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
return '(%s - %s)' % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params)
raise NotSupportedError("This backend does not support %s subtraction." % internal_type)
def window_frame_start(self, start):
if isinstance(start, int):
if start < 0:
return '%d %s' % (abs(start), self.PRECEDING)
elif start == 0:
return self.CURRENT_ROW
elif start is None:
return self.UNBOUNDED_PRECEDING
raise ValueError("start argument must be a negative integer, zero, or None, but got '%s'." % start)
def window_frame_end(self, end):
if isinstance(end, int):
if end == 0:
return self.CURRENT_ROW
elif end > 0:
return '%d %s' % (end, self.FOLLOWING)
elif end is None:
return self.UNBOUNDED_FOLLOWING
raise ValueError("end argument must be a positive integer, zero, or None, but got '%s'." % end)
def window_frame_rows_start_end(self, start=None, end=None):
"""
Return SQL for start and end points in an OVER clause window frame.
"""
if not self.connection.features.supports_over_clause:
raise NotSupportedError('This backend does not support window expressions.')
return self.window_frame_start(start), self.window_frame_end(end)
def window_frame_range_start_end(self, start=None, end=None):
start_, end_ = self.window_frame_rows_start_end(start, end)
if (
self.connection.features.only_supports_unbounded_with_preceding_and_following and
((start and start < 0) or (end and end > 0))
):
raise NotSupportedError(
'%s only supports UNBOUNDED together with PRECEDING and '
'FOLLOWING.' % self.connection.display_name
)
return start_, end_
def explain_query_prefix(self, format=None, **options):
if not self.connection.features.supports_explaining_query_execution:
raise NotSupportedError('This backend does not support explaining query execution.')
if format:
supported_formats = self.connection.features.supported_explain_formats
normalized_format = format.upper()
if normalized_format not in supported_formats:
msg = '%s is not a recognized format.' % normalized_format
if supported_formats:
msg += ' Allowed formats: %s' % ', '.join(sorted(supported_formats))
raise ValueError(msg)
if options:
raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys())))
return self.explain_prefix
def insert_statement(self, ignore_conflicts=False):
return 'INSERT INTO'
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
return ''

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,25 @@
class BaseDatabaseValidation:
"""Encapsulate backend-specific validation."""
def __init__(self, connection):
self.connection = connection
def check(self, **kwargs):
return []
def check_field(self, field, **kwargs):
errors = []
# Backends may implement a check_field_type() method.
if (hasattr(self, 'check_field_type') and
# Ignore any related fields.
not getattr(field, 'remote_field', None)):
# Ignore fields with unsupported features.
db_supports_all_required_features = all(
getattr(self.connection.features, feature, False)
for feature in field.model._meta.required_db_features
)
if db_supports_all_required_features:
field_type = field.db_type(self.connection)
# Ignore non-concrete fields.
if field_type is not None:
errors.extend(self.check_field_type(field, field_type))
return errors

View File

@@ -0,0 +1,232 @@
"""
Helpers to manipulate deferred DDL statements that might need to be adjusted or
discarded within when executing a migration.
"""
from copy import deepcopy
class Reference:
"""Base class that defines the reference interface."""
def references_table(self, table):
"""
Return whether or not this instance references the specified table.
"""
return False
def references_column(self, table, column):
"""
Return whether or not this instance references the specified column.
"""
return False
def rename_table_references(self, old_table, new_table):
"""
Rename all references to the old_name to the new_table.
"""
pass
def rename_column_references(self, table, old_column, new_column):
"""
Rename all references to the old_column to the new_column.
"""
pass
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, str(self))
def __str__(self):
raise NotImplementedError('Subclasses must define how they should be converted to string.')
class Table(Reference):
"""Hold a reference to a table."""
def __init__(self, table, quote_name):
self.table = table
self.quote_name = quote_name
def references_table(self, table):
return self.table == table
def rename_table_references(self, old_table, new_table):
if self.table == old_table:
self.table = new_table
def __str__(self):
return self.quote_name(self.table)
class TableColumns(Table):
"""Base class for references to multiple columns of a table."""
def __init__(self, table, columns):
self.table = table
self.columns = columns
def references_column(self, table, column):
return self.table == table and column in self.columns
def rename_column_references(self, table, old_column, new_column):
if self.table == table:
for index, column in enumerate(self.columns):
if column == old_column:
self.columns[index] = new_column
class Columns(TableColumns):
"""Hold a reference to one or many columns."""
def __init__(self, table, columns, quote_name, col_suffixes=()):
self.quote_name = quote_name
self.col_suffixes = col_suffixes
super().__init__(table, columns)
def __str__(self):
def col_str(column, idx):
col = self.quote_name(column)
try:
suffix = self.col_suffixes[idx]
if suffix:
col = '{} {}'.format(col, suffix)
except IndexError:
pass
return col
return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
class IndexName(TableColumns):
"""Hold a reference to an index name."""
def __init__(self, table, columns, suffix, create_index_name):
self.suffix = suffix
self.create_index_name = create_index_name
super().__init__(table, columns)
def __str__(self):
return self.create_index_name(self.table, self.columns, self.suffix)
class IndexColumns(Columns):
def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()):
self.opclasses = opclasses
super().__init__(table, columns, quote_name, col_suffixes)
def __str__(self):
def col_str(column, idx):
# Index.__init__() guarantees that self.opclasses is the same
# length as self.columns.
col = '{} {}'.format(self.quote_name(column), self.opclasses[idx])
try:
suffix = self.col_suffixes[idx]
if suffix:
col = '{} {}'.format(col, suffix)
except IndexError:
pass
return col
return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
class ForeignKeyName(TableColumns):
"""Hold a reference to a foreign key name."""
def __init__(self, from_table, from_columns, to_table, to_columns, suffix_template, create_fk_name):
self.to_reference = TableColumns(to_table, to_columns)
self.suffix_template = suffix_template
self.create_fk_name = create_fk_name
super().__init__(from_table, from_columns,)
def references_table(self, table):
return super().references_table(table) or self.to_reference.references_table(table)
def references_column(self, table, column):
return (
super().references_column(table, column) or
self.to_reference.references_column(table, column)
)
def rename_table_references(self, old_table, new_table):
super().rename_table_references(old_table, new_table)
self.to_reference.rename_table_references(old_table, new_table)
def rename_column_references(self, table, old_column, new_column):
super().rename_column_references(table, old_column, new_column)
self.to_reference.rename_column_references(table, old_column, new_column)
def __str__(self):
suffix = self.suffix_template % {
'to_table': self.to_reference.table,
'to_column': self.to_reference.columns[0],
}
return self.create_fk_name(self.table, self.columns, suffix)
class Statement(Reference):
"""
Statement template and formatting parameters container.
Allows keeping a reference to a statement without interpolating identifiers
that might have to be adjusted if they're referencing a table or column
that is removed
"""
def __init__(self, template, **parts):
self.template = template
self.parts = parts
def references_table(self, table):
return any(
hasattr(part, 'references_table') and part.references_table(table)
for part in self.parts.values()
)
def references_column(self, table, column):
return any(
hasattr(part, 'references_column') and part.references_column(table, column)
for part in self.parts.values()
)
def rename_table_references(self, old_table, new_table):
for part in self.parts.values():
if hasattr(part, 'rename_table_references'):
part.rename_table_references(old_table, new_table)
def rename_column_references(self, table, old_column, new_column):
for part in self.parts.values():
if hasattr(part, 'rename_column_references'):
part.rename_column_references(table, old_column, new_column)
def __str__(self):
return self.template % self.parts
class Expressions(TableColumns):
def __init__(self, table, expressions, compiler, quote_value):
self.compiler = compiler
self.expressions = expressions
self.quote_value = quote_value
columns = [col.target.column for col in self.compiler.query._gen_cols([self.expressions])]
super().__init__(table, columns)
def rename_table_references(self, old_table, new_table):
if self.table != old_table:
return
self.expressions = self.expressions.relabeled_clone({old_table: new_table})
super().rename_table_references(old_table, new_table)
def rename_column_references(self, table, old_column, new_column):
if self.table != table:
return
expressions = deepcopy(self.expressions)
self.columns = []
for col in self.compiler.query._gen_cols([expressions]):
if col.target.column == old_column:
col.target.column = new_column
self.columns.append(col.target.column)
self.expressions = expressions
def __str__(self):
sql, params = self.compiler.compile(self.expressions)
params = map(self.quote_value, params)
return sql % tuple(params)

View File

@@ -0,0 +1,73 @@
"""
Dummy database backend for Django.
Django uses this if the database ENGINE setting is empty (None or empty string).
Each of these API functions, except connection.close(), raise
ImproperlyConfigured.
"""
from django.core.exceptions import ImproperlyConfigured
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.base.client import BaseDatabaseClient
from django.db.backends.base.creation import BaseDatabaseCreation
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.dummy.features import DummyDatabaseFeatures
def complain(*args, **kwargs):
raise ImproperlyConfigured("settings.DATABASES is improperly configured. "
"Please supply the ENGINE value. Check "
"settings documentation for more details.")
def ignore(*args, **kwargs):
pass
class DatabaseOperations(BaseDatabaseOperations):
quote_name = complain
class DatabaseClient(BaseDatabaseClient):
runshell = complain
class DatabaseCreation(BaseDatabaseCreation):
create_test_db = ignore
destroy_test_db = ignore
class DatabaseIntrospection(BaseDatabaseIntrospection):
get_table_list = complain
get_table_description = complain
get_relations = complain
get_indexes = complain
get_key_columns = complain
class DatabaseWrapper(BaseDatabaseWrapper):
operators = {}
# Override the base class implementations with null
# implementations. Anything that tries to actually
# do something raises complain; anything that tries
# to rollback or undo something raises ignore.
_cursor = complain
ensure_connection = complain
_commit = complain
_rollback = ignore
_close = ignore
_savepoint = ignore
_savepoint_commit = complain
_savepoint_rollback = ignore
_set_autocommit = complain
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DummyDatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
def is_usable(self):
return True

View File

@@ -0,0 +1,6 @@
from django.db.backends.base.features import BaseDatabaseFeatures
class DummyDatabaseFeatures(BaseDatabaseFeatures):
supports_transactions = False
uses_savepoints = False

View File

@@ -0,0 +1,405 @@
"""
MySQL database backend for Django.
Requires mysqlclient: https://pypi.org/project/mysqlclient/
"""
from django.core.exceptions import ImproperlyConfigured
from django.db import IntegrityError
from django.db.backends import utils as backend_utils
from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
from django.utils.regex_helper import _lazy_re_compile
try:
import MySQLdb as Database
except ImportError as err:
raise ImproperlyConfigured(
'Error loading MySQLdb module.\n'
'Did you install mysqlclient?'
) from err
from MySQLdb.constants import CLIENT, FIELD_TYPE
from MySQLdb.converters import conversions
# Some of these import MySQLdb, so import them after checking if it's installed.
from .client import DatabaseClient
from .creation import DatabaseCreation
from .features import DatabaseFeatures
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .schema import DatabaseSchemaEditor
from .validation import DatabaseValidation
version = Database.version_info
if version < (1, 4, 0):
raise ImproperlyConfigured('mysqlclient 1.4.0 or newer is required; you have %s.' % Database.__version__)
# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
# terms of actual behavior as they are signed and include days -- and Django
# expects time.
django_conversions = {
**conversions,
**{FIELD_TYPE.TIME: backend_utils.typecast_time},
}
# This should match the numerical portion of the version numbers (we can treat
# versions like 5.0.24 and 5.0.24a as the same).
server_version_re = _lazy_re_compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
class CursorWrapper:
"""
A thin wrapper around MySQLdb's normal cursor class that catches particular
exception instances and reraises them with the correct types.
Implemented as a wrapper, rather than a subclass, so that it isn't stuck
to the particular underlying representation returned by Connection.cursor().
"""
codes_for_integrityerror = (
1048, # Column cannot be null
1690, # BIGINT UNSIGNED value is out of range
3819, # CHECK constraint is violated
4025, # CHECK constraint failed
)
def __init__(self, cursor):
self.cursor = cursor
def execute(self, query, args=None):
try:
# args is None means no string interpolation
return self.cursor.execute(query, args)
except Database.OperationalError as e:
# Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror:
raise IntegrityError(*tuple(e.args))
raise
def executemany(self, query, args):
try:
return self.cursor.executemany(query, args)
except Database.OperationalError as e:
# Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror:
raise IntegrityError(*tuple(e.args))
raise
def __getattr__(self, attr):
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'mysql'
# This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
'AutoField': 'integer AUTO_INCREMENT',
'BigAutoField': 'bigint AUTO_INCREMENT',
'BinaryField': 'longblob',
'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'datetime(6)',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'DurationField': 'bigint',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'BigIntegerField': 'bigint',
'IPAddressField': 'char(15)',
'GenericIPAddressField': 'char(39)',
'JSONField': 'json',
'OneToOneField': 'integer',
'PositiveBigIntegerField': 'bigint UNSIGNED',
'PositiveIntegerField': 'integer UNSIGNED',
'PositiveSmallIntegerField': 'smallint UNSIGNED',
'SlugField': 'varchar(%(max_length)s)',
'SmallAutoField': 'smallint AUTO_INCREMENT',
'SmallIntegerField': 'smallint',
'TextField': 'longtext',
'TimeField': 'time(6)',
'UUIDField': 'char(32)',
}
# For these data types:
# - MySQL < 8.0.13 and MariaDB < 10.2.1 don't accept default values and
# implicitly treat them as nullable
# - all versions of MySQL and MariaDB don't support full width database
# indexes
_limited_data_types = (
'tinyblob', 'blob', 'mediumblob', 'longblob', 'tinytext', 'text',
'mediumtext', 'longtext', 'json',
)
operators = {
'exact': '= %s',
'iexact': 'LIKE %s',
'contains': 'LIKE BINARY %s',
'icontains': 'LIKE %s',
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': 'LIKE BINARY %s',
'endswith': 'LIKE BINARY %s',
'istartswith': 'LIKE %s',
'iendswith': 'LIKE %s',
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
# escaped on database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
'contains': "LIKE BINARY CONCAT('%%', {}, '%%')",
'icontains': "LIKE CONCAT('%%', {}, '%%')",
'startswith': "LIKE BINARY CONCAT({}, '%%')",
'istartswith': "LIKE CONCAT({}, '%%')",
'endswith': "LIKE BINARY CONCAT('%%', {})",
'iendswith': "LIKE CONCAT('%%', {})",
}
isolation_levels = {
'read uncommitted',
'read committed',
'repeatable read',
'serializable',
}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
validation_class = DatabaseValidation
def get_connection_params(self):
kwargs = {
'conv': django_conversions,
'charset': 'utf8',
}
settings_dict = self.settings_dict
if settings_dict['USER']:
kwargs['user'] = settings_dict['USER']
if settings_dict['NAME']:
kwargs['database'] = settings_dict['NAME']
if settings_dict['PASSWORD']:
kwargs['password'] = settings_dict['PASSWORD']
if settings_dict['HOST'].startswith('/'):
kwargs['unix_socket'] = settings_dict['HOST']
elif settings_dict['HOST']:
kwargs['host'] = settings_dict['HOST']
if settings_dict['PORT']:
kwargs['port'] = int(settings_dict['PORT'])
# We need the number of potentially affected rows after an
# "UPDATE", not the number of changed rows.
kwargs['client_flag'] = CLIENT.FOUND_ROWS
# Validate the transaction isolation level, if specified.
options = settings_dict['OPTIONS'].copy()
isolation_level = options.pop('isolation_level', 'read committed')
if isolation_level:
isolation_level = isolation_level.lower()
if isolation_level not in self.isolation_levels:
raise ImproperlyConfigured(
"Invalid transaction isolation level '%s' specified.\n"
"Use one of %s, or None." % (
isolation_level,
', '.join("'%s'" % s for s in sorted(self.isolation_levels))
))
self.isolation_level = isolation_level
kwargs.update(options)
return kwargs
@async_unsafe
def get_new_connection(self, conn_params):
connection = Database.connect(**conn_params)
# bytes encoder in mysqlclient doesn't work and was added only to
# prevent KeyErrors in Django < 2.0. We can remove this workaround when
# mysqlclient 2.1 becomes the minimal mysqlclient supported by Django.
# See https://github.com/PyMySQL/mysqlclient/issues/489
if connection.encoders.get(bytes) is bytes:
connection.encoders.pop(bytes)
return connection
def init_connection_state(self):
assignments = []
if self.features.is_sql_auto_is_null_enabled:
# SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
# a recently inserted row will return when the field is tested
# for NULL. Disabling this brings this aspect of MySQL in line
# with SQL standards.
assignments.append('SET SQL_AUTO_IS_NULL = 0')
if self.isolation_level:
assignments.append('SET SESSION TRANSACTION ISOLATION LEVEL %s' % self.isolation_level.upper())
if assignments:
with self.cursor() as cursor:
cursor.execute('; '.join(assignments))
@async_unsafe
def create_cursor(self, name=None):
cursor = self.connection.cursor()
return CursorWrapper(cursor)
def _rollback(self):
try:
BaseDatabaseWrapper._rollback(self)
except Database.NotSupportedError:
pass
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
self.connection.autocommit(autocommit)
def disable_constraint_checking(self):
"""
Disable foreign key checks, primarily for use in adding rows with
forward references. Always return True to indicate constraint checks
need to be re-enabled.
"""
with self.cursor() as cursor:
cursor.execute('SET foreign_key_checks=0')
return True
def enable_constraint_checking(self):
"""
Re-enable foreign key checks after they have been disabled.
"""
# Override needs_rollback in case constraint_checks_disabled is
# nested inside transaction.atomic.
self.needs_rollback, needs_rollback = False, self.needs_rollback
try:
with self.cursor() as cursor:
cursor.execute('SET foreign_key_checks=1')
finally:
self.needs_rollback = needs_rollback
def check_constraints(self, table_names=None):
"""
Check each table name in `table_names` for rows with invalid foreign
key references. This method is intended to be used in conjunction with
`disable_constraint_checking()` and `enable_constraint_checking()`, to
determine if rows with invalid references were entered while constraint
checks were off.
"""
with self.cursor() as cursor:
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name:
continue
key_columns = self.introspection.get_key_columns(cursor, table_name)
for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
""" % (
primary_key_column_name, column_name, table_name,
referenced_table_name, column_name, referenced_column_name,
column_name, referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not "
"have a corresponding value in %s.%s."
% (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
)
)
def is_usable(self):
try:
self.connection.ping()
except Database.Error:
return False
else:
return True
@cached_property
def display_name(self):
return 'MariaDB' if self.mysql_is_mariadb else 'MySQL'
@cached_property
def data_type_check_constraints(self):
if self.features.supports_column_check_constraints:
check_constraints = {
'PositiveBigIntegerField': '`%(column)s` >= 0',
'PositiveIntegerField': '`%(column)s` >= 0',
'PositiveSmallIntegerField': '`%(column)s` >= 0',
}
if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
# MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
# a check constraint.
check_constraints['JSONField'] = 'JSON_VALID(`%(column)s`)'
return check_constraints
return {}
@cached_property
def mysql_server_data(self):
with self.temporary_connection() as cursor:
# Select some server variables and test if the time zone
# definitions are installed. CONVERT_TZ returns NULL if 'UTC'
# timezone isn't loaded into the mysql.time_zone table.
cursor.execute("""
SELECT VERSION(),
@@sql_mode,
@@default_storage_engine,
@@sql_auto_is_null,
@@lower_case_table_names,
CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
""")
row = cursor.fetchone()
return {
'version': row[0],
'sql_mode': row[1],
'default_storage_engine': row[2],
'sql_auto_is_null': bool(row[3]),
'lower_case_table_names': bool(row[4]),
'has_zoneinfo_database': bool(row[5]),
}
@cached_property
def mysql_server_info(self):
return self.mysql_server_data['version']
@cached_property
def mysql_version(self):
match = server_version_re.match(self.mysql_server_info)
if not match:
raise Exception('Unable to determine MySQL version from version string %r' % self.mysql_server_info)
return tuple(int(x) for x in match.groups())
@cached_property
def mysql_is_mariadb(self):
return 'mariadb' in self.mysql_server_info.lower()
@cached_property
def sql_mode(self):
sql_mode = self.mysql_server_data['sql_mode']
return set(sql_mode.split(',') if sql_mode else ())

View File

@@ -0,0 +1,60 @@
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = 'mysql'
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
env = None
database = settings_dict['OPTIONS'].get(
'database',
settings_dict['OPTIONS'].get('db', settings_dict['NAME']),
)
user = settings_dict['OPTIONS'].get('user', settings_dict['USER'])
password = settings_dict['OPTIONS'].get(
'password',
settings_dict['OPTIONS'].get('passwd', settings_dict['PASSWORD'])
)
host = settings_dict['OPTIONS'].get('host', settings_dict['HOST'])
port = settings_dict['OPTIONS'].get('port', settings_dict['PORT'])
server_ca = settings_dict['OPTIONS'].get('ssl', {}).get('ca')
client_cert = settings_dict['OPTIONS'].get('ssl', {}).get('cert')
client_key = settings_dict['OPTIONS'].get('ssl', {}).get('key')
defaults_file = settings_dict['OPTIONS'].get('read_default_file')
charset = settings_dict['OPTIONS'].get('charset')
# Seems to be no good way to set sql_mode with CLI.
if defaults_file:
args += ["--defaults-file=%s" % defaults_file]
if user:
args += ["--user=%s" % user]
if password:
# The MYSQL_PWD environment variable usage is discouraged per
# MySQL's documentation due to the possibility of exposure through
# `ps` on old Unix flavors but --password suffers from the same
# flaw on even more systems. Usage of an environment variable also
# prevents password exposure if the subprocess.run(check=True) call
# raises a CalledProcessError since the string representation of
# the latter includes all of the provided `args`.
env = {'MYSQL_PWD': password}
if host:
if '/' in host:
args += ["--socket=%s" % host]
else:
args += ["--host=%s" % host]
if port:
args += ["--port=%s" % port]
if server_ca:
args += ["--ssl-ca=%s" % server_ca]
if client_cert:
args += ["--ssl-cert=%s" % client_cert]
if client_key:
args += ["--ssl-key=%s" % client_key]
if charset:
args += ['--default-character-set=%s' % charset]
if database:
args += [database]
args.extend(parameters)
return args, env

View File

@@ -0,0 +1,71 @@
from django.core.exceptions import FieldError
from django.db.models.expressions import Col
from django.db.models.sql import compiler
class SQLCompiler(compiler.SQLCompiler):
def as_subquery_condition(self, alias, columns, compiler):
qn = compiler.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
sql, params = self.as_sql()
return '(%s) IN (%s)' % (', '.join('%s.%s' % (qn(alias), qn2(column)) for column in columns), sql), params
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
def as_sql(self):
# Prefer the non-standard DELETE FROM syntax over the SQL generated by
# the SQLDeleteCompiler's default implementation when multiple tables
# are involved since MySQL/MariaDB will generate a more efficient query
# plan than when using a subquery.
where, having = self.query.where.split_having()
if self.single_alias or having:
# DELETE FROM cannot be used when filtering against aggregates
# since it doesn't allow for GROUP BY and HAVING clauses.
return super().as_sql()
result = [
'DELETE %s FROM' % self.quote_name_unless_alias(
self.query.get_initial_alias()
)
]
from_sql, from_params = self.get_from_clause()
result.extend(from_sql)
where_sql, where_params = self.compile(where)
if where_sql:
result.append('WHERE %s' % where_sql)
return ' '.join(result), tuple(from_params) + tuple(where_params)
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
def as_sql(self):
update_query, update_params = super().as_sql()
# MySQL and MariaDB support UPDATE ... ORDER BY syntax.
if self.query.order_by:
order_by_sql = []
order_by_params = []
db_table = self.query.get_meta().db_table
try:
for resolved, (sql, params, _) in self.get_order_by():
if (
isinstance(resolved.expression, Col) and
resolved.expression.alias != db_table
):
# Ignore ordering if it contains joined fields, because
# they cannot be used in the ORDER BY clause.
raise FieldError
order_by_sql.append(sql)
order_by_params.extend(params)
update_query += ' ORDER BY ' + ', '.join(order_by_sql)
update_params += tuple(order_by_params)
except FieldError:
# Ignore ordering if it contains annotations, because they're
# removed in .update() and cannot be resolved.
pass
return update_query, update_params
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
pass

View File

@@ -0,0 +1,68 @@
import os
import subprocess
import sys
from django.db.backends.base.creation import BaseDatabaseCreation
from .client import DatabaseClient
class DatabaseCreation(BaseDatabaseCreation):
def sql_table_creation_suffix(self):
suffix = []
test_settings = self.connection.settings_dict['TEST']
if test_settings['CHARSET']:
suffix.append('CHARACTER SET %s' % test_settings['CHARSET'])
if test_settings['COLLATION']:
suffix.append('COLLATE %s' % test_settings['COLLATION'])
return ' '.join(suffix)
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
super()._execute_create_test_db(cursor, parameters, keepdb)
except Exception as e:
if len(e.args) < 1 or e.args[0] != 1007:
# All errors except "database exists" (1007) cancel tests.
self.log('Got an error creating the test database: %s' % e)
sys.exit(2)
else:
raise
def _clone_test_db(self, suffix, verbosity, keepdb=False):
source_database_name = self.connection.settings_dict['NAME']
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
test_db_params = {
'dbname': self.connection.ops.quote_name(target_database_name),
'suffix': self.sql_table_creation_suffix(),
}
with self._nodb_cursor() as cursor:
try:
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception:
if keepdb:
# If the database should be kept, skip everything else.
return
try:
if verbosity >= 1:
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, target_database_name),
))
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log('Got an error recreating the test database: %s' % e)
sys.exit(2)
self._clone_db(source_database_name, target_database_name)
def _clone_db(self, source_database_name, target_database_name):
cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(self.connection.settings_dict, [])
dump_cmd = ['mysqldump', *cmd_args[1:-1], '--routines', '--events', source_database_name]
dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None
load_cmd = cmd_args
load_cmd[-1] = target_database_name
with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE, env=dump_env) as dump_proc:
with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL, env=load_env):
# Allow dump_proc to receive a SIGPIPE if the load process exits.
dump_proc.stdout.close()

View File

@@ -0,0 +1,268 @@
import operator
from django.db.backends.base.features import BaseDatabaseFeatures
from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures):
empty_fetchmany_value = ()
allows_group_by_pk = True
related_fields_match_type = True
# MySQL doesn't support sliced subqueries with IN/ALL/ANY/SOME.
allow_sliced_subqueries_with_in = False
has_select_for_update = True
supports_forward_references = False
supports_regex_backreferencing = False
supports_date_lookup_using_string = False
supports_timezones = False
requires_explicit_null_ordering_when_grouping = True
can_release_savepoints = True
atomic_transactions = False
can_clone_databases = True
supports_temporal_subtraction = True
supports_select_intersection = False
supports_select_difference = False
supports_slicing_ordering_in_compound = True
supports_index_on_text_field = False
has_case_insensitive_like = False
create_test_procedure_without_params_sql = """
CREATE PROCEDURE test_procedure ()
BEGIN
DECLARE V_I INTEGER;
SET V_I = 1;
END;
"""
create_test_procedure_with_int_param_sql = """
CREATE PROCEDURE test_procedure (P_I INTEGER)
BEGIN
DECLARE V_I INTEGER;
SET V_I = P_I;
END;
"""
# Neither MySQL nor MariaDB support partial indexes.
supports_partial_indexes = False
# COLLATE must be wrapped in parentheses because MySQL treats COLLATE as an
# indexed expression.
collate_as_index_expression = True
supports_order_by_nulls_modifier = False
order_by_nulls_first = True
@cached_property
def test_collations(self):
charset = 'utf8'
if self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 6):
# utf8 is an alias for utf8mb3 in MariaDB 10.6+.
charset = 'utf8mb3'
return {
'ci': f'{charset}_general_ci',
'non_default': f'{charset}_esperanto_ci',
'swedish_ci': f'{charset}_swedish_ci',
}
test_now_utc_template = 'UTC_TIMESTAMP'
@cached_property
def django_test_skips(self):
skips = {
"This doesn't work on MySQL.": {
'db_functions.comparison.test_greatest.GreatestTests.test_coalesce_workaround',
'db_functions.comparison.test_least.LeastTests.test_coalesce_workaround',
},
'Running on MySQL requires utf8mb4 encoding (#18392).': {
'model_fields.test_textfield.TextFieldTests.test_emoji',
'model_fields.test_charfield.TestCharField.test_emoji',
},
"MySQL doesn't support functional indexes on a function that "
"returns JSON": {
'schema.tests.SchemaTests.test_func_index_json_key_transform',
},
"MySQL supports multiplying and dividing DurationFields by a "
"scalar value but it's not implemented (#25287).": {
'expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide',
},
}
if 'ONLY_FULL_GROUP_BY' in self.connection.sql_mode:
skips.update({
'GROUP BY optimization does not work properly when '
'ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #31331.': {
'aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation_multivalued',
'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o',
},
})
if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (8,):
skips.update({
'Casting to datetime/time is not supported by MySQL < 8.0. (#30224)': {
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python',
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python',
},
'MySQL < 8.0 returns string type instead of datetime/time. (#30224)': {
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database',
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database',
},
})
if (
self.connection.mysql_is_mariadb and
(10, 4, 3) < self.connection.mysql_version < (10, 5, 2)
):
skips.update({
'https://jira.mariadb.org/browse/MDEV-19598': {
'schema.tests.SchemaTests.test_alter_not_unique_field_to_primary_key',
},
})
if (
self.connection.mysql_is_mariadb and
(10, 4, 12) < self.connection.mysql_version < (10, 5)
):
skips.update({
'https://jira.mariadb.org/browse/MDEV-22775': {
'schema.tests.SchemaTests.test_alter_pk_with_self_referential_field',
},
})
if not self.supports_explain_analyze:
skips.update({
'MariaDB and MySQL >= 8.0.18 specific.': {
'queries.test_explain.ExplainTests.test_mysql_analyze',
},
})
return skips
@cached_property
def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code"
return self.connection.mysql_server_data['default_storage_engine']
@cached_property
def allows_auto_pk_0(self):
"""
Autoincrement primary key can be set to 0 if it doesn't generate new
autoincrement values.
"""
return 'NO_AUTO_VALUE_ON_ZERO' in self.connection.sql_mode
@cached_property
def update_can_self_select(self):
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 3, 2)
@cached_property
def can_introspect_foreign_keys(self):
"Confirm support for introspected foreign keys"
return self._mysql_storage_engine != 'MyISAM'
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
'BinaryField': 'TextField',
'BooleanField': 'IntegerField',
'DurationField': 'BigIntegerField',
'GenericIPAddressField': 'CharField',
}
@cached_property
def can_return_columns_from_insert(self):
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 5, 0)
can_return_rows_from_bulk_insert = property(operator.attrgetter('can_return_columns_from_insert'))
@cached_property
def has_zoneinfo_database(self):
return self.connection.mysql_server_data['has_zoneinfo_database']
@cached_property
def is_sql_auto_is_null_enabled(self):
return self.connection.mysql_server_data['sql_auto_is_null']
@cached_property
def supports_over_clause(self):
if self.connection.mysql_is_mariadb:
return True
return self.connection.mysql_version >= (8, 0, 2)
supports_frame_range_fixed_distance = property(operator.attrgetter('supports_over_clause'))
@cached_property
def supports_column_check_constraints(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 2, 1)
return self.connection.mysql_version >= (8, 0, 16)
supports_table_check_constraints = property(operator.attrgetter('supports_column_check_constraints'))
@cached_property
def can_introspect_check_constraints(self):
if self.connection.mysql_is_mariadb:
version = self.connection.mysql_version
return (version >= (10, 2, 22) and version < (10, 3)) or version >= (10, 3, 10)
return self.connection.mysql_version >= (8, 0, 16)
@cached_property
def has_select_for_update_skip_locked(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 6)
return self.connection.mysql_version >= (8, 0, 1)
@cached_property
def has_select_for_update_nowait(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 3, 0)
return self.connection.mysql_version >= (8, 0, 1)
@cached_property
def has_select_for_update_of(self):
return not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 1)
@cached_property
def supports_explain_analyze(self):
return self.connection.mysql_is_mariadb or self.connection.mysql_version >= (8, 0, 18)
@cached_property
def supported_explain_formats(self):
# Alias MySQL's TRADITIONAL to TEXT for consistency with other
# backends.
formats = {'JSON', 'TEXT', 'TRADITIONAL'}
if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 16):
formats.add('TREE')
return formats
@cached_property
def supports_transactions(self):
"""
All storage engines except MyISAM support transactions.
"""
return self._mysql_storage_engine != 'MyISAM'
@cached_property
def ignores_table_name_case(self):
return self.connection.mysql_server_data['lower_case_table_names']
@cached_property
def supports_default_in_lead_lag(self):
# To be added in https://jira.mariadb.org/browse/MDEV-12981.
return not self.connection.mysql_is_mariadb
@cached_property
def supports_json_field(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 2, 7)
return self.connection.mysql_version >= (5, 7, 8)
@cached_property
def can_introspect_json_field(self):
if self.connection.mysql_is_mariadb:
return self.supports_json_field and self.can_introspect_check_constraints
return self.supports_json_field
@cached_property
def supports_index_column_ordering(self):
return (
not self.connection.mysql_is_mariadb and
self.connection.mysql_version >= (8, 0, 1)
)
@cached_property
def supports_expression_indexes(self):
return (
not self.connection.mysql_is_mariadb and
self.connection.mysql_version >= (8, 0, 13)
)

View File

@@ -0,0 +1,309 @@
from collections import namedtuple
import sqlparse
from MySQLdb.constants import FIELD_TYPE
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
)
from django.db.models import Index
from django.utils.datastructures import OrderedSet
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned', 'has_json_constraint'))
InfoLine = namedtuple(
'InfoLine',
'col_name data_type max_len num_prec num_scale extra column_default '
'collation is_unsigned'
)
class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = {
FIELD_TYPE.BLOB: 'TextField',
FIELD_TYPE.CHAR: 'CharField',
FIELD_TYPE.DECIMAL: 'DecimalField',
FIELD_TYPE.NEWDECIMAL: 'DecimalField',
FIELD_TYPE.DATE: 'DateField',
FIELD_TYPE.DATETIME: 'DateTimeField',
FIELD_TYPE.DOUBLE: 'FloatField',
FIELD_TYPE.FLOAT: 'FloatField',
FIELD_TYPE.INT24: 'IntegerField',
FIELD_TYPE.JSON: 'JSONField',
FIELD_TYPE.LONG: 'IntegerField',
FIELD_TYPE.LONGLONG: 'BigIntegerField',
FIELD_TYPE.SHORT: 'SmallIntegerField',
FIELD_TYPE.STRING: 'CharField',
FIELD_TYPE.TIME: 'TimeField',
FIELD_TYPE.TIMESTAMP: 'DateTimeField',
FIELD_TYPE.TINY: 'IntegerField',
FIELD_TYPE.TINY_BLOB: 'TextField',
FIELD_TYPE.MEDIUM_BLOB: 'TextField',
FIELD_TYPE.LONG_BLOB: 'TextField',
FIELD_TYPE.VAR_STRING: 'CharField',
}
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if 'auto_increment' in description.extra:
if field_type == 'IntegerField':
return 'AutoField'
elif field_type == 'BigIntegerField':
return 'BigAutoField'
elif field_type == 'SmallIntegerField':
return 'SmallAutoField'
if description.is_unsigned:
if field_type == 'BigIntegerField':
return 'PositiveBigIntegerField'
elif field_type == 'IntegerField':
return 'PositiveIntegerField'
elif field_type == 'SmallIntegerField':
return 'PositiveSmallIntegerField'
# JSON data type is an alias for LONGTEXT in MariaDB, use check
# constraints clauses to introspect JSONField.
if description.has_json_constraint:
return 'JSONField'
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute("SHOW FULL TABLES")
return [TableInfo(row[0], {'BASE TABLE': 't', 'VIEW': 'v'}.get(row[1]))
for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface."
"""
json_constraints = {}
if self.connection.mysql_is_mariadb and self.connection.features.can_introspect_json_field:
# JSON data type is an alias for LONGTEXT in MariaDB, select
# JSON_VALID() constraints to introspect JSONField.
cursor.execute("""
SELECT c.constraint_name AS column_name
FROM information_schema.check_constraints AS c
WHERE
c.table_name = %s AND
LOWER(c.check_clause) = 'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
c.constraint_schema = DATABASE()
""", [table_name])
json_constraints = {row[0] for row in cursor.fetchall()}
# A default collation for the given table.
cursor.execute("""
SELECT table_collation
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = %s
""", [table_name])
row = cursor.fetchone()
default_column_collation = row[0] if row else ''
# information_schema database gives more accurate results for some figures:
# - varchar length returned by cursor.description is an internal length,
# not visible length (#5725)
# - precision and scale (for decimal fields) (#5014)
# - auto_increment is not available in cursor.description
cursor.execute("""
SELECT
column_name, data_type, character_maximum_length,
numeric_precision, numeric_scale, extra, column_default,
CASE
WHEN collation_name = %s THEN NULL
ELSE collation_name
END AS collation_name,
CASE
WHEN column_type LIKE '%% unsigned' THEN 1
ELSE 0
END AS is_unsigned
FROM information_schema.columns
WHERE table_name = %s AND table_schema = DATABASE()
""", [default_column_collation, table_name])
field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
def to_int(i):
return int(i) if i is not None else i
fields = []
for line in cursor.description:
info = field_info[line[0]]
fields.append(FieldInfo(
*line[:3],
to_int(info.max_len) or line[3],
to_int(info.num_prec) or line[4],
to_int(info.num_scale) or line[5],
line[6],
info.column_default,
info.collation,
info.extra,
info.is_unsigned,
line[0] in json_constraints,
))
return fields
def get_sequences(self, cursor, table_name, table_fields=()):
for field_info in self.get_table_description(cursor, table_name):
if 'auto_increment' in field_info.extra:
# MySQL allows only one auto-increment column per table.
return [{'table': table_name, 'column': field_info.name}]
return []
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table.
"""
constraints = self.get_key_columns(cursor, table_name)
relations = {}
for my_fieldname, other_table, other_field in constraints:
relations[my_fieldname] = (other_field, other_table)
return relations
def get_key_columns(self, cursor, table_name):
"""
Return a list of (column_name, referenced_table_name, referenced_column_name)
for all key columns in the given table.
"""
key_columns = []
cursor.execute("""
SELECT column_name, referenced_table_name, referenced_column_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL""", [table_name])
key_columns.extend(cursor.fetchall())
return key_columns
def get_storage_engine(self, cursor, table_name):
"""
Retrieve the storage engine for a given table. Return the default
storage engine if the table doesn't exist.
"""
cursor.execute("""
SELECT engine
FROM information_schema.tables
WHERE
table_name = %s AND
table_schema = DATABASE()
""", [table_name])
result = cursor.fetchone()
if not result:
return self.connection.features._mysql_storage_engine
return result[0]
def _parse_constraint_columns(self, check_clause, columns):
check_columns = OrderedSet()
statement = sqlparse.parse(check_clause)[0]
tokens = (token for token in statement.flatten() if not token.is_whitespace)
for token in tokens:
if (
token.ttype == sqlparse.tokens.Name and
self.connection.ops.quote_name(token.value) == token.value and
token.value[1:-1] in columns
):
check_columns.add(token.value[1:-1])
return check_columns
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns.
"""
constraints = {}
# Get the actual constraint names and columns
name_query = """
SELECT kc.`constraint_name`, kc.`column_name`,
kc.`referenced_table_name`, kc.`referenced_column_name`,
c.`constraint_type`
FROM
information_schema.key_column_usage AS kc,
information_schema.table_constraints AS c
WHERE
kc.table_schema = DATABASE() AND
c.table_schema = kc.table_schema AND
c.constraint_name = kc.constraint_name AND
c.constraint_type != 'CHECK' AND
kc.table_name = %s
ORDER BY kc.`ordinal_position`
"""
cursor.execute(name_query, [table_name])
for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
if constraint not in constraints:
constraints[constraint] = {
'columns': OrderedSet(),
'primary_key': kind == 'PRIMARY KEY',
'unique': kind in {'PRIMARY KEY', 'UNIQUE'},
'index': False,
'check': False,
'foreign_key': (ref_table, ref_column) if ref_column else None,
}
if self.connection.features.supports_index_column_ordering:
constraints[constraint]['orders'] = []
constraints[constraint]['columns'].add(column)
# Add check constraints.
if self.connection.features.can_introspect_check_constraints:
unnamed_constraints_index = 0
columns = {info.name for info in self.get_table_description(cursor, table_name)}
if self.connection.mysql_is_mariadb:
type_query = """
SELECT c.constraint_name, c.check_clause
FROM information_schema.check_constraints AS c
WHERE
c.constraint_schema = DATABASE() AND
c.table_name = %s
"""
else:
type_query = """
SELECT cc.constraint_name, cc.check_clause
FROM
information_schema.check_constraints AS cc,
information_schema.table_constraints AS tc
WHERE
cc.constraint_schema = DATABASE() AND
tc.table_schema = cc.constraint_schema AND
cc.constraint_name = tc.constraint_name AND
tc.constraint_type = 'CHECK' AND
tc.table_name = %s
"""
cursor.execute(type_query, [table_name])
for constraint, check_clause in cursor.fetchall():
constraint_columns = self._parse_constraint_columns(check_clause, columns)
# Ensure uniqueness of unnamed constraints. Unnamed unique
# and check columns constraints have the same name as
# a column.
if set(constraint_columns) == {constraint}:
unnamed_constraints_index += 1
constraint = '__unnamed_constraint_%s__' % unnamed_constraints_index
constraints[constraint] = {
'columns': constraint_columns,
'primary_key': False,
'unique': False,
'index': False,
'check': True,
'foreign_key': None,
}
# Now add in the indexes
cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name))
for table, non_unique, index, colseq, column, order, type_ in [
x[:6] + (x[10],) for x in cursor.fetchall()
]:
if index not in constraints:
constraints[index] = {
'columns': OrderedSet(),
'primary_key': False,
'unique': not non_unique,
'check': False,
'foreign_key': None,
}
if self.connection.features.supports_index_column_ordering:
constraints[index]['orders'] = []
constraints[index]['index'] = True
constraints[index]['type'] = Index.suffix if type_ == 'BTREE' else type_.lower()
constraints[index]['columns'].add(column)
if self.connection.features.supports_index_column_ordering:
constraints[index]['orders'].append('DESC' if order == 'D' else 'ASC')
# Convert the sorted sets to lists
for constraint in constraints.values():
constraint['columns'] = list(constraint['columns'])
return constraints

View File

@@ -0,0 +1,378 @@
import uuid
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
from django.utils import timezone
from django.utils.encoding import force_str
class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.mysql.compiler"
# MySQL stores positive fields as UNSIGNED ints.
integer_field_ranges = {
**BaseDatabaseOperations.integer_field_ranges,
'PositiveSmallIntegerField': (0, 65535),
'PositiveIntegerField': (0, 4294967295),
'PositiveBigIntegerField': (0, 18446744073709551615),
}
cast_data_types = {
'AutoField': 'signed integer',
'BigAutoField': 'signed integer',
'SmallAutoField': 'signed integer',
'CharField': 'char(%(max_length)s)',
'DecimalField': 'decimal(%(max_digits)s, %(decimal_places)s)',
'TextField': 'char',
'IntegerField': 'signed integer',
'BigIntegerField': 'signed integer',
'SmallIntegerField': 'signed integer',
'PositiveBigIntegerField': 'unsigned integer',
'PositiveIntegerField': 'unsigned integer',
'PositiveSmallIntegerField': 'unsigned integer',
'DurationField': 'signed integer',
}
cast_char_field_without_max_length = 'char'
explain_prefix = 'EXPLAIN'
def date_extract_sql(self, lookup_type, field_name):
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
if lookup_type == 'week_day':
# DAYOFWEEK() returns an integer, 1-7, Sunday=1.
return "DAYOFWEEK(%s)" % field_name
elif lookup_type == 'iso_week_day':
# WEEKDAY() returns an integer, 0-6, Monday=0.
return "WEEKDAY(%s) + 1" % field_name
elif lookup_type == 'week':
# Override the value of default_week_format for consistency with
# other database backends.
# Mode 3: Monday, 1-53, with 4 or more days this year.
return "WEEK(%s, 3)" % field_name
elif lookup_type == 'iso_year':
# Get the year part from the YEARWEEK function, which returns a
# number as year * 100 + week.
return "TRUNCATE(YEARWEEK(%s, 3), -2) / 100" % field_name
else:
# EXTRACT returns 1-53 based on ISO-8601 for the week number.
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = {
'year': '%%Y-01-01',
'month': '%%Y-%%m-01',
} # Use double percents to escape.
if lookup_type in fields:
format_str = fields[lookup_type]
return "CAST(DATE_FORMAT(%s, '%s') AS DATE)" % (field_name, format_str)
elif lookup_type == 'quarter':
return "MAKEDATE(YEAR(%s), 1) + INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER" % (
field_name, field_name
)
elif lookup_type == 'week':
return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (
field_name, field_name
)
else:
return "DATE(%s)" % (field_name)
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
return f'{sign}{offset}' if offset else tzname
def _convert_field_to_tz(self, field_name, tzname):
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
field_name = "CONVERT_TZ(%s, '%s', '%s')" % (
field_name,
self.connection.timezone_name,
self._prepare_tzname_delta(tzname),
)
return field_name
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "DATE(%s)" % field_name
def datetime_cast_time_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "TIME(%s)" % field_name
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return self.date_extract_sql(lookup_type, field_name)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = ['year', 'month', 'day', 'hour', 'minute', 'second']
format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape.
format_def = ('0000-', '01', '-01', ' 00:', '00', ':00')
if lookup_type == 'quarter':
return (
"CAST(DATE_FORMAT(MAKEDATE(YEAR({field_name}), 1) + "
"INTERVAL QUARTER({field_name}) QUARTER - " +
"INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)"
).format(field_name=field_name)
if lookup_type == 'week':
return (
"CAST(DATE_FORMAT(DATE_SUB({field_name}, "
"INTERVAL WEEKDAY({field_name}) DAY), "
"'%%Y-%%m-%%d 00:00:00') AS DATETIME)"
).format(field_name=field_name)
try:
i = fields.index(lookup_type) + 1
except ValueError:
sql = field_name
else:
format_str = ''.join(format[:i] + format_def[i:])
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
return sql
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = {
'hour': '%%H:00:00',
'minute': '%%H:%%i:00',
'second': '%%H:%%i:%%s',
} # Use double percents to escape.
if lookup_type in fields:
format_str = fields[lookup_type]
return "CAST(DATE_FORMAT(%s, '%s') AS TIME)" % (field_name, format_str)
else:
return "TIME(%s)" % (field_name)
def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the tuple of returned data.
"""
return cursor.fetchall()
def format_for_duration_arithmetic(self, sql):
return 'INTERVAL %s MICROSECOND' % sql
def force_no_ordering(self):
"""
"ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
columns. If no ordering would otherwise be applied, we don't want any
implicit sorting going on.
"""
return [(None, ("NULL", [], False))]
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
return value
def last_executed_query(self, cursor, sql, params):
# With MySQLdb, cursor objects have an (undocumented) "_executed"
# attribute where the exact query sent to the database is saved.
# See MySQLdb/cursors.py in the source distribution.
# MySQLdb returns string, PyMySQL bytes.
return force_str(getattr(cursor, '_executed', None), errors='replace')
def no_limit_value(self):
# 2**64 - 1, as recommended by the MySQL documentation
return 18446744073709551615
def quote_name(self, name):
if name.startswith("`") and name.endswith("`"):
return name # Quoting once is enough.
return "`%s`" % name
def return_insert_columns(self, fields):
# MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
# statement.
if not fields:
return '', ()
columns = [
'%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
) for field in fields
]
return 'RETURNING %s' % ', '.join(columns), ()
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if not tables:
return []
sql = ['SET FOREIGN_KEY_CHECKS = 0;']
if reset_sequences:
# It's faster to TRUNCATE tables that require a sequence reset
# since ALTER TABLE AUTO_INCREMENT is slower than TRUNCATE.
sql.extend(
'%s %s;' % (
style.SQL_KEYWORD('TRUNCATE'),
style.SQL_FIELD(self.quote_name(table_name)),
) for table_name in tables
)
else:
# Otherwise issue a simple DELETE since it's faster than TRUNCATE
# and preserves sequences.
sql.extend(
'%s %s %s;' % (
style.SQL_KEYWORD('DELETE'),
style.SQL_KEYWORD('FROM'),
style.SQL_FIELD(self.quote_name(table_name)),
) for table_name in tables
)
sql.append('SET FOREIGN_KEY_CHECKS = 1;')
return sql
def sequence_reset_by_name_sql(self, style, sequences):
return [
'%s %s %s %s = 1;' % (
style.SQL_KEYWORD('ALTER'),
style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(sequence_info['table'])),
style.SQL_FIELD('AUTO_INCREMENT'),
) for sequence_info in sequences
]
def validate_autopk_value(self, value):
# Zero in AUTO_INCREMENT field does not work without the
# NO_AUTO_VALUE_ON_ZERO SQL mode.
if value == 0 and not self.connection.features.allows_auto_pk_0:
raise ValueError('The database backend does not accept 0 as a '
'value for AutoField.')
return value
def adapt_datetimefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# MySQL doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError("MySQL backend does not support timezone-aware datetimes when USE_TZ is False.")
return str(value)
def adapt_timefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# MySQL doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("MySQL backend does not support timezone-aware times.")
return value.isoformat(timespec='microseconds')
def max_name_length(self):
return 64
def pk_default_value(self):
return 'NULL'
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
return "VALUES " + values_sql
def combine_expression(self, connector, sub_expressions):
if connector == '^':
return 'POW(%s)' % ','.join(sub_expressions)
# Convert the result to a signed integer since MySQL's binary operators
# return an unsigned integer.
elif connector in ('&', '|', '<<', '#'):
connector = '^' if connector == '#' else connector
return 'CONVERT(%s, SIGNED)' % connector.join(sub_expressions)
elif connector == '>>':
lhs, rhs = sub_expressions
return 'FLOOR(%(lhs)s / POW(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
return super().combine_expression(connector, sub_expressions)
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == 'BooleanField':
converters.append(self.convert_booleanfield_value)
elif internal_type == 'DateTimeField':
if settings.USE_TZ:
converters.append(self.convert_datetimefield_value)
elif internal_type == 'UUIDField':
converters.append(self.convert_uuidfield_value)
return converters
def convert_booleanfield_value(self, value, expression, connection):
if value in (0, 1):
value = bool(value)
return value
def convert_datetimefield_value(self, value, expression, connection):
if value is not None:
value = timezone.make_aware(value, self.connection.timezone)
return value
def convert_uuidfield_value(self, value, expression, connection):
if value is not None:
value = uuid.UUID(value)
return value
def binary_placeholder_sql(self, value):
return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s'
def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
if internal_type == 'TimeField':
if self.connection.mysql_is_mariadb:
# MariaDB includes the microsecond component in TIME_TO_SEC as
# a decimal. MySQL returns an integer without microseconds.
return 'CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) * 1000000 AS SIGNED)' % {
'lhs': lhs_sql, 'rhs': rhs_sql
}, (*lhs_params, *rhs_params)
return (
"((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -"
" (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))"
) % {'lhs': lhs_sql, 'rhs': rhs_sql}, tuple(lhs_params) * 2 + tuple(rhs_params) * 2
params = (*rhs_params, *lhs_params)
return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params
def explain_query_prefix(self, format=None, **options):
# Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
if format and format.upper() == 'TEXT':
format = 'TRADITIONAL'
elif not format and 'TREE' in self.connection.features.supported_explain_formats:
# Use TREE by default (if supported) as it's more informative.
format = 'TREE'
analyze = options.pop('analyze', False)
prefix = super().explain_query_prefix(format, **options)
if analyze and self.connection.features.supports_explain_analyze:
# MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
prefix = 'ANALYZE' if self.connection.mysql_is_mariadb else prefix + ' ANALYZE'
if format and not (analyze and not self.connection.mysql_is_mariadb):
# Only MariaDB supports the analyze option with formats.
prefix += ' FORMAT=%s' % format
return prefix
def regex_lookup(self, lookup_type):
# REGEXP BINARY doesn't work correctly in MySQL 8+ and REGEXP_LIKE
# doesn't exist in MySQL 5.x or in MariaDB.
if self.connection.mysql_version < (8, 0, 0) or self.connection.mysql_is_mariadb:
if lookup_type == 'regex':
return '%s REGEXP BINARY %s'
return '%s REGEXP %s'
match_option = 'c' if lookup_type == 'regex' else 'i'
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
def insert_statement(self, ignore_conflicts=False):
return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
def lookup_cast(self, lookup_type, internal_type=None):
lookup = '%s'
if internal_type == 'JSONField':
if self.connection.mysql_is_mariadb or lookup_type in (
'iexact', 'contains', 'icontains', 'startswith', 'istartswith',
'endswith', 'iendswith', 'regex', 'iregex',
):
lookup = 'JSON_UNQUOTE(%s)'
return lookup

View File

@@ -0,0 +1,160 @@
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models import NOT_PROVIDED
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_rename_table = "RENAME TABLE %(old_table)s TO %(new_table)s"
sql_alter_column_null = "MODIFY %(column)s %(type)s NULL"
sql_alter_column_not_null = "MODIFY %(column)s %(type)s NOT NULL"
sql_alter_column_type = "MODIFY %(column)s %(type)s"
sql_alter_column_collate = "MODIFY %(column)s %(type)s%(collation)s"
sql_alter_column_no_default_null = 'ALTER COLUMN %(column)s SET DEFAULT NULL'
# No 'CASCADE' which works as a no-op in MySQL but is undocumented
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
sql_delete_unique = "ALTER TABLE %(table)s DROP INDEX %(name)s"
sql_create_column_inline_fk = (
', ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) '
'REFERENCES %(to_table)s(%(to_column)s)'
)
sql_delete_fk = "ALTER TABLE %(table)s DROP FOREIGN KEY %(name)s"
sql_delete_index = "DROP INDEX %(name)s ON %(table)s"
sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY"
sql_create_index = 'CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s'
@property
def sql_delete_check(self):
if self.connection.mysql_is_mariadb:
# The name of the column check constraint is the same as the field
# name on MariaDB. Adding IF EXISTS clause prevents migrations
# crash. Constraint is removed during a "MODIFY" column statement.
return 'ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s'
return 'ALTER TABLE %(table)s DROP CHECK %(name)s'
@property
def sql_rename_column(self):
# MariaDB >= 10.5.2 and MySQL >= 8.0.4 support an
# "ALTER TABLE ... RENAME COLUMN" statement.
if self.connection.mysql_is_mariadb:
if self.connection.mysql_version >= (10, 5, 2):
return super().sql_rename_column
elif self.connection.mysql_version >= (8, 0, 4):
return super().sql_rename_column
return 'ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s'
def quote_value(self, value):
self.connection.ensure_connection()
if isinstance(value, str):
value = value.replace('%', '%%')
# MySQLdb escapes to string, PyMySQL to bytes.
quoted = self.connection.connection.escape(value, self.connection.connection.encoders)
if isinstance(value, str) and isinstance(quoted, bytes):
quoted = quoted.decode()
return quoted
def _is_limited_data_type(self, field):
db_type = field.db_type(self.connection)
return db_type is not None and db_type.lower() in self.connection._limited_data_types
def skip_default(self, field):
if not self._supports_limited_data_type_defaults:
return self._is_limited_data_type(field)
return False
def skip_default_on_alter(self, field):
if self._is_limited_data_type(field) and not self.connection.mysql_is_mariadb:
# MySQL doesn't support defaults for BLOB and TEXT in the
# ALTER COLUMN statement.
return True
return False
@property
def _supports_limited_data_type_defaults(self):
# MariaDB >= 10.2.1 and MySQL >= 8.0.13 supports defaults for BLOB
# and TEXT.
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 2, 1)
return self.connection.mysql_version >= (8, 0, 13)
def _column_default_sql(self, field):
if (
not self.connection.mysql_is_mariadb and
self._supports_limited_data_type_defaults and
self._is_limited_data_type(field)
):
# MySQL supports defaults for BLOB and TEXT columns only if the
# default value is written as an expression i.e. in parentheses.
return '(%s)'
return super()._column_default_sql(field)
def add_field(self, model, field):
super().add_field(model, field)
# Simulate the effect of a one-off default.
# field.default may be unhashable, so a set isn't used for "in" check.
if self.skip_default(field) and field.default not in (None, NOT_PROVIDED):
effective_default = self.effective_default(field)
self.execute('UPDATE %(table)s SET %(column)s = %%s' % {
'table': self.quote_name(model._meta.db_table),
'column': self.quote_name(field.column),
}, [effective_default])
def _field_should_be_indexed(self, model, field):
if not super()._field_should_be_indexed(model, field):
return False
storage = self.connection.introspection.get_storage_engine(
self.connection.cursor(), model._meta.db_table
)
# No need to create an index for ForeignKey fields except if
# db_constraint=False because the index from that constraint won't be
# created.
if (storage == "InnoDB" and
field.get_internal_type() == 'ForeignKey' and
field.db_constraint):
return False
return not self._is_limited_data_type(field)
def _delete_composed_index(self, model, fields, *args):
"""
MySQL can remove an implicit FK index on a field when that field is
covered by another index like a unique_together. "covered" here means
that the more complex index starts like the simpler one.
http://bugs.mysql.com/bug.php?id=37910 / Django ticket #24757
We check here before removing the [unique|index]_together if we have to
recreate a FK index.
"""
first_field = model._meta.get_field(fields[0])
if first_field.get_internal_type() == 'ForeignKey':
constraint_names = self._constraint_names(model, [first_field.column], index=True)
if not constraint_names:
self.execute(
self._create_index_sql(model, fields=[first_field], suffix='')
)
return super()._delete_composed_index(model, fields, *args)
def _set_field_new_type_null_status(self, field, new_type):
"""
Keep the null property of the old field. If it has changed, it will be
handled separately.
"""
if field.null:
new_type += " NULL"
else:
new_type += " NOT NULL"
return new_type
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
new_type = self._set_field_new_type_null_status(old_field, new_type)
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
def _rename_field_sql(self, table, old_field, new_field, new_type):
new_type = self._set_field_new_type_null_status(old_field, new_type)
return super()._rename_field_sql(table, old_field, new_field, new_type)

View File

@@ -0,0 +1,69 @@
from django.core import checks
from django.db.backends.base.validation import BaseDatabaseValidation
from django.utils.version import get_docs_version
class DatabaseValidation(BaseDatabaseValidation):
def check(self, **kwargs):
issues = super().check(**kwargs)
issues.extend(self._check_sql_mode(**kwargs))
return issues
def _check_sql_mode(self, **kwargs):
if not (self.connection.sql_mode & {'STRICT_TRANS_TABLES', 'STRICT_ALL_TABLES'}):
return [checks.Warning(
"%s Strict Mode is not set for database connection '%s'"
% (self.connection.display_name, self.connection.alias),
hint=(
"%s's Strict Mode fixes many data integrity problems in "
"%s, such as data truncation upon insertion, by "
"escalating warnings into errors. It is strongly "
"recommended you activate it. See: "
"https://docs.djangoproject.com/en/%s/ref/databases/#mysql-sql-mode"
% (
self.connection.display_name,
self.connection.display_name,
get_docs_version(),
),
),
id='mysql.W002',
)]
return []
def check_field_type(self, field, field_type):
"""
MySQL has the following field length restriction:
No character (varchar) fields can have a length exceeding 255
characters if they have a unique index on them.
MySQL doesn't support a database index on some data types.
"""
errors = []
if (field_type.startswith('varchar') and field.unique and
(field.max_length is None or int(field.max_length) > 255)):
errors.append(
checks.Warning(
'%s may not allow unique CharFields to have a max_length '
'> 255.' % self.connection.display_name,
obj=field,
hint=(
'See: https://docs.djangoproject.com/en/%s/ref/'
'databases/#mysql-character-fields' % get_docs_version()
),
id='mysql.W003',
)
)
if field.db_index and field_type.lower() in self.connection._limited_data_types:
errors.append(
checks.Warning(
'%s does not support a database index on %s columns.'
% (self.connection.display_name, field_type),
hint=(
"An index won't be created. Silence this warning if "
"you don't care about it."
),
obj=field,
id='fields.W162',
)
)
return errors

View File

@@ -0,0 +1,554 @@
"""
Oracle database backend for Django.
Requires cx_Oracle: https://oracle.github.io/python-cx_Oracle/
"""
import datetime
import decimal
import os
import platform
from contextlib import contextmanager
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import IntegrityError
from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils.asyncio import async_unsafe
from django.utils.encoding import force_bytes, force_str
from django.utils.functional import cached_property
def _setup_environment(environ):
# Cygwin requires some special voodoo to set the environment variables
# properly so that Oracle will see them.
if platform.system().upper().startswith('CYGWIN'):
try:
import ctypes
except ImportError as e:
raise ImproperlyConfigured("Error loading ctypes: %s; "
"the Oracle backend requires ctypes to "
"operate correctly under Cygwin." % e)
kernel32 = ctypes.CDLL('kernel32')
for name, value in environ:
kernel32.SetEnvironmentVariableA(name, value)
else:
os.environ.update(environ)
_setup_environment([
# Oracle takes client-side character set encoding from the environment.
('NLS_LANG', '.AL32UTF8'),
# This prevents Unicode from getting mangled by getting encoded into the
# potentially non-Unicode database character set.
('ORA_NCHAR_LITERAL_REPLACE', 'TRUE'),
])
try:
import cx_Oracle as Database
except ImportError as e:
raise ImproperlyConfigured("Error loading cx_Oracle module: %s" % e)
# Some of these import cx_Oracle, so import them after checking if it's installed.
from .client import DatabaseClient # NOQA
from .creation import DatabaseCreation # NOQA
from .features import DatabaseFeatures # NOQA
from .introspection import DatabaseIntrospection # NOQA
from .operations import DatabaseOperations # NOQA
from .schema import DatabaseSchemaEditor # NOQA
from .utils import Oracle_datetime, dsn # NOQA
from .validation import DatabaseValidation # NOQA
@contextmanager
def wrap_oracle_errors():
try:
yield
except Database.DatabaseError as e:
# cx_Oracle raises a cx_Oracle.DatabaseError exception with the
# following attributes and values:
# code = 2091
# message = 'ORA-02091: transaction rolled back
# 'ORA-02291: integrity constraint (TEST_DJANGOTEST.SYS
# _C00102056) violated - parent key not found'
# or:
# 'ORA-00001: unique constraint (DJANGOTEST.DEFERRABLE_
# PINK_CONSTRAINT) violated
# Convert that case to Django's IntegrityError exception.
x = e.args[0]
if (
hasattr(x, 'code') and
hasattr(x, 'message') and
x.code == 2091 and
('ORA-02291' in x.message or 'ORA-00001' in x.message)
):
raise IntegrityError(*tuple(e.args))
raise
class _UninitializedOperatorsDescriptor:
def __get__(self, instance, cls=None):
# If connection.operators is looked up before a connection has been
# created, transparently initialize connection.operators to avert an
# AttributeError.
if instance is None:
raise AttributeError("operators not available as class attribute")
# Creating a cursor will initialize the operators.
instance.cursor().close()
return instance.__dict__['operators']
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'oracle'
display_name = 'Oracle'
# This dictionary maps Field objects to their associated Oracle column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
#
# Any format strings starting with "qn_" are quoted before being used in the
# output (the "qn_" prefix is stripped before the lookup is performed.
data_types = {
'AutoField': 'NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY',
'BigAutoField': 'NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY',
'BinaryField': 'BLOB',
'BooleanField': 'NUMBER(1)',
'CharField': 'NVARCHAR2(%(max_length)s)',
'DateField': 'DATE',
'DateTimeField': 'TIMESTAMP',
'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)',
'DurationField': 'INTERVAL DAY(9) TO SECOND(6)',
'FileField': 'NVARCHAR2(%(max_length)s)',
'FilePathField': 'NVARCHAR2(%(max_length)s)',
'FloatField': 'DOUBLE PRECISION',
'IntegerField': 'NUMBER(11)',
'JSONField': 'NCLOB',
'BigIntegerField': 'NUMBER(19)',
'IPAddressField': 'VARCHAR2(15)',
'GenericIPAddressField': 'VARCHAR2(39)',
'OneToOneField': 'NUMBER(11)',
'PositiveBigIntegerField': 'NUMBER(19)',
'PositiveIntegerField': 'NUMBER(11)',
'PositiveSmallIntegerField': 'NUMBER(11)',
'SlugField': 'NVARCHAR2(%(max_length)s)',
'SmallAutoField': 'NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY',
'SmallIntegerField': 'NUMBER(11)',
'TextField': 'NCLOB',
'TimeField': 'TIMESTAMP',
'URLField': 'VARCHAR2(%(max_length)s)',
'UUIDField': 'VARCHAR2(32)',
}
data_type_check_constraints = {
'BooleanField': '%(qn_column)s IN (0,1)',
'JSONField': '%(qn_column)s IS JSON',
'PositiveBigIntegerField': '%(qn_column)s >= 0',
'PositiveIntegerField': '%(qn_column)s >= 0',
'PositiveSmallIntegerField': '%(qn_column)s >= 0',
}
# Oracle doesn't support a database index on these columns.
_limited_data_types = ('clob', 'nclob', 'blob')
operators = _UninitializedOperatorsDescriptor()
_standard_operators = {
'exact': '= %s',
'iexact': '= UPPER(%s)',
'contains': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'icontains': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'endswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'istartswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'iendswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
}
_likec_operators = {
**_standard_operators,
'contains': "LIKEC %s ESCAPE '\\'",
'icontains': "LIKEC UPPER(%s) ESCAPE '\\'",
'startswith': "LIKEC %s ESCAPE '\\'",
'endswith': "LIKEC %s ESCAPE '\\'",
'istartswith': "LIKEC UPPER(%s) ESCAPE '\\'",
'iendswith': "LIKEC UPPER(%s) ESCAPE '\\'",
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, %, _)
# should be escaped on the database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
_pattern_ops = {
'contains': "'%%' || {} || '%%'",
'icontains': "'%%' || UPPER({}) || '%%'",
'startswith': "{} || '%%'",
'istartswith': "UPPER({}) || '%%'",
'endswith': "'%%' || {}",
'iendswith': "'%%' || UPPER({})",
}
_standard_pattern_ops = {k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)"
" ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
for k, v in _pattern_ops.items()}
_likec_pattern_ops = {k: "LIKEC " + v + " ESCAPE '\\'"
for k, v in _pattern_ops.items()}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
validation_class = DatabaseValidation
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True)
self.features.can_return_columns_from_insert = use_returning_into
def get_connection_params(self):
conn_params = self.settings_dict['OPTIONS'].copy()
if 'use_returning_into' in conn_params:
del conn_params['use_returning_into']
return conn_params
@async_unsafe
def get_new_connection(self, conn_params):
return Database.connect(
user=self.settings_dict['USER'],
password=self.settings_dict['PASSWORD'],
dsn=dsn(self.settings_dict),
**conn_params,
)
def init_connection_state(self):
cursor = self.create_cursor()
# Set the territory first. The territory overrides NLS_DATE_FORMAT
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of
# these are set in single statement it isn't clear what is supposed
# to happen.
cursor.execute("ALTER SESSION SET NLS_TERRITORY = 'AMERICA'")
# Set Oracle date to ANSI date format. This only needs to execute
# once when we create a new connection. We also set the Territory
# to 'AMERICA' which forces Sunday to evaluate to a '1' in
# TO_CHAR().
cursor.execute(
"ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'"
" NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'" +
(" TIME_ZONE = 'UTC'" if settings.USE_TZ else '')
)
cursor.close()
if 'operators' not in self.__dict__:
# Ticket #14149: Check whether our LIKE implementation will
# work for this connection or we need to fall back on LIKEC.
# This check is performed only once per DatabaseWrapper
# instance per thread, since subsequent connections will use
# the same settings.
cursor = self.create_cursor()
try:
cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
% self._standard_operators['contains'],
['X'])
except Database.DatabaseError:
self.operators = self._likec_operators
self.pattern_ops = self._likec_pattern_ops
else:
self.operators = self._standard_operators
self.pattern_ops = self._standard_pattern_ops
cursor.close()
self.connection.stmtcachesize = 20
# Ensure all changes are preserved even when AUTOCOMMIT is False.
if not self.get_autocommit():
self.commit()
@async_unsafe
def create_cursor(self, name=None):
return FormatStylePlaceholderCursor(self.connection)
def _commit(self):
if self.connection is not None:
with wrap_oracle_errors():
return self.connection.commit()
# Oracle doesn't support releasing savepoints. But we fake them when query
# logging is enabled to keep query counts consistent with other backends.
def _savepoint_commit(self, sid):
if self.queries_logged:
self.queries_log.append({
'sql': '-- RELEASE SAVEPOINT %s (faked)' % self.ops.quote_name(sid),
'time': '0.000',
})
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
self.connection.autocommit = autocommit
def check_constraints(self, table_names=None):
"""
Check constraints by setting them to immediate. Return them to deferred
afterward.
"""
with self.cursor() as cursor:
cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
cursor.execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
self.connection.ping()
except Database.Error:
return False
else:
return True
@cached_property
def cx_oracle_version(self):
return tuple(int(x) for x in Database.version.split('.'))
@cached_property
def oracle_version(self):
with self.temporary_connection():
return tuple(int(x) for x in self.connection.version.split('.'))
class OracleParam:
"""
Wrapper object for formatting parameters for Oracle. If the string
representation of the value is large enough (greater than 4000 characters)
the input size needs to be set as CLOB. Alternatively, if the parameter
has an `input_size` attribute, then the value of the `input_size` attribute
will be used instead. Otherwise, no input size will be set for the
parameter when executing the query.
"""
def __init__(self, param, cursor, strings_only=False):
# With raw SQL queries, datetimes can reach this function
# without being converted by DateTimeField.get_db_prep_value.
if settings.USE_TZ and (isinstance(param, datetime.datetime) and
not isinstance(param, Oracle_datetime)):
param = Oracle_datetime.from_datetime(param)
string_size = 0
# Oracle doesn't recognize True and False correctly.
if param is True:
param = 1
elif param is False:
param = 0
if hasattr(param, 'bind_parameter'):
self.force_bytes = param.bind_parameter(cursor)
elif isinstance(param, (Database.Binary, datetime.timedelta)):
self.force_bytes = param
else:
# To transmit to the database, we need Unicode if supported
# To get size right, we must consider bytes.
self.force_bytes = force_str(param, cursor.charset, strings_only)
if isinstance(self.force_bytes, str):
# We could optimize by only converting up to 4000 bytes here
string_size = len(force_bytes(param, cursor.charset, strings_only))
if hasattr(param, 'input_size'):
# If parameter has `input_size` attribute, use that.
self.input_size = param.input_size
elif string_size > 4000:
# Mark any string param greater than 4000 characters as a CLOB.
self.input_size = Database.CLOB
elif isinstance(param, datetime.datetime):
self.input_size = Database.TIMESTAMP
else:
self.input_size = None
class VariableWrapper:
"""
An adapter class for cursor variables that prevents the wrapped object
from being converted into a string when used to instantiate an OracleParam.
This can be used generally for any other object that should be passed into
Cursor.execute as-is.
"""
def __init__(self, var):
self.var = var
def bind_parameter(self, cursor):
return self.var
def __getattr__(self, key):
return getattr(self.var, key)
def __setattr__(self, key, value):
if key == 'var':
self.__dict__[key] = value
else:
setattr(self.var, key, value)
class FormatStylePlaceholderCursor:
"""
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
style. This fixes it -- but note that if you want to use a literal "%s" in
a query, you'll need to use "%%s".
"""
charset = 'utf-8'
def __init__(self, connection):
self.cursor = connection.cursor()
self.cursor.outputtypehandler = self._output_type_handler
@staticmethod
def _output_number_converter(value):
return decimal.Decimal(value) if '.' in value else int(value)
@staticmethod
def _get_decimal_converter(precision, scale):
if scale == 0:
return int
context = decimal.Context(prec=precision)
quantize_value = decimal.Decimal(1).scaleb(-scale)
return lambda v: decimal.Decimal(v).quantize(quantize_value, context=context)
@staticmethod
def _output_type_handler(cursor, name, defaultType, length, precision, scale):
"""
Called for each db column fetched from cursors. Return numbers as the
appropriate Python type.
"""
if defaultType == Database.NUMBER:
if scale == -127:
if precision == 0:
# NUMBER column: decimal-precision floating point.
# This will normally be an integer from a sequence,
# but it could be a decimal value.
outconverter = FormatStylePlaceholderCursor._output_number_converter
else:
# FLOAT column: binary-precision floating point.
# This comes from FloatField columns.
outconverter = float
elif precision > 0:
# NUMBER(p,s) column: decimal-precision fixed point.
# This comes from IntegerField and DecimalField columns.
outconverter = FormatStylePlaceholderCursor._get_decimal_converter(precision, scale)
else:
# No type information. This normally comes from a
# mathematical expression in the SELECT list. Guess int
# or Decimal based on whether it has a decimal point.
outconverter = FormatStylePlaceholderCursor._output_number_converter
return cursor.var(
Database.STRING,
size=255,
arraysize=cursor.arraysize,
outconverter=outconverter,
)
def _format_params(self, params):
try:
return {k: OracleParam(v, self, True) for k, v in params.items()}
except AttributeError:
return tuple(OracleParam(p, self, True) for p in params)
def _guess_input_sizes(self, params_list):
# Try dict handling; if that fails, treat as sequence
if hasattr(params_list[0], 'keys'):
sizes = {}
for params in params_list:
for k, value in params.items():
if value.input_size:
sizes[k] = value.input_size
if sizes:
self.setinputsizes(**sizes)
else:
# It's not a list of dicts; it's a list of sequences
sizes = [None] * len(params_list[0])
for params in params_list:
for i, value in enumerate(params):
if value.input_size:
sizes[i] = value.input_size
if sizes:
self.setinputsizes(*sizes)
def _param_generator(self, params):
# Try dict handling; if that fails, treat as sequence
if hasattr(params, 'items'):
return {k: v.force_bytes for k, v in params.items()}
else:
return [p.force_bytes for p in params]
def _fix_for_params(self, query, params, unify_by_values=False):
# cx_Oracle wants no trailing ';' for SQL statements. For PL/SQL, it
# it does want a trailing ';' but not a trailing '/'. However, these
# characters must be included in the original query in case the query
# is being passed to SQL*Plus.
if query.endswith(';') or query.endswith('/'):
query = query[:-1]
if params is None:
params = []
elif hasattr(params, 'keys'):
# Handle params as dict
args = {k: ":%s" % k for k in params}
query = query % args
elif unify_by_values and params:
# Handle params as a dict with unified query parameters by their
# values. It can be used only in single query execute() because
# executemany() shares the formatted query with each of the params
# list. e.g. for input params = [0.75, 2, 0.75, 'sth', 0.75]
# params_dict = {0.75: ':arg0', 2: ':arg1', 'sth': ':arg2'}
# args = [':arg0', ':arg1', ':arg0', ':arg2', ':arg0']
# params = {':arg0': 0.75, ':arg1': 2, ':arg2': 'sth'}
params_dict = {
param: ':arg%d' % i
for i, param in enumerate(dict.fromkeys(params))
}
args = [params_dict[param] for param in params]
params = {value: key for key, value in params_dict.items()}
query = query % tuple(args)
else:
# Handle params as sequence
args = [(':arg%d' % i) for i in range(len(params))]
query = query % tuple(args)
return query, self._format_params(params)
def execute(self, query, params=None):
query, params = self._fix_for_params(query, params, unify_by_values=True)
self._guess_input_sizes([params])
with wrap_oracle_errors():
return self.cursor.execute(query, self._param_generator(params))
def executemany(self, query, params=None):
if not params:
# No params given, nothing to do
return None
# uniform treatment for sequences and iterables
params_iter = iter(params)
query, firstparams = self._fix_for_params(query, next(params_iter))
# we build a list of formatted params; as we're going to traverse it
# more than once, we can't make it lazy by using a generator
formatted = [firstparams] + [self._format_params(p) for p in params_iter]
self._guess_input_sizes(formatted)
with wrap_oracle_errors():
return self.cursor.executemany(query, [self._param_generator(p) for p in formatted])
def close(self):
try:
self.cursor.close()
except Database.InterfaceError:
# already closed
pass
def var(self, *args):
return VariableWrapper(self.cursor.var(*args))
def arrayvar(self, *args):
return VariableWrapper(self.cursor.arrayvar(*args))
def __getattr__(self, attr):
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)

View File

@@ -0,0 +1,27 @@
import shutil
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = 'sqlplus'
wrapper_name = 'rlwrap'
@staticmethod
def connect_string(settings_dict):
from django.db.backends.oracle.utils import dsn
return '%s/"%s"@%s' % (
settings_dict['USER'],
settings_dict['PASSWORD'],
dsn(settings_dict),
)
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name, '-L', cls.connect_string(settings_dict)]
wrapper_path = shutil.which(cls.wrapper_name)
if wrapper_path:
args = [wrapper_path, *args]
args.extend(parameters)
return args, None

View File

@@ -0,0 +1,400 @@
import sys
from django.conf import settings
from django.db import DatabaseError
from django.db.backends.base.creation import BaseDatabaseCreation
from django.utils.crypto import get_random_string
from django.utils.functional import cached_property
TEST_DATABASE_PREFIX = 'test_'
class DatabaseCreation(BaseDatabaseCreation):
@cached_property
def _maindb_connection(self):
"""
This is analogous to other backends' `_nodb_connection` property,
which allows access to an "administrative" connection which can
be used to manage the test databases.
For Oracle, the only connection that can be used for that purpose
is the main (non-test) connection.
"""
settings_dict = settings.DATABASES[self.connection.alias]
user = settings_dict.get('SAVED_USER') or settings_dict['USER']
password = settings_dict.get('SAVED_PASSWORD') or settings_dict['PASSWORD']
settings_dict = {**settings_dict, 'USER': user, 'PASSWORD': password}
DatabaseWrapper = type(self.connection)
return DatabaseWrapper(settings_dict, alias=self.connection.alias)
def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
parameters = self._get_test_db_params()
with self._maindb_connection.cursor() as cursor:
if self._test_database_create():
try:
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
except Exception as e:
if 'ORA-01543' not in str(e):
# All errors except "tablespace already exists" cancel tests
self.log('Got an error creating the test database: %s' % e)
sys.exit(2)
if not autoclobber:
confirm = input(
"It appears the test database, %s, already exists. "
"Type 'yes' to delete it, or 'no' to cancel: " % parameters['user'])
if autoclobber or confirm == 'yes':
if verbosity >= 1:
self.log("Destroying old test database for alias '%s'..." % self.connection.alias)
try:
self._execute_test_db_destruction(cursor, parameters, verbosity)
except DatabaseError as e:
if 'ORA-29857' in str(e):
self._handle_objects_preventing_db_destruction(cursor, parameters,
verbosity, autoclobber)
else:
# Ran into a database error that isn't about leftover objects in the tablespace
self.log('Got an error destroying the old test database: %s' % e)
sys.exit(2)
except Exception as e:
self.log('Got an error destroying the old test database: %s' % e)
sys.exit(2)
try:
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
except Exception as e:
self.log('Got an error recreating the test database: %s' % e)
sys.exit(2)
else:
self.log('Tests cancelled.')
sys.exit(1)
if self._test_user_create():
if verbosity >= 1:
self.log('Creating test user...')
try:
self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e:
if 'ORA-01920' not in str(e):
# All errors except "user already exists" cancel tests
self.log('Got an error creating the test user: %s' % e)
sys.exit(2)
if not autoclobber:
confirm = input(
"It appears the test user, %s, already exists. Type "
"'yes' to delete it, or 'no' to cancel: " % parameters['user'])
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
self.log('Destroying old test user...')
self._destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1:
self.log('Creating test user...')
self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e:
self.log('Got an error recreating the test user: %s' % e)
sys.exit(2)
else:
self.log('Tests cancelled.')
sys.exit(1)
self._maindb_connection.close() # done with main user -- test user and tablespaces created
self._switch_to_test_user(parameters)
return self.connection.settings_dict['NAME']
def _switch_to_test_user(self, parameters):
"""
Switch to the user that's used for creating the test database.
Oracle doesn't have the concept of separate databases under the same
user, so a separate user is used; see _create_test_db(). The main user
is also needed for cleanup when testing is completed, so save its
credentials in the SAVED_USER/SAVED_PASSWORD key in the settings dict.
"""
real_settings = settings.DATABASES[self.connection.alias]
real_settings['SAVED_USER'] = self.connection.settings_dict['SAVED_USER'] = \
self.connection.settings_dict['USER']
real_settings['SAVED_PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD'] = \
self.connection.settings_dict['PASSWORD']
real_test_settings = real_settings['TEST']
test_settings = self.connection.settings_dict['TEST']
real_test_settings['USER'] = real_settings['USER'] = test_settings['USER'] = \
self.connection.settings_dict['USER'] = parameters['user']
real_settings['PASSWORD'] = self.connection.settings_dict['PASSWORD'] = parameters['password']
def set_as_test_mirror(self, primary_settings_dict):
"""
Set this database up to be used in testing as a mirror of a primary
database whose settings are given.
"""
self.connection.settings_dict['USER'] = primary_settings_dict['USER']
self.connection.settings_dict['PASSWORD'] = primary_settings_dict['PASSWORD']
def _handle_objects_preventing_db_destruction(self, cursor, parameters, verbosity, autoclobber):
# There are objects in the test tablespace which prevent dropping it
# The easy fix is to drop the test user -- but are we allowed to do so?
self.log(
'There are objects in the old test database which prevent its destruction.\n'
'If they belong to the test user, deleting the user will allow the test '
'database to be recreated.\n'
'Otherwise, you will need to find and remove each of these objects, '
'or use a different tablespace.\n'
)
if self._test_user_create():
if not autoclobber:
confirm = input("Type 'yes' to delete user %s: " % parameters['user'])
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
self.log('Destroying old test user...')
self._destroy_test_user(cursor, parameters, verbosity)
except Exception as e:
self.log('Got an error destroying the test user: %s' % e)
sys.exit(2)
try:
if verbosity >= 1:
self.log("Destroying old test database for alias '%s'..." % self.connection.alias)
self._execute_test_db_destruction(cursor, parameters, verbosity)
except Exception as e:
self.log('Got an error destroying the test database: %s' % e)
sys.exit(2)
else:
self.log('Tests cancelled -- test database cannot be recreated.')
sys.exit(1)
else:
self.log("Django is configured to use pre-existing test user '%s',"
" and will not attempt to delete it." % parameters['user'])
self.log('Tests cancelled -- test database cannot be recreated.')
sys.exit(1)
def _destroy_test_db(self, test_database_name, verbosity=1):
"""
Destroy a test database, prompting the user for confirmation if the
database already exists. Return the name of the test database created.
"""
self.connection.settings_dict['USER'] = self.connection.settings_dict['SAVED_USER']
self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD']
self.connection.close()
parameters = self._get_test_db_params()
with self._maindb_connection.cursor() as cursor:
if self._test_user_create():
if verbosity >= 1:
self.log('Destroying test user...')
self._destroy_test_user(cursor, parameters, verbosity)
if self._test_database_create():
if verbosity >= 1:
self.log('Destroying test database tables...')
self._execute_test_db_destruction(cursor, parameters, verbosity)
self._maindb_connection.close()
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):
if verbosity >= 2:
self.log('_create_test_db(): dbname = %s' % parameters['user'])
if self._test_database_oracle_managed_files():
statements = [
"""
CREATE TABLESPACE %(tblspace)s
DATAFILE SIZE %(size)s
AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s
""",
"""
CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
TEMPFILE SIZE %(size_tmp)s
AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s
""",
]
else:
statements = [
"""
CREATE TABLESPACE %(tblspace)s
DATAFILE '%(datafile)s' SIZE %(size)s REUSE
AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s
""",
"""
CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
TEMPFILE '%(datafile_tmp)s' SIZE %(size_tmp)s REUSE
AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s
""",
]
# Ignore "tablespace already exists" error when keepdb is on.
acceptable_ora_err = 'ORA-01543' if keepdb else None
self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err)
def _create_test_user(self, cursor, parameters, verbosity, keepdb=False):
if verbosity >= 2:
self.log('_create_test_user(): username = %s' % parameters['user'])
statements = [
"""CREATE USER %(user)s
IDENTIFIED BY "%(password)s"
DEFAULT TABLESPACE %(tblspace)s
TEMPORARY TABLESPACE %(tblspace_temp)s
QUOTA UNLIMITED ON %(tblspace)s
""",
"""GRANT CREATE SESSION,
CREATE TABLE,
CREATE SEQUENCE,
CREATE PROCEDURE,
CREATE TRIGGER
TO %(user)s""",
]
# Ignore "user already exists" error when keepdb is on
acceptable_ora_err = 'ORA-01920' if keepdb else None
success = self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err)
# If the password was randomly generated, change the user accordingly.
if not success and self._test_settings_get('PASSWORD') is None:
set_password = 'ALTER USER %(user)s IDENTIFIED BY "%(password)s"'
self._execute_statements(cursor, [set_password], parameters, verbosity)
# Most test suites can be run without "create view" and
# "create materialized view" privileges. But some need it.
for object_type in ('VIEW', 'MATERIALIZED VIEW'):
extra = 'GRANT CREATE %(object_type)s TO %(user)s'
parameters['object_type'] = object_type
success = self._execute_allow_fail_statements(cursor, [extra], parameters, verbosity, 'ORA-01031')
if not success and verbosity >= 2:
self.log('Failed to grant CREATE %s permission to test user. This may be ok.' % object_type)
def _execute_test_db_destruction(self, cursor, parameters, verbosity):
if verbosity >= 2:
self.log('_execute_test_db_destruction(): dbname=%s' % parameters['user'])
statements = [
'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
]
self._execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_user(self, cursor, parameters, verbosity):
if verbosity >= 2:
self.log('_destroy_test_user(): user=%s' % parameters['user'])
self.log('Be patient. This can take some time...')
statements = [
'DROP USER %(user)s CASCADE',
]
self._execute_statements(cursor, statements, parameters, verbosity)
def _execute_statements(self, cursor, statements, parameters, verbosity, allow_quiet_fail=False):
for template in statements:
stmt = template % parameters
if verbosity >= 2:
print(stmt)
try:
cursor.execute(stmt)
except Exception as err:
if (not allow_quiet_fail) or verbosity >= 2:
self.log('Failed (%s)' % (err))
raise
def _execute_allow_fail_statements(self, cursor, statements, parameters, verbosity, acceptable_ora_err):
"""
Execute statements which are allowed to fail silently if the Oracle
error code given by `acceptable_ora_err` is raised. Return True if the
statements execute without an exception, or False otherwise.
"""
try:
# Statement can fail when acceptable_ora_err is not None
allow_quiet_fail = acceptable_ora_err is not None and len(acceptable_ora_err) > 0
self._execute_statements(cursor, statements, parameters, verbosity, allow_quiet_fail=allow_quiet_fail)
return True
except DatabaseError as err:
description = str(err)
if acceptable_ora_err is None or acceptable_ora_err not in description:
raise
return False
def _get_test_db_params(self):
return {
'dbname': self._test_database_name(),
'user': self._test_database_user(),
'password': self._test_database_passwd(),
'tblspace': self._test_database_tblspace(),
'tblspace_temp': self._test_database_tblspace_tmp(),
'datafile': self._test_database_tblspace_datafile(),
'datafile_tmp': self._test_database_tblspace_tmp_datafile(),
'maxsize': self._test_database_tblspace_maxsize(),
'maxsize_tmp': self._test_database_tblspace_tmp_maxsize(),
'size': self._test_database_tblspace_size(),
'size_tmp': self._test_database_tblspace_tmp_size(),
'extsize': self._test_database_tblspace_extsize(),
'extsize_tmp': self._test_database_tblspace_tmp_extsize(),
}
def _test_settings_get(self, key, default=None, prefixed=None):
"""
Return a value from the test settings dict, or a given default, or a
prefixed entry from the main settings dict.
"""
settings_dict = self.connection.settings_dict
val = settings_dict['TEST'].get(key, default)
if val is None and prefixed:
val = TEST_DATABASE_PREFIX + settings_dict[prefixed]
return val
def _test_database_name(self):
return self._test_settings_get('NAME', prefixed='NAME')
def _test_database_create(self):
return self._test_settings_get('CREATE_DB', default=True)
def _test_user_create(self):
return self._test_settings_get('CREATE_USER', default=True)
def _test_database_user(self):
return self._test_settings_get('USER', prefixed='USER')
def _test_database_passwd(self):
password = self._test_settings_get('PASSWORD')
if password is None and self._test_user_create():
# Oracle passwords are limited to 30 chars and can't contain symbols.
password = get_random_string(30)
return password
def _test_database_tblspace(self):
return self._test_settings_get('TBLSPACE', prefixed='USER')
def _test_database_tblspace_tmp(self):
settings_dict = self.connection.settings_dict
return settings_dict['TEST'].get('TBLSPACE_TMP',
TEST_DATABASE_PREFIX + settings_dict['USER'] + '_temp')
def _test_database_tblspace_datafile(self):
tblspace = '%s.dbf' % self._test_database_tblspace()
return self._test_settings_get('DATAFILE', default=tblspace)
def _test_database_tblspace_tmp_datafile(self):
tblspace = '%s.dbf' % self._test_database_tblspace_tmp()
return self._test_settings_get('DATAFILE_TMP', default=tblspace)
def _test_database_tblspace_maxsize(self):
return self._test_settings_get('DATAFILE_MAXSIZE', default='500M')
def _test_database_tblspace_tmp_maxsize(self):
return self._test_settings_get('DATAFILE_TMP_MAXSIZE', default='500M')
def _test_database_tblspace_size(self):
return self._test_settings_get('DATAFILE_SIZE', default='50M')
def _test_database_tblspace_tmp_size(self):
return self._test_settings_get('DATAFILE_TMP_SIZE', default='50M')
def _test_database_tblspace_extsize(self):
return self._test_settings_get('DATAFILE_EXTSIZE', default='25M')
def _test_database_tblspace_tmp_extsize(self):
return self._test_settings_get('DATAFILE_TMP_EXTSIZE', default='25M')
def _test_database_oracle_managed_files(self):
return self._test_settings_get('ORACLE_MANAGED_FILES', default=False)
def _get_test_db_name(self):
"""
Return the 'production' DB name to get the test DB creation machinery
to work. This isn't a great deal in this case because DB names as
handled by Django don't have real counterparts in Oracle.
"""
return self.connection.settings_dict['NAME']
def test_db_signature(self):
settings_dict = self.connection.settings_dict
return (
settings_dict['HOST'],
settings_dict['PORT'],
settings_dict['ENGINE'],
settings_dict['NAME'],
self._test_database_user(),
)

View File

@@ -0,0 +1,120 @@
from django.db import DatabaseError, InterfaceError
from django.db.backends.base.features import BaseDatabaseFeatures
from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures):
# Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got
# BLOB" when grouping by LOBs (#24096).
allows_group_by_lob = False
interprets_empty_strings_as_nulls = True
has_select_for_update = True
has_select_for_update_nowait = True
has_select_for_update_skip_locked = True
has_select_for_update_of = True
select_for_update_of_column = True
can_return_columns_from_insert = True
supports_subqueries_in_group_by = False
ignores_unnecessary_order_by_in_subqueries = False
supports_transactions = True
supports_timezones = False
has_native_duration_field = True
can_defer_constraint_checks = True
supports_partially_nullable_unique_constraints = False
supports_deferrable_unique_constraints = True
truncates_names = True
supports_tablespaces = True
supports_sequence_reset = False
can_introspect_materialized_views = True
atomic_transactions = False
supports_combined_alters = False
nulls_order_largest = True
requires_literal_defaults = True
closed_cursor_error_class = InterfaceError
bare_select_suffix = " FROM DUAL"
# select for update with limit can be achieved on Oracle, but not with the current backend.
supports_select_for_update_with_limit = False
supports_temporal_subtraction = True
# Oracle doesn't ignore quoted identifiers case but the current backend
# does by uppercasing all identifiers.
ignores_table_name_case = True
supports_index_on_text_field = False
has_case_insensitive_like = False
create_test_procedure_without_params_sql = """
CREATE PROCEDURE "TEST_PROCEDURE" AS
V_I INTEGER;
BEGIN
V_I := 1;
END;
"""
create_test_procedure_with_int_param_sql = """
CREATE PROCEDURE "TEST_PROCEDURE" (P_I INTEGER) AS
V_I INTEGER;
BEGIN
V_I := P_I;
END;
"""
supports_callproc_kwargs = True
supports_over_clause = True
supports_frame_range_fixed_distance = True
supports_ignore_conflicts = False
max_query_params = 2**16 - 1
supports_partial_indexes = False
supports_slicing_ordering_in_compound = True
allows_multiple_constraints_on_same_fields = False
supports_boolean_expr_in_select_clause = False
supports_primitives_in_json_field = False
supports_json_field_contains = False
supports_collation_on_textfield = False
test_collations = {
'ci': 'BINARY_CI',
'cs': 'BINARY',
'non_default': 'SWEDISH_CI',
'swedish_ci': 'SWEDISH_CI',
}
test_now_utc_template = "CURRENT_TIMESTAMP AT TIME ZONE 'UTC'"
django_test_skips = {
"Oracle doesn't support SHA224.": {
'db_functions.text.test_sha224.SHA224Tests.test_basic',
'db_functions.text.test_sha224.SHA224Tests.test_transform',
},
"Oracle doesn't support bitwise XOR.": {
'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor',
'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_null',
},
"Oracle requires ORDER BY in row_number, ANSI:SQL doesn't.": {
'expressions_window.tests.WindowFunctionTests.test_row_number_no_ordering',
},
'Raises ORA-00600: internal error code.': {
'model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery',
},
}
django_test_expected_failures = {
# A bug in Django/cx_Oracle with respect to string handling (#23843).
'annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions',
'annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions_can_ref_other_functions',
}
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
'GenericIPAddressField': 'CharField',
'PositiveBigIntegerField': 'BigIntegerField',
'PositiveIntegerField': 'IntegerField',
'PositiveSmallIntegerField': 'IntegerField',
'SmallIntegerField': 'IntegerField',
'TimeField': 'DateTimeField',
}
@cached_property
def supports_collation_on_charfield(self):
with self.connection.cursor() as cursor:
try:
cursor.execute("SELECT CAST('a' AS VARCHAR2(4001)) FROM dual")
except DatabaseError as e:
if e.args[0].code == 910:
return False
raise
return True

View File

@@ -0,0 +1,22 @@
from django.db.models import DecimalField, DurationField, Func
class IntervalToSeconds(Func):
function = ''
template = """
EXTRACT(day from %(expressions)s) * 86400 +
EXTRACT(hour from %(expressions)s) * 3600 +
EXTRACT(minute from %(expressions)s) * 60 +
EXTRACT(second from %(expressions)s)
"""
def __init__(self, expression, *, output_field=None, **extra):
super().__init__(expression, output_field=output_field or DecimalField(), **extra)
class SecondsToInterval(Func):
function = 'NUMTODSINTERVAL'
template = "%(function)s(%(expressions)s, 'SECOND')"
def __init__(self, expression, *, output_field=None, **extra):
super().__init__(expression, output_field=output_field or DurationField(), **extra)

View File

@@ -0,0 +1,336 @@
from collections import namedtuple
import cx_Oracle
from django.db import models
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
)
from django.utils.functional import cached_property
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('is_autofield', 'is_json'))
class DatabaseIntrospection(BaseDatabaseIntrospection):
cache_bust_counter = 1
# Maps type objects to Django Field types.
@cached_property
def data_types_reverse(self):
if self.connection.cx_oracle_version < (8,):
return {
cx_Oracle.BLOB: 'BinaryField',
cx_Oracle.CLOB: 'TextField',
cx_Oracle.DATETIME: 'DateField',
cx_Oracle.FIXED_CHAR: 'CharField',
cx_Oracle.FIXED_NCHAR: 'CharField',
cx_Oracle.INTERVAL: 'DurationField',
cx_Oracle.NATIVE_FLOAT: 'FloatField',
cx_Oracle.NCHAR: 'CharField',
cx_Oracle.NCLOB: 'TextField',
cx_Oracle.NUMBER: 'DecimalField',
cx_Oracle.STRING: 'CharField',
cx_Oracle.TIMESTAMP: 'DateTimeField',
}
else:
return {
cx_Oracle.DB_TYPE_DATE: 'DateField',
cx_Oracle.DB_TYPE_BINARY_DOUBLE: 'FloatField',
cx_Oracle.DB_TYPE_BLOB: 'BinaryField',
cx_Oracle.DB_TYPE_CHAR: 'CharField',
cx_Oracle.DB_TYPE_CLOB: 'TextField',
cx_Oracle.DB_TYPE_INTERVAL_DS: 'DurationField',
cx_Oracle.DB_TYPE_NCHAR: 'CharField',
cx_Oracle.DB_TYPE_NCLOB: 'TextField',
cx_Oracle.DB_TYPE_NVARCHAR: 'CharField',
cx_Oracle.DB_TYPE_NUMBER: 'DecimalField',
cx_Oracle.DB_TYPE_TIMESTAMP: 'DateTimeField',
cx_Oracle.DB_TYPE_VARCHAR: 'CharField',
}
def get_field_type(self, data_type, description):
if data_type == cx_Oracle.NUMBER:
precision, scale = description[4:6]
if scale == 0:
if precision > 11:
return 'BigAutoField' if description.is_autofield else 'BigIntegerField'
elif 1 < precision < 6 and description.is_autofield:
return 'SmallAutoField'
elif precision == 1:
return 'BooleanField'
elif description.is_autofield:
return 'AutoField'
else:
return 'IntegerField'
elif scale == -127:
return 'FloatField'
elif data_type == cx_Oracle.NCLOB and description.is_json:
return 'JSONField'
return super().get_field_type(data_type, description)
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute("""
SELECT table_name, 't'
FROM user_tables
WHERE
NOT EXISTS (
SELECT 1
FROM user_mviews
WHERE user_mviews.mview_name = user_tables.table_name
)
UNION ALL
SELECT view_name, 'v' FROM user_views
UNION ALL
SELECT mview_name, 'v' FROM user_mviews
""")
return [TableInfo(self.identifier_converter(row[0]), row[1]) for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface.
"""
# user_tab_columns gives data default for columns
cursor.execute("""
SELECT
user_tab_cols.column_name,
user_tab_cols.data_default,
CASE
WHEN user_tab_cols.collation = user_tables.default_collation
THEN NULL
ELSE user_tab_cols.collation
END collation,
CASE
WHEN user_tab_cols.char_used IS NULL
THEN user_tab_cols.data_length
ELSE user_tab_cols.char_length
END as internal_size,
CASE
WHEN user_tab_cols.identity_column = 'YES' THEN 1
ELSE 0
END as is_autofield,
CASE
WHEN EXISTS (
SELECT 1
FROM user_json_columns
WHERE
user_json_columns.table_name = user_tab_cols.table_name AND
user_json_columns.column_name = user_tab_cols.column_name
)
THEN 1
ELSE 0
END as is_json
FROM user_tab_cols
LEFT OUTER JOIN
user_tables ON user_tables.table_name = user_tab_cols.table_name
WHERE user_tab_cols.table_name = UPPER(%s)
""", [table_name])
field_map = {
column: (internal_size, default if default != 'NULL' else None, collation, is_autofield, is_json)
for column, default, collation, internal_size, is_autofield, is_json in cursor.fetchall()
}
self.cache_bust_counter += 1
cursor.execute("SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format(
self.connection.ops.quote_name(table_name),
self.cache_bust_counter))
description = []
for desc in cursor.description:
name = desc[0]
internal_size, default, collation, is_autofield, is_json = field_map[name]
name = name % {} # cx_Oracle, for some reason, doubles percent signs.
description.append(FieldInfo(
self.identifier_converter(name), *desc[1:3], internal_size, desc[4] or 0,
desc[5] or 0, *desc[6:], default, collation, is_autofield, is_json,
))
return description
def identifier_converter(self, name):
"""Identifier comparison is case insensitive under Oracle."""
return name.lower()
def get_sequences(self, cursor, table_name, table_fields=()):
cursor.execute("""
SELECT
user_tab_identity_cols.sequence_name,
user_tab_identity_cols.column_name
FROM
user_tab_identity_cols,
user_constraints,
user_cons_columns cols
WHERE
user_constraints.constraint_name = cols.constraint_name
AND user_constraints.table_name = user_tab_identity_cols.table_name
AND cols.column_name = user_tab_identity_cols.column_name
AND user_constraints.constraint_type = 'P'
AND user_tab_identity_cols.table_name = UPPER(%s)
""", [table_name])
# Oracle allows only one identity column per table.
row = cursor.fetchone()
if row:
return [{
'name': self.identifier_converter(row[0]),
'table': self.identifier_converter(table_name),
'column': self.identifier_converter(row[1]),
}]
# To keep backward compatibility for AutoFields that aren't Oracle
# identity columns.
for f in table_fields:
if isinstance(f, models.AutoField):
return [{'table': table_name, 'column': f.column}]
return []
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table.
"""
table_name = table_name.upper()
cursor.execute("""
SELECT ca.column_name, cb.table_name, cb.column_name
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb
WHERE user_constraints.table_name = %s AND
user_constraints.constraint_name = ca.constraint_name AND
user_constraints.r_constraint_name = cb.constraint_name AND
ca.position = cb.position""", [table_name])
return {
self.identifier_converter(field_name): (
self.identifier_converter(rel_field_name),
self.identifier_converter(rel_table_name),
) for field_name, rel_table_name, rel_field_name in cursor.fetchall()
}
def get_key_columns(self, cursor, table_name):
cursor.execute("""
SELECT ccol.column_name, rcol.table_name AS referenced_table, rcol.column_name AS referenced_column
FROM user_constraints c
JOIN user_cons_columns ccol
ON ccol.constraint_name = c.constraint_name
JOIN user_cons_columns rcol
ON rcol.constraint_name = c.r_constraint_name
WHERE c.table_name = %s AND c.constraint_type = 'R'""", [table_name.upper()])
return [
tuple(self.identifier_converter(cell) for cell in row)
for row in cursor.fetchall()
]
def get_primary_key_column(self, cursor, table_name):
cursor.execute("""
SELECT
cols.column_name
FROM
user_constraints,
user_cons_columns cols
WHERE
user_constraints.constraint_name = cols.constraint_name AND
user_constraints.constraint_type = 'P' AND
user_constraints.table_name = UPPER(%s) AND
cols.position = 1
""", [table_name])
row = cursor.fetchone()
return self.identifier_converter(row[0]) if row else None
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns.
"""
constraints = {}
# Loop over the constraints, getting PKs, uniques, and checks
cursor.execute("""
SELECT
user_constraints.constraint_name,
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.position),
CASE user_constraints.constraint_type
WHEN 'P' THEN 1
ELSE 0
END AS is_primary_key,
CASE
WHEN user_constraints.constraint_type IN ('P', 'U') THEN 1
ELSE 0
END AS is_unique,
CASE user_constraints.constraint_type
WHEN 'C' THEN 1
ELSE 0
END AS is_check_constraint
FROM
user_constraints
LEFT OUTER JOIN
user_cons_columns cols ON user_constraints.constraint_name = cols.constraint_name
WHERE
user_constraints.constraint_type = ANY('P', 'U', 'C')
AND user_constraints.table_name = UPPER(%s)
GROUP BY user_constraints.constraint_name, user_constraints.constraint_type
""", [table_name])
for constraint, columns, pk, unique, check in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
'columns': columns.split(','),
'primary_key': pk,
'unique': unique,
'foreign_key': None,
'check': check,
'index': unique, # All uniques come with an index
}
# Foreign key constraints
cursor.execute("""
SELECT
cons.constraint_name,
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.position),
LOWER(rcols.table_name),
LOWER(rcols.column_name)
FROM
user_constraints cons
INNER JOIN
user_cons_columns rcols ON rcols.constraint_name = cons.r_constraint_name AND rcols.position = 1
LEFT OUTER JOIN
user_cons_columns cols ON cons.constraint_name = cols.constraint_name
WHERE
cons.constraint_type = 'R' AND
cons.table_name = UPPER(%s)
GROUP BY cons.constraint_name, rcols.table_name, rcols.column_name
""", [table_name])
for constraint, columns, other_table, other_column in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
'primary_key': False,
'unique': False,
'foreign_key': (other_table, other_column),
'check': False,
'index': False,
'columns': columns.split(','),
}
# Now get indexes
cursor.execute("""
SELECT
ind.index_name,
LOWER(ind.index_type),
LOWER(ind.uniqueness),
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.column_position),
LISTAGG(cols.descend, ',') WITHIN GROUP (ORDER BY cols.column_position)
FROM
user_ind_columns cols, user_indexes ind
WHERE
cols.table_name = UPPER(%s) AND
NOT EXISTS (
SELECT 1
FROM user_constraints cons
WHERE ind.index_name = cons.index_name
) AND cols.index_name = ind.index_name
GROUP BY ind.index_name, ind.index_type, ind.uniqueness
""", [table_name])
for constraint, type_, unique, columns, orders in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
'primary_key': False,
'unique': unique == 'unique',
'foreign_key': None,
'check': False,
'index': True,
'type': 'idx' if type_ == 'normal' else type_,
'columns': columns.split(','),
'orders': orders.split(','),
}
return constraints

View File

@@ -0,0 +1,647 @@
import datetime
import uuid
from functools import lru_cache
from django.conf import settings
from django.db import DatabaseError, NotSupportedError
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import (
split_tzname_delta, strip_quotes, truncate_name,
)
from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
from django.db.models.expressions import RawSQL
from django.db.models.sql.where import WhereNode
from django.utils import timezone
from django.utils.encoding import force_bytes, force_str
from django.utils.functional import cached_property
from django.utils.regex_helper import _lazy_re_compile
from .base import Database
from .utils import BulkInsertMapper, InsertVar, Oracle_datetime
class DatabaseOperations(BaseDatabaseOperations):
# Oracle uses NUMBER(5), NUMBER(11), and NUMBER(19) for integer fields.
# SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by
# SmallAutoField, to preserve backward compatibility.
integer_field_ranges = {
'SmallIntegerField': (-99999999999, 99999999999),
'IntegerField': (-99999999999, 99999999999),
'BigIntegerField': (-9999999999999999999, 9999999999999999999),
'PositiveBigIntegerField': (0, 9999999999999999999),
'PositiveSmallIntegerField': (0, 99999999999),
'PositiveIntegerField': (0, 99999999999),
'SmallAutoField': (-99999, 99999),
'AutoField': (-99999999999, 99999999999),
'BigAutoField': (-9999999999999999999, 9999999999999999999),
}
set_operators = {**BaseDatabaseOperations.set_operators, 'difference': 'MINUS'}
# TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
_sequence_reset_sql = """
DECLARE
table_value integer;
seq_value integer;
seq_name user_tab_identity_cols.sequence_name%%TYPE;
BEGIN
BEGIN
SELECT sequence_name INTO seq_name FROM user_tab_identity_cols
WHERE table_name = '%(table_name)s' AND
column_name = '%(column_name)s';
EXCEPTION WHEN NO_DATA_FOUND THEN
seq_name := '%(no_autofield_sequence_name)s';
END;
SELECT NVL(MAX(%(column)s), 0) INTO table_value FROM %(table)s;
SELECT NVL(last_number - cache_size, 0) INTO seq_value FROM user_sequences
WHERE sequence_name = seq_name;
WHILE table_value > seq_value LOOP
EXECUTE IMMEDIATE 'SELECT "'||seq_name||'".nextval FROM DUAL'
INTO seq_value;
END LOOP;
END;
/"""
# Oracle doesn't support string without precision; use the max string size.
cast_char_field_without_max_length = 'NVARCHAR2(2000)'
cast_data_types = {
'AutoField': 'NUMBER(11)',
'BigAutoField': 'NUMBER(19)',
'SmallAutoField': 'NUMBER(5)',
'TextField': cast_char_field_without_max_length,
}
def cache_key_culling_sql(self):
return 'SELECT cache_key FROM %s ORDER BY cache_key OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY'
def date_extract_sql(self, lookup_type, field_name):
if lookup_type == 'week_day':
# TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
return "TO_CHAR(%s, 'D')" % field_name
elif lookup_type == 'iso_week_day':
return "TO_CHAR(%s - 1, 'D')" % field_name
elif lookup_type == 'week':
# IW = ISO week number
return "TO_CHAR(%s, 'IW')" % field_name
elif lookup_type == 'quarter':
return "TO_CHAR(%s, 'Q')" % field_name
elif lookup_type == 'iso_year':
return "TO_CHAR(%s, 'IYYY')" % field_name
else:
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/EXTRACT-datetime.html
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
if lookup_type in ('year', 'month'):
return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
elif lookup_type == 'quarter':
return "TRUNC(%s, 'Q')" % field_name
elif lookup_type == 'week':
return "TRUNC(%s, 'IW')" % field_name
else:
return "TRUNC(%s)" % field_name
# Oracle crashes with "ORA-03113: end-of-file on communication channel"
# if the time zone name is passed in parameter. Use interpolation instead.
# https://groups.google.com/forum/#!msg/django-developers/zwQju7hbG78/9l934yelwfsJ
# This regexp matches all time zone names from the zoneinfo database.
_tzname_re = _lazy_re_compile(r'^[\w/:+-]+$')
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
return f'{sign}{offset}' if offset else tzname
def _convert_field_to_tz(self, field_name, tzname):
if not (settings.USE_TZ and tzname):
return field_name
if not self._tzname_re.match(tzname):
raise ValueError("Invalid time zone name: %s" % tzname)
# Convert from connection timezone to the local time, returning
# TIMESTAMP WITH TIME ZONE and cast it back to TIMESTAMP to strip the
# TIME ZONE details.
if self.connection.timezone_name != tzname:
return "CAST((FROM_TZ(%s, '%s') AT TIME ZONE '%s') AS TIMESTAMP)" % (
field_name,
self.connection.timezone_name,
self._prepare_tzname_delta(tzname),
)
return field_name
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return 'TRUNC(%s)' % field_name
def datetime_cast_time_sql(self, field_name, tzname):
# Since `TimeField` values are stored as TIMESTAMP change to the
# default date and convert the field to the specified timezone.
convert_datetime_sql = (
"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR(%s, 'HH24:MI:SS.FF')), "
"'YYYY-MM-DD HH24:MI:SS.FF')"
) % self._convert_field_to_tz(field_name, tzname)
return "CASE WHEN %s IS NOT NULL THEN %s ELSE NULL END" % (
field_name, convert_datetime_sql,
)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return self.date_extract_sql(lookup_type, field_name)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
if lookup_type in ('year', 'month'):
sql = "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
elif lookup_type == 'quarter':
sql = "TRUNC(%s, 'Q')" % field_name
elif lookup_type == 'week':
sql = "TRUNC(%s, 'IW')" % field_name
elif lookup_type == 'day':
sql = "TRUNC(%s)" % field_name
elif lookup_type == 'hour':
sql = "TRUNC(%s, 'HH24')" % field_name
elif lookup_type == 'minute':
sql = "TRUNC(%s, 'MI')" % field_name
else:
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
return sql
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
# The implementation is similar to `datetime_trunc_sql` as both
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where
# the date part of the later is ignored.
field_name = self._convert_field_to_tz(field_name, tzname)
if lookup_type == 'hour':
sql = "TRUNC(%s, 'HH24')" % field_name
elif lookup_type == 'minute':
sql = "TRUNC(%s, 'MI')" % field_name
elif lookup_type == 'second':
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
return sql
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type in ['JSONField', 'TextField']:
converters.append(self.convert_textfield_value)
elif internal_type == 'BinaryField':
converters.append(self.convert_binaryfield_value)
elif internal_type == 'BooleanField':
converters.append(self.convert_booleanfield_value)
elif internal_type == 'DateTimeField':
if settings.USE_TZ:
converters.append(self.convert_datetimefield_value)
elif internal_type == 'DateField':
converters.append(self.convert_datefield_value)
elif internal_type == 'TimeField':
converters.append(self.convert_timefield_value)
elif internal_type == 'UUIDField':
converters.append(self.convert_uuidfield_value)
# Oracle stores empty strings as null. If the field accepts the empty
# string, undo this to adhere to the Django convention of using
# the empty string instead of null.
if expression.output_field.empty_strings_allowed:
converters.append(
self.convert_empty_bytes
if internal_type == 'BinaryField' else
self.convert_empty_string
)
return converters
def convert_textfield_value(self, value, expression, connection):
if isinstance(value, Database.LOB):
value = value.read()
return value
def convert_binaryfield_value(self, value, expression, connection):
if isinstance(value, Database.LOB):
value = force_bytes(value.read())
return value
def convert_booleanfield_value(self, value, expression, connection):
if value in (0, 1):
value = bool(value)
return value
# cx_Oracle always returns datetime.datetime objects for
# DATE and TIMESTAMP columns, but Django wants to see a
# python datetime.date, .time, or .datetime.
def convert_datetimefield_value(self, value, expression, connection):
if value is not None:
value = timezone.make_aware(value, self.connection.timezone)
return value
def convert_datefield_value(self, value, expression, connection):
if isinstance(value, Database.Timestamp):
value = value.date()
return value
def convert_timefield_value(self, value, expression, connection):
if isinstance(value, Database.Timestamp):
value = value.time()
return value
def convert_uuidfield_value(self, value, expression, connection):
if value is not None:
value = uuid.UUID(value)
return value
@staticmethod
def convert_empty_string(value, expression, connection):
return '' if value is None else value
@staticmethod
def convert_empty_bytes(value, expression, connection):
return b'' if value is None else value
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_columns(self, cursor, returning_params):
columns = []
for param in returning_params:
value = param.get_value()
if value == []:
raise DatabaseError(
'The database did not return a new row id. Probably '
'"ORA-1403: no data found" was raised internally but was '
'hidden by the Oracle OCI library (see '
'https://code.djangoproject.com/ticket/28859).'
)
columns.append(value[0])
return tuple(columns)
def field_cast_sql(self, db_type, internal_type):
if db_type and db_type.endswith('LOB') and internal_type != 'JSONField':
return "DBMS_LOB.SUBSTR(%s)"
else:
return "%s"
def no_limit_value(self):
return None
def limit_offset_sql(self, low_mark, high_mark):
fetch, offset = self._get_limit_offset_params(low_mark, high_mark)
return ' '.join(sql for sql in (
('OFFSET %d ROWS' % offset) if offset else None,
('FETCH FIRST %d ROWS ONLY' % fetch) if fetch else None,
) if sql)
def last_executed_query(self, cursor, sql, params):
# https://cx-oracle.readthedocs.io/en/latest/cursor.html#Cursor.statement
# The DB API definition does not define this attribute.
statement = cursor.statement
# Unlike Psycopg's `query` and MySQLdb`'s `_executed`, cx_Oracle's
# `statement` doesn't contain the query parameters. Substitute
# parameters manually.
if isinstance(params, (tuple, list)):
for i, param in enumerate(params):
statement = statement.replace(':arg%d' % i, force_str(param, errors='replace'))
elif isinstance(params, dict):
for key, param in params.items():
statement = statement.replace(':%s' % key, force_str(param, errors='replace'))
return statement
def last_insert_id(self, cursor, table_name, pk_name):
sq_name = self._get_sequence_name(cursor, strip_quotes(table_name), pk_name)
cursor.execute('"%s".currval' % sq_name)
return cursor.fetchone()[0]
def lookup_cast(self, lookup_type, internal_type=None):
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
return "UPPER(%s)"
if internal_type == 'JSONField' and lookup_type == 'exact':
return 'DBMS_LOB.SUBSTR(%s)'
return "%s"
def max_in_list_size(self):
return 1000
def max_name_length(self):
return 30
def pk_default_value(self):
return "NULL"
def prep_for_iexact_query(self, x):
return x
def process_clob(self, value):
if value is None:
return ''
return value.read()
def quote_name(self, name):
# SQL92 requires delimited (quoted) names to be case-sensitive. When
# not quoted, Oracle has case-insensitive behavior for identifiers, but
# always defaults to uppercase.
# We simplify things by making Oracle identifiers always uppercase.
if not name.startswith('"') and not name.endswith('"'):
name = '"%s"' % truncate_name(name, self.max_name_length())
# Oracle puts the query text into a (query % args) construct, so % signs
# in names need to be escaped. The '%%' will be collapsed back to '%' at
# that stage so we aren't really making the name longer here.
name = name.replace('%', '%%')
return name.upper()
def regex_lookup(self, lookup_type):
if lookup_type == 'regex':
match_option = "'c'"
else:
match_option = "'i'"
return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option
def return_insert_columns(self, fields):
if not fields:
return '', ()
field_names = []
params = []
for field in fields:
field_names.append('%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
))
params.append(InsertVar(field))
return 'RETURNING %s INTO %s' % (
', '.join(field_names),
', '.join(['%s'] * len(params)),
), tuple(params)
def __foreign_key_constraints(self, table_name, recursive):
with self.connection.cursor() as cursor:
if recursive:
cursor.execute("""
SELECT
user_tables.table_name, rcons.constraint_name
FROM
user_tables
JOIN
user_constraints cons
ON (user_tables.table_name = cons.table_name AND cons.constraint_type = ANY('P', 'U'))
LEFT JOIN
user_constraints rcons
ON (user_tables.table_name = rcons.table_name AND rcons.constraint_type = 'R')
START WITH user_tables.table_name = UPPER(%s)
CONNECT BY NOCYCLE PRIOR cons.constraint_name = rcons.r_constraint_name
GROUP BY
user_tables.table_name, rcons.constraint_name
HAVING user_tables.table_name != UPPER(%s)
ORDER BY MAX(level) DESC
""", (table_name, table_name))
else:
cursor.execute("""
SELECT
cons.table_name, cons.constraint_name
FROM
user_constraints cons
WHERE
cons.constraint_type = 'R'
AND cons.table_name = UPPER(%s)
""", (table_name,))
return cursor.fetchall()
@cached_property
def _foreign_key_constraints(self):
# 512 is large enough to fit the ~330 tables (as of this writing) in
# Django's test suite.
return lru_cache(maxsize=512)(self.__foreign_key_constraints)
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if not tables:
return []
truncated_tables = {table.upper() for table in tables}
constraints = set()
# Oracle's TRUNCATE CASCADE only works with ON DELETE CASCADE foreign
# keys which Django doesn't define. Emulate the PostgreSQL behavior
# which truncates all dependent tables by manually retrieving all
# foreign key constraints and resolving dependencies.
for table in tables:
for foreign_table, constraint in self._foreign_key_constraints(table, recursive=allow_cascade):
if allow_cascade:
truncated_tables.add(foreign_table)
constraints.add((foreign_table, constraint))
sql = [
'%s %s %s %s %s %s %s %s;' % (
style.SQL_KEYWORD('ALTER'),
style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)),
style.SQL_KEYWORD('DISABLE'),
style.SQL_KEYWORD('CONSTRAINT'),
style.SQL_FIELD(self.quote_name(constraint)),
style.SQL_KEYWORD('KEEP'),
style.SQL_KEYWORD('INDEX'),
) for table, constraint in constraints
] + [
'%s %s %s;' % (
style.SQL_KEYWORD('TRUNCATE'),
style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)),
) for table in truncated_tables
] + [
'%s %s %s %s %s %s;' % (
style.SQL_KEYWORD('ALTER'),
style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)),
style.SQL_KEYWORD('ENABLE'),
style.SQL_KEYWORD('CONSTRAINT'),
style.SQL_FIELD(self.quote_name(constraint)),
) for table, constraint in constraints
]
if reset_sequences:
sequences = [
sequence
for sequence in self.connection.introspection.sequence_list()
if sequence['table'].upper() in truncated_tables
]
# Since we've just deleted all the rows, running our sequence ALTER
# code will reset the sequence to 0.
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql
def sequence_reset_by_name_sql(self, style, sequences):
sql = []
for sequence_info in sequences:
no_autofield_sequence_name = self._get_no_autofield_sequence_name(sequence_info['table'])
table = self.quote_name(sequence_info['table'])
column = self.quote_name(sequence_info['column'] or 'id')
query = self._sequence_reset_sql % {
'no_autofield_sequence_name': no_autofield_sequence_name,
'table': table,
'column': column,
'table_name': strip_quotes(table),
'column_name': strip_quotes(column),
}
sql.append(query)
return sql
def sequence_reset_sql(self, style, model_list):
output = []
query = self._sequence_reset_sql
for model in model_list:
for f in model._meta.local_fields:
if isinstance(f, AutoField):
no_autofield_sequence_name = self._get_no_autofield_sequence_name(model._meta.db_table)
table = self.quote_name(model._meta.db_table)
column = self.quote_name(f.column)
output.append(query % {
'no_autofield_sequence_name': no_autofield_sequence_name,
'table': table,
'column': column,
'table_name': strip_quotes(table),
'column_name': strip_quotes(column),
})
# Only one AutoField is allowed per model, so don't
# continue to loop
break
return output
def start_transaction_sql(self):
return ''
def tablespace_sql(self, tablespace, inline=False):
if inline:
return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
else:
return "TABLESPACE %s" % self.quote_name(tablespace)
def adapt_datefield_value(self, value):
"""
Transform a date value to an object compatible with what is expected
by the backend driver for date columns.
The default implementation transforms the date to text, but that is not
necessary for Oracle.
"""
return value
def adapt_datetimefield_value(self, value):
"""
Transform a datetime value to an object compatible with what is expected
by the backend driver for datetime columns.
If naive datetime is passed assumes that is in UTC. Normally Django
models.DateTimeField makes sure that if USE_TZ is True passed datetime
is timezone aware.
"""
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# cx_Oracle doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError("Oracle backend does not support timezone-aware datetimes when USE_TZ is False.")
return Oracle_datetime.from_datetime(value)
def adapt_timefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
if isinstance(value, str):
return datetime.datetime.strptime(value, '%H:%M:%S')
# Oracle doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("Oracle backend does not support timezone-aware times.")
return Oracle_datetime(1900, 1, 1, value.hour, value.minute,
value.second, value.microsecond)
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
return value
def combine_expression(self, connector, sub_expressions):
lhs, rhs = sub_expressions
if connector == '%%':
return 'MOD(%s)' % ','.join(sub_expressions)
elif connector == '&':
return 'BITAND(%s)' % ','.join(sub_expressions)
elif connector == '|':
return 'BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s' % {'lhs': lhs, 'rhs': rhs}
elif connector == '<<':
return '(%(lhs)s * POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
elif connector == '>>':
return 'FLOOR(%(lhs)s / POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
elif connector == '^':
return 'POWER(%s)' % ','.join(sub_expressions)
elif connector == '#':
raise NotSupportedError('Bitwise XOR is not supported in Oracle.')
return super().combine_expression(connector, sub_expressions)
def _get_no_autofield_sequence_name(self, table):
"""
Manually created sequence name to keep backward compatibility for
AutoFields that aren't Oracle identity columns.
"""
name_length = self.max_name_length() - 3
return '%s_SQ' % truncate_name(strip_quotes(table), name_length).upper()
def _get_sequence_name(self, cursor, table, pk_name):
cursor.execute("""
SELECT sequence_name
FROM user_tab_identity_cols
WHERE table_name = UPPER(%s)
AND column_name = UPPER(%s)""", [table, pk_name])
row = cursor.fetchone()
return self._get_no_autofield_sequence_name(table) if row is None else row[0]
def bulk_insert_sql(self, fields, placeholder_rows):
query = []
for row in placeholder_rows:
select = []
for i, placeholder in enumerate(row):
# A model without any fields has fields=[None].
if fields[i]:
internal_type = getattr(fields[i], 'target_field', fields[i]).get_internal_type()
placeholder = BulkInsertMapper.types.get(internal_type, '%s') % placeholder
# Add columns aliases to the first select to avoid "ORA-00918:
# column ambiguously defined" when two or more columns in the
# first select have the same value.
if not query:
placeholder = '%s col_%s' % (placeholder, i)
select.append(placeholder)
query.append('SELECT %s FROM DUAL' % ', '.join(select))
# Bulk insert to tables with Oracle identity columns causes Oracle to
# add sequence.nextval to it. Sequence.nextval cannot be used with the
# UNION operator. To prevent incorrect SQL, move UNION to a subquery.
return 'SELECT * FROM (%s)' % ' UNION ALL '.join(query)
def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == 'DateField':
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
return "NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql), params
return super().subtract_temporals(internal_type, lhs, rhs)
def bulk_batch_size(self, fields, objs):
"""Oracle restricts the number of parameters in a query."""
if fields:
return self.connection.features.max_query_params // len(fields)
return len(objs)
def conditional_expression_supported_in_where_clause(self, expression):
"""
Oracle supports only EXISTS(...) or filters in the WHERE clause, others
must be compared with True.
"""
if isinstance(expression, (Exists, Lookup, WhereNode)):
return True
if isinstance(expression, ExpressionWrapper) and expression.conditional:
return self.conditional_expression_supported_in_where_clause(expression.expression)
if isinstance(expression, RawSQL) and expression.conditional:
return True
return False

View File

@@ -0,0 +1,211 @@
import copy
import datetime
import re
from django.db import DatabaseError
from django.db.backends.base.schema import (
BaseDatabaseSchemaEditor, _related_non_m2m_objects,
)
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_create_column = "ALTER TABLE %(table)s ADD %(column)s %(definition)s"
sql_alter_column_type = "MODIFY %(column)s %(type)s"
sql_alter_column_null = "MODIFY %(column)s NULL"
sql_alter_column_not_null = "MODIFY %(column)s NOT NULL"
sql_alter_column_default = "MODIFY %(column)s DEFAULT %(default)s"
sql_alter_column_no_default = "MODIFY %(column)s DEFAULT NULL"
sql_alter_column_no_default_null = sql_alter_column_no_default
sql_alter_column_collate = "MODIFY %(column)s %(type)s%(collation)s"
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
sql_create_column_inline_fk = 'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
sql_delete_table = "DROP TABLE %(table)s CASCADE CONSTRAINTS"
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
def quote_value(self, value):
if isinstance(value, (datetime.date, datetime.time, datetime.datetime)):
return "'%s'" % value
elif isinstance(value, str):
return "'%s'" % value.replace("\'", "\'\'").replace('%', '%%')
elif isinstance(value, (bytes, bytearray, memoryview)):
return "'%s'" % value.hex()
elif isinstance(value, bool):
return "1" if value else "0"
else:
return str(value)
def remove_field(self, model, field):
# If the column is an identity column, drop the identity before
# removing the field.
if self._is_identity_column(model._meta.db_table, field.column):
self._drop_identity(model._meta.db_table, field.column)
super().remove_field(model, field)
def delete_model(self, model):
# Run superclass action
super().delete_model(model)
# Clean up manually created sequence.
self.execute("""
DECLARE
i INTEGER;
BEGIN
SELECT COUNT(1) INTO i FROM USER_SEQUENCES
WHERE SEQUENCE_NAME = '%(sq_name)s';
IF i = 1 THEN
EXECUTE IMMEDIATE 'DROP SEQUENCE "%(sq_name)s"';
END IF;
END;
/""" % {'sq_name': self.connection.ops._get_no_autofield_sequence_name(model._meta.db_table)})
def alter_field(self, model, old_field, new_field, strict=False):
try:
super().alter_field(model, old_field, new_field, strict)
except DatabaseError as e:
description = str(e)
# If we're changing type to an unsupported type we need a
# SQLite-ish workaround
if 'ORA-22858' in description or 'ORA-22859' in description:
self._alter_field_type_workaround(model, old_field, new_field)
# If an identity column is changing to a non-numeric type, drop the
# identity first.
elif 'ORA-30675' in description:
self._drop_identity(model._meta.db_table, old_field.column)
self.alter_field(model, old_field, new_field, strict)
# If a primary key column is changing to an identity column, drop
# the primary key first.
elif 'ORA-30673' in description and old_field.primary_key:
self._delete_primary_key(model, strict=True)
self._alter_field_type_workaround(model, old_field, new_field)
else:
raise
def _alter_field_type_workaround(self, model, old_field, new_field):
"""
Oracle refuses to change from some type to other type.
What we need to do instead is:
- Add a nullable version of the desired field with a temporary name. If
the new column is an auto field, then the temporary column can't be
nullable.
- Update the table to transfer values from old to new
- Drop old column
- Rename the new column and possibly drop the nullable property
"""
# Make a new field that's like the new one but with a temporary
# column name.
new_temp_field = copy.deepcopy(new_field)
new_temp_field.null = (new_field.get_internal_type() not in ('AutoField', 'BigAutoField', 'SmallAutoField'))
new_temp_field.column = self._generate_temp_name(new_field.column)
# Add it
self.add_field(model, new_temp_field)
# Explicit data type conversion
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf
# /Data-Type-Comparison-Rules.html#GUID-D0C5A47E-6F93-4C2D-9E49-4F2B86B359DD
new_value = self.quote_name(old_field.column)
old_type = old_field.db_type(self.connection)
if re.match('^N?CLOB', old_type):
new_value = "TO_CHAR(%s)" % new_value
old_type = 'VARCHAR2'
if re.match('^N?VARCHAR2', old_type):
new_internal_type = new_field.get_internal_type()
if new_internal_type == 'DateField':
new_value = "TO_DATE(%s, 'YYYY-MM-DD')" % new_value
elif new_internal_type == 'DateTimeField':
new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
elif new_internal_type == 'TimeField':
# TimeField are stored as TIMESTAMP with a 1900-01-01 date part.
new_value = "TO_TIMESTAMP(CONCAT('1900-01-01 ', %s), 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
# Transfer values across
self.execute("UPDATE %s set %s=%s" % (
self.quote_name(model._meta.db_table),
self.quote_name(new_temp_field.column),
new_value,
))
# Drop the old field
self.remove_field(model, old_field)
# Rename and possibly make the new field NOT NULL
super().alter_field(model, new_temp_field, new_field)
# Recreate foreign key (if necessary) because the old field is not
# passed to the alter_field() and data types of new_temp_field and
# new_field always match.
new_type = new_field.db_type(self.connection)
if (
(old_field.primary_key and new_field.primary_key) or
(old_field.unique and new_field.unique)
) and old_type != new_type:
for _, rel in _related_non_m2m_objects(new_temp_field, new_field):
if rel.field.db_constraint:
self.execute(self._create_fk_sql(rel.related_model, rel.field, '_fk'))
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
auto_field_types = {'AutoField', 'BigAutoField', 'SmallAutoField'}
# Drop the identity if migrating away from AutoField.
if (
old_field.get_internal_type() in auto_field_types and
new_field.get_internal_type() not in auto_field_types and
self._is_identity_column(model._meta.db_table, new_field.column)
):
self._drop_identity(model._meta.db_table, new_field.column)
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
def normalize_name(self, name):
"""
Get the properly shortened and uppercased identifier as returned by
quote_name() but without the quotes.
"""
nn = self.quote_name(name)
if nn[0] == '"' and nn[-1] == '"':
nn = nn[1:-1]
return nn
def _generate_temp_name(self, for_name):
"""Generate temporary names for workarounds that need temp columns."""
suffix = hex(hash(for_name)).upper()[1:]
return self.normalize_name(for_name + "_" + suffix)
def prepare_default(self, value):
return self.quote_value(value)
def _field_should_be_indexed(self, model, field):
create_index = super()._field_should_be_indexed(model, field)
db_type = field.db_type(self.connection)
if db_type is not None and db_type.lower() in self.connection._limited_data_types:
return False
return create_index
def _unique_should_be_added(self, old_field, new_field):
return (
super()._unique_should_be_added(old_field, new_field) and
not self._field_became_primary_key(old_field, new_field)
)
def _is_identity_column(self, table_name, column_name):
with self.connection.cursor() as cursor:
cursor.execute("""
SELECT
CASE WHEN identity_column = 'YES' THEN 1 ELSE 0 END
FROM user_tab_cols
WHERE table_name = %s AND
column_name = %s
""", [self.normalize_name(table_name), self.normalize_name(column_name)])
row = cursor.fetchone()
return row[0] if row else False
def _drop_identity(self, table_name, column_name):
self.execute('ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY' % {
'table': self.quote_name(table_name),
'column': self.quote_name(column_name),
})
def _get_default_collation(self, table_name):
with self.connection.cursor() as cursor:
cursor.execute("""
SELECT default_collation FROM user_tables WHERE table_name = %s
""", [self.normalize_name(table_name)])
return cursor.fetchone()[0]
def _alter_column_collation_sql(self, model, new_field, new_type, new_collation):
if new_collation is None:
new_collation = self._get_default_collation(model._meta.db_table)
return super()._alter_column_collation_sql(model, new_field, new_type, new_collation)

View File

@@ -0,0 +1,90 @@
import datetime
from .base import Database
class InsertVar:
"""
A late-binding cursor variable that can be passed to Cursor.execute
as a parameter, in order to receive the id of the row created by an
insert statement.
"""
types = {
'AutoField': int,
'BigAutoField': int,
'SmallAutoField': int,
'IntegerField': int,
'BigIntegerField': int,
'SmallIntegerField': int,
'PositiveBigIntegerField': int,
'PositiveSmallIntegerField': int,
'PositiveIntegerField': int,
'FloatField': Database.NATIVE_FLOAT,
'DateTimeField': Database.TIMESTAMP,
'DateField': Database.Date,
'DecimalField': Database.NUMBER,
}
def __init__(self, field):
internal_type = getattr(field, 'target_field', field).get_internal_type()
self.db_type = self.types.get(internal_type, str)
self.bound_param = None
def bind_parameter(self, cursor):
self.bound_param = cursor.cursor.var(self.db_type)
return self.bound_param
def get_value(self):
return self.bound_param.getvalue()
class Oracle_datetime(datetime.datetime):
"""
A datetime object, with an additional class attribute
to tell cx_Oracle to save the microseconds too.
"""
input_size = Database.TIMESTAMP
@classmethod
def from_datetime(cls, dt):
return Oracle_datetime(
dt.year, dt.month, dt.day,
dt.hour, dt.minute, dt.second, dt.microsecond,
)
class BulkInsertMapper:
BLOB = 'TO_BLOB(%s)'
CLOB = 'TO_CLOB(%s)'
DATE = 'TO_DATE(%s)'
INTERVAL = 'CAST(%s as INTERVAL DAY(9) TO SECOND(6))'
NUMBER = 'TO_NUMBER(%s)'
TIMESTAMP = 'TO_TIMESTAMP(%s)'
types = {
'AutoField': NUMBER,
'BigAutoField': NUMBER,
'BigIntegerField': NUMBER,
'BinaryField': BLOB,
'BooleanField': NUMBER,
'DateField': DATE,
'DateTimeField': TIMESTAMP,
'DecimalField': NUMBER,
'DurationField': INTERVAL,
'FloatField': NUMBER,
'IntegerField': NUMBER,
'PositiveBigIntegerField': NUMBER,
'PositiveIntegerField': NUMBER,
'PositiveSmallIntegerField': NUMBER,
'SmallAutoField': NUMBER,
'SmallIntegerField': NUMBER,
'TextField': CLOB,
'TimeField': TIMESTAMP,
}
def dsn(settings_dict):
if settings_dict['PORT']:
host = settings_dict['HOST'].strip() or 'localhost'
return Database.makedsn(host, int(settings_dict['PORT']), settings_dict['NAME'])
return settings_dict['NAME']

View File

@@ -0,0 +1,22 @@
from django.core import checks
from django.db.backends.base.validation import BaseDatabaseValidation
class DatabaseValidation(BaseDatabaseValidation):
def check_field_type(self, field, field_type):
"""Oracle doesn't support a database index on some data types."""
errors = []
if field.db_index and field_type.lower() in self.connection._limited_data_types:
errors.append(
checks.Warning(
'Oracle does not support a database index on %s columns.'
% field_type,
hint=(
"An index won't be created. Silence this warning if "
"you don't care about it."
),
obj=field,
id='fields.W162',
)
)
return errors

View File

@@ -0,0 +1,353 @@
"""
PostgreSQL database backend for Django.
Requires psycopg 2: https://www.psycopg.org/
"""
import asyncio
import threading
import warnings
from contextlib import contextmanager
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import DatabaseError as WrappedDatabaseError, connections
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.utils import (
CursorDebugWrapper as BaseCursorDebugWrapper,
)
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
from django.utils.safestring import SafeString
from django.utils.version import get_version_tuple
try:
import psycopg2 as Database
import psycopg2.extensions
import psycopg2.extras
except ImportError as e:
raise ImproperlyConfigured("Error loading psycopg2 module: %s" % e)
def psycopg2_version():
version = psycopg2.__version__.split(' ', 1)[0]
return get_version_tuple(version)
PSYCOPG2_VERSION = psycopg2_version()
if PSYCOPG2_VERSION < (2, 5, 4):
raise ImproperlyConfigured("psycopg2_version 2.5.4 or newer is required; you have %s" % psycopg2.__version__)
# Some of these import psycopg2, so import them after checking if it's installed.
from .client import DatabaseClient # NOQA
from .creation import DatabaseCreation # NOQA
from .features import DatabaseFeatures # NOQA
from .introspection import DatabaseIntrospection # NOQA
from .operations import DatabaseOperations # NOQA
from .schema import DatabaseSchemaEditor # NOQA
psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
psycopg2.extras.register_uuid()
# Register support for inet[] manually so we don't have to handle the Inet()
# object on load all the time.
INETARRAY_OID = 1041
INETARRAY = psycopg2.extensions.new_array_type(
(INETARRAY_OID,),
'INETARRAY',
psycopg2.extensions.UNICODE,
)
psycopg2.extensions.register_type(INETARRAY)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'postgresql'
display_name = 'PostgreSQL'
# This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
'AutoField': 'serial',
'BigAutoField': 'bigserial',
'BinaryField': 'bytea',
'BooleanField': 'boolean',
'CharField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'timestamp with time zone',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'DurationField': 'interval',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'BigIntegerField': 'bigint',
'IPAddressField': 'inet',
'GenericIPAddressField': 'inet',
'JSONField': 'jsonb',
'OneToOneField': 'integer',
'PositiveBigIntegerField': 'bigint',
'PositiveIntegerField': 'integer',
'PositiveSmallIntegerField': 'smallint',
'SlugField': 'varchar(%(max_length)s)',
'SmallAutoField': 'smallserial',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'UUIDField': 'uuid',
}
data_type_check_constraints = {
'PositiveBigIntegerField': '"%(column)s" >= 0',
'PositiveIntegerField': '"%(column)s" >= 0',
'PositiveSmallIntegerField': '"%(column)s" >= 0',
}
operators = {
'exact': '= %s',
'iexact': '= UPPER(%s)',
'contains': 'LIKE %s',
'icontains': 'LIKE UPPER(%s)',
'regex': '~ %s',
'iregex': '~* %s',
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': 'LIKE %s',
'endswith': 'LIKE %s',
'istartswith': 'LIKE UPPER(%s)',
'iendswith': 'LIKE UPPER(%s)',
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
# escaped on database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
pattern_ops = {
'contains': "LIKE '%%' || {} || '%%'",
'icontains': "LIKE '%%' || UPPER({}) || '%%'",
'startswith': "LIKE {} || '%%'",
'istartswith': "LIKE UPPER({}) || '%%'",
'endswith': "LIKE '%%' || {}",
'iendswith': "LIKE '%%' || UPPER({})",
}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
# PostgreSQL backend-specific attributes.
_named_cursor_idx = 0
def get_connection_params(self):
settings_dict = self.settings_dict
# None may be used to connect to the default 'postgres' db
if (
settings_dict['NAME'] == '' and
not settings_dict.get('OPTIONS', {}).get('service')
):
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME or OPTIONS['service'] value."
)
if len(settings_dict['NAME'] or '') > self.ops.max_name_length():
raise ImproperlyConfigured(
"The database name '%s' (%d characters) is longer than "
"PostgreSQL's limit of %d characters. Supply a shorter NAME "
"in settings.DATABASES." % (
settings_dict['NAME'],
len(settings_dict['NAME']),
self.ops.max_name_length(),
)
)
conn_params = {}
if settings_dict['NAME']:
conn_params = {
'database': settings_dict['NAME'],
**settings_dict['OPTIONS'],
}
elif settings_dict['NAME'] is None:
# Connect to the default 'postgres' db.
settings_dict.get('OPTIONS', {}).pop('service', None)
conn_params = {'database': 'postgres', **settings_dict['OPTIONS']}
else:
conn_params = {**settings_dict['OPTIONS']}
conn_params.pop('isolation_level', None)
if settings_dict['USER']:
conn_params['user'] = settings_dict['USER']
if settings_dict['PASSWORD']:
conn_params['password'] = settings_dict['PASSWORD']
if settings_dict['HOST']:
conn_params['host'] = settings_dict['HOST']
if settings_dict['PORT']:
conn_params['port'] = settings_dict['PORT']
return conn_params
@async_unsafe
def get_new_connection(self, conn_params):
connection = Database.connect(**conn_params)
# self.isolation_level must be set:
# - after connecting to the database in order to obtain the database's
# default when no value is explicitly specified in options.
# - before calling _set_autocommit() because if autocommit is on, that
# will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
options = self.settings_dict['OPTIONS']
try:
self.isolation_level = options['isolation_level']
except KeyError:
self.isolation_level = connection.isolation_level
else:
# Set the isolation level to the value from OPTIONS.
if self.isolation_level != connection.isolation_level:
connection.set_session(isolation_level=self.isolation_level)
# Register dummy loads() to avoid a round trip from psycopg2's decode
# to json.dumps() to json.loads(), when using a custom decoder in
# JSONField.
psycopg2.extras.register_default_jsonb(conn_or_curs=connection, loads=lambda x: x)
return connection
def ensure_timezone(self):
if self.connection is None:
return False
conn_timezone_name = self.connection.get_parameter_status('TimeZone')
timezone_name = self.timezone_name
if timezone_name and conn_timezone_name != timezone_name:
with self.connection.cursor() as cursor:
cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
return True
return False
def init_connection_state(self):
self.connection.set_client_encoding('UTF8')
timezone_changed = self.ensure_timezone()
if timezone_changed:
# Commit after setting the time zone (see #17062)
if not self.get_autocommit():
self.connection.commit()
@async_unsafe
def create_cursor(self, name=None):
if name:
# In autocommit mode, the cursor will be used outside of a
# transaction, hence use a holdable cursor.
cursor = self.connection.cursor(name, scrollable=False, withhold=self.connection.autocommit)
else:
cursor = self.connection.cursor()
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
return cursor
def tzinfo_factory(self, offset):
return self.timezone
@async_unsafe
def chunked_cursor(self):
self._named_cursor_idx += 1
# Get the current async task
# Note that right now this is behind @async_unsafe, so this is
# unreachable, but in future we'll start loosening this restriction.
# For now, it's here so that every use of "threading" is
# also async-compatible.
try:
current_task = asyncio.current_task()
except RuntimeError:
current_task = None
# Current task can be none even if the current_task call didn't error
if current_task:
task_ident = str(id(current_task))
else:
task_ident = 'sync'
# Use that and the thread ident to get a unique name
return self._cursor(
name='_django_curs_%d_%s_%d' % (
# Avoid reusing name in other threads / tasks
threading.current_thread().ident,
task_ident,
self._named_cursor_idx,
)
)
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
self.connection.autocommit = autocommit
def check_constraints(self, table_names=None):
"""
Check constraints by setting them to immediate. Return them to deferred
afterward.
"""
with self.cursor() as cursor:
cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
cursor.execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
with self.connection.cursor() as cursor:
cursor.execute('SELECT 1')
except Database.Error:
return False
else:
return True
@contextmanager
def _nodb_cursor(self):
cursor = None
try:
with super()._nodb_cursor() as cursor:
yield cursor
except (Database.DatabaseError, WrappedDatabaseError):
if cursor is not None:
raise
warnings.warn(
"Normally Django will use a connection to the 'postgres' database "
"to avoid running initialization queries against the production "
"database when it's not needed (for example, when running tests). "
"Django was unable to create a connection to the 'postgres' database "
"and will use the first PostgreSQL database instead.",
RuntimeWarning
)
for connection in connections.all():
if connection.vendor == 'postgresql' and connection.settings_dict['NAME'] != 'postgres':
conn = self.__class__(
{**self.settings_dict, 'NAME': connection.settings_dict['NAME']},
alias=self.alias,
)
try:
with conn.cursor() as cursor:
yield cursor
finally:
conn.close()
break
else:
raise
@cached_property
def pg_version(self):
with self.temporary_connection():
return self.connection.server_version
def make_debug_cursor(self, cursor):
return CursorDebugWrapper(cursor, self)
class CursorDebugWrapper(BaseCursorDebugWrapper):
def copy_expert(self, sql, file, *args):
with self.debug_sql(sql):
return self.cursor.copy_expert(sql, file, *args)
def copy_to(self, file, table, *args, **kwargs):
with self.debug_sql(sql='COPY %s TO STDOUT' % table):
return self.cursor.copy_to(file, table, *args, **kwargs)

View File

@@ -0,0 +1,64 @@
import signal
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = 'psql'
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
options = settings_dict.get('OPTIONS', {})
host = settings_dict.get('HOST')
port = settings_dict.get('PORT')
dbname = settings_dict.get('NAME')
user = settings_dict.get('USER')
passwd = settings_dict.get('PASSWORD')
passfile = options.get('passfile')
service = options.get('service')
sslmode = options.get('sslmode')
sslrootcert = options.get('sslrootcert')
sslcert = options.get('sslcert')
sslkey = options.get('sslkey')
if not dbname and not service:
# Connect to the default 'postgres' db.
dbname = 'postgres'
if user:
args += ['-U', user]
if host:
args += ['-h', host]
if port:
args += ['-p', str(port)]
if dbname:
args += [dbname]
args.extend(parameters)
env = {}
if passwd:
env['PGPASSWORD'] = str(passwd)
if service:
env['PGSERVICE'] = str(service)
if sslmode:
env['PGSSLMODE'] = str(sslmode)
if sslrootcert:
env['PGSSLROOTCERT'] = str(sslrootcert)
if sslcert:
env['PGSSLCERT'] = str(sslcert)
if sslkey:
env['PGSSLKEY'] = str(sslkey)
if passfile:
env['PGPASSFILE'] = str(passfile)
return args, (env or None)
def runshell(self, parameters):
sigint_handler = signal.getsignal(signal.SIGINT)
try:
# Allow SIGINT to pass to psql to abort queries.
signal.signal(signal.SIGINT, signal.SIG_IGN)
super().runshell(parameters)
finally:
# Restore the original SIGINT handler.
signal.signal(signal.SIGINT, sigint_handler)

View File

@@ -0,0 +1,80 @@
import sys
from psycopg2 import errorcodes
from django.core.exceptions import ImproperlyConfigured
from django.db.backends.base.creation import BaseDatabaseCreation
from django.db.backends.utils import strip_quotes
class DatabaseCreation(BaseDatabaseCreation):
def _quote_name(self, name):
return self.connection.ops.quote_name(name)
def _get_database_create_suffix(self, encoding=None, template=None):
suffix = ""
if encoding:
suffix += " ENCODING '{}'".format(encoding)
if template:
suffix += " TEMPLATE {}".format(self._quote_name(template))
return suffix and "WITH" + suffix
def sql_table_creation_suffix(self):
test_settings = self.connection.settings_dict['TEST']
if test_settings.get('COLLATION') is not None:
raise ImproperlyConfigured(
'PostgreSQL does not support collation setting at database '
'creation time.'
)
return self._get_database_create_suffix(
encoding=test_settings['CHARSET'],
template=test_settings.get('TEMPLATE'),
)
def _database_exists(self, cursor, database_name):
cursor.execute('SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s', [strip_quotes(database_name)])
return cursor.fetchone() is not None
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
if keepdb and self._database_exists(cursor, parameters['dbname']):
# If the database should be kept and it already exists, don't
# try to create a new one.
return
super()._execute_create_test_db(cursor, parameters, keepdb)
except Exception as e:
if getattr(e.__cause__, 'pgcode', '') != errorcodes.DUPLICATE_DATABASE:
# All errors except "database already exists" cancel tests.
self.log('Got an error creating the test database: %s' % e)
sys.exit(2)
elif not keepdb:
# If the database should be kept, ignore "database already
# exists".
raise
def _clone_test_db(self, suffix, verbosity, keepdb=False):
# CREATE DATABASE ... WITH TEMPLATE ... requires closing connections
# to the template database.
self.connection.close()
source_database_name = self.connection.settings_dict['NAME']
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
test_db_params = {
'dbname': self._quote_name(target_database_name),
'suffix': self._get_database_create_suffix(template=source_database_name),
}
with self._nodb_cursor() as cursor:
try:
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception:
try:
if verbosity >= 1:
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, target_database_name),
))
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log('Got an error cloning the test database: %s' % e)
sys.exit(2)

View File

@@ -0,0 +1,97 @@
import operator
from django.db import InterfaceError
from django.db.backends.base.features import BaseDatabaseFeatures
from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_selected_pks = True
can_return_columns_from_insert = True
can_return_rows_from_bulk_insert = True
has_real_datatype = True
has_native_uuid_field = True
has_native_duration_field = True
has_native_json_field = True
can_defer_constraint_checks = True
has_select_for_update = True
has_select_for_update_nowait = True
has_select_for_update_of = True
has_select_for_update_skip_locked = True
has_select_for_no_key_update = True
can_release_savepoints = True
supports_tablespaces = True
supports_transactions = True
can_introspect_materialized_views = True
can_distinct_on_fields = True
can_rollback_ddl = True
supports_combined_alters = True
nulls_order_largest = True
closed_cursor_error_class = InterfaceError
has_case_insensitive_like = False
greatest_least_ignores_nulls = True
can_clone_databases = True
supports_temporal_subtraction = True
supports_slicing_ordering_in_compound = True
create_test_procedure_without_params_sql = """
CREATE FUNCTION test_procedure () RETURNS void AS $$
DECLARE
V_I INTEGER;
BEGIN
V_I := 1;
END;
$$ LANGUAGE plpgsql;"""
create_test_procedure_with_int_param_sql = """
CREATE FUNCTION test_procedure (P_I INTEGER) RETURNS void AS $$
DECLARE
V_I INTEGER;
BEGIN
V_I := P_I;
END;
$$ LANGUAGE plpgsql;"""
requires_casted_case_in_updates = True
supports_over_clause = True
only_supports_unbounded_with_preceding_and_following = True
supports_aggregate_filter_clause = True
supported_explain_formats = {'JSON', 'TEXT', 'XML', 'YAML'}
validates_explain_options = False # A query will error on invalid options.
supports_deferrable_unique_constraints = True
has_json_operators = True
json_key_contains_list_matching_requires_list = True
test_collations = {
'non_default': 'sv-x-icu',
'swedish_ci': 'sv-x-icu',
}
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
django_test_skips = {
'opclasses are PostgreSQL only.': {
'indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses',
},
}
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
'PositiveBigIntegerField': 'BigIntegerField',
'PositiveIntegerField': 'IntegerField',
'PositiveSmallIntegerField': 'SmallIntegerField',
}
@cached_property
def is_postgresql_11(self):
return self.connection.pg_version >= 110000
@cached_property
def is_postgresql_12(self):
return self.connection.pg_version >= 120000
@cached_property
def is_postgresql_13(self):
return self.connection.pg_version >= 130000
has_websearch_to_tsquery = property(operator.attrgetter('is_postgresql_11'))
supports_covering_indexes = property(operator.attrgetter('is_postgresql_11'))
supports_covering_gist_indexes = property(operator.attrgetter('is_postgresql_12'))
supports_non_deterministic_collations = property(operator.attrgetter('is_postgresql_12'))

View File

@@ -0,0 +1,234 @@
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo, TableInfo,
)
from django.db.models import Index
class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types.
data_types_reverse = {
16: 'BooleanField',
17: 'BinaryField',
20: 'BigIntegerField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
700: 'FloatField',
701: 'FloatField',
869: 'GenericIPAddressField',
1042: 'CharField', # blank-padded
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1186: 'DurationField',
1266: 'TimeField',
1700: 'DecimalField',
2950: 'UUIDField',
3802: 'JSONField',
}
# A hook for subclasses.
index_default_access_method = 'btree'
ignored_tables = []
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.default and 'nextval' in description.default:
if field_type == 'IntegerField':
return 'AutoField'
elif field_type == 'BigIntegerField':
return 'BigAutoField'
elif field_type == 'SmallIntegerField':
return 'SmallAutoField'
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute("""
SELECT c.relname,
CASE WHEN c.relispartition THEN 'p' WHEN c.relkind IN ('m', 'v') THEN 'v' ELSE 't' END
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
""")
return [TableInfo(*row) for row in cursor.fetchall() if row[0] not in self.ignored_tables]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface.
"""
# Query the pg_catalog tables as cursor.description does not reliably
# return the nullable property and information_schema.columns does not
# contain details of materialized views.
cursor.execute("""
SELECT
a.attname AS column_name,
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation
FROM pg_attribute a
LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
LEFT JOIN pg_collation co ON a.attcollation = co.oid
JOIN pg_type t ON a.atttypid = t.oid
JOIN pg_class c ON a.attrelid = c.oid
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND c.relname = %s
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
""", [table_name])
field_map = {line[0]: line[1:] for line in cursor.fetchall()}
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
return [
FieldInfo(
line.name,
line.type_code,
line.display_size,
line.internal_size,
line.precision,
line.scale,
*field_map[line.name],
)
for line in cursor.description
]
def get_sequences(self, cursor, table_name, table_fields=()):
cursor.execute("""
SELECT s.relname as sequence_name, col.attname
FROM pg_class s
JOIN pg_namespace sn ON sn.oid = s.relnamespace
JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid = 'pg_class'::regclass
JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass
JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum
JOIN pg_class tbl ON tbl.oid = ad.adrelid
WHERE s.relkind = 'S'
AND d.deptype in ('a', 'n')
AND pg_catalog.pg_table_is_visible(tbl.oid)
AND tbl.relname = %s
""", [table_name])
return [
{'name': row[0], 'table': table_name, 'column': row[1]}
for row in cursor.fetchall()
]
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table.
"""
return {row[0]: (row[2], row[1]) for row in self.get_key_columns(cursor, table_name)}
def get_key_columns(self, cursor, table_name):
cursor.execute("""
SELECT a1.attname, c2.relname, a2.attname
FROM pg_constraint con
LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
LEFT JOIN pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
LEFT JOIN pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
WHERE
c1.relname = %s AND
con.contype = 'f' AND
c1.relnamespace = c2.relnamespace AND
pg_catalog.pg_table_is_visible(c1.oid)
""", [table_name])
return cursor.fetchall()
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns. Also retrieve the definition of expression-based
indexes.
"""
constraints = {}
# Loop over the key table, collecting things as constraints. The column
# array must return column names in the same order in which they were
# created.
cursor.execute("""
SELECT
c.conname,
array(
SELECT attname
FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
JOIN pg_attribute AS ca ON cols.colid = ca.attnum
WHERE ca.attrelid = c.conrelid
ORDER BY cols.arridx
),
c.contype,
(SELECT fkc.relname || '.' || fka.attname
FROM pg_attribute AS fka
JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
cl.reloptions
FROM pg_constraint AS c
JOIN pg_class AS cl ON c.conrelid = cl.oid
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
""", [table_name])
for constraint, columns, kind, used_cols, options in cursor.fetchall():
constraints[constraint] = {
"columns": columns,
"primary_key": kind == "p",
"unique": kind in ["p", "u"],
"foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
"check": kind == "c",
"index": False,
"definition": None,
"options": options,
}
# Now get indexes
cursor.execute("""
SELECT
indexname, array_agg(attname ORDER BY arridx), indisunique, indisprimary,
array_agg(ordering ORDER BY arridx), amname, exprdef, s2.attoptions
FROM (
SELECT
c2.relname as indexname, idx.*, attr.attname, am.amname,
CASE
WHEN idx.indexprs IS NOT NULL THEN
pg_get_indexdef(idx.indexrelid)
END AS exprdef,
CASE am.amname
WHEN %s THEN
CASE (option & 1)
WHEN 1 THEN 'DESC' ELSE 'ASC'
END
END as ordering,
c2.reloptions as attoptions
FROM (
SELECT *
FROM pg_index i, unnest(i.indkey, i.indoption) WITH ORDINALITY koi(key, option, arridx)
) idx
LEFT JOIN pg_class c ON idx.indrelid = c.oid
LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
LEFT JOIN pg_am am ON c2.relam = am.oid
LEFT JOIN pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
) s2
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
""", [self.index_default_access_method, table_name])
for index, columns, unique, primary, orders, type_, definition, options in cursor.fetchall():
if index not in constraints:
basic_index = (
type_ == self.index_default_access_method and
# '_btree' references
# django.contrib.postgres.indexes.BTreeIndex.suffix.
not index.endswith('_btree') and options is None
)
constraints[index] = {
"columns": columns if columns != [None] else [],
"orders": orders if orders != [None] else [],
"primary_key": primary,
"unique": unique,
"foreign_key": None,
"check": False,
"index": True,
"type": Index.suffix if basic_index else type_,
"definition": definition,
"options": options,
}
return constraints

View File

@@ -0,0 +1,276 @@
from psycopg2.extras import Inet
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = 'varchar'
explain_prefix = 'EXPLAIN'
cast_data_types = {
'AutoField': 'integer',
'BigAutoField': 'bigint',
'SmallAutoField': 'smallint',
}
def unification_cast_sql(self, output_field):
internal_type = output_field.get_internal_type()
if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"):
# PostgreSQL will resolve a union as type 'text' if input types are
# 'unknown'.
# https://www.postgresql.org/docs/current/typeconv-union-case.html
# These fields cannot be implicitly cast back in the default
# PostgreSQL configuration so we need to explicitly cast them.
# We must also remove components of the type within brackets:
# varchar(255) -> varchar.
return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0]
return '%s'
def date_extract_sql(self, lookup_type, field_name):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
if lookup_type == 'week_day':
# For consistency across backends, we return Sunday=1, Saturday=7.
return "EXTRACT('dow' FROM %s) + 1" % field_name
elif lookup_type == 'iso_week_day':
return "EXTRACT('isodow' FROM %s)" % field_name
elif lookup_type == 'iso_year':
return "EXTRACT('isoyear' FROM %s)" % field_name
else:
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
if offset:
sign = '-' if sign == '+' else '+'
return f'{tzname}{sign}{offset}'
return tzname
def _convert_field_to_tz(self, field_name, tzname):
if tzname and settings.USE_TZ:
field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname))
return field_name
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return '(%s)::date' % field_name
def datetime_cast_time_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return '(%s)::time' % field_name
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return self.date_extract_sql(lookup_type, field_name)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name)
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the tuple of returned data.
"""
return cursor.fetchall()
def lookup_cast(self, lookup_type, internal_type=None):
lookup = '%s'
# Cast text lookups to text to allow things like filter(x__contains=4)
if lookup_type in ('iexact', 'contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'):
if internal_type in ('IPAddressField', 'GenericIPAddressField'):
lookup = "HOST(%s)"
elif internal_type in ('CICharField', 'CIEmailField', 'CITextField'):
lookup = '%s::citext'
else:
lookup = "%s::text"
# Use UPPER(x) for case-insensitive lookups; it's faster.
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
lookup = 'UPPER(%s)' % lookup
return lookup
def no_limit_value(self):
return None
def prepare_sql_script(self, sql):
return [sql]
def quote_name(self, name):
if name.startswith('"') and name.endswith('"'):
return name # Quoting once is enough.
return '"%s"' % name
def set_time_zone_sql(self):
return "SET TIME ZONE %s"
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if not tables:
return []
# Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us
# to truncate tables referenced by a foreign key in any other table.
sql_parts = [
style.SQL_KEYWORD('TRUNCATE'),
', '.join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
]
if reset_sequences:
sql_parts.append(style.SQL_KEYWORD('RESTART IDENTITY'))
if allow_cascade:
sql_parts.append(style.SQL_KEYWORD('CASCADE'))
return ['%s;' % ' '.join(sql_parts)]
def sequence_reset_by_name_sql(self, style, sequences):
# 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
# to reset sequence indices
sql = []
for sequence_info in sequences:
table_name = sequence_info['table']
# 'id' will be the case if it's an m2m using an autogenerated
# intermediate table (see BaseDatabaseIntrospection.sequence_list).
column_name = sequence_info['column'] or 'id'
sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % (
style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(self.quote_name(table_name)),
style.SQL_FIELD(column_name),
))
return sql
def tablespace_sql(self, tablespace, inline=False):
if inline:
return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
else:
return "TABLESPACE %s" % self.quote_name(tablespace)
def sequence_reset_sql(self, style, model_list):
from django.db import models
output = []
qn = self.quote_name
for model in model_list:
# Use `coalesce` to set the sequence for each model to the max pk value if there are records,
# or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
# if there are records (as the max pk value is already in use), otherwise set it to false.
# Use pg_get_serial_sequence to get the underlying sequence name from the table name
# and column name (available since PostgreSQL 8)
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
output.append(
"%s setval(pg_get_serial_sequence('%s','%s'), "
"coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (
style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(qn(model._meta.db_table)),
style.SQL_FIELD(f.column),
style.SQL_FIELD(qn(f.column)),
style.SQL_FIELD(qn(f.column)),
style.SQL_KEYWORD('IS NOT'),
style.SQL_KEYWORD('FROM'),
style.SQL_TABLE(qn(model._meta.db_table)),
)
)
break # Only one AutoField is allowed per model, so don't bother continuing.
return output
def prep_for_iexact_query(self, x):
return x
def max_name_length(self):
"""
Return the maximum length of an identifier.
The maximum length of an identifier is 63 by default, but can be
changed by recompiling PostgreSQL after editing the NAMEDATALEN
macro in src/include/pg_config_manual.h.
This implementation returns 63, but can be overridden by a custom
database backend that inherits most of its behavior from this one.
"""
return 63
def distinct_sql(self, fields, params):
if fields:
params = [param for param_list in params for param in param_list]
return (['DISTINCT ON (%s)' % ', '.join(fields)], params)
else:
return ['DISTINCT'], []
def last_executed_query(self, cursor, sql, params):
# https://www.psycopg.org/docs/cursor.html#cursor.query
# The query attribute is a Psycopg extension to the DB API 2.0.
if cursor.query is not None:
return cursor.query.decode()
return None
def return_insert_columns(self, fields):
if not fields:
return '', ()
columns = [
'%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
) for field in fields
]
return 'RETURNING %s' % ', '.join(columns), ()
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
return "VALUES " + values_sql
def adapt_datefield_value(self, value):
return value
def adapt_datetimefield_value(self, value):
return value
def adapt_timefield_value(self, value):
return value
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
return value
def adapt_ipaddressfield_value(self, value):
if value:
return Inet(value)
return None
def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == 'DateField':
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), params
return super().subtract_temporals(internal_type, lhs, rhs)
def explain_query_prefix(self, format=None, **options):
prefix = super().explain_query_prefix(format)
extra = {}
if format:
extra['FORMAT'] = format
if options:
extra.update({
name.upper(): 'true' if value else 'false'
for name, value in options.items()
})
if extra:
prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
return prefix
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts)

View File

@@ -0,0 +1,238 @@
import psycopg2
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.backends.ddl_references import IndexColumns
from django.db.backends.utils import strip_quotes
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
sql_set_sequence_owner = 'ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s'
sql_create_index = (
'CREATE INDEX %(name)s ON %(table)s%(using)s '
'(%(columns)s)%(include)s%(extra)s%(condition)s'
)
sql_create_index_concurrently = (
'CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s '
'(%(columns)s)%(include)s%(extra)s%(condition)s'
)
sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
# Setting the constraint to IMMEDIATE to allow changing data in the same
# transaction.
sql_create_column_inline_fk = (
'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
'; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE'
)
# Setting the constraint to IMMEDIATE runs any deferred checks to allow
# dropping it in the same transaction.
sql_delete_fk = "SET CONSTRAINTS %(name)s IMMEDIATE; ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_delete_procedure = 'DROP FUNCTION %(procedure)s(%(param_types)s)'
def quote_value(self, value):
if isinstance(value, str):
value = value.replace('%', '%%')
adapted = psycopg2.extensions.adapt(value)
if hasattr(adapted, 'encoding'):
adapted.encoding = 'utf8'
# getquoted() returns a quoted bytestring of the adapted value.
return adapted.getquoted().decode()
def _field_indexes_sql(self, model, field):
output = super()._field_indexes_sql(model, field)
like_index_statement = self._create_like_index_sql(model, field)
if like_index_statement is not None:
output.append(like_index_statement)
return output
def _field_data_type(self, field):
if field.is_relation:
return field.rel_db_type(self.connection)
return self.connection.data_types.get(
field.get_internal_type(),
field.db_type(self.connection),
)
def _field_base_data_types(self, field):
# Yield base data types for array fields.
if field.base_field.get_internal_type() == 'ArrayField':
yield from self._field_base_data_types(field.base_field)
else:
yield self._field_data_type(field.base_field)
def _create_like_index_sql(self, model, field):
"""
Return the statement to create an index with varchar operator pattern
when the column type is 'varchar' or 'text', otherwise return None.
"""
db_type = field.db_type(connection=self.connection)
if db_type is not None and (field.db_index or field.unique):
# Fields with database column types of `varchar` and `text` need
# a second index that specifies their operator class, which is
# needed when performing correct LIKE queries outside the
# C locale. See #12234.
#
# The same doesn't apply to array fields such as varchar[size]
# and text[size], so skip them.
if '[' in db_type:
return None
if db_type.startswith('varchar'):
return self._create_index_sql(
model,
fields=[field],
suffix='_like',
opclasses=['varchar_pattern_ops'],
)
elif db_type.startswith('text'):
return self._create_index_sql(
model,
fields=[field],
suffix='_like',
opclasses=['text_pattern_ops'],
)
return None
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
self.sql_alter_column_type = 'ALTER COLUMN %(column)s TYPE %(type)s'
# Cast when data type changed.
using_sql = ' USING %(column)s::%(type)s'
new_internal_type = new_field.get_internal_type()
old_internal_type = old_field.get_internal_type()
if new_internal_type == 'ArrayField' and new_internal_type == old_internal_type:
# Compare base data types for array fields.
if list(self._field_base_data_types(old_field)) != list(self._field_base_data_types(new_field)):
self.sql_alter_column_type += using_sql
elif self._field_data_type(old_field) != self._field_data_type(new_field):
self.sql_alter_column_type += using_sql
# Make ALTER TYPE with SERIAL make sense.
table = strip_quotes(model._meta.db_table)
serial_fields_map = {'bigserial': 'bigint', 'serial': 'integer', 'smallserial': 'smallint'}
if new_type.lower() in serial_fields_map:
column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
return (
(
self.sql_alter_column_type % {
"column": self.quote_name(column),
"type": serial_fields_map[new_type.lower()],
},
[],
),
[
(
self.sql_delete_sequence % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_create_sequence % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_alter_column % {
"table": self.quote_name(table),
"changes": self.sql_alter_column_default % {
"column": self.quote_name(column),
"default": "nextval('%s')" % self.quote_name(sequence_name),
}
},
[],
),
(
self.sql_set_sequence_max % {
"table": self.quote_name(table),
"column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_set_sequence_owner % {
'table': self.quote_name(table),
'column': self.quote_name(column),
'sequence': self.quote_name(sequence_name),
},
[],
),
],
)
elif old_field.db_parameters(connection=self.connection)['type'] in serial_fields_map:
# Drop the sequence if migrating away from AutoField.
column = strip_quotes(new_field.column)
sequence_name = '%s_%s_seq' % (table, column)
fragment, _ = super()._alter_column_type_sql(model, old_field, new_field, new_type)
return fragment, [
(
self.sql_delete_sequence % {
'sequence': self.quote_name(sequence_name),
},
[],
),
]
else:
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
def _alter_field(self, model, old_field, new_field, old_type, new_type,
old_db_params, new_db_params, strict=False):
# Drop indexes on varchar/text/citext columns that are changing to a
# different type.
if (old_field.db_index or old_field.unique) and (
(old_type.startswith('varchar') and not new_type.startswith('varchar')) or
(old_type.startswith('text') and not new_type.startswith('text')) or
(old_type.startswith('citext') and not new_type.startswith('citext'))
):
index_name = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
self.execute(self._delete_index_sql(model, index_name))
super()._alter_field(
model, old_field, new_field, old_type, new_type, old_db_params,
new_db_params, strict,
)
# Added an index? Create any PostgreSQL-specific indexes.
if ((not (old_field.db_index or old_field.unique) and new_field.db_index) or
(not old_field.unique and new_field.unique)):
like_index_statement = self._create_like_index_sql(model, new_field)
if like_index_statement is not None:
self.execute(like_index_statement)
# Removed an index? Drop any PostgreSQL-specific indexes.
if old_field.unique and not (new_field.db_index or new_field.unique):
index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
self.execute(self._delete_index_sql(model, index_to_remove))
def _index_columns(self, table, columns, col_suffixes, opclasses):
if opclasses:
return IndexColumns(table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses)
return super()._index_columns(table, columns, col_suffixes, opclasses)
def add_index(self, model, index, concurrently=False):
self.execute(index.create_sql(model, self, concurrently=concurrently), params=None)
def remove_index(self, model, index, concurrently=False):
self.execute(index.remove_sql(model, self, concurrently=concurrently))
def _delete_index_sql(self, model, name, sql=None, concurrently=False):
sql = self.sql_delete_index_concurrently if concurrently else self.sql_delete_index
return super()._delete_index_sql(model, name, sql)
def _create_index_sql(
self, model, *, fields=None, name=None, suffix='', using='',
db_tablespace=None, col_suffixes=(), sql=None, opclasses=(),
condition=None, concurrently=False, include=None, expressions=None,
):
sql = self.sql_create_index if not concurrently else self.sql_create_index_concurrently
return super()._create_index_sql(
model, fields=fields, name=name, suffix=suffix, using=using,
db_tablespace=db_tablespace, col_suffixes=col_suffixes, sql=sql,
opclasses=opclasses, condition=condition, include=include,
expressions=expressions,
)

View File

@@ -0,0 +1,3 @@
from django.dispatch import Signal
connection_created = Signal()

View File

@@ -0,0 +1,621 @@
"""
SQLite backend for the sqlite3 module in the standard library.
"""
import datetime
import decimal
import functools
import hashlib
import math
import operator
import random
import re
import statistics
import warnings
from itertools import chain
from sqlite3 import dbapi2 as Database
from django.core.exceptions import ImproperlyConfigured
from django.db import IntegrityError
from django.db.backends import utils as backend_utils
from django.db.backends.base.base import (
BaseDatabaseWrapper, timezone_constructor,
)
from django.utils import timezone
from django.utils.asyncio import async_unsafe
from django.utils.dateparse import parse_datetime, parse_time
from django.utils.duration import duration_microseconds
from django.utils.regex_helper import _lazy_re_compile
from .client import DatabaseClient
from .creation import DatabaseCreation
from .features import DatabaseFeatures
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .schema import DatabaseSchemaEditor
def decoder(conv_func):
"""
Convert bytestrings from Python's sqlite3 interface to a regular string.
"""
return lambda s: conv_func(s.decode())
def none_guard(func):
"""
Decorator that returns None if any of the arguments to the decorated
function are None. Many SQL functions return NULL if any of their arguments
are NULL. This decorator simplifies the implementation of this for the
custom functions registered below.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
return None if None in args else func(*args, **kwargs)
return wrapper
def list_aggregate(function):
"""
Return an aggregate class that accumulates values in a list and applies
the provided function to the data.
"""
return type('ListAggregate', (list,), {'finalize': function, 'step': list.append})
def check_sqlite_version():
if Database.sqlite_version_info < (3, 9, 0):
raise ImproperlyConfigured(
'SQLite 3.9.0 or later is required (found %s).' % Database.sqlite_version
)
check_sqlite_version()
Database.register_converter("bool", b'1'.__eq__)
Database.register_converter("time", decoder(parse_time))
Database.register_converter("datetime", decoder(parse_datetime))
Database.register_converter("timestamp", decoder(parse_datetime))
Database.register_adapter(decimal.Decimal, str)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'sqlite'
display_name = 'SQLite'
# SQLite doesn't actually support most of these types, but it "does the right
# thing" given more verbose field definitions, so leave them as is so that
# schema inspection is more useful.
data_types = {
'AutoField': 'integer',
'BigAutoField': 'integer',
'BinaryField': 'BLOB',
'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'datetime',
'DecimalField': 'decimal',
'DurationField': 'bigint',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'real',
'IntegerField': 'integer',
'BigIntegerField': 'bigint',
'IPAddressField': 'char(15)',
'GenericIPAddressField': 'char(39)',
'JSONField': 'text',
'OneToOneField': 'integer',
'PositiveBigIntegerField': 'bigint unsigned',
'PositiveIntegerField': 'integer unsigned',
'PositiveSmallIntegerField': 'smallint unsigned',
'SlugField': 'varchar(%(max_length)s)',
'SmallAutoField': 'integer',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'UUIDField': 'char(32)',
}
data_type_check_constraints = {
'PositiveBigIntegerField': '"%(column)s" >= 0',
'JSONField': '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
'PositiveIntegerField': '"%(column)s" >= 0',
'PositiveSmallIntegerField': '"%(column)s" >= 0',
}
data_types_suffix = {
'AutoField': 'AUTOINCREMENT',
'BigAutoField': 'AUTOINCREMENT',
'SmallAutoField': 'AUTOINCREMENT',
}
# SQLite requires LIKE statements to include an ESCAPE clause if the value
# being escaped has a percent or underscore in it.
# See https://www.sqlite.org/lang_expr.html for an explanation.
operators = {
'exact': '= %s',
'iexact': "LIKE %s ESCAPE '\\'",
'contains': "LIKE %s ESCAPE '\\'",
'icontains': "LIKE %s ESCAPE '\\'",
'regex': 'REGEXP %s',
'iregex': "REGEXP '(?i)' || %s",
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': "LIKE %s ESCAPE '\\'",
'endswith': "LIKE %s ESCAPE '\\'",
'istartswith': "LIKE %s ESCAPE '\\'",
'iendswith': "LIKE %s ESCAPE '\\'",
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
# escaped on database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
'contains': r"LIKE '%%' || {} || '%%' ESCAPE '\'",
'icontains': r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
'startswith': r"LIKE {} || '%%' ESCAPE '\'",
'istartswith': r"LIKE UPPER({}) || '%%' ESCAPE '\'",
'endswith': r"LIKE '%%' || {} ESCAPE '\'",
'iendswith': r"LIKE '%%' || UPPER({}) ESCAPE '\'",
}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
def get_connection_params(self):
settings_dict = self.settings_dict
if not settings_dict['NAME']:
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME value.")
kwargs = {
'database': settings_dict['NAME'],
'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
**settings_dict['OPTIONS'],
}
# Always allow the underlying SQLite connection to be shareable
# between multiple threads. The safe-guarding will be handled at a
# higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
# property. This is necessary as the shareability is disabled by
# default in pysqlite and it cannot be changed once a connection is
# opened.
if 'check_same_thread' in kwargs and kwargs['check_same_thread']:
warnings.warn(
'The `check_same_thread` option was provided and set to '
'True. It will be overridden with False. Use the '
'`DatabaseWrapper.allow_thread_sharing` property instead '
'for controlling thread shareability.',
RuntimeWarning
)
kwargs.update({'check_same_thread': False, 'uri': True})
return kwargs
@async_unsafe
def get_new_connection(self, conn_params):
conn = Database.connect(**conn_params)
create_deterministic_function = functools.partial(
conn.create_function,
deterministic=True,
)
create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract)
create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc)
create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date)
create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time)
create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract)
create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc)
create_deterministic_function('django_time_extract', 2, _sqlite_time_extract)
create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc)
create_deterministic_function('django_time_diff', 2, _sqlite_time_diff)
create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff)
create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta)
create_deterministic_function('regexp', 2, _sqlite_regexp)
create_deterministic_function('ACOS', 1, none_guard(math.acos))
create_deterministic_function('ASIN', 1, none_guard(math.asin))
create_deterministic_function('ATAN', 1, none_guard(math.atan))
create_deterministic_function('ATAN2', 2, none_guard(math.atan2))
create_deterministic_function('BITXOR', 2, none_guard(operator.xor))
create_deterministic_function('CEILING', 1, none_guard(math.ceil))
create_deterministic_function('COS', 1, none_guard(math.cos))
create_deterministic_function('COT', 1, none_guard(lambda x: 1 / math.tan(x)))
create_deterministic_function('DEGREES', 1, none_guard(math.degrees))
create_deterministic_function('EXP', 1, none_guard(math.exp))
create_deterministic_function('FLOOR', 1, none_guard(math.floor))
create_deterministic_function('LN', 1, none_guard(math.log))
create_deterministic_function('LOG', 2, none_guard(lambda x, y: math.log(y, x)))
create_deterministic_function('LPAD', 3, _sqlite_lpad)
create_deterministic_function('MD5', 1, none_guard(lambda x: hashlib.md5(x.encode()).hexdigest()))
create_deterministic_function('MOD', 2, none_guard(math.fmod))
create_deterministic_function('PI', 0, lambda: math.pi)
create_deterministic_function('POWER', 2, none_guard(operator.pow))
create_deterministic_function('RADIANS', 1, none_guard(math.radians))
create_deterministic_function('REPEAT', 2, none_guard(operator.mul))
create_deterministic_function('REVERSE', 1, none_guard(lambda x: x[::-1]))
create_deterministic_function('RPAD', 3, _sqlite_rpad)
create_deterministic_function('SHA1', 1, none_guard(lambda x: hashlib.sha1(x.encode()).hexdigest()))
create_deterministic_function('SHA224', 1, none_guard(lambda x: hashlib.sha224(x.encode()).hexdigest()))
create_deterministic_function('SHA256', 1, none_guard(lambda x: hashlib.sha256(x.encode()).hexdigest()))
create_deterministic_function('SHA384', 1, none_guard(lambda x: hashlib.sha384(x.encode()).hexdigest()))
create_deterministic_function('SHA512', 1, none_guard(lambda x: hashlib.sha512(x.encode()).hexdigest()))
create_deterministic_function('SIGN', 1, none_guard(lambda x: (x > 0) - (x < 0)))
create_deterministic_function('SIN', 1, none_guard(math.sin))
create_deterministic_function('SQRT', 1, none_guard(math.sqrt))
create_deterministic_function('TAN', 1, none_guard(math.tan))
# Don't use the built-in RANDOM() function because it returns a value
# in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
conn.create_function('RAND', 0, random.random)
conn.create_aggregate('STDDEV_POP', 1, list_aggregate(statistics.pstdev))
conn.create_aggregate('STDDEV_SAMP', 1, list_aggregate(statistics.stdev))
conn.create_aggregate('VAR_POP', 1, list_aggregate(statistics.pvariance))
conn.create_aggregate('VAR_SAMP', 1, list_aggregate(statistics.variance))
conn.execute('PRAGMA foreign_keys = ON')
return conn
def init_connection_state(self):
pass
def create_cursor(self, name=None):
return self.connection.cursor(factory=SQLiteCursorWrapper)
@async_unsafe
def close(self):
self.validate_thread_sharing()
# If database is in memory, closing the connection destroys the
# database. To prevent accidental data loss, ignore close requests on
# an in-memory db.
if not self.is_in_memory_db():
BaseDatabaseWrapper.close(self)
def _savepoint_allowed(self):
# When 'isolation_level' is not None, sqlite3 commits before each
# savepoint; it's a bug. When it is None, savepoints don't make sense
# because autocommit is enabled. The only exception is inside 'atomic'
# blocks. To work around that bug, on SQLite, 'atomic' starts a
# transaction explicitly rather than simply disable autocommit.
return self.in_atomic_block
def _set_autocommit(self, autocommit):
if autocommit:
level = None
else:
# sqlite3's internal default is ''. It's different from None.
# See Modules/_sqlite/connection.c.
level = ''
# 'isolation_level' is a misleading API.
# SQLite always runs at the SERIALIZABLE isolation level.
with self.wrap_database_errors:
self.connection.isolation_level = level
def disable_constraint_checking(self):
with self.cursor() as cursor:
cursor.execute('PRAGMA foreign_keys = OFF')
# Foreign key constraints cannot be turned off while in a multi-
# statement transaction. Fetch the current state of the pragma
# to determine if constraints are effectively disabled.
enabled = cursor.execute('PRAGMA foreign_keys').fetchone()[0]
return not bool(enabled)
def enable_constraint_checking(self):
with self.cursor() as cursor:
cursor.execute('PRAGMA foreign_keys = ON')
def check_constraints(self, table_names=None):
"""
Check each table name in `table_names` for rows with invalid foreign
key references. This method is intended to be used in conjunction with
`disable_constraint_checking()` and `enable_constraint_checking()`, to
determine if rows with invalid references were entered while constraint
checks were off.
"""
if self.features.supports_pragma_foreign_key_check:
with self.cursor() as cursor:
if table_names is None:
violations = cursor.execute('PRAGMA foreign_key_check').fetchall()
else:
violations = chain.from_iterable(
cursor.execute(
'PRAGMA foreign_key_check(%s)'
% self.ops.quote_name(table_name)
).fetchall()
for table_name in table_names
)
# See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
for table_name, rowid, referenced_table_name, foreign_key_index in violations:
foreign_key = cursor.execute(
'PRAGMA foreign_key_list(%s)' % self.ops.quote_name(table_name)
).fetchall()[foreign_key_index]
column_name, referenced_column_name = foreign_key[3:5]
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
primary_key_value, bad_value = cursor.execute(
'SELECT %s, %s FROM %s WHERE rowid = %%s' % (
self.ops.quote_name(primary_key_column_name),
self.ops.quote_name(column_name),
self.ops.quote_name(table_name),
),
(rowid,),
).fetchone()
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s." % (
table_name, primary_key_value, table_name, column_name,
bad_value, referenced_table_name, referenced_column_name
)
)
else:
with self.cursor() as cursor:
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name:
continue
key_columns = self.introspection.get_key_columns(cursor, table_name)
for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
"""
% (
primary_key_column_name, column_name, table_name,
referenced_table_name, column_name, referenced_column_name,
column_name, referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s." % (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
)
)
def is_usable(self):
return True
def _start_transaction_under_autocommit(self):
"""
Start a transaction explicitly in autocommit mode.
Staying in autocommit mode works around a bug of sqlite3 that breaks
savepoints when autocommit is disabled.
"""
self.cursor().execute("BEGIN")
def is_in_memory_db(self):
return self.creation.is_in_memory_db(self.settings_dict['NAME'])
FORMAT_QMARK_REGEX = _lazy_re_compile(r'(?<!%)%s')
class SQLiteCursorWrapper(Database.Cursor):
"""
Django uses "format" style placeholders, but pysqlite2 uses "qmark" style.
This fixes it -- but note that if you want to use a literal "%s" in a query,
you'll need to use "%%s".
"""
def execute(self, query, params=None):
if params is None:
return Database.Cursor.execute(self, query)
query = self.convert_query(query)
return Database.Cursor.execute(self, query, params)
def executemany(self, query, param_list):
query = self.convert_query(query)
return Database.Cursor.executemany(self, query, param_list)
def convert_query(self, query):
return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%')
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
if dt is None:
return None
try:
dt = backend_utils.typecast_timestamp(dt)
except (TypeError, ValueError):
return None
if conn_tzname:
dt = dt.replace(tzinfo=timezone_constructor(conn_tzname))
if tzname is not None and tzname != conn_tzname:
tzname, sign, offset = backend_utils.split_tzname_delta(tzname)
if offset:
hours, minutes = offset.split(':')
offset_delta = datetime.timedelta(hours=int(hours), minutes=int(minutes))
dt += offset_delta if sign == '+' else -offset_delta
dt = timezone.localtime(dt, timezone_constructor(tzname))
return dt
def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == 'year':
return "%i-01-01" % dt.year
elif lookup_type == 'quarter':
month_in_quarter = dt.month - (dt.month - 1) % 3
return '%i-%02i-01' % (dt.year, month_in_quarter)
elif lookup_type == 'month':
return "%i-%02i-01" % (dt.year, dt.month)
elif lookup_type == 'week':
dt = dt - datetime.timedelta(days=dt.weekday())
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
elif lookup_type == 'day':
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
if dt is None:
return None
dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt_parsed is None:
try:
dt = backend_utils.typecast_time(dt)
except (ValueError, TypeError):
return None
else:
dt = dt_parsed
if lookup_type == 'hour':
return "%02i:00:00" % dt.hour
elif lookup_type == 'minute':
return "%02i:%02i:00" % (dt.hour, dt.minute)
elif lookup_type == 'second':
return "%02i:%02i:%02i" % (dt.hour, dt.minute, dt.second)
def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
return dt.date().isoformat()
def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
return dt.time().isoformat()
def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == 'week_day':
return (dt.isoweekday() % 7) + 1
elif lookup_type == 'iso_week_day':
return dt.isoweekday()
elif lookup_type == 'week':
return dt.isocalendar()[1]
elif lookup_type == 'quarter':
return math.ceil(dt.month / 3)
elif lookup_type == 'iso_year':
return dt.isocalendar()[0]
else:
return getattr(dt, lookup_type)
def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == 'year':
return "%i-01-01 00:00:00" % dt.year
elif lookup_type == 'quarter':
month_in_quarter = dt.month - (dt.month - 1) % 3
return '%i-%02i-01 00:00:00' % (dt.year, month_in_quarter)
elif lookup_type == 'month':
return "%i-%02i-01 00:00:00" % (dt.year, dt.month)
elif lookup_type == 'week':
dt = dt - datetime.timedelta(days=dt.weekday())
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
elif lookup_type == 'day':
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
elif lookup_type == 'hour':
return "%i-%02i-%02i %02i:00:00" % (dt.year, dt.month, dt.day, dt.hour)
elif lookup_type == 'minute':
return "%i-%02i-%02i %02i:%02i:00" % (dt.year, dt.month, dt.day, dt.hour, dt.minute)
elif lookup_type == 'second':
return "%i-%02i-%02i %02i:%02i:%02i" % (dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second)
def _sqlite_time_extract(lookup_type, dt):
if dt is None:
return None
try:
dt = backend_utils.typecast_time(dt)
except (ValueError, TypeError):
return None
return getattr(dt, lookup_type)
def _sqlite_prepare_dtdelta_param(conn, param):
if conn in ['+', '-']:
if isinstance(param, int):
return datetime.timedelta(0, 0, param)
else:
return backend_utils.typecast_timestamp(param)
return param
@none_guard
def _sqlite_format_dtdelta(conn, lhs, rhs):
"""
LHS and RHS can be either:
- An integer number of microseconds
- A string representing a datetime
- A scalar value, e.g. float
"""
conn = conn.strip()
try:
real_lhs = _sqlite_prepare_dtdelta_param(conn, lhs)
real_rhs = _sqlite_prepare_dtdelta_param(conn, rhs)
except (ValueError, TypeError):
return None
if conn == '+':
# typecast_timestamp returns a date or a datetime without timezone.
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
out = str(real_lhs + real_rhs)
elif conn == '-':
out = str(real_lhs - real_rhs)
elif conn == '*':
out = real_lhs * real_rhs
else:
out = real_lhs / real_rhs
return out
@none_guard
def _sqlite_time_diff(lhs, rhs):
left = backend_utils.typecast_time(lhs)
right = backend_utils.typecast_time(rhs)
return (
(left.hour * 60 * 60 * 1000000) +
(left.minute * 60 * 1000000) +
(left.second * 1000000) +
(left.microsecond) -
(right.hour * 60 * 60 * 1000000) -
(right.minute * 60 * 1000000) -
(right.second * 1000000) -
(right.microsecond)
)
@none_guard
def _sqlite_timestamp_diff(lhs, rhs):
left = backend_utils.typecast_timestamp(lhs)
right = backend_utils.typecast_timestamp(rhs)
return duration_microseconds(left - right)
@none_guard
def _sqlite_regexp(re_pattern, re_string):
return bool(re.search(re_pattern, str(re_string)))
@none_guard
def _sqlite_lpad(text, length, fill_text):
if len(text) >= length:
return text[:length]
return (fill_text * length)[:length - len(text)] + text
@none_guard
def _sqlite_rpad(text, length, fill_text):
return (text + fill_text * length)[:length]

View File

@@ -0,0 +1,10 @@
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = 'sqlite3'
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name, settings_dict['NAME'], *parameters]
return args, None

View File

@@ -0,0 +1,103 @@
import os
import shutil
import sys
from pathlib import Path
from django.db.backends.base.creation import BaseDatabaseCreation
class DatabaseCreation(BaseDatabaseCreation):
@staticmethod
def is_in_memory_db(database_name):
return not isinstance(database_name, Path) and (
database_name == ':memory:' or 'mode=memory' in database_name
)
def _get_test_db_name(self):
test_database_name = self.connection.settings_dict['TEST']['NAME'] or ':memory:'
if test_database_name == ':memory:':
return 'file:memorydb_%s?mode=memory&cache=shared' % self.connection.alias
return test_database_name
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
test_database_name = self._get_test_db_name()
if keepdb:
return test_database_name
if not self.is_in_memory_db(test_database_name):
# Erase the old test database
if verbosity >= 1:
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, test_database_name),
))
if os.access(test_database_name, os.F_OK):
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name
)
if autoclobber or confirm == 'yes':
try:
os.remove(test_database_name)
except Exception as e:
self.log('Got an error deleting the old test database: %s' % e)
sys.exit(2)
else:
self.log('Tests cancelled.')
sys.exit(1)
return test_database_name
def get_test_db_clone_settings(self, suffix):
orig_settings_dict = self.connection.settings_dict
source_database_name = orig_settings_dict['NAME']
if self.is_in_memory_db(source_database_name):
return orig_settings_dict
else:
root, ext = os.path.splitext(orig_settings_dict['NAME'])
return {**orig_settings_dict, 'NAME': '{}_{}{}'.format(root, suffix, ext)}
def _clone_test_db(self, suffix, verbosity, keepdb=False):
source_database_name = self.connection.settings_dict['NAME']
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
# Forking automatically makes a copy of an in-memory database.
if not self.is_in_memory_db(source_database_name):
# Erase the old test database
if os.access(target_database_name, os.F_OK):
if keepdb:
return
if verbosity >= 1:
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, target_database_name),
))
try:
os.remove(target_database_name)
except Exception as e:
self.log('Got an error deleting the old test database: %s' % e)
sys.exit(2)
try:
shutil.copy(source_database_name, target_database_name)
except Exception as e:
self.log('Got an error cloning the test database: %s' % e)
sys.exit(2)
def _destroy_test_db(self, test_database_name, verbosity):
if test_database_name and not self.is_in_memory_db(test_database_name):
# Remove the SQLite database file
os.remove(test_database_name)
def test_db_signature(self):
"""
Return a tuple that uniquely identifies a test database.
This takes into account the special cases of ":memory:" and "" for
SQLite since the databases will be distinct despite having the same
TEST NAME. See https://www.sqlite.org/inmemorydb.html
"""
test_database_name = self._get_test_db_name()
sig = [self.connection.settings_dict['NAME']]
if self.is_in_memory_db(test_database_name):
sig.append(self.connection.alias)
else:
sig.append(test_database_name)
return tuple(sig)

View File

@@ -0,0 +1,126 @@
import operator
import platform
from django.db import transaction
from django.db.backends.base.features import BaseDatabaseFeatures
from django.db.utils import OperationalError
from django.utils.functional import cached_property
from .base import Database
class DatabaseFeatures(BaseDatabaseFeatures):
# SQLite can read from a cursor since SQLite 3.6.5, subject to the caveat
# that statements within a connection aren't isolated from each other. See
# https://sqlite.org/isolation.html.
can_use_chunked_reads = True
test_db_allows_multiple_connections = False
supports_unspecified_pk = True
supports_timezones = False
max_query_params = 999
supports_mixed_date_datetime_comparisons = False
supports_transactions = True
atomic_transactions = False
can_rollback_ddl = True
can_create_inline_fk = False
supports_paramstyle_pyformat = False
can_clone_databases = True
supports_temporal_subtraction = True
ignores_table_name_case = True
supports_cast_with_precision = False
time_cast_precision = 3
can_release_savepoints = True
# Is "ALTER TABLE ... RENAME COLUMN" supported?
can_alter_table_rename_column = Database.sqlite_version_info >= (3, 25, 0)
supports_parentheses_in_compound = False
# Deferred constraint checks can be emulated on SQLite < 3.20 but not in a
# reasonably performant way.
supports_pragma_foreign_key_check = Database.sqlite_version_info >= (3, 20, 0)
can_defer_constraint_checks = supports_pragma_foreign_key_check
supports_functions_in_partial_indexes = Database.sqlite_version_info >= (3, 15, 0)
supports_over_clause = Database.sqlite_version_info >= (3, 25, 0)
supports_frame_range_fixed_distance = Database.sqlite_version_info >= (3, 28, 0)
supports_aggregate_filter_clause = Database.sqlite_version_info >= (3, 30, 1)
supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
order_by_nulls_first = True
supports_json_field_contains = False
test_collations = {
'ci': 'nocase',
'cs': 'binary',
'non_default': 'nocase',
}
@cached_property
def django_test_skips(self):
skips = {
'SQLite stores values rounded to 15 significant digits.': {
'model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding',
},
'SQLite naively remakes the table on field alteration.': {
'schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops',
'schema.tests.SchemaTests.test_unique_and_reverse_m2m',
'schema.tests.SchemaTests.test_alter_field_default_doesnt_perform_queries',
'schema.tests.SchemaTests.test_rename_column_renames_deferred_sql_references',
},
"SQLite doesn't have a constraint.": {
'model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values',
},
"SQLite doesn't support negative precision for ROUND().": {
'db_functions.math.test_round.RoundTests.test_null_with_negative_precision',
'db_functions.math.test_round.RoundTests.test_decimal_with_negative_precision',
'db_functions.math.test_round.RoundTests.test_float_with_negative_precision',
'db_functions.math.test_round.RoundTests.test_integer_with_negative_precision',
},
}
if Database.sqlite_version_info < (3, 27):
skips.update({
'Nondeterministic failure on SQLite < 3.27.': {
'expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank',
},
})
if self.connection.is_in_memory_db():
skips.update({
"the sqlite backend's close() method is a no-op when using an "
"in-memory database": {
'servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections',
'servers.tests.LiveServerTestCloseConnectionTest.test_closes_connections',
},
})
return skips
@cached_property
def supports_atomic_references_rename(self):
# SQLite 3.28.0 bundled with MacOS 10.15 does not support renaming
# references atomically.
if platform.mac_ver()[0].startswith('10.15.') and Database.sqlite_version_info == (3, 28, 0):
return False
return Database.sqlite_version_info >= (3, 26, 0)
@cached_property
def introspected_field_types(self):
return{
**super().introspected_field_types,
'BigAutoField': 'AutoField',
'DurationField': 'BigIntegerField',
'GenericIPAddressField': 'CharField',
'SmallAutoField': 'AutoField',
}
@cached_property
def supports_json_field(self):
with self.connection.cursor() as cursor:
try:
with transaction.atomic(self.connection.alias):
cursor.execute('SELECT JSON(\'{"a": "b"}\')')
except OperationalError:
return False
return True
can_introspect_json_field = property(operator.attrgetter('supports_json_field'))
has_json_object_function = property(operator.attrgetter('supports_json_field'))
@cached_property
def can_return_columns_from_insert(self):
return Database.sqlite_version_info >= (3, 35)
can_return_rows_from_bulk_insert = property(operator.attrgetter('can_return_columns_from_insert'))

View File

@@ -0,0 +1,470 @@
import re
from collections import namedtuple
import sqlparse
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
)
from django.db.models import Index
from django.utils.regex_helper import _lazy_re_compile
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk', 'has_json_constraint'))
field_size_re = _lazy_re_compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$')
def get_field_size(name):
""" Extract the size number from a "varchar(11)" type name """
m = field_size_re.search(name)
return int(m[1]) if m else None
# This light wrapper "fakes" a dictionary interface, because some SQLite data
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
# as a simple dictionary lookup.
class FlexibleFieldLookupDict:
# Maps SQL types to Django Field types. Some of the SQL types have multiple
# entries here because SQLite allows for anything and doesn't normalize the
# field type; it uses whatever was given.
base_data_types_reverse = {
'bool': 'BooleanField',
'boolean': 'BooleanField',
'smallint': 'SmallIntegerField',
'smallint unsigned': 'PositiveSmallIntegerField',
'smallinteger': 'SmallIntegerField',
'int': 'IntegerField',
'integer': 'IntegerField',
'bigint': 'BigIntegerField',
'integer unsigned': 'PositiveIntegerField',
'bigint unsigned': 'PositiveBigIntegerField',
'decimal': 'DecimalField',
'real': 'FloatField',
'text': 'TextField',
'char': 'CharField',
'varchar': 'CharField',
'blob': 'BinaryField',
'date': 'DateField',
'datetime': 'DateTimeField',
'time': 'TimeField',
}
def __getitem__(self, key):
key = key.lower().split('(', 1)[0].strip()
return self.base_data_types_reverse[key]
class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = FlexibleFieldLookupDict()
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.pk and field_type in {'BigIntegerField', 'IntegerField', 'SmallIntegerField'}:
# No support for BigAutoField or SmallAutoField as SQLite treats
# all integer primary keys as signed 64-bit integers.
return 'AutoField'
if description.has_json_constraint:
return 'JSONField'
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
# Skip the sqlite_sequence system table used for autoincrement key
# generation.
cursor.execute("""
SELECT name, type FROM sqlite_master
WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
ORDER BY name""")
return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface.
"""
cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name))
table_info = cursor.fetchall()
collations = self._get_column_collations(cursor, table_name)
json_columns = set()
if self.connection.features.can_introspect_json_field:
for line in table_info:
column = line[1]
json_constraint_sql = '%%json_valid("%s")%%' % column
has_json_constraint = cursor.execute("""
SELECT sql
FROM sqlite_master
WHERE
type = 'table' AND
name = %s AND
sql LIKE %s
""", [table_name, json_constraint_sql]).fetchone()
if has_json_constraint:
json_columns.add(column)
return [
FieldInfo(
name, data_type, None, get_field_size(data_type), None, None,
not notnull, default, collations.get(name), pk == 1, name in json_columns
)
for cid, name, data_type, notnull, default, pk in table_info
]
def get_sequences(self, cursor, table_name, table_fields=()):
pk_col = self.get_primary_key_column(cursor, table_name)
return [{'table': table_name, 'column': pk_col}]
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table.
"""
# Dictionary of relations to return
relations = {}
# Schema for this table
cursor.execute(
"SELECT sql, type FROM sqlite_master "
"WHERE tbl_name = %s AND type IN ('table', 'view')",
[table_name]
)
create_sql, table_type = cursor.fetchone()
if table_type == 'view':
# It might be a view, then no results will be returned
return relations
results = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
# Walk through and look for references to other tables. SQLite doesn't
# really have enforced references, but since it echoes out the SQL used
# to create the table we can look for REFERENCES statements used there.
for field_desc in results.split(','):
field_desc = field_desc.strip()
if field_desc.startswith("UNIQUE"):
continue
m = re.search(r'references (\S*) ?\(["|]?(.*)["|]?\)', field_desc, re.I)
if not m:
continue
table, column = [s.strip('"') for s in m.groups()]
if field_desc.startswith("FOREIGN KEY"):
# Find name of the target FK field
m = re.match(r'FOREIGN KEY\s*\(([^\)]*)\).*', field_desc, re.I)
field_name = m[1].strip('"')
else:
field_name = field_desc.split()[0].strip('"')
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s", [table])
result = cursor.fetchall()[0]
other_table_results = result[0].strip()
li, ri = other_table_results.index('('), other_table_results.rindex(')')
other_table_results = other_table_results[li + 1:ri]
for other_desc in other_table_results.split(','):
other_desc = other_desc.strip()
if other_desc.startswith('UNIQUE'):
continue
other_name = other_desc.split(' ', 1)[0].strip('"')
if other_name == column:
relations[field_name] = (other_name, table)
break
return relations
def get_key_columns(self, cursor, table_name):
"""
Return a list of (column_name, referenced_table_name, referenced_column_name)
for all key columns in given table.
"""
key_columns = []
# Schema for this table
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"])
results = cursor.fetchone()[0].strip()
results = results[results.index('(') + 1:results.rindex(')')]
# Walk through and look for references to other tables. SQLite doesn't
# really have enforced references, but since it echoes out the SQL used
# to create the table we can look for REFERENCES statements used there.
for field_index, field_desc in enumerate(results.split(',')):
field_desc = field_desc.strip()
if field_desc.startswith("UNIQUE"):
continue
m = re.search(r'"(.*)".*references (.*) \(["|](.*)["|]\)', field_desc, re.I)
if not m:
continue
# This will append (column_name, referenced_table_name, referenced_column_name) to key_columns
key_columns.append(tuple(s.strip('"') for s in m.groups()))
return key_columns
def get_primary_key_column(self, cursor, table_name):
"""Return the column name of the primary key for the given table."""
# Don't use PRAGMA because that causes issues with some transactions
cursor.execute(
"SELECT sql, type FROM sqlite_master "
"WHERE tbl_name = %s AND type IN ('table', 'view')",
[table_name]
)
row = cursor.fetchone()
if row is None:
raise ValueError("Table %s does not exist" % table_name)
create_sql, table_type = row
if table_type == 'view':
# Views don't have a primary key.
return None
fields_sql = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
for field_desc in fields_sql.split(','):
field_desc = field_desc.strip()
m = re.match(r'(?:(?:["`\[])(.*)(?:["`\]])|(\w+)).*PRIMARY KEY.*', field_desc)
if m:
return m[1] if m[1] else m[2]
return None
def _get_foreign_key_constraints(self, cursor, table_name):
constraints = {}
cursor.execute('PRAGMA foreign_key_list(%s)' % self.connection.ops.quote_name(table_name))
for row in cursor.fetchall():
# Remaining on_update/on_delete/match values are of no interest.
id_, _, table, from_, to = row[:5]
constraints['fk_%d' % id_] = {
'columns': [from_],
'primary_key': False,
'unique': False,
'foreign_key': (table, to),
'check': False,
'index': False,
}
return constraints
def _parse_column_or_constraint_definition(self, tokens, columns):
token = None
is_constraint_definition = None
field_name = None
constraint_name = None
unique = False
unique_columns = []
check = False
check_columns = []
braces_deep = 0
for token in tokens:
if token.match(sqlparse.tokens.Punctuation, '('):
braces_deep += 1
elif token.match(sqlparse.tokens.Punctuation, ')'):
braces_deep -= 1
if braces_deep < 0:
# End of columns and constraints for table definition.
break
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','):
# End of current column or constraint definition.
break
# Detect column or constraint definition by first token.
if is_constraint_definition is None:
is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT')
if is_constraint_definition:
continue
if is_constraint_definition:
# Detect constraint name by second token.
if constraint_name is None:
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
constraint_name = token.value
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
constraint_name = token.value[1:-1]
# Start constraint columns parsing after UNIQUE keyword.
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
unique = True
unique_braces_deep = braces_deep
elif unique:
if unique_braces_deep == braces_deep:
if unique_columns:
# Stop constraint parsing.
unique = False
continue
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
unique_columns.append(token.value)
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
unique_columns.append(token.value[1:-1])
else:
# Detect field name by first token.
if field_name is None:
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
field_name = token.value
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
field_name = token.value[1:-1]
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
unique_columns = [field_name]
# Start constraint columns parsing after CHECK keyword.
if token.match(sqlparse.tokens.Keyword, 'CHECK'):
check = True
check_braces_deep = braces_deep
elif check:
if check_braces_deep == braces_deep:
if check_columns:
# Stop constraint parsing.
check = False
continue
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
if token.value in columns:
check_columns.append(token.value)
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
if token.value[1:-1] in columns:
check_columns.append(token.value[1:-1])
unique_constraint = {
'unique': True,
'columns': unique_columns,
'primary_key': False,
'foreign_key': None,
'check': False,
'index': False,
} if unique_columns else None
check_constraint = {
'check': True,
'columns': check_columns,
'primary_key': False,
'unique': False,
'foreign_key': None,
'index': False,
} if check_columns else None
return constraint_name, unique_constraint, check_constraint, token
def _parse_table_constraints(self, sql, columns):
# Check constraint parsing is based of SQLite syntax diagram.
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint
statement = sqlparse.parse(sql)[0]
constraints = {}
unnamed_constrains_index = 0
tokens = (token for token in statement.flatten() if not token.is_whitespace)
# Go to columns and constraint definition
for token in tokens:
if token.match(sqlparse.tokens.Punctuation, '('):
break
# Parse columns and constraint definition
while True:
constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns)
if unique:
if constraint_name:
constraints[constraint_name] = unique
else:
unnamed_constrains_index += 1
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique
if check:
if constraint_name:
constraints[constraint_name] = check
else:
unnamed_constrains_index += 1
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check
if end_token.match(sqlparse.tokens.Punctuation, ')'):
break
return constraints
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns.
"""
constraints = {}
# Find inline check constraints.
try:
table_schema = cursor.execute(
"SELECT sql FROM sqlite_master WHERE type='table' and name=%s" % (
self.connection.ops.quote_name(table_name),
)
).fetchone()[0]
except TypeError:
# table_name is a view.
pass
else:
columns = {info.name for info in self.get_table_description(cursor, table_name)}
constraints.update(self._parse_table_constraints(table_schema, columns))
# Get the index info
cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name))
for row in cursor.fetchall():
# SQLite 3.8.9+ has 5 columns, however older versions only give 3
# columns. Discard last 2 columns if there.
number, index, unique = row[:3]
cursor.execute(
"SELECT sql FROM sqlite_master "
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
)
# There's at most one row.
sql, = cursor.fetchone() or (None,)
# Inline constraints are already detected in
# _parse_table_constraints(). The reasons to avoid fetching inline
# constraints from `PRAGMA index_list` are:
# - Inline constraints can have a different name and information
# than what `PRAGMA index_list` gives.
# - Not all inline constraints may appear in `PRAGMA index_list`.
if not sql:
# An inline constraint
continue
# Get the index info for that index
cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
for index_rank, column_rank, column in cursor.fetchall():
if index not in constraints:
constraints[index] = {
"columns": [],
"primary_key": False,
"unique": bool(unique),
"foreign_key": None,
"check": False,
"index": True,
}
constraints[index]['columns'].append(column)
# Add type and column orders for indexes
if constraints[index]['index']:
# SQLite doesn't support any index type other than b-tree
constraints[index]['type'] = Index.suffix
orders = self._get_index_columns_orders(sql)
if orders is not None:
constraints[index]['orders'] = orders
# Get the PK
pk_column = self.get_primary_key_column(cursor, table_name)
if pk_column:
# SQLite doesn't actually give a name to the PK constraint,
# so we invent one. This is fine, as the SQLite backend never
# deletes PK constraints by name, as you can't delete constraints
# in SQLite; we remake the table with a new PK instead.
constraints["__primary__"] = {
"columns": [pk_column],
"primary_key": True,
"unique": False, # It's not actually a unique constraint.
"foreign_key": None,
"check": False,
"index": False,
}
constraints.update(self._get_foreign_key_constraints(cursor, table_name))
return constraints
def _get_index_columns_orders(self, sql):
tokens = sqlparse.parse(sql)[0]
for token in tokens:
if isinstance(token, sqlparse.sql.Parenthesis):
columns = str(token).strip('()').split(', ')
return ['DESC' if info.endswith('DESC') else 'ASC' for info in columns]
return None
def _get_column_collations(self, cursor, table_name):
row = cursor.execute("""
SELECT sql
FROM sqlite_master
WHERE type = 'table' AND name = %s
""", [table_name]).fetchone()
if not row:
return {}
sql = row[0]
columns = str(sqlparse.parse(sql)[0][-1]).strip('()').split(', ')
collations = {}
for column in columns:
tokens = column[1:].split()
column_name = tokens[0].strip('"')
for index, token in enumerate(tokens):
if token == 'COLLATE':
collation = tokens[index + 1]
break
else:
collation = None
collations[column_name] = collation
return collations

View File

@@ -0,0 +1,386 @@
import datetime
import decimal
import uuid
from functools import lru_cache
from itertools import chain
from django.conf import settings
from django.core.exceptions import FieldError
from django.db import DatabaseError, NotSupportedError, models
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models.expressions import Col
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.functional import cached_property
class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = 'text'
cast_data_types = {
'DateField': 'TEXT',
'DateTimeField': 'TEXT',
}
explain_prefix = 'EXPLAIN QUERY PLAN'
# List of datatypes to that cannot be extracted with JSON_EXTRACT() on
# SQLite. Use JSON_TYPE() instead.
jsonfield_datatype_values = frozenset(['null', 'false', 'true'])
def bulk_batch_size(self, fields, objs):
"""
SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
999 variables per query.
If there's only a single field to insert, the limit is 500
(SQLITE_MAX_COMPOUND_SELECT).
"""
if len(fields) == 1:
return 500
elif len(fields) > 1:
return self.connection.features.max_query_params // len(fields)
else:
return len(objs)
def check_expression_support(self, expression):
bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
if isinstance(expression, bad_aggregates):
for expr in expression.get_source_expressions():
try:
output_field = expr.output_field
except (AttributeError, FieldError):
# Not every subexpression has an output_field which is fine
# to ignore.
pass
else:
if isinstance(output_field, bad_fields):
raise NotSupportedError(
'You cannot use Sum, Avg, StdDev, and Variance '
'aggregations on date/time fields in sqlite3 '
'since date/time is saved as text.'
)
if (
isinstance(expression, models.Aggregate) and
expression.distinct and
len(expression.source_expressions) > 1
):
raise NotSupportedError(
"SQLite doesn't support DISTINCT on aggregate functions "
"accepting multiple arguments."
)
def date_extract_sql(self, lookup_type, field_name):
"""
Support EXTRACT with a user-defined function django_date_extract()
that's registered in connect(). Use single quotes because this is a
string and could otherwise cause a collision with a field name.
"""
return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name)
def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the list of returned data.
"""
return cursor.fetchall()
def format_for_duration_arithmetic(self, sql):
"""Do nothing since formatting is handled in the custom function."""
return sql
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
return "django_date_trunc('%s', %s, %s, %s)" % (
lookup_type.lower(),
field_name,
*self._convert_tznames_to_sql(tzname),
)
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
return "django_time_trunc('%s', %s, %s, %s)" % (
lookup_type.lower(),
field_name,
*self._convert_tznames_to_sql(tzname),
)
def _convert_tznames_to_sql(self, tzname):
if tzname and settings.USE_TZ:
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
return 'NULL', 'NULL'
def datetime_cast_date_sql(self, field_name, tzname):
return 'django_datetime_cast_date(%s, %s, %s)' % (
field_name, *self._convert_tznames_to_sql(tzname),
)
def datetime_cast_time_sql(self, field_name, tzname):
return 'django_datetime_cast_time(%s, %s, %s)' % (
field_name, *self._convert_tznames_to_sql(tzname),
)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
return "django_datetime_extract('%s', %s, %s, %s)" % (
lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
return "django_datetime_trunc('%s', %s, %s, %s)" % (
lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
)
def time_extract_sql(self, lookup_type, field_name):
return "django_time_extract('%s', %s)" % (lookup_type.lower(), field_name)
def pk_default_value(self):
return "NULL"
def _quote_params_for_last_executed_query(self, params):
"""
Only for last_executed_query! Don't use this to execute SQL queries!
"""
# This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
# number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
# number of return values, default = 2000). Since Python's sqlite3
# module doesn't expose the get_limit() C API, assume the default
# limits are in effect and split the work in batches if needed.
BATCH_SIZE = 999
if len(params) > BATCH_SIZE:
results = ()
for index in range(0, len(params), BATCH_SIZE):
chunk = params[index:index + BATCH_SIZE]
results += self._quote_params_for_last_executed_query(chunk)
return results
sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params))
# Bypass Django's wrappers and use the underlying sqlite3 connection
# to avoid logging this query - it would trigger infinite recursion.
cursor = self.connection.connection.cursor()
# Native sqlite3 cursors cannot be used as context managers.
try:
return cursor.execute(sql, params).fetchone()
finally:
cursor.close()
def last_executed_query(self, cursor, sql, params):
# Python substitutes parameters in Modules/_sqlite/cursor.c with:
# pysqlite_statement_bind_parameters(self->statement, parameters, allow_8bit_chars);
# Unfortunately there is no way to reach self->statement from Python,
# so we quote and substitute parameters manually.
if params:
if isinstance(params, (list, tuple)):
params = self._quote_params_for_last_executed_query(params)
else:
values = tuple(params.values())
values = self._quote_params_for_last_executed_query(values)
params = dict(zip(params, values))
return sql % params
# For consistency with SQLiteCursorWrapper.execute(), just return sql
# when there are no parameters. See #13648 and #17158.
else:
return sql
def quote_name(self, name):
if name.startswith('"') and name.endswith('"'):
return name # Quoting once is enough.
return '"%s"' % name
def no_limit_value(self):
return -1
def __references_graph(self, table_name):
query = """
WITH tables AS (
SELECT %s name
UNION
SELECT sqlite_master.name
FROM sqlite_master
JOIN tables ON (sql REGEXP %s || tables.name || %s)
) SELECT name FROM tables;
"""
params = (
table_name,
r'(?i)\s+references\s+("|\')?',
r'("|\')?\s*\(',
)
with self.connection.cursor() as cursor:
results = cursor.execute(query, params)
return [row[0] for row in results.fetchall()]
@cached_property
def _references_graph(self):
# 512 is large enough to fit the ~330 tables (as of this writing) in
# Django's test suite.
return lru_cache(maxsize=512)(self.__references_graph)
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if tables and allow_cascade:
# Simulate TRUNCATE CASCADE by recursively collecting the tables
# referencing the tables to be flushed.
tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
sql = ['%s %s %s;' % (
style.SQL_KEYWORD('DELETE'),
style.SQL_KEYWORD('FROM'),
style.SQL_FIELD(self.quote_name(table))
) for table in tables]
if reset_sequences:
sequences = [{'table': table} for table in tables]
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql
def sequence_reset_by_name_sql(self, style, sequences):
if not sequences:
return []
return [
'%s %s %s %s = 0 %s %s %s (%s);' % (
style.SQL_KEYWORD('UPDATE'),
style.SQL_TABLE(self.quote_name('sqlite_sequence')),
style.SQL_KEYWORD('SET'),
style.SQL_FIELD(self.quote_name('seq')),
style.SQL_KEYWORD('WHERE'),
style.SQL_FIELD(self.quote_name('name')),
style.SQL_KEYWORD('IN'),
', '.join([
"'%s'" % sequence_info['table'] for sequence_info in sequences
]),
),
]
def adapt_datetimefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# SQLite doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.")
return str(value)
def adapt_timefield_value(self, value):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# SQLite doesn't support tz-aware datetimes
if timezone.is_aware(value):
raise ValueError("SQLite backend does not support timezone-aware times.")
return str(value)
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == 'DateTimeField':
converters.append(self.convert_datetimefield_value)
elif internal_type == 'DateField':
converters.append(self.convert_datefield_value)
elif internal_type == 'TimeField':
converters.append(self.convert_timefield_value)
elif internal_type == 'DecimalField':
converters.append(self.get_decimalfield_converter(expression))
elif internal_type == 'UUIDField':
converters.append(self.convert_uuidfield_value)
elif internal_type == 'BooleanField':
converters.append(self.convert_booleanfield_value)
return converters
def convert_datetimefield_value(self, value, expression, connection):
if value is not None:
if not isinstance(value, datetime.datetime):
value = parse_datetime(value)
if settings.USE_TZ and not timezone.is_aware(value):
value = timezone.make_aware(value, self.connection.timezone)
return value
def convert_datefield_value(self, value, expression, connection):
if value is not None:
if not isinstance(value, datetime.date):
value = parse_date(value)
return value
def convert_timefield_value(self, value, expression, connection):
if value is not None:
if not isinstance(value, datetime.time):
value = parse_time(value)
return value
def get_decimalfield_converter(self, expression):
# SQLite stores only 15 significant digits. Digits coming from
# float inaccuracy must be removed.
create_decimal = decimal.Context(prec=15).create_decimal_from_float
if isinstance(expression, Col):
quantize_value = decimal.Decimal(1).scaleb(-expression.output_field.decimal_places)
def converter(value, expression, connection):
if value is not None:
return create_decimal(value).quantize(quantize_value, context=expression.output_field.context)
else:
def converter(value, expression, connection):
if value is not None:
return create_decimal(value)
return converter
def convert_uuidfield_value(self, value, expression, connection):
if value is not None:
value = uuid.UUID(value)
return value
def convert_booleanfield_value(self, value, expression, connection):
return bool(value) if value in (1, 0) else value
def bulk_insert_sql(self, fields, placeholder_rows):
return " UNION ALL ".join(
"SELECT %s" % ", ".join(row)
for row in placeholder_rows
)
def combine_expression(self, connector, sub_expressions):
# SQLite doesn't have a ^ operator, so use the user-defined POWER
# function that's registered in connect().
if connector == '^':
return 'POWER(%s)' % ','.join(sub_expressions)
elif connector == '#':
return 'BITXOR(%s)' % ','.join(sub_expressions)
return super().combine_expression(connector, sub_expressions)
def combine_duration_expression(self, connector, sub_expressions):
if connector not in ['+', '-', '*', '/']:
raise DatabaseError('Invalid connector for timedelta: %s.' % connector)
fn_params = ["'%s'" % connector] + sub_expressions
if len(fn_params) > 3:
raise ValueError('Too many params for timedelta operations.')
return "django_format_dtdelta(%s)" % ', '.join(fn_params)
def integer_field_range(self, internal_type):
# SQLite doesn't enforce any integer constraints
return (None, None)
def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
if internal_type == 'TimeField':
return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params
def insert_statement(self, ignore_conflicts=False):
return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
def return_insert_columns(self, fields):
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
if not fields:
return '', ()
columns = [
'%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
) for field in fields
]
return 'RETURNING %s' % ', '.join(columns), ()

View File

@@ -0,0 +1,444 @@
import copy
from decimal import Decimal
from django.apps.registry import Apps
from django.db import NotSupportedError
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.backends.ddl_references import Statement
from django.db.backends.utils import strip_quotes
from django.db.models import UniqueConstraint
from django.db.transaction import atomic
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_delete_table = "DROP TABLE %(table)s"
sql_create_fk = None
sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
sql_delete_unique = "DROP INDEX %(name)s"
def __enter__(self):
# Some SQLite schema alterations need foreign key constraints to be
# disabled. Enforce it here for the duration of the schema edition.
if not self.connection.disable_constraint_checking():
raise NotSupportedError(
'SQLite schema editor cannot be used while foreign key '
'constraint checks are enabled. Make sure to disable them '
'before entering a transaction.atomic() context because '
'SQLite does not support disabling them in the middle of '
'a multi-statement transaction.'
)
return super().__enter__()
def __exit__(self, exc_type, exc_value, traceback):
self.connection.check_constraints()
super().__exit__(exc_type, exc_value, traceback)
self.connection.enable_constraint_checking()
def quote_value(self, value):
# The backend "mostly works" without this function and there are use
# cases for compiling Python without the sqlite3 libraries (e.g.
# security hardening).
try:
import sqlite3
value = sqlite3.adapt(value)
except ImportError:
pass
except sqlite3.ProgrammingError:
pass
# Manual emulation of SQLite parameter quoting
if isinstance(value, bool):
return str(int(value))
elif isinstance(value, (Decimal, float, int)):
return str(value)
elif isinstance(value, str):
return "'%s'" % value.replace("\'", "\'\'")
elif value is None:
return "NULL"
elif isinstance(value, (bytes, bytearray, memoryview)):
# Bytes are only allowed for BLOB fields, encoded as string
# literals containing hexadecimal data and preceded by a single "X"
# character.
return "X'%s'" % value.hex()
else:
raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value)))
def _is_referenced_by_fk_constraint(self, table_name, column_name=None, ignore_self=False):
"""
Return whether or not the provided table name is referenced by another
one. If `column_name` is specified, only references pointing to that
column are considered. If `ignore_self` is True, self-referential
constraints are ignored.
"""
with self.connection.cursor() as cursor:
for other_table in self.connection.introspection.get_table_list(cursor):
if ignore_self and other_table.name == table_name:
continue
constraints = self.connection.introspection._get_foreign_key_constraints(cursor, other_table.name)
for constraint in constraints.values():
constraint_table, constraint_column = constraint['foreign_key']
if (constraint_table == table_name and
(column_name is None or constraint_column == column_name)):
return True
return False
def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True):
if (not self.connection.features.supports_atomic_references_rename and
disable_constraints and self._is_referenced_by_fk_constraint(old_db_table)):
if self.connection.in_atomic_block:
raise NotSupportedError((
'Renaming the %r table while in a transaction is not '
'supported on SQLite < 3.26 because it would break referential '
'integrity. Try adding `atomic = False` to the Migration class.'
) % old_db_table)
self.connection.enable_constraint_checking()
super().alter_db_table(model, old_db_table, new_db_table)
self.connection.disable_constraint_checking()
else:
super().alter_db_table(model, old_db_table, new_db_table)
def alter_field(self, model, old_field, new_field, strict=False):
if not self._field_should_be_altered(old_field, new_field):
return
old_field_name = old_field.name
table_name = model._meta.db_table
_, old_column_name = old_field.get_attname_column()
if (new_field.name != old_field_name and
not self.connection.features.supports_atomic_references_rename and
self._is_referenced_by_fk_constraint(table_name, old_column_name, ignore_self=True)):
if self.connection.in_atomic_block:
raise NotSupportedError((
'Renaming the %r.%r column while in a transaction is not '
'supported on SQLite < 3.26 because it would break referential '
'integrity. Try adding `atomic = False` to the Migration class.'
) % (model._meta.db_table, old_field_name))
with atomic(self.connection.alias):
super().alter_field(model, old_field, new_field, strict=strict)
# Follow SQLite's documented procedure for performing changes
# that don't affect the on-disk content.
# https://sqlite.org/lang_altertable.html#otheralter
with self.connection.cursor() as cursor:
schema_version = cursor.execute('PRAGMA schema_version').fetchone()[0]
cursor.execute('PRAGMA writable_schema = 1')
references_template = ' REFERENCES "%s" ("%%s") ' % table_name
new_column_name = new_field.get_attname_column()[1]
search = references_template % old_column_name
replacement = references_template % new_column_name
cursor.execute('UPDATE sqlite_master SET sql = replace(sql, %s, %s)', (search, replacement))
cursor.execute('PRAGMA schema_version = %d' % (schema_version + 1))
cursor.execute('PRAGMA writable_schema = 0')
# The integrity check will raise an exception and rollback
# the transaction if the sqlite_master updates corrupt the
# database.
cursor.execute('PRAGMA integrity_check')
# Perform a VACUUM to refresh the database representation from
# the sqlite_master table.
with self.connection.cursor() as cursor:
cursor.execute('VACUUM')
else:
super().alter_field(model, old_field, new_field, strict=strict)
def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None):
"""
Shortcut to transform a model from old_model into new_model
This follows the correct procedure to perform non-rename or column
addition operations based on SQLite's documentation
https://www.sqlite.org/lang_altertable.html#caution
The essential steps are:
1. Create a table with the updated definition called "new__app_model"
2. Copy the data from the existing "app_model" table to the new table
3. Drop the "app_model" table
4. Rename the "new__app_model" table to "app_model"
5. Restore any index of the previous "app_model" table.
"""
# Self-referential fields must be recreated rather than copied from
# the old model to ensure their remote_field.field_name doesn't refer
# to an altered field.
def is_self_referential(f):
return f.is_relation and f.remote_field.model is model
# Work out the new fields dict / mapping
body = {
f.name: f.clone() if is_self_referential(f) else f
for f in model._meta.local_concrete_fields
}
# Since mapping might mix column names and default values,
# its values must be already quoted.
mapping = {f.column: self.quote_name(f.column) for f in model._meta.local_concrete_fields}
# This maps field names (not columns) for things like unique_together
rename_mapping = {}
# If any of the new or altered fields is introducing a new PK,
# remove the old one
restore_pk_field = None
if getattr(create_field, 'primary_key', False) or (
alter_field and getattr(alter_field[1], 'primary_key', False)):
for name, field in list(body.items()):
if field.primary_key:
field.primary_key = False
restore_pk_field = field
if field.auto_created:
del body[name]
del mapping[field.column]
# Add in any created fields
if create_field:
body[create_field.name] = create_field
# Choose a default and insert it into the copy map
if not create_field.many_to_many and create_field.concrete:
mapping[create_field.column] = self.quote_value(
self.effective_default(create_field)
)
# Add in any altered fields
if alter_field:
old_field, new_field = alter_field
body.pop(old_field.name, None)
mapping.pop(old_field.column, None)
body[new_field.name] = new_field
if old_field.null and not new_field.null:
case_sql = "coalesce(%(col)s, %(default)s)" % {
'col': self.quote_name(old_field.column),
'default': self.quote_value(self.effective_default(new_field))
}
mapping[new_field.column] = case_sql
else:
mapping[new_field.column] = self.quote_name(old_field.column)
rename_mapping[old_field.name] = new_field.name
# Remove any deleted fields
if delete_field:
del body[delete_field.name]
del mapping[delete_field.column]
# Remove any implicit M2M tables
if delete_field.many_to_many and delete_field.remote_field.through._meta.auto_created:
return self.delete_model(delete_field.remote_field.through)
# Work inside a new app registry
apps = Apps()
# Work out the new value of unique_together, taking renames into
# account
unique_together = [
[rename_mapping.get(n, n) for n in unique]
for unique in model._meta.unique_together
]
# Work out the new value for index_together, taking renames into
# account
index_together = [
[rename_mapping.get(n, n) for n in index]
for index in model._meta.index_together
]
indexes = model._meta.indexes
if delete_field:
indexes = [
index for index in indexes
if delete_field.name not in index.fields
]
constraints = list(model._meta.constraints)
# Provide isolated instances of the fields to the new model body so
# that the existing model's internals aren't interfered with when
# the dummy model is constructed.
body_copy = copy.deepcopy(body)
# Construct a new model with the new fields to allow self referential
# primary key to resolve to. This model won't ever be materialized as a
# table and solely exists for foreign key reference resolution purposes.
# This wouldn't be required if the schema editor was operating on model
# states instead of rendered models.
meta_contents = {
'app_label': model._meta.app_label,
'db_table': model._meta.db_table,
'unique_together': unique_together,
'index_together': index_together,
'indexes': indexes,
'constraints': constraints,
'apps': apps,
}
meta = type("Meta", (), meta_contents)
body_copy['Meta'] = meta
body_copy['__module__'] = model.__module__
type(model._meta.object_name, model.__bases__, body_copy)
# Construct a model with a renamed table name.
body_copy = copy.deepcopy(body)
meta_contents = {
'app_label': model._meta.app_label,
'db_table': 'new__%s' % strip_quotes(model._meta.db_table),
'unique_together': unique_together,
'index_together': index_together,
'indexes': indexes,
'constraints': constraints,
'apps': apps,
}
meta = type("Meta", (), meta_contents)
body_copy['Meta'] = meta
body_copy['__module__'] = model.__module__
new_model = type('New%s' % model._meta.object_name, model.__bases__, body_copy)
# Create a new table with the updated schema.
self.create_model(new_model)
# Copy data from the old table into the new table
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
self.quote_name(new_model._meta.db_table),
', '.join(self.quote_name(x) for x in mapping),
', '.join(mapping.values()),
self.quote_name(model._meta.db_table),
))
# Delete the old table to make way for the new
self.delete_model(model, handle_autom2m=False)
# Rename the new table to take way for the old
self.alter_db_table(
new_model, new_model._meta.db_table, model._meta.db_table,
disable_constraints=False,
)
# Run deferred SQL on correct table
for sql in self.deferred_sql:
self.execute(sql)
self.deferred_sql = []
# Fix any PK-removed field
if restore_pk_field:
restore_pk_field.primary_key = True
def delete_model(self, model, handle_autom2m=True):
if handle_autom2m:
super().delete_model(model)
else:
# Delete the table (and only that)
self.execute(self.sql_delete_table % {
"table": self.quote_name(model._meta.db_table),
})
# Remove all deferred statements referencing the deleted table.
for sql in list(self.deferred_sql):
if isinstance(sql, Statement) and sql.references_table(model._meta.db_table):
self.deferred_sql.remove(sql)
def add_field(self, model, field):
"""
Create a field on a model. Usually involves adding a column, but may
involve adding a table instead (for M2M fields).
"""
# Special-case implicit M2M tables
if field.many_to_many and field.remote_field.through._meta.auto_created:
return self.create_model(field.remote_field.through)
self._remake_table(model, create_field=field)
def remove_field(self, model, field):
"""
Remove a field from a model. Usually involves deleting a column,
but for M2Ms may involve deleting a table.
"""
# M2M fields are a special case
if field.many_to_many:
# For implicit M2M tables, delete the auto-created table
if field.remote_field.through._meta.auto_created:
self.delete_model(field.remote_field.through)
# For explicit "through" M2M fields, do nothing
# For everything else, remake.
else:
# It might not actually have a column behind it
if field.db_parameters(connection=self.connection)['type'] is None:
return
self._remake_table(model, delete_field=field)
def _alter_field(self, model, old_field, new_field, old_type, new_type,
old_db_params, new_db_params, strict=False):
"""Perform a "physical" (non-ManyToMany) field update."""
# Use "ALTER TABLE ... RENAME COLUMN" if only the column name
# changed and there aren't any constraints.
if (self.connection.features.can_alter_table_rename_column and
old_field.column != new_field.column and
self.column_sql(model, old_field) == self.column_sql(model, new_field) and
not (old_field.remote_field and old_field.db_constraint or
new_field.remote_field and new_field.db_constraint)):
return self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
# Alter by remaking table
self._remake_table(model, alter_field=(old_field, new_field))
# Rebuild tables with FKs pointing to this field.
if new_field.unique and old_type != new_type:
related_models = set()
opts = new_field.model._meta
for remote_field in opts.related_objects:
# Ignore self-relationship since the table was already rebuilt.
if remote_field.related_model == model:
continue
if not remote_field.many_to_many:
if remote_field.field_name == new_field.name:
related_models.add(remote_field.related_model)
elif new_field.primary_key and remote_field.through._meta.auto_created:
related_models.add(remote_field.through)
if new_field.primary_key:
for many_to_many in opts.many_to_many:
# Ignore self-relationship since the table was already rebuilt.
if many_to_many.related_model == model:
continue
if many_to_many.remote_field.through._meta.auto_created:
related_models.add(many_to_many.remote_field.through)
for related_model in related_models:
self._remake_table(related_model)
def _alter_many_to_many(self, model, old_field, new_field, strict):
"""Alter M2Ms to repoint their to= endpoints."""
if old_field.remote_field.through._meta.db_table == new_field.remote_field.through._meta.db_table:
# The field name didn't change, but some options did; we have to propagate this altering.
self._remake_table(
old_field.remote_field.through,
alter_field=(
# We need the field that points to the target model, so we can tell alter_field to change it -
# this is m2m_reverse_field_name() (as opposed to m2m_field_name, which points to our model)
old_field.remote_field.through._meta.get_field(old_field.m2m_reverse_field_name()),
new_field.remote_field.through._meta.get_field(new_field.m2m_reverse_field_name()),
),
)
return
# Make a new through table
self.create_model(new_field.remote_field.through)
# Copy the data across
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
self.quote_name(new_field.remote_field.through._meta.db_table),
', '.join([
"id",
new_field.m2m_column_name(),
new_field.m2m_reverse_name(),
]),
', '.join([
"id",
old_field.m2m_column_name(),
old_field.m2m_reverse_name(),
]),
self.quote_name(old_field.remote_field.through._meta.db_table),
))
# Delete the old through table
self.delete_model(old_field.remote_field.through)
def add_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and (
constraint.condition or
constraint.contains_expressions or
constraint.include or
constraint.deferrable
):
super().add_constraint(model, constraint)
else:
self._remake_table(model)
def remove_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and (
constraint.condition or
constraint.contains_expressions or
constraint.include or
constraint.deferrable
):
super().remove_constraint(model, constraint)
else:
self._remake_table(model)
def _collate_sql(self, collation):
return 'COLLATE ' + collation

View File

@@ -0,0 +1,263 @@
import datetime
import decimal
import functools
import hashlib
import logging
import time
from contextlib import contextmanager
from django.db import NotSupportedError
from django.utils.dateparse import parse_time
logger = logging.getLogger('django.db.backends')
class CursorWrapper:
def __init__(self, cursor, db):
self.cursor = cursor
self.db = db
WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
def __getattr__(self, attr):
cursor_attr = getattr(self.cursor, attr)
if attr in CursorWrapper.WRAP_ERROR_ATTRS:
return self.db.wrap_database_errors(cursor_attr)
else:
return cursor_attr
def __iter__(self):
with self.db.wrap_database_errors:
yield from self.cursor
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
# Close instead of passing through to avoid backend-specific behavior
# (#17671). Catch errors liberally because errors in cleanup code
# aren't useful.
try:
self.close()
except self.db.Database.Error:
pass
# The following methods cannot be implemented in __getattr__, because the
# code must run when the method is invoked, not just when it is accessed.
def callproc(self, procname, params=None, kparams=None):
# Keyword parameters for callproc aren't supported in PEP 249, but the
# database driver may support them (e.g. cx_Oracle).
if kparams is not None and not self.db.features.supports_callproc_kwargs:
raise NotSupportedError(
'Keyword parameters for callproc are not supported on this '
'database backend.'
)
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
if params is None and kparams is None:
return self.cursor.callproc(procname)
elif kparams is None:
return self.cursor.callproc(procname, params)
else:
params = params or ()
return self.cursor.callproc(procname, params, kparams)
def execute(self, sql, params=None):
return self._execute_with_wrappers(sql, params, many=False, executor=self._execute)
def executemany(self, sql, param_list):
return self._execute_with_wrappers(sql, param_list, many=True, executor=self._executemany)
def _execute_with_wrappers(self, sql, params, many, executor):
context = {'connection': self.db, 'cursor': self}
for wrapper in reversed(self.db.execute_wrappers):
executor = functools.partial(wrapper, executor)
return executor(sql, params, many, context)
def _execute(self, sql, params, *ignored_wrapper_args):
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
if params is None:
# params default might be backend specific.
return self.cursor.execute(sql)
else:
return self.cursor.execute(sql, params)
def _executemany(self, sql, param_list, *ignored_wrapper_args):
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
return self.cursor.executemany(sql, param_list)
class CursorDebugWrapper(CursorWrapper):
# XXX callproc isn't instrumented at this time.
def execute(self, sql, params=None):
with self.debug_sql(sql, params, use_last_executed_query=True):
return super().execute(sql, params)
def executemany(self, sql, param_list):
with self.debug_sql(sql, param_list, many=True):
return super().executemany(sql, param_list)
@contextmanager
def debug_sql(self, sql=None, params=None, use_last_executed_query=False, many=False):
start = time.monotonic()
try:
yield
finally:
stop = time.monotonic()
duration = stop - start
if use_last_executed_query:
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
try:
times = len(params) if many else ''
except TypeError:
# params could be an iterator.
times = '?'
self.db.queries_log.append({
'sql': '%s times: %s' % (times, sql) if many else sql,
'time': '%.3f' % duration,
})
logger.debug(
'(%.3f) %s; args=%s; alias=%s',
duration,
sql,
params,
self.db.alias,
extra={'duration': duration, 'sql': sql, 'params': params, 'alias': self.db.alias},
)
def split_tzname_delta(tzname):
"""
Split a time zone name into a 3-tuple of (name, sign, offset).
"""
for sign in ['+', '-']:
if sign in tzname:
name, offset = tzname.rsplit(sign, 1)
if offset and parse_time(offset):
return name, sign, offset
return tzname, None, None
###############################################
# Converters from database (string) to Python #
###############################################
def typecast_date(s):
return datetime.date(*map(int, s.split('-'))) if s else None # return None if s is null
def typecast_time(s): # does NOT store time zone information
if not s:
return None
hour, minutes, seconds = s.split(':')
if '.' in seconds: # check whether seconds have a fractional part
seconds, microseconds = seconds.split('.')
else:
microseconds = '0'
return datetime.time(int(hour), int(minutes), int(seconds), int((microseconds + '000000')[:6]))
def typecast_timestamp(s): # does NOT store time zone information
# "2005-07-29 15:48:00.590358-05"
# "2005-07-29 09:56:00-05"
if not s:
return None
if ' ' not in s:
return typecast_date(s)
d, t = s.split()
# Remove timezone information.
if '-' in t:
t, _ = t.split('-', 1)
elif '+' in t:
t, _ = t.split('+', 1)
dates = d.split('-')
times = t.split(':')
seconds = times[2]
if '.' in seconds: # check whether seconds have a fractional part
seconds, microseconds = seconds.split('.')
else:
microseconds = '0'
return datetime.datetime(
int(dates[0]), int(dates[1]), int(dates[2]),
int(times[0]), int(times[1]), int(seconds),
int((microseconds + '000000')[:6])
)
###############################################
# Converters from Python to database (string) #
###############################################
def split_identifier(identifier):
"""
Split an SQL identifier into a two element tuple of (namespace, name).
The identifier could be a table, column, or sequence name might be prefixed
by a namespace.
"""
try:
namespace, name = identifier.split('"."')
except ValueError:
namespace, name = '', identifier
return namespace.strip('"'), name.strip('"')
def truncate_name(identifier, length=None, hash_len=4):
"""
Shorten an SQL identifier to a repeatable mangled version with the given
length.
If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
truncate the table portion only.
"""
namespace, name = split_identifier(identifier)
if length is None or len(name) <= length:
return identifier
digest = names_digest(name, length=hash_len)
return '%s%s%s' % ('%s"."' % namespace if namespace else '', name[:length - hash_len], digest)
def names_digest(*args, length):
"""
Generate a 32-bit digest of a set of arguments that can be used to shorten
identifying names.
"""
h = hashlib.md5()
for arg in args:
h.update(arg.encode())
return h.hexdigest()[:length]
def format_number(value, max_digits, decimal_places):
"""
Format a number into a string with the requisite number of digits and
decimal places.
"""
if value is None:
return None
context = decimal.getcontext().copy()
if max_digits is not None:
context.prec = max_digits
if decimal_places is not None:
value = value.quantize(decimal.Decimal(1).scaleb(-decimal_places), context=context)
else:
context.traps[decimal.Rounded] = 1
value = context.create_decimal(value)
return "{:f}".format(value)
def strip_quotes(table_name):
"""
Strip quotes off of quoted table names to make them safe for use in index
names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
scheme) becomes 'USER"."TABLE'.
"""
has_quotes = table_name.startswith('"') and table_name.endswith('"')
return table_name[1:-1] if has_quotes else table_name

View File

@@ -0,0 +1,2 @@
from .migration import Migration, swappable_dependency # NOQA
from .operations import * # NOQA

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,54 @@
from django.db import DatabaseError
class AmbiguityError(Exception):
"""More than one migration matches a name prefix."""
pass
class BadMigrationError(Exception):
"""There's a bad migration (unreadable/bad format/etc.)."""
pass
class CircularDependencyError(Exception):
"""There's an impossible-to-resolve circular dependency."""
pass
class InconsistentMigrationHistory(Exception):
"""An applied migration has some of its dependencies not applied."""
pass
class InvalidBasesError(ValueError):
"""A model's base classes can't be resolved."""
pass
class IrreversibleError(RuntimeError):
"""An irreversible migration is about to be reversed."""
pass
class NodeNotFoundError(LookupError):
"""An attempt on a node is made that is not available in the graph."""
def __init__(self, message, node, origin=None):
self.message = message
self.origin = origin
self.node = node
def __str__(self):
return self.message
def __repr__(self):
return "NodeNotFoundError(%r)" % (self.node,)
class MigrationSchemaMissing(DatabaseError):
pass
class InvalidMigrationPlan(ValueError):
pass

View File

@@ -0,0 +1,381 @@
from django.apps.registry import apps as global_apps
from django.db import migrations, router
from .exceptions import InvalidMigrationPlan
from .loader import MigrationLoader
from .recorder import MigrationRecorder
from .state import ProjectState
class MigrationExecutor:
"""
End-to-end migration execution - load migrations and run them up or down
to a specified set of targets.
"""
def __init__(self, connection, progress_callback=None):
self.connection = connection
self.loader = MigrationLoader(self.connection)
self.recorder = MigrationRecorder(self.connection)
self.progress_callback = progress_callback
def migration_plan(self, targets, clean_start=False):
"""
Given a set of targets, return a list of (Migration instance, backwards?).
"""
plan = []
if clean_start:
applied = {}
else:
applied = dict(self.loader.applied_migrations)
for target in targets:
# If the target is (app_label, None), that means unmigrate everything
if target[1] is None:
for root in self.loader.graph.root_nodes():
if root[0] == target[0]:
for migration in self.loader.graph.backwards_plan(root):
if migration in applied:
plan.append((self.loader.graph.nodes[migration], True))
applied.pop(migration)
# If the migration is already applied, do backwards mode,
# otherwise do forwards mode.
elif target in applied:
# If the target is missing, it's likely a replaced migration.
# Reload the graph without replacements.
if (
self.loader.replace_migrations and
target not in self.loader.graph.node_map
):
self.loader.replace_migrations = False
self.loader.build_graph()
return self.migration_plan(targets, clean_start=clean_start)
# Don't migrate backwards all the way to the target node (that
# may roll back dependencies in other apps that don't need to
# be rolled back); instead roll back through target's immediate
# child(ren) in the same app, and no further.
next_in_app = sorted(
n for n in
self.loader.graph.node_map[target].children
if n[0] == target[0]
)
for node in next_in_app:
for migration in self.loader.graph.backwards_plan(node):
if migration in applied:
plan.append((self.loader.graph.nodes[migration], True))
applied.pop(migration)
else:
for migration in self.loader.graph.forwards_plan(target):
if migration not in applied:
plan.append((self.loader.graph.nodes[migration], False))
applied[migration] = self.loader.graph.nodes[migration]
return plan
def _create_project_state(self, with_applied_migrations=False):
"""
Create a project state including all the applications without
migrations and applied migrations if with_applied_migrations=True.
"""
state = ProjectState(real_apps=self.loader.unmigrated_apps)
if with_applied_migrations:
# Create the forwards plan Django would follow on an empty database
full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True)
applied_migrations = {
self.loader.graph.nodes[key] for key in self.loader.applied_migrations
if key in self.loader.graph.nodes
}
for migration, _ in full_plan:
if migration in applied_migrations:
migration.mutate_state(state, preserve=False)
return state
def migrate(self, targets, plan=None, state=None, fake=False, fake_initial=False):
"""
Migrate the database up to the given targets.
Django first needs to create all project states before a migration is
(un)applied and in a second step run all the database operations.
"""
# The django_migrations table must be present to record applied
# migrations.
self.recorder.ensure_schema()
if plan is None:
plan = self.migration_plan(targets)
# Create the forwards plan Django would follow on an empty database
full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True)
all_forwards = all(not backwards for mig, backwards in plan)
all_backwards = all(backwards for mig, backwards in plan)
if not plan:
if state is None:
# The resulting state should include applied migrations.
state = self._create_project_state(with_applied_migrations=True)
elif all_forwards == all_backwards:
# This should only happen if there's a mixed plan
raise InvalidMigrationPlan(
"Migration plans with both forwards and backwards migrations "
"are not supported. Please split your migration process into "
"separate plans of only forwards OR backwards migrations.",
plan
)
elif all_forwards:
if state is None:
# The resulting state should still include applied migrations.
state = self._create_project_state(with_applied_migrations=True)
state = self._migrate_all_forwards(state, plan, full_plan, fake=fake, fake_initial=fake_initial)
else:
# No need to check for `elif all_backwards` here, as that condition
# would always evaluate to true.
state = self._migrate_all_backwards(plan, full_plan, fake=fake)
self.check_replacements()
return state
def _migrate_all_forwards(self, state, plan, full_plan, fake, fake_initial):
"""
Take a list of 2-tuples of the form (migration instance, False) and
apply them in the order they occur in the full_plan.
"""
migrations_to_run = {m[0] for m in plan}
for migration, _ in full_plan:
if not migrations_to_run:
# We remove every migration that we applied from these sets so
# that we can bail out once the last migration has been applied
# and don't always run until the very end of the migration
# process.
break
if migration in migrations_to_run:
if 'apps' not in state.__dict__:
if self.progress_callback:
self.progress_callback("render_start")
state.apps # Render all -- performance critical
if self.progress_callback:
self.progress_callback("render_success")
state = self.apply_migration(state, migration, fake=fake, fake_initial=fake_initial)
migrations_to_run.remove(migration)
return state
def _migrate_all_backwards(self, plan, full_plan, fake):
"""
Take a list of 2-tuples of the form (migration instance, True) and
unapply them in reverse order they occur in the full_plan.
Since unapplying a migration requires the project state prior to that
migration, Django will compute the migration states before each of them
in a first run over the plan and then unapply them in a second run over
the plan.
"""
migrations_to_run = {m[0] for m in plan}
# Holds all migration states prior to the migrations being unapplied
states = {}
state = self._create_project_state()
applied_migrations = {
self.loader.graph.nodes[key] for key in self.loader.applied_migrations
if key in self.loader.graph.nodes
}
if self.progress_callback:
self.progress_callback("render_start")
for migration, _ in full_plan:
if not migrations_to_run:
# We remove every migration that we applied from this set so
# that we can bail out once the last migration has been applied
# and don't always run until the very end of the migration
# process.
break
if migration in migrations_to_run:
if 'apps' not in state.__dict__:
state.apps # Render all -- performance critical
# The state before this migration
states[migration] = state
# The old state keeps as-is, we continue with the new state
state = migration.mutate_state(state, preserve=True)
migrations_to_run.remove(migration)
elif migration in applied_migrations:
# Only mutate the state if the migration is actually applied
# to make sure the resulting state doesn't include changes
# from unrelated migrations.
migration.mutate_state(state, preserve=False)
if self.progress_callback:
self.progress_callback("render_success")
for migration, _ in plan:
self.unapply_migration(states[migration], migration, fake=fake)
applied_migrations.remove(migration)
# Generate the post migration state by starting from the state before
# the last migration is unapplied and mutating it to include all the
# remaining applied migrations.
last_unapplied_migration = plan[-1][0]
state = states[last_unapplied_migration]
for index, (migration, _) in enumerate(full_plan):
if migration == last_unapplied_migration:
for migration, _ in full_plan[index:]:
if migration in applied_migrations:
migration.mutate_state(state, preserve=False)
break
return state
def apply_migration(self, state, migration, fake=False, fake_initial=False):
"""Run a migration forwards."""
migration_recorded = False
if self.progress_callback:
self.progress_callback("apply_start", migration, fake)
if not fake:
if fake_initial:
# Test to see if this is an already-applied initial migration
applied, state = self.detect_soft_applied(state, migration)
if applied:
fake = True
if not fake:
# Alright, do it normally
with self.connection.schema_editor(atomic=migration.atomic) as schema_editor:
state = migration.apply(state, schema_editor)
if not schema_editor.deferred_sql:
self.record_migration(migration)
migration_recorded = True
if not migration_recorded:
self.record_migration(migration)
# Report progress
if self.progress_callback:
self.progress_callback("apply_success", migration, fake)
return state
def record_migration(self, migration):
# For replacement migrations, record individual statuses
if migration.replaces:
for app_label, name in migration.replaces:
self.recorder.record_applied(app_label, name)
else:
self.recorder.record_applied(migration.app_label, migration.name)
def unapply_migration(self, state, migration, fake=False):
"""Run a migration backwards."""
if self.progress_callback:
self.progress_callback("unapply_start", migration, fake)
if not fake:
with self.connection.schema_editor(atomic=migration.atomic) as schema_editor:
state = migration.unapply(state, schema_editor)
# For replacement migrations, also record individual statuses.
if migration.replaces:
for app_label, name in migration.replaces:
self.recorder.record_unapplied(app_label, name)
self.recorder.record_unapplied(migration.app_label, migration.name)
# Report progress
if self.progress_callback:
self.progress_callback("unapply_success", migration, fake)
return state
def check_replacements(self):
"""
Mark replacement migrations applied if their replaced set all are.
Do this unconditionally on every migrate, rather than just when
migrations are applied or unapplied, to correctly handle the case
when a new squash migration is pushed to a deployment that already had
all its replaced migrations applied. In this case no new migration will
be applied, but the applied state of the squashed migration must be
maintained.
"""
applied = self.recorder.applied_migrations()
for key, migration in self.loader.replacements.items():
all_applied = all(m in applied for m in migration.replaces)
if all_applied and key not in applied:
self.recorder.record_applied(*key)
def detect_soft_applied(self, project_state, migration):
"""
Test whether a migration has been implicitly applied - that the
tables or columns it would create exist. This is intended only for use
on initial migrations (as it only looks for CreateModel and AddField).
"""
def should_skip_detecting_model(migration, model):
"""
No need to detect tables for proxy models, unmanaged models, or
models that can't be migrated on the current database.
"""
return (
model._meta.proxy or not model._meta.managed or not
router.allow_migrate(
self.connection.alias, migration.app_label,
model_name=model._meta.model_name,
)
)
if migration.initial is None:
# Bail if the migration isn't the first one in its app
if any(app == migration.app_label for app, name in migration.dependencies):
return False, project_state
elif migration.initial is False:
# Bail if it's NOT an initial migration
return False, project_state
if project_state is None:
after_state = self.loader.project_state((migration.app_label, migration.name), at_end=True)
else:
after_state = migration.mutate_state(project_state)
apps = after_state.apps
found_create_model_migration = False
found_add_field_migration = False
fold_identifier_case = self.connection.features.ignores_table_name_case
with self.connection.cursor() as cursor:
existing_table_names = set(self.connection.introspection.table_names(cursor))
if fold_identifier_case:
existing_table_names = {name.casefold() for name in existing_table_names}
# Make sure all create model and add field operations are done
for operation in migration.operations:
if isinstance(operation, migrations.CreateModel):
model = apps.get_model(migration.app_label, operation.name)
if model._meta.swapped:
# We have to fetch the model to test with from the
# main app cache, as it's not a direct dependency.
model = global_apps.get_model(model._meta.swapped)
if should_skip_detecting_model(migration, model):
continue
db_table = model._meta.db_table
if fold_identifier_case:
db_table = db_table.casefold()
if db_table not in existing_table_names:
return False, project_state
found_create_model_migration = True
elif isinstance(operation, migrations.AddField):
model = apps.get_model(migration.app_label, operation.model_name)
if model._meta.swapped:
# We have to fetch the model to test with from the
# main app cache, as it's not a direct dependency.
model = global_apps.get_model(model._meta.swapped)
if should_skip_detecting_model(migration, model):
continue
table = model._meta.db_table
field = model._meta.get_field(operation.name)
# Handle implicit many-to-many tables created by AddField.
if field.many_to_many:
through_db_table = field.remote_field.through._meta.db_table
if fold_identifier_case:
through_db_table = through_db_table.casefold()
if through_db_table not in existing_table_names:
return False, project_state
else:
found_add_field_migration = True
continue
with self.connection.cursor() as cursor:
columns = self.connection.introspection.get_table_description(cursor, table)
for column in columns:
field_column = field.column
column_name = column.name
if fold_identifier_case:
column_name = column_name.casefold()
field_column = field_column.casefold()
if column_name == field_column:
found_add_field_migration = True
break
else:
return False, project_state
# If we get this far and we found at least one CreateModel or AddField migration,
# the migration is considered implicitly applied.
return (found_create_model_migration or found_add_field_migration), after_state

View File

@@ -0,0 +1,319 @@
from functools import total_ordering
from django.db.migrations.state import ProjectState
from .exceptions import CircularDependencyError, NodeNotFoundError
@total_ordering
class Node:
"""
A single node in the migration graph. Contains direct links to adjacent
nodes in either direction.
"""
def __init__(self, key):
self.key = key
self.children = set()
self.parents = set()
def __eq__(self, other):
return self.key == other
def __lt__(self, other):
return self.key < other
def __hash__(self):
return hash(self.key)
def __getitem__(self, item):
return self.key[item]
def __str__(self):
return str(self.key)
def __repr__(self):
return '<%s: (%r, %r)>' % (self.__class__.__name__, self.key[0], self.key[1])
def add_child(self, child):
self.children.add(child)
def add_parent(self, parent):
self.parents.add(parent)
class DummyNode(Node):
"""
A node that doesn't correspond to a migration file on disk.
(A squashed migration that was removed, for example.)
After the migration graph is processed, all dummy nodes should be removed.
If there are any left, a nonexistent dependency error is raised.
"""
def __init__(self, key, origin, error_message):
super().__init__(key)
self.origin = origin
self.error_message = error_message
def raise_error(self):
raise NodeNotFoundError(self.error_message, self.key, origin=self.origin)
class MigrationGraph:
"""
Represent the digraph of all migrations in a project.
Each migration is a node, and each dependency is an edge. There are
no implicit dependencies between numbered migrations - the numbering is
merely a convention to aid file listing. Every new numbered migration
has a declared dependency to the previous number, meaning that VCS
branch merges can be detected and resolved.
Migrations files can be marked as replacing another set of migrations -
this is to support the "squash" feature. The graph handler isn't responsible
for these; instead, the code to load them in here should examine the
migration files and if the replaced migrations are all either unapplied
or not present, it should ignore the replaced ones, load in just the
replacing migration, and repoint any dependencies that pointed to the
replaced migrations to point to the replacing one.
A node should be a tuple: (app_path, migration_name). The tree special-cases
things within an app - namely, root nodes and leaf nodes ignore dependencies
to other apps.
"""
def __init__(self):
self.node_map = {}
self.nodes = {}
def add_node(self, key, migration):
assert key not in self.node_map
node = Node(key)
self.node_map[key] = node
self.nodes[key] = migration
def add_dummy_node(self, key, origin, error_message):
node = DummyNode(key, origin, error_message)
self.node_map[key] = node
self.nodes[key] = None
def add_dependency(self, migration, child, parent, skip_validation=False):
"""
This may create dummy nodes if they don't yet exist. If
`skip_validation=True`, validate_consistency() should be called
afterward.
"""
if child not in self.nodes:
error_message = (
"Migration %s dependencies reference nonexistent"
" child node %r" % (migration, child)
)
self.add_dummy_node(child, migration, error_message)
if parent not in self.nodes:
error_message = (
"Migration %s dependencies reference nonexistent"
" parent node %r" % (migration, parent)
)
self.add_dummy_node(parent, migration, error_message)
self.node_map[child].add_parent(self.node_map[parent])
self.node_map[parent].add_child(self.node_map[child])
if not skip_validation:
self.validate_consistency()
def remove_replaced_nodes(self, replacement, replaced):
"""
Remove each of the `replaced` nodes (when they exist). Any
dependencies that were referencing them are changed to reference the
`replacement` node instead.
"""
# Cast list of replaced keys to set to speed up lookup later.
replaced = set(replaced)
try:
replacement_node = self.node_map[replacement]
except KeyError as err:
raise NodeNotFoundError(
"Unable to find replacement node %r. It was either never added"
" to the migration graph, or has been removed." % (replacement,),
replacement
) from err
for replaced_key in replaced:
self.nodes.pop(replaced_key, None)
replaced_node = self.node_map.pop(replaced_key, None)
if replaced_node:
for child in replaced_node.children:
child.parents.remove(replaced_node)
# We don't want to create dependencies between the replaced
# node and the replacement node as this would lead to
# self-referencing on the replacement node at a later iteration.
if child.key not in replaced:
replacement_node.add_child(child)
child.add_parent(replacement_node)
for parent in replaced_node.parents:
parent.children.remove(replaced_node)
# Again, to avoid self-referencing.
if parent.key not in replaced:
replacement_node.add_parent(parent)
parent.add_child(replacement_node)
def remove_replacement_node(self, replacement, replaced):
"""
The inverse operation to `remove_replaced_nodes`. Almost. Remove the
replacement node `replacement` and remap its child nodes to `replaced`
- the list of nodes it would have replaced. Don't remap its parent
nodes as they are expected to be correct already.
"""
self.nodes.pop(replacement, None)
try:
replacement_node = self.node_map.pop(replacement)
except KeyError as err:
raise NodeNotFoundError(
"Unable to remove replacement node %r. It was either never added"
" to the migration graph, or has been removed already." % (replacement,),
replacement
) from err
replaced_nodes = set()
replaced_nodes_parents = set()
for key in replaced:
replaced_node = self.node_map.get(key)
if replaced_node:
replaced_nodes.add(replaced_node)
replaced_nodes_parents |= replaced_node.parents
# We're only interested in the latest replaced node, so filter out
# replaced nodes that are parents of other replaced nodes.
replaced_nodes -= replaced_nodes_parents
for child in replacement_node.children:
child.parents.remove(replacement_node)
for replaced_node in replaced_nodes:
replaced_node.add_child(child)
child.add_parent(replaced_node)
for parent in replacement_node.parents:
parent.children.remove(replacement_node)
# NOTE: There is no need to remap parent dependencies as we can
# assume the replaced nodes already have the correct ancestry.
def validate_consistency(self):
"""Ensure there are no dummy nodes remaining in the graph."""
[n.raise_error() for n in self.node_map.values() if isinstance(n, DummyNode)]
def forwards_plan(self, target):
"""
Given a node, return a list of which previous nodes (dependencies) must
be applied, ending with the node itself. This is the list you would
follow if applying the migrations to a database.
"""
if target not in self.nodes:
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
return self.iterative_dfs(self.node_map[target])
def backwards_plan(self, target):
"""
Given a node, return a list of which dependent nodes (dependencies)
must be unapplied, ending with the node itself. This is the list you
would follow if removing the migrations from a database.
"""
if target not in self.nodes:
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
return self.iterative_dfs(self.node_map[target], forwards=False)
def iterative_dfs(self, start, forwards=True):
"""Iterative depth-first search for finding dependencies."""
visited = []
visited_set = set()
stack = [(start, False)]
while stack:
node, processed = stack.pop()
if node in visited_set:
pass
elif processed:
visited_set.add(node)
visited.append(node.key)
else:
stack.append((node, True))
stack += [(n, False) for n in sorted(node.parents if forwards else node.children)]
return visited
def root_nodes(self, app=None):
"""
Return all root nodes - that is, nodes with no dependencies inside
their app. These are the starting point for an app.
"""
roots = set()
for node in self.nodes:
if all(key[0] != node[0] for key in self.node_map[node].parents) and (not app or app == node[0]):
roots.add(node)
return sorted(roots)
def leaf_nodes(self, app=None):
"""
Return all leaf nodes - that is, nodes with no dependents in their app.
These are the "most current" version of an app's schema.
Having more than one per app is technically an error, but one that
gets handled further up, in the interactive command - it's usually the
result of a VCS merge and needs some user input.
"""
leaves = set()
for node in self.nodes:
if all(key[0] != node[0] for key in self.node_map[node].children) and (not app or app == node[0]):
leaves.add(node)
return sorted(leaves)
def ensure_not_cyclic(self):
# Algo from GvR:
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
todo = set(self.nodes)
while todo:
node = todo.pop()
stack = [node]
while stack:
top = stack[-1]
for child in self.node_map[top].children:
# Use child.key instead of child to speed up the frequent
# hashing.
node = child.key
if node in stack:
cycle = stack[stack.index(node):]
raise CircularDependencyError(", ".join("%s.%s" % n for n in cycle))
if node in todo:
stack.append(node)
todo.remove(node)
break
else:
node = stack.pop()
def __str__(self):
return 'Graph: %s nodes, %s edges' % self._nodes_and_edges()
def __repr__(self):
nodes, edges = self._nodes_and_edges()
return '<%s: nodes=%s, edges=%s>' % (self.__class__.__name__, nodes, edges)
def _nodes_and_edges(self):
return len(self.nodes), sum(len(node.parents) for node in self.node_map.values())
def _generate_plan(self, nodes, at_end):
plan = []
for node in nodes:
for migration in self.forwards_plan(node):
if migration not in plan and (at_end or migration not in nodes):
plan.append(migration)
return plan
def make_state(self, nodes=None, at_end=True, real_apps=None):
"""
Given a migration node or nodes, return a complete ProjectState for it.
If at_end is False, return the state before the migration has run.
If nodes is not provided, return the overall most current project state.
"""
if nodes is None:
nodes = list(self.leaf_nodes())
if not nodes:
return ProjectState()
if not isinstance(nodes[0], tuple):
nodes = [nodes]
plan = self._generate_plan(nodes, at_end)
project_state = ProjectState(real_apps=real_apps)
for node in plan:
project_state = self.nodes[node].mutate_state(project_state, preserve=False)
return project_state
def __contains__(self, node):
return node in self.nodes

View File

@@ -0,0 +1,356 @@
import pkgutil
import sys
from importlib import import_module, reload
from django.apps import apps
from django.conf import settings
from django.db.migrations.graph import MigrationGraph
from django.db.migrations.recorder import MigrationRecorder
from .exceptions import (
AmbiguityError, BadMigrationError, InconsistentMigrationHistory,
NodeNotFoundError,
)
MIGRATIONS_MODULE_NAME = 'migrations'
class MigrationLoader:
"""
Load migration files from disk and their status from the database.
Migration files are expected to live in the "migrations" directory of
an app. Their names are entirely unimportant from a code perspective,
but will probably follow the 1234_name.py convention.
On initialization, this class will scan those directories, and open and
read the Python files, looking for a class called Migration, which should
inherit from django.db.migrations.Migration. See
django.db.migrations.migration for what that looks like.
Some migrations will be marked as "replacing" another set of migrations.
These are loaded into a separate set of migrations away from the main ones.
If all the migrations they replace are either unapplied or missing from
disk, then they are injected into the main set, replacing the named migrations.
Any dependency pointers to the replaced migrations are re-pointed to the
new migration.
This does mean that this class MUST also talk to the database as well as
to disk, but this is probably fine. We're already not just operating
in memory.
"""
def __init__(
self, connection, load=True, ignore_no_migrations=False,
replace_migrations=True,
):
self.connection = connection
self.disk_migrations = None
self.applied_migrations = None
self.ignore_no_migrations = ignore_no_migrations
self.replace_migrations = replace_migrations
if load:
self.build_graph()
@classmethod
def migrations_module(cls, app_label):
"""
Return the path to the migrations module for the specified app_label
and a boolean indicating if the module is specified in
settings.MIGRATION_MODULE.
"""
if app_label in settings.MIGRATION_MODULES:
return settings.MIGRATION_MODULES[app_label], True
else:
app_package_name = apps.get_app_config(app_label).name
return '%s.%s' % (app_package_name, MIGRATIONS_MODULE_NAME), False
def load_disk(self):
"""Load the migrations from all INSTALLED_APPS from disk."""
self.disk_migrations = {}
self.unmigrated_apps = set()
self.migrated_apps = set()
for app_config in apps.get_app_configs():
# Get the migrations module directory
module_name, explicit = self.migrations_module(app_config.label)
if module_name is None:
self.unmigrated_apps.add(app_config.label)
continue
was_loaded = module_name in sys.modules
try:
module = import_module(module_name)
except ModuleNotFoundError as e:
if (
(explicit and self.ignore_no_migrations) or
(not explicit and MIGRATIONS_MODULE_NAME in e.name.split('.'))
):
self.unmigrated_apps.add(app_config.label)
continue
raise
else:
# Module is not a package (e.g. migrations.py).
if not hasattr(module, '__path__'):
self.unmigrated_apps.add(app_config.label)
continue
# Empty directories are namespaces. Namespace packages have no
# __file__ and don't use a list for __path__. See
# https://docs.python.org/3/reference/import.html#namespace-packages
if (
getattr(module, '__file__', None) is None and
not isinstance(module.__path__, list)
):
self.unmigrated_apps.add(app_config.label)
continue
# Force a reload if it's already loaded (tests need this)
if was_loaded:
reload(module)
self.migrated_apps.add(app_config.label)
migration_names = {
name for _, name, is_pkg in pkgutil.iter_modules(module.__path__)
if not is_pkg and name[0] not in '_~'
}
# Load migrations
for migration_name in migration_names:
migration_path = '%s.%s' % (module_name, migration_name)
try:
migration_module = import_module(migration_path)
except ImportError as e:
if 'bad magic number' in str(e):
raise ImportError(
"Couldn't import %r as it appears to be a stale "
".pyc file." % migration_path
) from e
else:
raise
if not hasattr(migration_module, "Migration"):
raise BadMigrationError(
"Migration %s in app %s has no Migration class" % (migration_name, app_config.label)
)
self.disk_migrations[app_config.label, migration_name] = migration_module.Migration(
migration_name,
app_config.label,
)
def get_migration(self, app_label, name_prefix):
"""Return the named migration or raise NodeNotFoundError."""
return self.graph.nodes[app_label, name_prefix]
def get_migration_by_prefix(self, app_label, name_prefix):
"""
Return the migration(s) which match the given app label and name_prefix.
"""
# Do the search
results = []
for migration_app_label, migration_name in self.disk_migrations:
if migration_app_label == app_label and migration_name.startswith(name_prefix):
results.append((migration_app_label, migration_name))
if len(results) > 1:
raise AmbiguityError(
"There is more than one migration for '%s' with the prefix '%s'" % (app_label, name_prefix)
)
elif not results:
raise KeyError(
f"There is no migration for '{app_label}' with the prefix "
f"'{name_prefix}'"
)
else:
return self.disk_migrations[results[0]]
def check_key(self, key, current_app):
if (key[1] != "__first__" and key[1] != "__latest__") or key in self.graph:
return key
# Special-case __first__, which means "the first migration" for
# migrated apps, and is ignored for unmigrated apps. It allows
# makemigrations to declare dependencies on apps before they even have
# migrations.
if key[0] == current_app:
# Ignore __first__ references to the same app (#22325)
return
if key[0] in self.unmigrated_apps:
# This app isn't migrated, but something depends on it.
# The models will get auto-added into the state, though
# so we're fine.
return
if key[0] in self.migrated_apps:
try:
if key[1] == "__first__":
return self.graph.root_nodes(key[0])[0]
else: # "__latest__"
return self.graph.leaf_nodes(key[0])[0]
except IndexError:
if self.ignore_no_migrations:
return None
else:
raise ValueError("Dependency on app with no migrations: %s" % key[0])
raise ValueError("Dependency on unknown app: %s" % key[0])
def add_internal_dependencies(self, key, migration):
"""
Internal dependencies need to be added first to ensure `__first__`
dependencies find the correct root node.
"""
for parent in migration.dependencies:
# Ignore __first__ references to the same app.
if parent[0] == key[0] and parent[1] != '__first__':
self.graph.add_dependency(migration, key, parent, skip_validation=True)
def add_external_dependencies(self, key, migration):
for parent in migration.dependencies:
# Skip internal dependencies
if key[0] == parent[0]:
continue
parent = self.check_key(parent, key[0])
if parent is not None:
self.graph.add_dependency(migration, key, parent, skip_validation=True)
for child in migration.run_before:
child = self.check_key(child, key[0])
if child is not None:
self.graph.add_dependency(migration, child, key, skip_validation=True)
def build_graph(self):
"""
Build a migration dependency graph using both the disk and database.
You'll need to rebuild the graph if you apply migrations. This isn't
usually a problem as generally migration stuff runs in a one-shot process.
"""
# Load disk data
self.load_disk()
# Load database data
if self.connection is None:
self.applied_migrations = {}
else:
recorder = MigrationRecorder(self.connection)
self.applied_migrations = recorder.applied_migrations()
# To start, populate the migration graph with nodes for ALL migrations
# and their dependencies. Also make note of replacing migrations at this step.
self.graph = MigrationGraph()
self.replacements = {}
for key, migration in self.disk_migrations.items():
self.graph.add_node(key, migration)
# Replacing migrations.
if migration.replaces:
self.replacements[key] = migration
for key, migration in self.disk_migrations.items():
# Internal (same app) dependencies.
self.add_internal_dependencies(key, migration)
# Add external dependencies now that the internal ones have been resolved.
for key, migration in self.disk_migrations.items():
self.add_external_dependencies(key, migration)
# Carry out replacements where possible and if enabled.
if self.replace_migrations:
for key, migration in self.replacements.items():
# Get applied status of each of this migration's replacement
# targets.
applied_statuses = [(target in self.applied_migrations) for target in migration.replaces]
# The replacing migration is only marked as applied if all of
# its replacement targets are.
if all(applied_statuses):
self.applied_migrations[key] = migration
else:
self.applied_migrations.pop(key, None)
# A replacing migration can be used if either all or none of
# its replacement targets have been applied.
if all(applied_statuses) or (not any(applied_statuses)):
self.graph.remove_replaced_nodes(key, migration.replaces)
else:
# This replacing migration cannot be used because it is
# partially applied. Remove it from the graph and remap
# dependencies to it (#25945).
self.graph.remove_replacement_node(key, migration.replaces)
# Ensure the graph is consistent.
try:
self.graph.validate_consistency()
except NodeNotFoundError as exc:
# Check if the missing node could have been replaced by any squash
# migration but wasn't because the squash migration was partially
# applied before. In that case raise a more understandable exception
# (#23556).
# Get reverse replacements.
reverse_replacements = {}
for key, migration in self.replacements.items():
for replaced in migration.replaces:
reverse_replacements.setdefault(replaced, set()).add(key)
# Try to reraise exception with more detail.
if exc.node in reverse_replacements:
candidates = reverse_replacements.get(exc.node, set())
is_replaced = any(candidate in self.graph.nodes for candidate in candidates)
if not is_replaced:
tries = ', '.join('%s.%s' % c for c in candidates)
raise NodeNotFoundError(
"Migration {0} depends on nonexistent node ('{1}', '{2}'). "
"Django tried to replace migration {1}.{2} with any of [{3}] "
"but wasn't able to because some of the replaced migrations "
"are already applied.".format(
exc.origin, exc.node[0], exc.node[1], tries
),
exc.node
) from exc
raise
self.graph.ensure_not_cyclic()
def check_consistent_history(self, connection):
"""
Raise InconsistentMigrationHistory if any applied migrations have
unapplied dependencies.
"""
recorder = MigrationRecorder(connection)
applied = recorder.applied_migrations()
for migration in applied:
# If the migration is unknown, skip it.
if migration not in self.graph.nodes:
continue
for parent in self.graph.node_map[migration].parents:
if parent not in applied:
# Skip unapplied squashed migrations that have all of their
# `replaces` applied.
if parent in self.replacements:
if all(m in applied for m in self.replacements[parent].replaces):
continue
raise InconsistentMigrationHistory(
"Migration {}.{} is applied before its dependency "
"{}.{} on database '{}'.".format(
migration[0], migration[1], parent[0], parent[1],
connection.alias,
)
)
def detect_conflicts(self):
"""
Look through the loaded graph and detect any conflicts - apps
with more than one leaf migration. Return a dict of the app labels
that conflict with the migration names that conflict.
"""
seen_apps = {}
conflicting_apps = set()
for app_label, migration_name in self.graph.leaf_nodes():
if app_label in seen_apps:
conflicting_apps.add(app_label)
seen_apps.setdefault(app_label, set()).add(migration_name)
return {app_label: sorted(seen_apps[app_label]) for app_label in conflicting_apps}
def project_state(self, nodes=None, at_end=True):
"""
Return a ProjectState object representing the most recent state
that the loaded migrations represent.
See graph.make_state() for the meaning of "nodes" and "at_end".
"""
return self.graph.make_state(nodes=nodes, at_end=at_end, real_apps=self.unmigrated_apps)
def collect_sql(self, plan):
"""
Take a migration plan and return a list of collected SQL statements
that represent the best-efforts version of that plan.
"""
statements = []
state = None
for migration, backwards in plan:
with self.connection.schema_editor(collect_sql=True, atomic=migration.atomic) as schema_editor:
if state is None:
state = self.project_state((migration.app_label, migration.name), at_end=False)
if not backwards:
state = migration.apply(state, schema_editor, collect_sql=True)
else:
state = migration.unapply(state, schema_editor, collect_sql=True)
statements.extend(schema_editor.collected_sql)
return statements

View File

@@ -0,0 +1,218 @@
from django.db.migrations.utils import get_migration_name_timestamp
from django.db.transaction import atomic
from .exceptions import IrreversibleError
class Migration:
"""
The base class for all migrations.
Migration files will import this from django.db.migrations.Migration
and subclass it as a class called Migration. It will have one or more
of the following attributes:
- operations: A list of Operation instances, probably from django.db.migrations.operations
- dependencies: A list of tuples of (app_path, migration_name)
- run_before: A list of tuples of (app_path, migration_name)
- replaces: A list of migration_names
Note that all migrations come out of migrations and into the Loader or
Graph as instances, having been initialized with their app label and name.
"""
# Operations to apply during this migration, in order.
operations = []
# Other migrations that should be run before this migration.
# Should be a list of (app, migration_name).
dependencies = []
# Other migrations that should be run after this one (i.e. have
# this migration added to their dependencies). Useful to make third-party
# apps' migrations run after your AUTH_USER replacement, for example.
run_before = []
# Migration names in this app that this migration replaces. If this is
# non-empty, this migration will only be applied if all these migrations
# are not applied.
replaces = []
# Is this an initial migration? Initial migrations are skipped on
# --fake-initial if the table or fields already exist. If None, check if
# the migration has any dependencies to determine if there are dependencies
# to tell if db introspection needs to be done. If True, always perform
# introspection. If False, never perform introspection.
initial = None
# Whether to wrap the whole migration in a transaction. Only has an effect
# on database backends which support transactional DDL.
atomic = True
def __init__(self, name, app_label):
self.name = name
self.app_label = app_label
# Copy dependencies & other attrs as we might mutate them at runtime
self.operations = list(self.__class__.operations)
self.dependencies = list(self.__class__.dependencies)
self.run_before = list(self.__class__.run_before)
self.replaces = list(self.__class__.replaces)
def __eq__(self, other):
return (
isinstance(other, Migration) and
self.name == other.name and
self.app_label == other.app_label
)
def __repr__(self):
return "<Migration %s.%s>" % (self.app_label, self.name)
def __str__(self):
return "%s.%s" % (self.app_label, self.name)
def __hash__(self):
return hash("%s.%s" % (self.app_label, self.name))
def mutate_state(self, project_state, preserve=True):
"""
Take a ProjectState and return a new one with the migration's
operations applied to it. Preserve the original object state by
default and return a mutated state from a copy.
"""
new_state = project_state
if preserve:
new_state = project_state.clone()
for operation in self.operations:
operation.state_forwards(self.app_label, new_state)
return new_state
def apply(self, project_state, schema_editor, collect_sql=False):
"""
Take a project_state representing all migrations prior to this one
and a schema_editor for a live database and apply the migration
in a forwards order.
Return the resulting project state for efficient reuse by following
Migrations.
"""
for operation in self.operations:
# If this operation cannot be represented as SQL, place a comment
# there instead
if collect_sql:
schema_editor.collected_sql.append("--")
if not operation.reduces_to_sql:
schema_editor.collected_sql.append(
"-- MIGRATION NOW PERFORMS OPERATION THAT CANNOT BE WRITTEN AS SQL:"
)
schema_editor.collected_sql.append("-- %s" % operation.describe())
schema_editor.collected_sql.append("--")
if not operation.reduces_to_sql:
continue
# Save the state before the operation has run
old_state = project_state.clone()
operation.state_forwards(self.app_label, project_state)
# Run the operation
atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False)
if not schema_editor.atomic_migration and atomic_operation:
# Force a transaction on a non-transactional-DDL backend or an
# atomic operation inside a non-atomic migration.
with atomic(schema_editor.connection.alias):
operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
else:
# Normal behaviour
operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
return project_state
def unapply(self, project_state, schema_editor, collect_sql=False):
"""
Take a project_state representing all migrations prior to this one
and a schema_editor for a live database and apply the migration
in a reverse order.
The backwards migration process consists of two phases:
1. The intermediate states from right before the first until right
after the last operation inside this migration are preserved.
2. The operations are applied in reverse order using the states
recorded in step 1.
"""
# Construct all the intermediate states we need for a reverse migration
to_run = []
new_state = project_state
# Phase 1
for operation in self.operations:
# If it's irreversible, error out
if not operation.reversible:
raise IrreversibleError("Operation %s in %s is not reversible" % (operation, self))
# Preserve new state from previous run to not tamper the same state
# over all operations
new_state = new_state.clone()
old_state = new_state.clone()
operation.state_forwards(self.app_label, new_state)
to_run.insert(0, (operation, old_state, new_state))
# Phase 2
for operation, to_state, from_state in to_run:
if collect_sql:
schema_editor.collected_sql.append("--")
if not operation.reduces_to_sql:
schema_editor.collected_sql.append(
"-- MIGRATION NOW PERFORMS OPERATION THAT CANNOT BE WRITTEN AS SQL:"
)
schema_editor.collected_sql.append("-- %s" % operation.describe())
schema_editor.collected_sql.append("--")
if not operation.reduces_to_sql:
continue
atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False)
if not schema_editor.atomic_migration and atomic_operation:
# Force a transaction on a non-transactional-DDL backend or an
# atomic operation inside a non-atomic migration.
with atomic(schema_editor.connection.alias):
operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
else:
# Normal behaviour
operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
return project_state
def suggest_name(self):
"""
Suggest a name for the operations this migration might represent. Names
are not guaranteed to be unique, but put some effort into the fallback
name to avoid VCS conflicts if possible.
"""
if self.initial:
return 'initial'
raw_fragments = [op.migration_name_fragment for op in self.operations]
fragments = [name for name in raw_fragments if name]
if not fragments or len(fragments) != len(self.operations):
return 'auto_%s' % get_migration_name_timestamp()
name = fragments[0]
for fragment in fragments[1:]:
new_name = f'{name}_{fragment}'
if len(new_name) > 52:
name = f'{name}_and_more'
break
name = new_name
return name
class SwappableTuple(tuple):
"""
Subclass of tuple so Django can tell this was originally a swappable
dependency when it reads the migration file.
"""
def __new__(cls, value, setting):
self = tuple.__new__(cls, value)
self.setting = setting
return self
def swappable_dependency(value):
"""Turn a setting value into a dependency."""
return SwappableTuple((value.split(".", 1)[0], "__first__"), value)

View File

@@ -0,0 +1,17 @@
from .fields import AddField, AlterField, RemoveField, RenameField
from .models import (
AddConstraint, AddIndex, AlterIndexTogether, AlterModelManagers,
AlterModelOptions, AlterModelTable, AlterOrderWithRespectTo,
AlterUniqueTogether, CreateModel, DeleteModel, RemoveConstraint,
RemoveIndex, RenameModel,
)
from .special import RunPython, RunSQL, SeparateDatabaseAndState
__all__ = [
'CreateModel', 'DeleteModel', 'AlterModelTable', 'AlterUniqueTogether',
'RenameModel', 'AlterIndexTogether', 'AlterModelOptions', 'AddIndex',
'RemoveIndex', 'AddField', 'RemoveField', 'AlterField', 'RenameField',
'AddConstraint', 'RemoveConstraint',
'SeparateDatabaseAndState', 'RunSQL', 'RunPython',
'AlterOrderWithRespectTo', 'AlterModelManagers',
]

View File

@@ -0,0 +1,140 @@
from django.db import router
class Operation:
"""
Base class for migration operations.
It's responsible for both mutating the in-memory model state
(see db/migrations/state.py) to represent what it performs, as well
as actually performing it against a live database.
Note that some operations won't modify memory state at all (e.g. data
copying operations), and some will need their modifications to be
optionally specified by the user (e.g. custom Python code snippets)
Due to the way this class deals with deconstruction, it should be
considered immutable.
"""
# If this migration can be run in reverse.
# Some operations are impossible to reverse, like deleting data.
reversible = True
# Can this migration be represented as SQL? (things like RunPython cannot)
reduces_to_sql = True
# Should this operation be forced as atomic even on backends with no
# DDL transaction support (i.e., does it have no DDL, like RunPython)
atomic = False
# Should this operation be considered safe to elide and optimize across?
elidable = False
serialization_expand_args = []
def __new__(cls, *args, **kwargs):
# We capture the arguments to make returning them trivial
self = object.__new__(cls)
self._constructor_args = (args, kwargs)
return self
def deconstruct(self):
"""
Return a 3-tuple of class import path (or just name if it lives
under django.db.migrations), positional arguments, and keyword
arguments.
"""
return (
self.__class__.__name__,
self._constructor_args[0],
self._constructor_args[1],
)
def state_forwards(self, app_label, state):
"""
Take the state from the previous migration, and mutate it
so that it matches what this migration would perform.
"""
raise NotImplementedError('subclasses of Operation must provide a state_forwards() method')
def database_forwards(self, app_label, schema_editor, from_state, to_state):
"""
Perform the mutation on the database schema in the normal
(forwards) direction.
"""
raise NotImplementedError('subclasses of Operation must provide a database_forwards() method')
def database_backwards(self, app_label, schema_editor, from_state, to_state):
"""
Perform the mutation on the database schema in the reverse
direction - e.g. if this were CreateModel, it would in fact
drop the model's table.
"""
raise NotImplementedError('subclasses of Operation must provide a database_backwards() method')
def describe(self):
"""
Output a brief summary of what the action does.
"""
return "%s: %s" % (self.__class__.__name__, self._constructor_args)
@property
def migration_name_fragment(self):
"""
A filename part suitable for automatically naming a migration
containing this operation, or None if not applicable.
"""
return None
def references_model(self, name, app_label):
"""
Return True if there is a chance this operation references the given
model name (as a string), with an app label for accuracy.
Used for optimization. If in doubt, return True;
returning a false positive will merely make the optimizer a little
less efficient, while returning a false negative may result in an
unusable optimized migration.
"""
return True
def references_field(self, model_name, name, app_label):
"""
Return True if there is a chance this operation references the given
field name, with an app label for accuracy.
Used for optimization. If in doubt, return True.
"""
return self.references_model(model_name, app_label)
def allow_migrate_model(self, connection_alias, model):
"""
Return whether or not a model may be migrated.
This is a thin wrapper around router.allow_migrate_model() that
preemptively rejects any proxy, swapped out, or unmanaged model.
"""
if not model._meta.can_migrate(connection_alias):
return False
return router.allow_migrate_model(connection_alias, model)
def reduce(self, operation, app_label):
"""
Return either a list of operations the actual operation should be
replaced with or a boolean that indicates whether or not the specified
operation can be optimized across.
"""
if self.elidable:
return [operation]
elif operation.elidable:
return [self]
return False
def __repr__(self):
return "<%s %s%s>" % (
self.__class__.__name__,
", ".join(map(repr, self._constructor_args[0])),
",".join(" %s=%r" % x for x in self._constructor_args[1].items()),
)

View File

@@ -0,0 +1,341 @@
from django.db.migrations.utils import field_references
from django.db.models import NOT_PROVIDED
from django.utils.functional import cached_property
from .base import Operation
class FieldOperation(Operation):
def __init__(self, model_name, name, field=None):
self.model_name = model_name
self.name = name
self.field = field
@cached_property
def model_name_lower(self):
return self.model_name.lower()
@cached_property
def name_lower(self):
return self.name.lower()
def is_same_model_operation(self, operation):
return self.model_name_lower == operation.model_name_lower
def is_same_field_operation(self, operation):
return self.is_same_model_operation(operation) and self.name_lower == operation.name_lower
def references_model(self, name, app_label):
name_lower = name.lower()
if name_lower == self.model_name_lower:
return True
if self.field:
return bool(field_references(
(app_label, self.model_name_lower), self.field, (app_label, name_lower)
))
return False
def references_field(self, model_name, name, app_label):
model_name_lower = model_name.lower()
# Check if this operation locally references the field.
if model_name_lower == self.model_name_lower:
if name == self.name:
return True
elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields:
return True
# Check if this operation remotely references the field.
if self.field is None:
return False
return bool(field_references(
(app_label, self.model_name_lower),
self.field,
(app_label, model_name_lower),
name,
))
def reduce(self, operation, app_label):
return (
super().reduce(operation, app_label) or
not operation.references_field(self.model_name, self.name, app_label)
)
class AddField(FieldOperation):
"""Add a field to a model."""
def __init__(self, model_name, name, field, preserve_default=True):
self.preserve_default = preserve_default
super().__init__(model_name, name, field)
def deconstruct(self):
kwargs = {
'model_name': self.model_name,
'name': self.name,
'field': self.field,
}
if self.preserve_default is not True:
kwargs['preserve_default'] = self.preserve_default
return (
self.__class__.__name__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.add_field(
app_label,
self.model_name_lower,
self.name,
self.field,
self.preserve_default,
)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.model_name)
field = to_model._meta.get_field(self.name)
if not self.preserve_default:
field.default = self.field.default
schema_editor.add_field(
from_model,
field,
)
if not self.preserve_default:
field.default = NOT_PROVIDED
def database_backwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
def describe(self):
return "Add field %s to %s" % (self.name, self.model_name)
@property
def migration_name_fragment(self):
return '%s_%s' % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
if isinstance(operation, FieldOperation) and self.is_same_field_operation(operation):
if isinstance(operation, AlterField):
return [
AddField(
model_name=self.model_name,
name=operation.name,
field=operation.field,
),
]
elif isinstance(operation, RemoveField):
return []
elif isinstance(operation, RenameField):
return [
AddField(
model_name=self.model_name,
name=operation.new_name,
field=self.field,
),
]
return super().reduce(operation, app_label)
class RemoveField(FieldOperation):
"""Remove a field from a model."""
def deconstruct(self):
kwargs = {
'model_name': self.model_name,
'name': self.name,
}
return (
self.__class__.__name__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.remove_field(app_label, self.model_name_lower, self.name)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
def database_backwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.model_name)
schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
def describe(self):
return "Remove field %s from %s" % (self.name, self.model_name)
@property
def migration_name_fragment(self):
return 'remove_%s_%s' % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
from .models import DeleteModel
if isinstance(operation, DeleteModel) and operation.name_lower == self.model_name_lower:
return [operation]
return super().reduce(operation, app_label)
class AlterField(FieldOperation):
"""
Alter a field's database column (e.g. null, max_length) to the provided
new field.
"""
def __init__(self, model_name, name, field, preserve_default=True):
self.preserve_default = preserve_default
super().__init__(model_name, name, field)
def deconstruct(self):
kwargs = {
'model_name': self.model_name,
'name': self.name,
'field': self.field,
}
if self.preserve_default is not True:
kwargs['preserve_default'] = self.preserve_default
return (
self.__class__.__name__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.alter_field(
app_label,
self.model_name_lower,
self.name,
self.field,
self.preserve_default,
)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.model_name)
from_field = from_model._meta.get_field(self.name)
to_field = to_model._meta.get_field(self.name)
if not self.preserve_default:
to_field.default = self.field.default
schema_editor.alter_field(from_model, from_field, to_field)
if not self.preserve_default:
to_field.default = NOT_PROVIDED
def database_backwards(self, app_label, schema_editor, from_state, to_state):
self.database_forwards(app_label, schema_editor, from_state, to_state)
def describe(self):
return "Alter field %s on %s" % (self.name, self.model_name)
@property
def migration_name_fragment(self):
return 'alter_%s_%s' % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
if isinstance(operation, RemoveField) and self.is_same_field_operation(operation):
return [operation]
elif isinstance(operation, RenameField) and self.is_same_field_operation(operation):
return [
operation,
AlterField(
model_name=self.model_name,
name=operation.new_name,
field=self.field,
),
]
return super().reduce(operation, app_label)
class RenameField(FieldOperation):
"""Rename a field on the model. Might affect db_column too."""
def __init__(self, model_name, old_name, new_name):
self.old_name = old_name
self.new_name = new_name
super().__init__(model_name, old_name)
@cached_property
def old_name_lower(self):
return self.old_name.lower()
@cached_property
def new_name_lower(self):
return self.new_name.lower()
def deconstruct(self):
kwargs = {
'model_name': self.model_name,
'old_name': self.old_name,
'new_name': self.new_name,
}
return (
self.__class__.__name__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.rename_field(app_label, self.model_name_lower, self.old_name, self.new_name)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.model_name)
schema_editor.alter_field(
from_model,
from_model._meta.get_field(self.old_name),
to_model._meta.get_field(self.new_name),
)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.model_name)
schema_editor.alter_field(
from_model,
from_model._meta.get_field(self.new_name),
to_model._meta.get_field(self.old_name),
)
def describe(self):
return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name)
@property
def migration_name_fragment(self):
return 'rename_%s_%s_%s' % (
self.old_name_lower,
self.model_name_lower,
self.new_name_lower,
)
def references_field(self, model_name, name, app_label):
return self.references_model(model_name, app_label) and (
name.lower() == self.old_name_lower or
name.lower() == self.new_name_lower
)
def reduce(self, operation, app_label):
if (isinstance(operation, RenameField) and
self.is_same_model_operation(operation) and
self.new_name_lower == operation.old_name_lower):
return [
RenameField(
self.model_name,
self.old_name,
operation.new_name,
),
]
# Skip `FieldOperation.reduce` as we want to run `references_field`
# against self.old_name and self.new_name.
return (
super(FieldOperation, self).reduce(operation, app_label) or
not (
operation.references_field(self.model_name, self.old_name, app_label) or
operation.references_field(self.model_name, self.new_name, app_label)
)
)

View File

@@ -0,0 +1,884 @@
from django.db import models
from django.db.migrations.operations.base import Operation
from django.db.migrations.state import ModelState
from django.db.migrations.utils import field_references, resolve_relation
from django.db.models.options import normalize_together
from django.utils.functional import cached_property
from .fields import (
AddField, AlterField, FieldOperation, RemoveField, RenameField,
)
def _check_for_duplicates(arg_name, objs):
used_vals = set()
for val in objs:
if val in used_vals:
raise ValueError(
"Found duplicate value %s in CreateModel %s argument." % (val, arg_name)
)
used_vals.add(val)
class ModelOperation(Operation):
def __init__(self, name):
self.name = name
@cached_property
def name_lower(self):
return self.name.lower()
def references_model(self, name, app_label):
return name.lower() == self.name_lower
def reduce(self, operation, app_label):
return (
super().reduce(operation, app_label) or
not operation.references_model(self.name, app_label)
)
class CreateModel(ModelOperation):
"""Create a model's table."""
serialization_expand_args = ['fields', 'options', 'managers']
def __init__(self, name, fields, options=None, bases=None, managers=None):
self.fields = fields
self.options = options or {}
self.bases = bases or (models.Model,)
self.managers = managers or []
super().__init__(name)
# Sanity-check that there are no duplicated field names, bases, or
# manager names
_check_for_duplicates('fields', (name for name, _ in self.fields))
_check_for_duplicates('bases', (
base._meta.label_lower if hasattr(base, '_meta') else
base.lower() if isinstance(base, str) else base
for base in self.bases
))
_check_for_duplicates('managers', (name for name, _ in self.managers))
def deconstruct(self):
kwargs = {
'name': self.name,
'fields': self.fields,
}
if self.options:
kwargs['options'] = self.options
if self.bases and self.bases != (models.Model,):
kwargs['bases'] = self.bases
if self.managers and self.managers != [('objects', models.Manager())]:
kwargs['managers'] = self.managers
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.add_model(ModelState(
app_label,
self.name,
list(self.fields),
dict(self.options),
tuple(self.bases),
list(self.managers),
))
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.create_model(model)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = from_state.apps.get_model(app_label, self.name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.delete_model(model)
def describe(self):
return "Create %smodel %s" % ("proxy " if self.options.get("proxy", False) else "", self.name)
@property
def migration_name_fragment(self):
return self.name_lower
def references_model(self, name, app_label):
name_lower = name.lower()
if name_lower == self.name_lower:
return True
# Check we didn't inherit from the model
reference_model_tuple = (app_label, name_lower)
for base in self.bases:
if (base is not models.Model and isinstance(base, (models.base.ModelBase, str)) and
resolve_relation(base, app_label) == reference_model_tuple):
return True
# Check we have no FKs/M2Ms with it
for _name, field in self.fields:
if field_references((app_label, self.name_lower), field, reference_model_tuple):
return True
return False
def reduce(self, operation, app_label):
if (isinstance(operation, DeleteModel) and
self.name_lower == operation.name_lower and
not self.options.get("proxy", False)):
return []
elif isinstance(operation, RenameModel) and self.name_lower == operation.old_name_lower:
return [
CreateModel(
operation.new_name,
fields=self.fields,
options=self.options,
bases=self.bases,
managers=self.managers,
),
]
elif isinstance(operation, AlterModelOptions) and self.name_lower == operation.name_lower:
options = {**self.options, **operation.options}
for key in operation.ALTER_OPTION_KEYS:
if key not in operation.options:
options.pop(key, None)
return [
CreateModel(
self.name,
fields=self.fields,
options=options,
bases=self.bases,
managers=self.managers,
),
]
elif isinstance(operation, AlterTogetherOptionOperation) and self.name_lower == operation.name_lower:
return [
CreateModel(
self.name,
fields=self.fields,
options={**self.options, **{operation.option_name: operation.option_value}},
bases=self.bases,
managers=self.managers,
),
]
elif isinstance(operation, AlterOrderWithRespectTo) and self.name_lower == operation.name_lower:
return [
CreateModel(
self.name,
fields=self.fields,
options={**self.options, 'order_with_respect_to': operation.order_with_respect_to},
bases=self.bases,
managers=self.managers,
),
]
elif isinstance(operation, FieldOperation) and self.name_lower == operation.model_name_lower:
if isinstance(operation, AddField):
return [
CreateModel(
self.name,
fields=self.fields + [(operation.name, operation.field)],
options=self.options,
bases=self.bases,
managers=self.managers,
),
]
elif isinstance(operation, AlterField):
return [
CreateModel(
self.name,
fields=[
(n, operation.field if n == operation.name else v)
for n, v in self.fields
],
options=self.options,
bases=self.bases,
managers=self.managers,
),
]
elif isinstance(operation, RemoveField):
options = self.options.copy()
for option_name in ('unique_together', 'index_together'):
option = options.pop(option_name, None)
if option:
option = set(filter(bool, (
tuple(f for f in fields if f != operation.name_lower) for fields in option
)))
if option:
options[option_name] = option
order_with_respect_to = options.get('order_with_respect_to')
if order_with_respect_to == operation.name_lower:
del options['order_with_respect_to']
return [
CreateModel(
self.name,
fields=[
(n, v)
for n, v in self.fields
if n.lower() != operation.name_lower
],
options=options,
bases=self.bases,
managers=self.managers,
),
]
elif isinstance(operation, RenameField):
options = self.options.copy()
for option_name in ('unique_together', 'index_together'):
option = options.get(option_name)
if option:
options[option_name] = {
tuple(operation.new_name if f == operation.old_name else f for f in fields)
for fields in option
}
order_with_respect_to = options.get('order_with_respect_to')
if order_with_respect_to == operation.old_name:
options['order_with_respect_to'] = operation.new_name
return [
CreateModel(
self.name,
fields=[
(operation.new_name if n == operation.old_name else n, v)
for n, v in self.fields
],
options=options,
bases=self.bases,
managers=self.managers,
),
]
return super().reduce(operation, app_label)
class DeleteModel(ModelOperation):
"""Drop a model's table."""
def deconstruct(self):
kwargs = {
'name': self.name,
}
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.remove_model(app_label, self.name_lower)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = from_state.apps.get_model(app_label, self.name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.delete_model(model)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.create_model(model)
def references_model(self, name, app_label):
# The deleted model could be referencing the specified model through
# related fields.
return True
def describe(self):
return "Delete model %s" % self.name
@property
def migration_name_fragment(self):
return 'delete_%s' % self.name_lower
class RenameModel(ModelOperation):
"""Rename a model."""
def __init__(self, old_name, new_name):
self.old_name = old_name
self.new_name = new_name
super().__init__(old_name)
@cached_property
def old_name_lower(self):
return self.old_name.lower()
@cached_property
def new_name_lower(self):
return self.new_name.lower()
def deconstruct(self):
kwargs = {
'old_name': self.old_name,
'new_name': self.new_name,
}
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.rename_model(app_label, self.old_name, self.new_name)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.new_name)
if self.allow_migrate_model(schema_editor.connection.alias, new_model):
old_model = from_state.apps.get_model(app_label, self.old_name)
# Move the main table
schema_editor.alter_db_table(
new_model,
old_model._meta.db_table,
new_model._meta.db_table,
)
# Alter the fields pointing to us
for related_object in old_model._meta.related_objects:
if related_object.related_model == old_model:
model = new_model
related_key = (app_label, self.new_name_lower)
else:
model = related_object.related_model
related_key = (
related_object.related_model._meta.app_label,
related_object.related_model._meta.model_name,
)
to_field = to_state.apps.get_model(
*related_key
)._meta.get_field(related_object.field.name)
schema_editor.alter_field(
model,
related_object.field,
to_field,
)
# Rename M2M fields whose name is based on this model's name.
fields = zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many)
for (old_field, new_field) in fields:
# Skip self-referential fields as these are renamed above.
if new_field.model == new_field.related_model or not new_field.remote_field.through._meta.auto_created:
continue
# Rename the M2M table that's based on this model's name.
old_m2m_model = old_field.remote_field.through
new_m2m_model = new_field.remote_field.through
schema_editor.alter_db_table(
new_m2m_model,
old_m2m_model._meta.db_table,
new_m2m_model._meta.db_table,
)
# Rename the column in the M2M table that's based on this
# model's name.
schema_editor.alter_field(
new_m2m_model,
old_m2m_model._meta.get_field(old_model._meta.model_name),
new_m2m_model._meta.get_field(new_model._meta.model_name),
)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower
self.new_name, self.old_name = self.old_name, self.new_name
self.database_forwards(app_label, schema_editor, from_state, to_state)
self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower
self.new_name, self.old_name = self.old_name, self.new_name
def references_model(self, name, app_label):
return (
name.lower() == self.old_name_lower or
name.lower() == self.new_name_lower
)
def describe(self):
return "Rename model %s to %s" % (self.old_name, self.new_name)
@property
def migration_name_fragment(self):
return 'rename_%s_%s' % (self.old_name_lower, self.new_name_lower)
def reduce(self, operation, app_label):
if (isinstance(operation, RenameModel) and
self.new_name_lower == operation.old_name_lower):
return [
RenameModel(
self.old_name,
operation.new_name,
),
]
# Skip `ModelOperation.reduce` as we want to run `references_model`
# against self.new_name.
return (
super(ModelOperation, self).reduce(operation, app_label) or
not operation.references_model(self.new_name, app_label)
)
class ModelOptionOperation(ModelOperation):
def reduce(self, operation, app_label):
if isinstance(operation, (self.__class__, DeleteModel)) and self.name_lower == operation.name_lower:
return [operation]
return super().reduce(operation, app_label)
class AlterModelTable(ModelOptionOperation):
"""Rename a model's table."""
def __init__(self, name, table):
self.table = table
super().__init__(name)
def deconstruct(self):
kwargs = {
'name': self.name,
'table': self.table,
}
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.alter_model_options(app_label, self.name_lower, {'db_table': self.table})
def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name)
if self.allow_migrate_model(schema_editor.connection.alias, new_model):
old_model = from_state.apps.get_model(app_label, self.name)
schema_editor.alter_db_table(
new_model,
old_model._meta.db_table,
new_model._meta.db_table,
)
# Rename M2M fields whose name is based on this model's db_table
for (old_field, new_field) in zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many):
if new_field.remote_field.through._meta.auto_created:
schema_editor.alter_db_table(
new_field.remote_field.through,
old_field.remote_field.through._meta.db_table,
new_field.remote_field.through._meta.db_table,
)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
return self.database_forwards(app_label, schema_editor, from_state, to_state)
def describe(self):
return "Rename table for %s to %s" % (
self.name,
self.table if self.table is not None else "(default)"
)
@property
def migration_name_fragment(self):
return 'alter_%s_table' % self.name_lower
class AlterTogetherOptionOperation(ModelOptionOperation):
option_name = None
def __init__(self, name, option_value):
if option_value:
option_value = set(normalize_together(option_value))
setattr(self, self.option_name, option_value)
super().__init__(name)
@cached_property
def option_value(self):
return getattr(self, self.option_name)
def deconstruct(self):
kwargs = {
'name': self.name,
self.option_name: self.option_value,
}
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.alter_model_options(
app_label,
self.name_lower,
{self.option_name: self.option_value},
)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name)
if self.allow_migrate_model(schema_editor.connection.alias, new_model):
old_model = from_state.apps.get_model(app_label, self.name)
alter_together = getattr(schema_editor, 'alter_%s' % self.option_name)
alter_together(
new_model,
getattr(old_model._meta, self.option_name, set()),
getattr(new_model._meta, self.option_name, set()),
)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
return self.database_forwards(app_label, schema_editor, from_state, to_state)
def references_field(self, model_name, name, app_label):
return (
self.references_model(model_name, app_label) and
(
not self.option_value or
any((name in fields) for fields in self.option_value)
)
)
def describe(self):
return "Alter %s for %s (%s constraint(s))" % (self.option_name, self.name, len(self.option_value or ''))
@property
def migration_name_fragment(self):
return 'alter_%s_%s' % (self.name_lower, self.option_name)
class AlterUniqueTogether(AlterTogetherOptionOperation):
"""
Change the value of unique_together to the target one.
Input value of unique_together must be a set of tuples.
"""
option_name = 'unique_together'
def __init__(self, name, unique_together):
super().__init__(name, unique_together)
class AlterIndexTogether(AlterTogetherOptionOperation):
"""
Change the value of index_together to the target one.
Input value of index_together must be a set of tuples.
"""
option_name = "index_together"
def __init__(self, name, index_together):
super().__init__(name, index_together)
class AlterOrderWithRespectTo(ModelOptionOperation):
"""Represent a change with the order_with_respect_to option."""
option_name = 'order_with_respect_to'
def __init__(self, name, order_with_respect_to):
self.order_with_respect_to = order_with_respect_to
super().__init__(name)
def deconstruct(self):
kwargs = {
'name': self.name,
'order_with_respect_to': self.order_with_respect_to,
}
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.alter_model_options(
app_label,
self.name_lower,
{self.option_name: self.order_with_respect_to},
)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.name)
# Remove a field if we need to
if from_model._meta.order_with_respect_to and not to_model._meta.order_with_respect_to:
schema_editor.remove_field(from_model, from_model._meta.get_field("_order"))
# Add a field if we need to (altering the column is untouched as
# it's likely a rename)
elif to_model._meta.order_with_respect_to and not from_model._meta.order_with_respect_to:
field = to_model._meta.get_field("_order")
if not field.has_default():
field.default = 0
schema_editor.add_field(
from_model,
field,
)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
self.database_forwards(app_label, schema_editor, from_state, to_state)
def references_field(self, model_name, name, app_label):
return (
self.references_model(model_name, app_label) and
(
self.order_with_respect_to is None or
name == self.order_with_respect_to
)
)
def describe(self):
return "Set order_with_respect_to on %s to %s" % (self.name, self.order_with_respect_to)
@property
def migration_name_fragment(self):
return 'alter_%s_order_with_respect_to' % self.name_lower
class AlterModelOptions(ModelOptionOperation):
"""
Set new model options that don't directly affect the database schema
(like verbose_name, permissions, ordering). Python code in migrations
may still need them.
"""
# Model options we want to compare and preserve in an AlterModelOptions op
ALTER_OPTION_KEYS = [
"base_manager_name",
"default_manager_name",
"default_related_name",
"get_latest_by",
"managed",
"ordering",
"permissions",
"default_permissions",
"select_on_save",
"verbose_name",
"verbose_name_plural",
]
def __init__(self, name, options):
self.options = options
super().__init__(name)
def deconstruct(self):
kwargs = {
'name': self.name,
'options': self.options,
}
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.alter_model_options(
app_label,
self.name_lower,
self.options,
self.ALTER_OPTION_KEYS,
)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass
def database_backwards(self, app_label, schema_editor, from_state, to_state):
pass
def describe(self):
return "Change Meta options on %s" % self.name
@property
def migration_name_fragment(self):
return 'alter_%s_options' % self.name_lower
class AlterModelManagers(ModelOptionOperation):
"""Alter the model's managers."""
serialization_expand_args = ['managers']
def __init__(self, name, managers):
self.managers = managers
super().__init__(name)
def deconstruct(self):
return (
self.__class__.__qualname__,
[self.name, self.managers],
{}
)
def state_forwards(self, app_label, state):
state.alter_model_managers(app_label, self.name_lower, self.managers)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass
def database_backwards(self, app_label, schema_editor, from_state, to_state):
pass
def describe(self):
return "Change managers on %s" % self.name
@property
def migration_name_fragment(self):
return 'alter_%s_managers' % self.name_lower
class IndexOperation(Operation):
option_name = 'indexes'
@cached_property
def model_name_lower(self):
return self.model_name.lower()
class AddIndex(IndexOperation):
"""Add an index on a model."""
def __init__(self, model_name, index):
self.model_name = model_name
if not index.name:
raise ValueError(
"Indexes passed to AddIndex operations require a name "
"argument. %r doesn't have one." % index
)
self.index = index
def state_forwards(self, app_label, state):
state.add_index(app_label, self.model_name_lower, self.index)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.add_index(model, self.index)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.remove_index(model, self.index)
def deconstruct(self):
kwargs = {
'model_name': self.model_name,
'index': self.index,
}
return (
self.__class__.__qualname__,
[],
kwargs,
)
def describe(self):
if self.index.expressions:
return 'Create index %s on %s on model %s' % (
self.index.name,
', '.join([str(expression) for expression in self.index.expressions]),
self.model_name,
)
return 'Create index %s on field(s) %s of model %s' % (
self.index.name,
', '.join(self.index.fields),
self.model_name,
)
@property
def migration_name_fragment(self):
return '%s_%s' % (self.model_name_lower, self.index.name.lower())
class RemoveIndex(IndexOperation):
"""Remove an index from a model."""
def __init__(self, model_name, name):
self.model_name = model_name
self.name = name
def state_forwards(self, app_label, state):
state.remove_index(app_label, self.model_name_lower, self.name)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
from_model_state = from_state.models[app_label, self.model_name_lower]
index = from_model_state.get_index_by_name(self.name)
schema_editor.remove_index(model, index)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
to_model_state = to_state.models[app_label, self.model_name_lower]
index = to_model_state.get_index_by_name(self.name)
schema_editor.add_index(model, index)
def deconstruct(self):
kwargs = {
'model_name': self.model_name,
'name': self.name,
}
return (
self.__class__.__qualname__,
[],
kwargs,
)
def describe(self):
return 'Remove index %s from %s' % (self.name, self.model_name)
@property
def migration_name_fragment(self):
return 'remove_%s_%s' % (self.model_name_lower, self.name.lower())
class AddConstraint(IndexOperation):
option_name = 'constraints'
def __init__(self, model_name, constraint):
self.model_name = model_name
self.constraint = constraint
def state_forwards(self, app_label, state):
state.add_constraint(app_label, self.model_name_lower, self.constraint)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.add_constraint(model, self.constraint)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.remove_constraint(model, self.constraint)
def deconstruct(self):
return self.__class__.__name__, [], {
'model_name': self.model_name,
'constraint': self.constraint,
}
def describe(self):
return 'Create constraint %s on model %s' % (self.constraint.name, self.model_name)
@property
def migration_name_fragment(self):
return '%s_%s' % (self.model_name_lower, self.constraint.name.lower())
class RemoveConstraint(IndexOperation):
option_name = 'constraints'
def __init__(self, model_name, name):
self.model_name = model_name
self.name = name
def state_forwards(self, app_label, state):
state.remove_constraint(app_label, self.model_name_lower, self.name)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
from_model_state = from_state.models[app_label, self.model_name_lower]
constraint = from_model_state.get_constraint_by_name(self.name)
schema_editor.remove_constraint(model, constraint)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, model):
to_model_state = to_state.models[app_label, self.model_name_lower]
constraint = to_model_state.get_constraint_by_name(self.name)
schema_editor.add_constraint(model, constraint)
def deconstruct(self):
return self.__class__.__name__, [], {
'model_name': self.model_name,
'name': self.name,
}
def describe(self):
return 'Remove constraint %s from model %s' % (self.name, self.model_name)
@property
def migration_name_fragment(self):
return 'remove_%s_%s' % (self.model_name_lower, self.name.lower())

View File

@@ -0,0 +1,203 @@
from django.db import router
from .base import Operation
class SeparateDatabaseAndState(Operation):
"""
Take two lists of operations - ones that will be used for the database,
and ones that will be used for the state change. This allows operations
that don't support state change to have it applied, or have operations
that affect the state or not the database, or so on.
"""
serialization_expand_args = ['database_operations', 'state_operations']
def __init__(self, database_operations=None, state_operations=None):
self.database_operations = database_operations or []
self.state_operations = state_operations or []
def deconstruct(self):
kwargs = {}
if self.database_operations:
kwargs['database_operations'] = self.database_operations
if self.state_operations:
kwargs['state_operations'] = self.state_operations
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
for state_operation in self.state_operations:
state_operation.state_forwards(app_label, state)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
# We calculate state separately in here since our state functions aren't useful
for database_operation in self.database_operations:
to_state = from_state.clone()
database_operation.state_forwards(app_label, to_state)
database_operation.database_forwards(app_label, schema_editor, from_state, to_state)
from_state = to_state
def database_backwards(self, app_label, schema_editor, from_state, to_state):
# We calculate state separately in here since our state functions aren't useful
to_states = {}
for dbop in self.database_operations:
to_states[dbop] = to_state
to_state = to_state.clone()
dbop.state_forwards(app_label, to_state)
# to_state now has the states of all the database_operations applied
# which is the from_state for the backwards migration of the last
# operation.
for database_operation in reversed(self.database_operations):
from_state = to_state
to_state = to_states[database_operation]
database_operation.database_backwards(app_label, schema_editor, from_state, to_state)
def describe(self):
return "Custom state/database change combination"
class RunSQL(Operation):
"""
Run some raw SQL. A reverse SQL statement may be provided.
Also accept a list of operations that represent the state change effected
by this SQL change, in case it's custom column/table creation/deletion.
"""
noop = ''
def __init__(self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False):
self.sql = sql
self.reverse_sql = reverse_sql
self.state_operations = state_operations or []
self.hints = hints or {}
self.elidable = elidable
def deconstruct(self):
kwargs = {
'sql': self.sql,
}
if self.reverse_sql is not None:
kwargs['reverse_sql'] = self.reverse_sql
if self.state_operations:
kwargs['state_operations'] = self.state_operations
if self.hints:
kwargs['hints'] = self.hints
return (
self.__class__.__qualname__,
[],
kwargs
)
@property
def reversible(self):
return self.reverse_sql is not None
def state_forwards(self, app_label, state):
for state_operation in self.state_operations:
state_operation.state_forwards(app_label, state)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
self._run_sql(schema_editor, self.sql)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
if self.reverse_sql is None:
raise NotImplementedError("You cannot reverse this operation")
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
self._run_sql(schema_editor, self.reverse_sql)
def describe(self):
return "Raw SQL operation"
def _run_sql(self, schema_editor, sqls):
if isinstance(sqls, (list, tuple)):
for sql in sqls:
params = None
if isinstance(sql, (list, tuple)):
elements = len(sql)
if elements == 2:
sql, params = sql
else:
raise ValueError("Expected a 2-tuple but got %d" % elements)
schema_editor.execute(sql, params=params)
elif sqls != RunSQL.noop:
statements = schema_editor.connection.ops.prepare_sql_script(sqls)
for statement in statements:
schema_editor.execute(statement, params=None)
class RunPython(Operation):
"""
Run Python code in a context suitable for doing versioned ORM operations.
"""
reduces_to_sql = False
def __init__(self, code, reverse_code=None, atomic=None, hints=None, elidable=False):
self.atomic = atomic
# Forwards code
if not callable(code):
raise ValueError("RunPython must be supplied with a callable")
self.code = code
# Reverse code
if reverse_code is None:
self.reverse_code = None
else:
if not callable(reverse_code):
raise ValueError("RunPython must be supplied with callable arguments")
self.reverse_code = reverse_code
self.hints = hints or {}
self.elidable = elidable
def deconstruct(self):
kwargs = {
'code': self.code,
}
if self.reverse_code is not None:
kwargs['reverse_code'] = self.reverse_code
if self.atomic is not None:
kwargs['atomic'] = self.atomic
if self.hints:
kwargs['hints'] = self.hints
return (
self.__class__.__qualname__,
[],
kwargs
)
@property
def reversible(self):
return self.reverse_code is not None
def state_forwards(self, app_label, state):
# RunPython objects have no state effect. To add some, combine this
# with SeparateDatabaseAndState.
pass
def database_forwards(self, app_label, schema_editor, from_state, to_state):
# RunPython has access to all models. Ensure that all models are
# reloaded in case any are delayed.
from_state.clear_delayed_apps_cache()
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
# We now execute the Python code in a context that contains a 'models'
# object, representing the versioned models as an app registry.
# We could try to override the global cache, but then people will still
# use direct imports, so we go with a documentation approach instead.
self.code(from_state.apps, schema_editor)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
if self.reverse_code is None:
raise NotImplementedError("You cannot reverse this operation")
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
self.reverse_code(from_state.apps, schema_editor)
def describe(self):
return "Raw Python operation"
@staticmethod
def noop(apps, schema_editor):
return None

View File

@@ -0,0 +1,69 @@
class MigrationOptimizer:
"""
Power the optimization process, where you provide a list of Operations
and you are returned a list of equal or shorter length - operations
are merged into one if possible.
For example, a CreateModel and an AddField can be optimized into a
new CreateModel, and CreateModel and DeleteModel can be optimized into
nothing.
"""
def optimize(self, operations, app_label):
"""
Main optimization entry point. Pass in a list of Operation instances,
get out a new list of Operation instances.
Unfortunately, due to the scope of the optimization (two combinable
operations might be separated by several hundred others), this can't be
done as a peephole optimization with checks/output implemented on
the Operations themselves; instead, the optimizer looks at each
individual operation and scans forwards in the list to see if there
are any matches, stopping at boundaries - operations which can't
be optimized over (RunSQL, operations on the same field/model, etc.)
The inner loop is run until the starting list is the same as the result
list, and then the result is returned. This means that operation
optimization must be stable and always return an equal or shorter list.
"""
# Internal tracking variable for test assertions about # of loops
if app_label is None:
raise TypeError('app_label must be a str.')
self._iterations = 0
while True:
result = self.optimize_inner(operations, app_label)
self._iterations += 1
if result == operations:
return result
operations = result
def optimize_inner(self, operations, app_label):
"""Inner optimization loop."""
new_operations = []
for i, operation in enumerate(operations):
right = True # Should we reduce on the right or on the left.
# Compare it to each operation after it
for j, other in enumerate(operations[i + 1:]):
result = operation.reduce(other, app_label)
if isinstance(result, list):
in_between = operations[i + 1:i + j + 1]
if right:
new_operations.extend(in_between)
new_operations.extend(result)
elif all(op.reduce(other, app_label) is True for op in in_between):
# Perform a left reduction if all of the in-between
# operations can optimize through other.
new_operations.extend(result)
new_operations.extend(in_between)
else:
# Otherwise keep trying.
new_operations.append(operation)
break
new_operations.extend(operations[i + j + 2:])
return new_operations
elif not result:
# Can't perform a right reduction.
right = False
else:
new_operations.append(operation)
return new_operations

View File

@@ -0,0 +1,245 @@
import datetime
import importlib
import os
import sys
from django.apps import apps
from django.db.models import NOT_PROVIDED
from django.utils import timezone
from .loader import MigrationLoader
class MigrationQuestioner:
"""
Give the autodetector responses to questions it might have.
This base class has a built-in noninteractive mode, but the
interactive subclass is what the command-line arguments will use.
"""
def __init__(self, defaults=None, specified_apps=None, dry_run=None):
self.defaults = defaults or {}
self.specified_apps = specified_apps or set()
self.dry_run = dry_run
def ask_initial(self, app_label):
"""Should we create an initial migration for the app?"""
# If it was specified on the command line, definitely true
if app_label in self.specified_apps:
return True
# Otherwise, we look to see if it has a migrations module
# without any Python files in it, apart from __init__.py.
# Apps from the new app template will have these; the Python
# file check will ensure we skip South ones.
try:
app_config = apps.get_app_config(app_label)
except LookupError: # It's a fake app.
return self.defaults.get("ask_initial", False)
migrations_import_path, _ = MigrationLoader.migrations_module(app_config.label)
if migrations_import_path is None:
# It's an application with migrations disabled.
return self.defaults.get("ask_initial", False)
try:
migrations_module = importlib.import_module(migrations_import_path)
except ImportError:
return self.defaults.get("ask_initial", False)
else:
if getattr(migrations_module, "__file__", None):
filenames = os.listdir(os.path.dirname(migrations_module.__file__))
elif hasattr(migrations_module, "__path__"):
if len(migrations_module.__path__) > 1:
return False
filenames = os.listdir(list(migrations_module.__path__)[0])
return not any(x.endswith(".py") for x in filenames if x != "__init__.py")
def ask_not_null_addition(self, field_name, model_name):
"""Adding a NOT NULL field to a model."""
# None means quit
return None
def ask_not_null_alteration(self, field_name, model_name):
"""Changing a NULL field to NOT NULL."""
# None means quit
return None
def ask_rename(self, model_name, old_name, new_name, field_instance):
"""Was this field really renamed?"""
return self.defaults.get("ask_rename", False)
def ask_rename_model(self, old_model_state, new_model_state):
"""Was this model really renamed?"""
return self.defaults.get("ask_rename_model", False)
def ask_merge(self, app_label):
"""Should these migrations really be merged?"""
return self.defaults.get("ask_merge", False)
def ask_auto_now_add_addition(self, field_name, model_name):
"""Adding an auto_now_add field to a model."""
# None means quit
return None
class InteractiveMigrationQuestioner(MigrationQuestioner):
def _boolean_input(self, question, default=None):
result = input("%s " % question)
if not result and default is not None:
return default
while not result or result[0].lower() not in "yn":
result = input("Please answer yes or no: ")
return result[0].lower() == "y"
def _choice_input(self, question, choices):
print(question)
for i, choice in enumerate(choices):
print(" %s) %s" % (i + 1, choice))
result = input("Select an option: ")
while True:
try:
value = int(result)
except ValueError:
pass
else:
if 0 < value <= len(choices):
return value
result = input("Please select a valid option: ")
def _ask_default(self, default=''):
"""
Prompt for a default value.
The ``default`` argument allows providing a custom default value (as a
string) which will be shown to the user and used as the return value
if the user doesn't provide any other input.
"""
print('Please enter the default value as valid Python.')
if default:
print(
f"Accept the default '{default}' by pressing 'Enter' or "
f"provide another value."
)
print(
'The datetime and django.utils.timezone modules are available, so '
'it is possible to provide e.g. timezone.now as a value.'
)
print("Type 'exit' to exit this prompt")
while True:
if default:
prompt = "[default: {}] >>> ".format(default)
else:
prompt = ">>> "
code = input(prompt)
if not code and default:
code = default
if not code:
print("Please enter some code, or 'exit' (without quotes) to exit.")
elif code == "exit":
sys.exit(1)
else:
try:
return eval(code, {}, {'datetime': datetime, 'timezone': timezone})
except (SyntaxError, NameError) as e:
print("Invalid input: %s" % e)
def ask_not_null_addition(self, field_name, model_name):
"""Adding a NOT NULL field to a model."""
if not self.dry_run:
choice = self._choice_input(
f"It is impossible to add a non-nullable field '{field_name}' "
f"to {model_name} without specifying a default. This is "
f"because the database needs something to populate existing "
f"rows.\n"
f"Please select a fix:",
[
("Provide a one-off default now (will be set on all existing "
"rows with a null value for this column)"),
'Quit and manually define a default value in models.py.',
]
)
if choice == 2:
sys.exit(3)
else:
return self._ask_default()
return None
def ask_not_null_alteration(self, field_name, model_name):
"""Changing a NULL field to NOT NULL."""
if not self.dry_run:
choice = self._choice_input(
f"It is impossible to change a nullable field '{field_name}' "
f"on {model_name} to non-nullable without providing a "
f"default. This is because the database needs something to "
f"populate existing rows.\n"
f"Please select a fix:",
[
("Provide a one-off default now (will be set on all existing "
"rows with a null value for this column)"),
'Ignore for now. Existing rows that contain NULL values '
'will have to be handled manually, for example with a '
'RunPython or RunSQL operation.',
'Quit and manually define a default value in models.py.',
]
)
if choice == 2:
return NOT_PROVIDED
elif choice == 3:
sys.exit(3)
else:
return self._ask_default()
return None
def ask_rename(self, model_name, old_name, new_name, field_instance):
"""Was this field really renamed?"""
msg = 'Was %s.%s renamed to %s.%s (a %s)? [y/N]'
return self._boolean_input(msg % (model_name, old_name, model_name, new_name,
field_instance.__class__.__name__), False)
def ask_rename_model(self, old_model_state, new_model_state):
"""Was this model really renamed?"""
msg = 'Was the model %s.%s renamed to %s? [y/N]'
return self._boolean_input(msg % (old_model_state.app_label, old_model_state.name,
new_model_state.name), False)
def ask_merge(self, app_label):
return self._boolean_input(
"\nMerging will only work if the operations printed above do not conflict\n" +
"with each other (working on different fields or models)\n" +
'Should these migration branches be merged? [y/N]',
False,
)
def ask_auto_now_add_addition(self, field_name, model_name):
"""Adding an auto_now_add field to a model."""
if not self.dry_run:
choice = self._choice_input(
f"It is impossible to add the field '{field_name}' with "
f"'auto_now_add=True' to {model_name} without providing a "
f"default. This is because the database needs something to "
f"populate existing rows.\n",
[
'Provide a one-off default now which will be set on all '
'existing rows',
'Quit and manually define a default value in models.py.',
]
)
if choice == 2:
sys.exit(3)
else:
return self._ask_default(default='timezone.now')
return None
class NonInteractiveMigrationQuestioner(MigrationQuestioner):
def ask_not_null_addition(self, field_name, model_name):
# We can't ask the user, so act like the user aborted.
sys.exit(3)
def ask_not_null_alteration(self, field_name, model_name):
# We can't ask the user, so set as not provided.
return NOT_PROVIDED
def ask_auto_now_add_addition(self, field_name, model_name):
# We can't ask the user, so act like the user aborted.
sys.exit(3)

View File

@@ -0,0 +1,96 @@
from django.apps.registry import Apps
from django.db import DatabaseError, models
from django.utils.functional import classproperty
from django.utils.timezone import now
from .exceptions import MigrationSchemaMissing
class MigrationRecorder:
"""
Deal with storing migration records in the database.
Because this table is actually itself used for dealing with model
creation, it's the one thing we can't do normally via migrations.
We manually handle table creation/schema updating (using schema backend)
and then have a floating model to do queries with.
If a migration is unapplied its row is removed from the table. Having
a row in the table always means a migration is applied.
"""
_migration_class = None
@classproperty
def Migration(cls):
"""
Lazy load to avoid AppRegistryNotReady if installed apps import
MigrationRecorder.
"""
if cls._migration_class is None:
class Migration(models.Model):
app = models.CharField(max_length=255)
name = models.CharField(max_length=255)
applied = models.DateTimeField(default=now)
class Meta:
apps = Apps()
app_label = 'migrations'
db_table = 'django_migrations'
def __str__(self):
return 'Migration %s for %s' % (self.name, self.app)
cls._migration_class = Migration
return cls._migration_class
def __init__(self, connection):
self.connection = connection
@property
def migration_qs(self):
return self.Migration.objects.using(self.connection.alias)
def has_table(self):
"""Return True if the django_migrations table exists."""
with self.connection.cursor() as cursor:
tables = self.connection.introspection.table_names(cursor)
return self.Migration._meta.db_table in tables
def ensure_schema(self):
"""Ensure the table exists and has the correct schema."""
# If the table's there, that's fine - we've never changed its schema
# in the codebase.
if self.has_table():
return
# Make the table
try:
with self.connection.schema_editor() as editor:
editor.create_model(self.Migration)
except DatabaseError as exc:
raise MigrationSchemaMissing("Unable to create the django_migrations table (%s)" % exc)
def applied_migrations(self):
"""
Return a dict mapping (app_name, migration_name) to Migration instances
for all applied migrations.
"""
if self.has_table():
return {(migration.app, migration.name): migration for migration in self.migration_qs}
else:
# If the django_migrations table doesn't exist, then no migrations
# are applied.
return {}
def record_applied(self, app, name):
"""Record that a migration was applied."""
self.ensure_schema()
self.migration_qs.create(app=app, name=name)
def record_unapplied(self, app, name):
"""Record that a migration was unapplied."""
self.ensure_schema()
self.migration_qs.filter(app=app, name=name).delete()
def flush(self):
"""Delete all migration records. Useful for testing migrations."""
self.migration_qs.all().delete()

View File

@@ -0,0 +1,357 @@
import builtins
import collections.abc
import datetime
import decimal
import enum
import functools
import math
import os
import pathlib
import re
import types
import uuid
from django.conf import SettingsReference
from django.db import models
from django.db.migrations.operations.base import Operation
from django.db.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
from django.utils.functional import LazyObject, Promise
from django.utils.timezone import utc
from django.utils.version import get_docs_version
class BaseSerializer:
def __init__(self, value):
self.value = value
def serialize(self):
raise NotImplementedError('Subclasses of BaseSerializer must implement the serialize() method.')
class BaseSequenceSerializer(BaseSerializer):
def _format(self):
raise NotImplementedError('Subclasses of BaseSequenceSerializer must implement the _format() method.')
def serialize(self):
imports = set()
strings = []
for item in self.value:
item_string, item_imports = serializer_factory(item).serialize()
imports.update(item_imports)
strings.append(item_string)
value = self._format()
return value % (", ".join(strings)), imports
class BaseSimpleSerializer(BaseSerializer):
def serialize(self):
return repr(self.value), set()
class ChoicesSerializer(BaseSerializer):
def serialize(self):
return serializer_factory(self.value.value).serialize()
class DateTimeSerializer(BaseSerializer):
"""For datetime.*, except datetime.datetime."""
def serialize(self):
return repr(self.value), {'import datetime'}
class DatetimeDatetimeSerializer(BaseSerializer):
"""For datetime.datetime."""
def serialize(self):
if self.value.tzinfo is not None and self.value.tzinfo != utc:
self.value = self.value.astimezone(utc)
imports = ["import datetime"]
if self.value.tzinfo is not None:
imports.append("from django.utils.timezone import utc")
return repr(self.value).replace('datetime.timezone.utc', 'utc'), set(imports)
class DecimalSerializer(BaseSerializer):
def serialize(self):
return repr(self.value), {"from decimal import Decimal"}
class DeconstructableSerializer(BaseSerializer):
@staticmethod
def serialize_deconstructed(path, args, kwargs):
name, imports = DeconstructableSerializer._serialize_path(path)
strings = []
for arg in args:
arg_string, arg_imports = serializer_factory(arg).serialize()
strings.append(arg_string)
imports.update(arg_imports)
for kw, arg in sorted(kwargs.items()):
arg_string, arg_imports = serializer_factory(arg).serialize()
imports.update(arg_imports)
strings.append("%s=%s" % (kw, arg_string))
return "%s(%s)" % (name, ", ".join(strings)), imports
@staticmethod
def _serialize_path(path):
module, name = path.rsplit(".", 1)
if module == "django.db.models":
imports = {"from django.db import models"}
name = "models.%s" % name
else:
imports = {"import %s" % module}
name = path
return name, imports
def serialize(self):
return self.serialize_deconstructed(*self.value.deconstruct())
class DictionarySerializer(BaseSerializer):
def serialize(self):
imports = set()
strings = []
for k, v in sorted(self.value.items()):
k_string, k_imports = serializer_factory(k).serialize()
v_string, v_imports = serializer_factory(v).serialize()
imports.update(k_imports)
imports.update(v_imports)
strings.append((k_string, v_string))
return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports
class EnumSerializer(BaseSerializer):
def serialize(self):
enum_class = self.value.__class__
module = enum_class.__module__
return (
'%s.%s[%r]' % (module, enum_class.__qualname__, self.value.name),
{'import %s' % module},
)
class FloatSerializer(BaseSimpleSerializer):
def serialize(self):
if math.isnan(self.value) or math.isinf(self.value):
return 'float("{}")'.format(self.value), set()
return super().serialize()
class FrozensetSerializer(BaseSequenceSerializer):
def _format(self):
return "frozenset([%s])"
class FunctionTypeSerializer(BaseSerializer):
def serialize(self):
if getattr(self.value, "__self__", None) and isinstance(self.value.__self__, type):
klass = self.value.__self__
module = klass.__module__
return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), {"import %s" % module}
# Further error checking
if self.value.__name__ == '<lambda>':
raise ValueError("Cannot serialize function: lambda")
if self.value.__module__ is None:
raise ValueError("Cannot serialize function %r: No module" % self.value)
module_name = self.value.__module__
if '<' not in self.value.__qualname__: # Qualname can include <locals>
return '%s.%s' % (module_name, self.value.__qualname__), {'import %s' % self.value.__module__}
raise ValueError(
'Could not find function %s in %s.\n' % (self.value.__name__, module_name)
)
class FunctoolsPartialSerializer(BaseSerializer):
def serialize(self):
# Serialize functools.partial() arguments
func_string, func_imports = serializer_factory(self.value.func).serialize()
args_string, args_imports = serializer_factory(self.value.args).serialize()
keywords_string, keywords_imports = serializer_factory(self.value.keywords).serialize()
# Add any imports needed by arguments
imports = {'import functools', *func_imports, *args_imports, *keywords_imports}
return (
'functools.%s(%s, *%s, **%s)' % (
self.value.__class__.__name__,
func_string,
args_string,
keywords_string,
),
imports,
)
class IterableSerializer(BaseSerializer):
def serialize(self):
imports = set()
strings = []
for item in self.value:
item_string, item_imports = serializer_factory(item).serialize()
imports.update(item_imports)
strings.append(item_string)
# When len(strings)==0, the empty iterable should be serialized as
# "()", not "(,)" because (,) is invalid Python syntax.
value = "(%s)" if len(strings) != 1 else "(%s,)"
return value % (", ".join(strings)), imports
class ModelFieldSerializer(DeconstructableSerializer):
def serialize(self):
attr_name, path, args, kwargs = self.value.deconstruct()
return self.serialize_deconstructed(path, args, kwargs)
class ModelManagerSerializer(DeconstructableSerializer):
def serialize(self):
as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct()
if as_manager:
name, imports = self._serialize_path(qs_path)
return "%s.as_manager()" % name, imports
else:
return self.serialize_deconstructed(manager_path, args, kwargs)
class OperationSerializer(BaseSerializer):
def serialize(self):
from django.db.migrations.writer import OperationWriter
string, imports = OperationWriter(self.value, indentation=0).serialize()
# Nested operation, trailing comma is handled in upper OperationWriter._write()
return string.rstrip(','), imports
class PathLikeSerializer(BaseSerializer):
def serialize(self):
return repr(os.fspath(self.value)), {}
class PathSerializer(BaseSerializer):
def serialize(self):
# Convert concrete paths to pure paths to avoid issues with migrations
# generated on one platform being used on a different platform.
prefix = 'Pure' if isinstance(self.value, pathlib.Path) else ''
return 'pathlib.%s%r' % (prefix, self.value), {'import pathlib'}
class RegexSerializer(BaseSerializer):
def serialize(self):
regex_pattern, pattern_imports = serializer_factory(self.value.pattern).serialize()
# Turn off default implicit flags (e.g. re.U) because regexes with the
# same implicit and explicit flags aren't equal.
flags = self.value.flags ^ re.compile('').flags
regex_flags, flag_imports = serializer_factory(flags).serialize()
imports = {'import re', *pattern_imports, *flag_imports}
args = [regex_pattern]
if flags:
args.append(regex_flags)
return "re.compile(%s)" % ', '.join(args), imports
class SequenceSerializer(BaseSequenceSerializer):
def _format(self):
return "[%s]"
class SetSerializer(BaseSequenceSerializer):
def _format(self):
# Serialize as a set literal except when value is empty because {}
# is an empty dict.
return '{%s}' if self.value else 'set(%s)'
class SettingsReferenceSerializer(BaseSerializer):
def serialize(self):
return "settings.%s" % self.value.setting_name, {"from django.conf import settings"}
class TupleSerializer(BaseSequenceSerializer):
def _format(self):
# When len(value)==0, the empty tuple should be serialized as "()",
# not "(,)" because (,) is invalid Python syntax.
return "(%s)" if len(self.value) != 1 else "(%s,)"
class TypeSerializer(BaseSerializer):
def serialize(self):
special_cases = [
(models.Model, "models.Model", ['from django.db import models']),
(type(None), 'type(None)', []),
]
for case, string, imports in special_cases:
if case is self.value:
return string, set(imports)
if hasattr(self.value, "__module__"):
module = self.value.__module__
if module == builtins.__name__:
return self.value.__name__, set()
else:
return "%s.%s" % (module, self.value.__qualname__), {"import %s" % module}
class UUIDSerializer(BaseSerializer):
def serialize(self):
return "uuid.%s" % repr(self.value), {"import uuid"}
class Serializer:
_registry = {
# Some of these are order-dependent.
frozenset: FrozensetSerializer,
list: SequenceSerializer,
set: SetSerializer,
tuple: TupleSerializer,
dict: DictionarySerializer,
models.Choices: ChoicesSerializer,
enum.Enum: EnumSerializer,
datetime.datetime: DatetimeDatetimeSerializer,
(datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
SettingsReference: SettingsReferenceSerializer,
float: FloatSerializer,
(bool, int, type(None), bytes, str, range): BaseSimpleSerializer,
decimal.Decimal: DecimalSerializer,
(functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
(types.FunctionType, types.BuiltinFunctionType, types.MethodType): FunctionTypeSerializer,
collections.abc.Iterable: IterableSerializer,
(COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
uuid.UUID: UUIDSerializer,
pathlib.PurePath: PathSerializer,
os.PathLike: PathLikeSerializer,
}
@classmethod
def register(cls, type_, serializer):
if not issubclass(serializer, BaseSerializer):
raise ValueError("'%s' must inherit from 'BaseSerializer'." % serializer.__name__)
cls._registry[type_] = serializer
@classmethod
def unregister(cls, type_):
cls._registry.pop(type_)
def serializer_factory(value):
if isinstance(value, Promise):
value = str(value)
elif isinstance(value, LazyObject):
# The unwrapped value is returned as the first item of the arguments
# tuple.
value = value.__reduce__()[1][0]
if isinstance(value, models.Field):
return ModelFieldSerializer(value)
if isinstance(value, models.manager.BaseManager):
return ModelManagerSerializer(value)
if isinstance(value, Operation):
return OperationSerializer(value)
if isinstance(value, type):
return TypeSerializer(value)
# Anything that knows how to deconstruct itself.
if hasattr(value, 'deconstruct'):
return DeconstructableSerializer(value)
for type_, serializer_cls in Serializer._registry.items():
if isinstance(value, type_):
return serializer_cls(value)
raise ValueError(
"Cannot serialize: %r\nThere are some values Django cannot serialize into "
"migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
"topics/migrations/#migration-serializing" % (value, get_docs_version())
)

View File

@@ -0,0 +1,904 @@
import copy
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
from django.apps import AppConfig
from django.apps.registry import Apps, apps as global_apps
from django.conf import settings
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.migrations.utils import field_is_referenced, get_references
from django.db.models import NOT_PROVIDED
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
from django.db.models.options import DEFAULT_NAMES, normalize_together
from django.db.models.utils import make_model_tuple
from django.utils.functional import cached_property
from django.utils.module_loading import import_string
from django.utils.version import get_docs_version
from .exceptions import InvalidBasesError
from .utils import resolve_relation
def _get_app_label_and_model_name(model, app_label=''):
if isinstance(model, str):
split = model.split('.', 1)
return tuple(split) if len(split) == 2 else (app_label, split[0])
else:
return model._meta.app_label, model._meta.model_name
def _get_related_models(m):
"""Return all models that have a direct relationship to the given model."""
related_models = [
subclass for subclass in m.__subclasses__()
if issubclass(subclass, models.Model)
]
related_fields_models = set()
for f in m._meta.get_fields(include_parents=True, include_hidden=True):
if f.is_relation and f.related_model is not None and not isinstance(f.related_model, str):
related_fields_models.add(f.model)
related_models.append(f.related_model)
# Reverse accessors of foreign keys to proxy models are attached to their
# concrete proxied model.
opts = m._meta
if opts.proxy and m in related_fields_models:
related_models.append(opts.concrete_model)
return related_models
def get_related_models_tuples(model):
"""
Return a list of typical (app_label, model_name) tuples for all related
models for the given model.
"""
return {
(rel_mod._meta.app_label, rel_mod._meta.model_name)
for rel_mod in _get_related_models(model)
}
def get_related_models_recursive(model):
"""
Return all models that have a direct or indirect relationship
to the given model.
Relationships are either defined by explicit relational fields, like
ForeignKey, ManyToManyField or OneToOneField, or by inheriting from another
model (a superclass is related to its subclasses, but not vice versa). Note,
however, that a model inheriting from a concrete model is also related to
its superclass through the implicit *_ptr OneToOneField on the subclass.
"""
seen = set()
queue = _get_related_models(model)
for rel_mod in queue:
rel_app_label, rel_model_name = rel_mod._meta.app_label, rel_mod._meta.model_name
if (rel_app_label, rel_model_name) in seen:
continue
seen.add((rel_app_label, rel_model_name))
queue.extend(_get_related_models(rel_mod))
return seen - {(model._meta.app_label, model._meta.model_name)}
class ProjectState:
"""
Represent the entire project's overall state. This is the item that is
passed around - do it here rather than at the app level so that cross-app
FKs/etc. resolve properly.
"""
def __init__(self, models=None, real_apps=None):
self.models = models or {}
# Apps to include from main registry, usually unmigrated ones
if real_apps is None:
real_apps = set()
else:
assert isinstance(real_apps, set)
self.real_apps = real_apps
self.is_delayed = False
# {remote_model_key: {model_key: {field_name: field}}}
self._relations = None
@property
def relations(self):
if self._relations is None:
self.resolve_fields_and_relations()
return self._relations
def add_model(self, model_state):
model_key = model_state.app_label, model_state.name_lower
self.models[model_key] = model_state
if self._relations is not None:
self.resolve_model_relations(model_key)
if 'apps' in self.__dict__: # hasattr would cache the property
self.reload_model(*model_key)
def remove_model(self, app_label, model_name):
model_key = app_label, model_name
del self.models[model_key]
if self._relations is not None:
self._relations.pop(model_key, None)
# Call list() since _relations can change size during iteration.
for related_model_key, model_relations in list(self._relations.items()):
model_relations.pop(model_key, None)
if not model_relations:
del self._relations[related_model_key]
if 'apps' in self.__dict__: # hasattr would cache the property
self.apps.unregister_model(*model_key)
# Need to do this explicitly since unregister_model() doesn't clear
# the cache automatically (#24513)
self.apps.clear_cache()
def rename_model(self, app_label, old_name, new_name):
# Add a new model.
old_name_lower = old_name.lower()
new_name_lower = new_name.lower()
renamed_model = self.models[app_label, old_name_lower].clone()
renamed_model.name = new_name
self.models[app_label, new_name_lower] = renamed_model
# Repoint all fields pointing to the old model to the new one.
old_model_tuple = (app_label, old_name_lower)
new_remote_model = f'{app_label}.{new_name}'
to_reload = set()
for model_state, name, field, reference in get_references(self, old_model_tuple):
changed_field = None
if reference.to:
changed_field = field.clone()
changed_field.remote_field.model = new_remote_model
if reference.through:
if changed_field is None:
changed_field = field.clone()
changed_field.remote_field.through = new_remote_model
if changed_field:
model_state.fields[name] = changed_field
to_reload.add((model_state.app_label, model_state.name_lower))
if self._relations is not None:
old_name_key = app_label, old_name_lower
new_name_key = app_label, new_name_lower
if old_name_key in self._relations:
self._relations[new_name_key] = self._relations.pop(old_name_key)
for model_relations in self._relations.values():
if old_name_key in model_relations:
model_relations[new_name_key] = model_relations.pop(old_name_key)
# Reload models related to old model before removing the old model.
self.reload_models(to_reload, delay=True)
# Remove the old model.
self.remove_model(app_label, old_name_lower)
self.reload_model(app_label, new_name_lower, delay=True)
def alter_model_options(self, app_label, model_name, options, option_keys=None):
model_state = self.models[app_label, model_name]
model_state.options = {**model_state.options, **options}
if option_keys:
for key in option_keys:
if key not in options:
model_state.options.pop(key, False)
self.reload_model(app_label, model_name, delay=True)
def alter_model_managers(self, app_label, model_name, managers):
model_state = self.models[app_label, model_name]
model_state.managers = list(managers)
self.reload_model(app_label, model_name, delay=True)
def _append_option(self, app_label, model_name, option_name, obj):
model_state = self.models[app_label, model_name]
model_state.options[option_name] = [*model_state.options[option_name], obj]
self.reload_model(app_label, model_name, delay=True)
def _remove_option(self, app_label, model_name, option_name, obj_name):
model_state = self.models[app_label, model_name]
objs = model_state.options[option_name]
model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
self.reload_model(app_label, model_name, delay=True)
def add_index(self, app_label, model_name, index):
self._append_option(app_label, model_name, 'indexes', index)
def remove_index(self, app_label, model_name, index_name):
self._remove_option(app_label, model_name, 'indexes', index_name)
def add_constraint(self, app_label, model_name, constraint):
self._append_option(app_label, model_name, 'constraints', constraint)
def remove_constraint(self, app_label, model_name, constraint_name):
self._remove_option(app_label, model_name, 'constraints', constraint_name)
def add_field(self, app_label, model_name, name, field, preserve_default):
# If preserve default is off, don't use the default for future state.
if not preserve_default:
field = field.clone()
field.default = NOT_PROVIDED
else:
field = field
model_key = app_label, model_name
self.models[model_key].fields[name] = field
if self._relations is not None:
self.resolve_model_field_relations(model_key, name, field)
# Delay rendering of relationships if it's not a relational field.
delay = not field.is_relation
self.reload_model(*model_key, delay=delay)
def remove_field(self, app_label, model_name, name):
model_key = app_label, model_name
model_state = self.models[model_key]
old_field = model_state.fields.pop(name)
if self._relations is not None:
self.resolve_model_field_relations(model_key, name, old_field)
# Delay rendering of relationships if it's not a relational field.
delay = not old_field.is_relation
self.reload_model(*model_key, delay=delay)
def alter_field(self, app_label, model_name, name, field, preserve_default):
if not preserve_default:
field = field.clone()
field.default = NOT_PROVIDED
else:
field = field
model_key = app_label, model_name
fields = self.models[model_key].fields
if self._relations is not None:
old_field = fields.pop(name)
if old_field.is_relation:
self.resolve_model_field_relations(model_key, name, old_field)
fields[name] = field
if field.is_relation:
self.resolve_model_field_relations(model_key, name, field)
else:
fields[name] = field
# TODO: investigate if old relational fields must be reloaded or if
# it's sufficient if the new field is (#27737).
# Delay rendering of relationships if it's not a relational field and
# not referenced by a foreign key.
delay = (
not field.is_relation and
not field_is_referenced(self, model_key, (name, field))
)
self.reload_model(*model_key, delay=delay)
def rename_field(self, app_label, model_name, old_name, new_name):
model_key = app_label, model_name
model_state = self.models[model_key]
# Rename the field.
fields = model_state.fields
try:
found = fields.pop(old_name)
except KeyError:
raise FieldDoesNotExist(
f"{app_label}.{model_name} has no field named '{old_name}'"
)
fields[new_name] = found
for field in fields.values():
# Fix from_fields to refer to the new field.
from_fields = getattr(field, 'from_fields', None)
if from_fields:
field.from_fields = tuple([
new_name if from_field_name == old_name else from_field_name
for from_field_name in from_fields
])
# Fix index/unique_together to refer to the new field.
options = model_state.options
for option in ('index_together', 'unique_together'):
if option in options:
options[option] = [
[new_name if n == old_name else n for n in together]
for together in options[option]
]
# Fix to_fields to refer to the new field.
delay = True
references = get_references(self, model_key, (old_name, found))
for *_, field, reference in references:
delay = False
if reference.to:
remote_field, to_fields = reference.to
if getattr(remote_field, 'field_name', None) == old_name:
remote_field.field_name = new_name
if to_fields:
field.to_fields = tuple([
new_name if to_field_name == old_name else to_field_name
for to_field_name in to_fields
])
if self._relations is not None:
old_name_lower = old_name.lower()
new_name_lower = new_name.lower()
for to_model in self._relations.values():
if old_name_lower in to_model[model_key]:
field = to_model[model_key].pop(old_name_lower)
field.name = new_name_lower
to_model[model_key][new_name_lower] = field
self.reload_model(*model_key, delay=delay)
def _find_reload_model(self, app_label, model_name, delay=False):
if delay:
self.is_delayed = True
related_models = set()
try:
old_model = self.apps.get_model(app_label, model_name)
except LookupError:
pass
else:
# Get all relations to and from the old model before reloading,
# as _meta.apps may change
if delay:
related_models = get_related_models_tuples(old_model)
else:
related_models = get_related_models_recursive(old_model)
# Get all outgoing references from the model to be rendered
model_state = self.models[(app_label, model_name)]
# Directly related models are the models pointed to by ForeignKeys,
# OneToOneFields, and ManyToManyFields.
direct_related_models = set()
for field in model_state.fields.values():
if field.is_relation:
if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT:
continue
rel_app_label, rel_model_name = _get_app_label_and_model_name(field.related_model, app_label)
direct_related_models.add((rel_app_label, rel_model_name.lower()))
# For all direct related models recursively get all related models.
related_models.update(direct_related_models)
for rel_app_label, rel_model_name in direct_related_models:
try:
rel_model = self.apps.get_model(rel_app_label, rel_model_name)
except LookupError:
pass
else:
if delay:
related_models.update(get_related_models_tuples(rel_model))
else:
related_models.update(get_related_models_recursive(rel_model))
# Include the model itself
related_models.add((app_label, model_name))
return related_models
def reload_model(self, app_label, model_name, delay=False):
if 'apps' in self.__dict__: # hasattr would cache the property
related_models = self._find_reload_model(app_label, model_name, delay)
self._reload(related_models)
def reload_models(self, models, delay=True):
if 'apps' in self.__dict__: # hasattr would cache the property
related_models = set()
for app_label, model_name in models:
related_models.update(self._find_reload_model(app_label, model_name, delay))
self._reload(related_models)
def _reload(self, related_models):
# Unregister all related models
with self.apps.bulk_update():
for rel_app_label, rel_model_name in related_models:
self.apps.unregister_model(rel_app_label, rel_model_name)
states_to_be_rendered = []
# Gather all models states of those models that will be rerendered.
# This includes:
# 1. All related models of unmigrated apps
for model_state in self.apps.real_models:
if (model_state.app_label, model_state.name_lower) in related_models:
states_to_be_rendered.append(model_state)
# 2. All related models of migrated apps
for rel_app_label, rel_model_name in related_models:
try:
model_state = self.models[rel_app_label, rel_model_name]
except KeyError:
pass
else:
states_to_be_rendered.append(model_state)
# Render all models
self.apps.render_multiple(states_to_be_rendered)
def update_model_field_relation(
self, model, model_key, field_name, field, concretes,
):
remote_model_key = resolve_relation(model, *model_key)
if remote_model_key[0] not in self.real_apps and remote_model_key in concretes:
remote_model_key = concretes[remote_model_key]
relations_to_remote_model = self._relations[remote_model_key]
if field_name in self.models[model_key].fields:
# The assert holds because it's a new relation, or an altered
# relation, in which case references have been removed by
# alter_field().
assert field_name not in relations_to_remote_model[model_key]
relations_to_remote_model[model_key][field_name] = field
else:
del relations_to_remote_model[model_key][field_name]
if not relations_to_remote_model[model_key]:
del relations_to_remote_model[model_key]
def resolve_model_field_relations(
self, model_key, field_name, field, concretes=None,
):
remote_field = field.remote_field
if not remote_field:
return
if concretes is None:
concretes, _ = self._get_concrete_models_mapping_and_proxy_models()
self.update_model_field_relation(
remote_field.model, model_key, field_name, field, concretes,
)
through = getattr(remote_field, 'through', None)
if not through:
return
self.update_model_field_relation(through, model_key, field_name, field, concretes)
def resolve_model_relations(self, model_key, concretes=None):
if concretes is None:
concretes, _ = self._get_concrete_models_mapping_and_proxy_models()
model_state = self.models[model_key]
for field_name, field in model_state.fields.items():
self.resolve_model_field_relations(model_key, field_name, field, concretes)
def resolve_fields_and_relations(self):
# Resolve fields.
for model_state in self.models.values():
for field_name, field in model_state.fields.items():
field.name = field_name
# Resolve relations.
# {remote_model_key: {model_key: {field_name: field}}}
self._relations = defaultdict(partial(defaultdict, dict))
concretes, proxies = self._get_concrete_models_mapping_and_proxy_models()
for model_key in concretes:
self.resolve_model_relations(model_key, concretes)
for model_key in proxies:
self._relations[model_key] = self._relations[concretes[model_key]]
def get_concrete_model_key(self, model):
concrete_models_mapping, _ = self._get_concrete_models_mapping_and_proxy_models()
model_key = make_model_tuple(model)
return concrete_models_mapping[model_key]
def _get_concrete_models_mapping_and_proxy_models(self):
concrete_models_mapping = {}
proxy_models = {}
# Split models to proxy and concrete models.
for model_key, model_state in self.models.items():
if model_state.options.get('proxy'):
proxy_models[model_key] = model_state
# Find a concrete model for the proxy.
concrete_models_mapping[model_key] = self._find_concrete_model_from_proxy(
proxy_models, model_state,
)
else:
concrete_models_mapping[model_key] = model_key
return concrete_models_mapping, proxy_models
def _find_concrete_model_from_proxy(self, proxy_models, model_state):
for base in model_state.bases:
if not (isinstance(base, str) or issubclass(base, models.Model)):
continue
base_key = make_model_tuple(base)
base_state = proxy_models.get(base_key)
if not base_state:
# Concrete model found, stop looking at bases.
return base_key
return self._find_concrete_model_from_proxy(proxy_models, base_state)
def clone(self):
"""Return an exact copy of this ProjectState."""
new_state = ProjectState(
models={k: v.clone() for k, v in self.models.items()},
real_apps=self.real_apps,
)
if 'apps' in self.__dict__:
new_state.apps = self.apps.clone()
new_state.is_delayed = self.is_delayed
return new_state
def clear_delayed_apps_cache(self):
if self.is_delayed and 'apps' in self.__dict__:
del self.__dict__['apps']
@cached_property
def apps(self):
return StateApps(self.real_apps, self.models)
@classmethod
def from_apps(cls, apps):
"""Take an Apps and return a ProjectState matching it."""
app_models = {}
for model in apps.get_models(include_swapped=True):
model_state = ModelState.from_model(model)
app_models[(model_state.app_label, model_state.name_lower)] = model_state
return cls(app_models)
def __eq__(self, other):
return self.models == other.models and self.real_apps == other.real_apps
class AppConfigStub(AppConfig):
"""Stub of an AppConfig. Only provides a label and a dict of models."""
def __init__(self, label):
self.apps = None
self.models = {}
# App-label and app-name are not the same thing, so technically passing
# in the label here is wrong. In practice, migrations don't care about
# the app name, but we need something unique, and the label works fine.
self.label = label
self.name = label
def import_models(self):
self.models = self.apps.all_models[self.label]
class StateApps(Apps):
"""
Subclass of the global Apps registry class to better handle dynamic model
additions and removals.
"""
def __init__(self, real_apps, models, ignore_swappable=False):
# Any apps in self.real_apps should have all their models included
# in the render. We don't use the original model instances as there
# are some variables that refer to the Apps object.
# FKs/M2Ms from real apps are also not included as they just
# mess things up with partial states (due to lack of dependencies)
self.real_models = []
for app_label in real_apps:
app = global_apps.get_app_config(app_label)
for model in app.get_models():
self.real_models.append(ModelState.from_model(model, exclude_rels=True))
# Populate the app registry with a stub for each application.
app_labels = {model_state.app_label for model_state in models.values()}
app_configs = [AppConfigStub(label) for label in sorted([*real_apps, *app_labels])]
super().__init__(app_configs)
# These locks get in the way of copying as implemented in clone(),
# which is called whenever Django duplicates a StateApps before
# updating it.
self._lock = None
self.ready_event = None
self.render_multiple([*models.values(), *self.real_models])
# There shouldn't be any operations pending at this point.
from django.core.checks.model_checks import _check_lazy_references
ignore = {make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set()
errors = _check_lazy_references(self, ignore=ignore)
if errors:
raise ValueError("\n".join(error.msg for error in errors))
@contextmanager
def bulk_update(self):
# Avoid clearing each model's cache for each change. Instead, clear
# all caches when we're finished updating the model instances.
ready = self.ready
self.ready = False
try:
yield
finally:
self.ready = ready
self.clear_cache()
def render_multiple(self, model_states):
# We keep trying to render the models in a loop, ignoring invalid
# base errors, until the size of the unrendered models doesn't
# decrease by at least one, meaning there's a base dependency loop/
# missing base.
if not model_states:
return
# Prevent that all model caches are expired for each render.
with self.bulk_update():
unrendered_models = model_states
while unrendered_models:
new_unrendered_models = []
for model in unrendered_models:
try:
model.render(self)
except InvalidBasesError:
new_unrendered_models.append(model)
if len(new_unrendered_models) == len(unrendered_models):
raise InvalidBasesError(
"Cannot resolve bases for %r\nThis can happen if you are inheriting models from an "
"app with migrations (e.g. contrib.auth)\n in an app with no migrations; see "
"https://docs.djangoproject.com/en/%s/topics/migrations/#dependencies "
"for more" % (new_unrendered_models, get_docs_version())
)
unrendered_models = new_unrendered_models
def clone(self):
"""Return a clone of this registry."""
clone = StateApps([], {})
clone.all_models = copy.deepcopy(self.all_models)
clone.app_configs = copy.deepcopy(self.app_configs)
# Set the pointer to the correct app registry.
for app_config in clone.app_configs.values():
app_config.apps = clone
# No need to actually clone them, they'll never change
clone.real_models = self.real_models
return clone
def register_model(self, app_label, model):
self.all_models[app_label][model._meta.model_name] = model
if app_label not in self.app_configs:
self.app_configs[app_label] = AppConfigStub(app_label)
self.app_configs[app_label].apps = self
self.app_configs[app_label].models[model._meta.model_name] = model
self.do_pending_operations(model)
self.clear_cache()
def unregister_model(self, app_label, model_name):
try:
del self.all_models[app_label][model_name]
del self.app_configs[app_label].models[model_name]
except KeyError:
pass
class ModelState:
"""
Represent a Django Model. Don't use the actual Model class as it's not
designed to have its options changed - instead, mutate this one and then
render it into a Model as required.
Note that while you are allowed to mutate .fields, you are not allowed
to mutate the Field instances inside there themselves - you must instead
assign new ones, as these are not detached during a clone.
"""
def __init__(self, app_label, name, fields, options=None, bases=None, managers=None):
self.app_label = app_label
self.name = name
self.fields = dict(fields)
self.options = options or {}
self.options.setdefault('indexes', [])
self.options.setdefault('constraints', [])
self.bases = bases or (models.Model,)
self.managers = managers or []
for name, field in self.fields.items():
# Sanity-check that fields are NOT already bound to a model.
if hasattr(field, 'model'):
raise ValueError(
'ModelState.fields cannot be bound to a model - "%s" is.' % name
)
# Sanity-check that relation fields are NOT referring to a model class.
if field.is_relation and hasattr(field.related_model, '_meta'):
raise ValueError(
'ModelState.fields cannot refer to a model class - "%s.to" does. '
'Use a string reference instead.' % name
)
if field.many_to_many and hasattr(field.remote_field.through, '_meta'):
raise ValueError(
'ModelState.fields cannot refer to a model class - "%s.through" does. '
'Use a string reference instead.' % name
)
# Sanity-check that indexes have their name set.
for index in self.options['indexes']:
if not index.name:
raise ValueError(
"Indexes passed to ModelState require a name attribute. "
"%r doesn't have one." % index
)
@cached_property
def name_lower(self):
return self.name.lower()
def get_field(self, field_name):
field_name = (
self.options['order_with_respect_to']
if field_name == '_order'
else field_name
)
return self.fields[field_name]
@classmethod
def from_model(cls, model, exclude_rels=False):
"""Given a model, return a ModelState representing it."""
# Deconstruct the fields
fields = []
for field in model._meta.local_fields:
if getattr(field, "remote_field", None) and exclude_rels:
continue
if isinstance(field, models.OrderWrt):
continue
name = field.name
try:
fields.append((name, field.clone()))
except TypeError as e:
raise TypeError("Couldn't reconstruct field %s on %s: %s" % (
name,
model._meta.label,
e,
))
if not exclude_rels:
for field in model._meta.local_many_to_many:
name = field.name
try:
fields.append((name, field.clone()))
except TypeError as e:
raise TypeError("Couldn't reconstruct m2m field %s on %s: %s" % (
name,
model._meta.object_name,
e,
))
# Extract the options
options = {}
for name in DEFAULT_NAMES:
# Ignore some special options
if name in ["apps", "app_label"]:
continue
elif name in model._meta.original_attrs:
if name == "unique_together":
ut = model._meta.original_attrs["unique_together"]
options[name] = set(normalize_together(ut))
elif name == "index_together":
it = model._meta.original_attrs["index_together"]
options[name] = set(normalize_together(it))
elif name == "indexes":
indexes = [idx.clone() for idx in model._meta.indexes]
for index in indexes:
if not index.name:
index.set_name_with_model(model)
options['indexes'] = indexes
elif name == 'constraints':
options['constraints'] = [con.clone() for con in model._meta.constraints]
else:
options[name] = model._meta.original_attrs[name]
# If we're ignoring relationships, remove all field-listing model
# options (that option basically just means "make a stub model")
if exclude_rels:
for key in ["unique_together", "index_together", "order_with_respect_to"]:
if key in options:
del options[key]
# Private fields are ignored, so remove options that refer to them.
elif options.get('order_with_respect_to') in {field.name for field in model._meta.private_fields}:
del options['order_with_respect_to']
def flatten_bases(model):
bases = []
for base in model.__bases__:
if hasattr(base, "_meta") and base._meta.abstract:
bases.extend(flatten_bases(base))
else:
bases.append(base)
return bases
# We can't rely on __mro__ directly because we only want to flatten
# abstract models and not the whole tree. However by recursing on
# __bases__ we may end up with duplicates and ordering issues, we
# therefore discard any duplicates and reorder the bases according
# to their index in the MRO.
flattened_bases = sorted(set(flatten_bases(model)), key=lambda x: model.__mro__.index(x))
# Make our record
bases = tuple(
(
base._meta.label_lower
if hasattr(base, "_meta") else
base
)
for base in flattened_bases
)
# Ensure at least one base inherits from models.Model
if not any((isinstance(base, str) or issubclass(base, models.Model)) for base in bases):
bases = (models.Model,)
managers = []
manager_names = set()
default_manager_shim = None
for manager in model._meta.managers:
if manager.name in manager_names:
# Skip overridden managers.
continue
elif manager.use_in_migrations:
# Copy managers usable in migrations.
new_manager = copy.copy(manager)
new_manager._set_creation_counter()
elif manager is model._base_manager or manager is model._default_manager:
# Shim custom managers used as default and base managers.
new_manager = models.Manager()
new_manager.model = manager.model
new_manager.name = manager.name
if manager is model._default_manager:
default_manager_shim = new_manager
else:
continue
manager_names.add(manager.name)
managers.append((manager.name, new_manager))
# Ignore a shimmed default manager called objects if it's the only one.
if managers == [('objects', default_manager_shim)]:
managers = []
# Construct the new ModelState
return cls(
model._meta.app_label,
model._meta.object_name,
fields,
options,
bases,
managers,
)
def construct_managers(self):
"""Deep-clone the managers using deconstruction."""
# Sort all managers by their creation counter
sorted_managers = sorted(self.managers, key=lambda v: v[1].creation_counter)
for mgr_name, manager in sorted_managers:
as_manager, manager_path, qs_path, args, kwargs = manager.deconstruct()
if as_manager:
qs_class = import_string(qs_path)
yield mgr_name, qs_class.as_manager()
else:
manager_class = import_string(manager_path)
yield mgr_name, manager_class(*args, **kwargs)
def clone(self):
"""Return an exact copy of this ModelState."""
return self.__class__(
app_label=self.app_label,
name=self.name,
fields=dict(self.fields),
# Since options are shallow-copied here, operations such as
# AddIndex must replace their option (e.g 'indexes') rather
# than mutating it.
options=dict(self.options),
bases=self.bases,
managers=list(self.managers),
)
def render(self, apps):
"""Create a Model object from our current state into the given apps."""
# First, make a Meta object
meta_contents = {'app_label': self.app_label, 'apps': apps, **self.options}
meta = type("Meta", (), meta_contents)
# Then, work out our bases
try:
bases = tuple(
(apps.get_model(base) if isinstance(base, str) else base)
for base in self.bases
)
except LookupError:
raise InvalidBasesError("Cannot resolve one or more bases from %r" % (self.bases,))
# Clone fields for the body, add other bits.
body = {name: field.clone() for name, field in self.fields.items()}
body['Meta'] = meta
body['__module__'] = "__fake__"
# Restore managers
body.update(self.construct_managers())
# Then, make a Model object (apps.register_model is called in __new__)
return type(self.name, bases, body)
def get_index_by_name(self, name):
for index in self.options['indexes']:
if index.name == name:
return index
raise ValueError("No index named %s on model %s" % (name, self.name))
def get_constraint_by_name(self, name):
for constraint in self.options['constraints']:
if constraint.name == name:
return constraint
raise ValueError('No constraint named %s on model %s' % (name, self.name))
def __repr__(self):
return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name)
def __eq__(self, other):
return (
(self.app_label == other.app_label) and
(self.name == other.name) and
(len(self.fields) == len(other.fields)) and
all(
k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:]
for (k1, f1), (k2, f2) in zip(
sorted(self.fields.items()),
sorted(other.fields.items()),
)
) and
(self.options == other.options) and
(self.bases == other.bases) and
(self.managers == other.managers)
)

View File

@@ -0,0 +1,118 @@
import datetime
import re
from collections import namedtuple
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
FieldReference = namedtuple('FieldReference', 'to through')
COMPILED_REGEX_TYPE = type(re.compile(''))
class RegexObject:
def __init__(self, obj):
self.pattern = obj.pattern
self.flags = obj.flags
def __eq__(self, other):
return self.pattern == other.pattern and self.flags == other.flags
def get_migration_name_timestamp():
return datetime.datetime.now().strftime("%Y%m%d_%H%M")
def resolve_relation(model, app_label=None, model_name=None):
"""
Turn a model class or model reference string and return a model tuple.
app_label and model_name are used to resolve the scope of recursive and
unscoped model relationship.
"""
if isinstance(model, str):
if model == RECURSIVE_RELATIONSHIP_CONSTANT:
if app_label is None or model_name is None:
raise TypeError(
'app_label and model_name must be provided to resolve '
'recursive relationships.'
)
return app_label, model_name
if '.' in model:
app_label, model_name = model.split('.', 1)
return app_label, model_name.lower()
if app_label is None:
raise TypeError(
'app_label must be provided to resolve unscoped model '
'relationships.'
)
return app_label, model.lower()
return model._meta.app_label, model._meta.model_name
def field_references(
model_tuple,
field,
reference_model_tuple,
reference_field_name=None,
reference_field=None,
):
"""
Return either False or a FieldReference if `field` references provided
context.
False positives can be returned if `reference_field_name` is provided
without `reference_field` because of the introspection limitation it
incurs. This should not be an issue when this function is used to determine
whether or not an optimization can take place.
"""
remote_field = field.remote_field
if not remote_field:
return False
references_to = None
references_through = None
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
to_fields = getattr(field, 'to_fields', None)
if (
reference_field_name is None or
# Unspecified to_field(s).
to_fields is None or
# Reference to primary key.
(None in to_fields and (reference_field is None or reference_field.primary_key)) or
# Reference to field.
reference_field_name in to_fields
):
references_to = (remote_field, to_fields)
through = getattr(remote_field, 'through', None)
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
through_fields = remote_field.through_fields
if (
reference_field_name is None or
# Unspecified through_fields.
through_fields is None or
# Reference to field.
reference_field_name in through_fields
):
references_through = (remote_field, through_fields)
if not (references_to or references_through):
return False
return FieldReference(references_to, references_through)
def get_references(state, model_tuple, field_tuple=()):
"""
Generator of (model_state, name, field, reference) referencing
provided context.
If field_tuple is provided only references to this particular field of
model_tuple will be generated.
"""
for state_model_tuple, model_state in state.models.items():
for name, field in model_state.fields.items():
reference = field_references(state_model_tuple, field, model_tuple, *field_tuple)
if reference:
yield model_state, name, field, reference
def field_is_referenced(state, model_tuple, field_tuple):
"""Return whether `field_tuple` is referenced by any state models."""
return next(get_references(state, model_tuple, field_tuple), None) is not None

View File

@@ -0,0 +1,300 @@
import os
import re
from importlib import import_module
from django import get_version
from django.apps import apps
# SettingsReference imported for backwards compatibility in Django 2.2.
from django.conf import SettingsReference # NOQA
from django.db import migrations
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.serializer import Serializer, serializer_factory
from django.utils.inspect import get_func_args
from django.utils.module_loading import module_dir
from django.utils.timezone import now
class OperationWriter:
def __init__(self, operation, indentation=2):
self.operation = operation
self.buff = []
self.indentation = indentation
def serialize(self):
def _write(_arg_name, _arg_value):
if (_arg_name in self.operation.serialization_expand_args and
isinstance(_arg_value, (list, tuple, dict))):
if isinstance(_arg_value, dict):
self.feed('%s={' % _arg_name)
self.indent()
for key, value in _arg_value.items():
key_string, key_imports = MigrationWriter.serialize(key)
arg_string, arg_imports = MigrationWriter.serialize(value)
args = arg_string.splitlines()
if len(args) > 1:
self.feed('%s: %s' % (key_string, args[0]))
for arg in args[1:-1]:
self.feed(arg)
self.feed('%s,' % args[-1])
else:
self.feed('%s: %s,' % (key_string, arg_string))
imports.update(key_imports)
imports.update(arg_imports)
self.unindent()
self.feed('},')
else:
self.feed('%s=[' % _arg_name)
self.indent()
for item in _arg_value:
arg_string, arg_imports = MigrationWriter.serialize(item)
args = arg_string.splitlines()
if len(args) > 1:
for arg in args[:-1]:
self.feed(arg)
self.feed('%s,' % args[-1])
else:
self.feed('%s,' % arg_string)
imports.update(arg_imports)
self.unindent()
self.feed('],')
else:
arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
args = arg_string.splitlines()
if len(args) > 1:
self.feed('%s=%s' % (_arg_name, args[0]))
for arg in args[1:-1]:
self.feed(arg)
self.feed('%s,' % args[-1])
else:
self.feed('%s=%s,' % (_arg_name, arg_string))
imports.update(arg_imports)
imports = set()
name, args, kwargs = self.operation.deconstruct()
operation_args = get_func_args(self.operation.__init__)
# See if this operation is in django.db.migrations. If it is,
# We can just use the fact we already have that imported,
# otherwise, we need to add an import for the operation class.
if getattr(migrations, name, None) == self.operation.__class__:
self.feed('migrations.%s(' % name)
else:
imports.add('import %s' % (self.operation.__class__.__module__))
self.feed('%s.%s(' % (self.operation.__class__.__module__, name))
self.indent()
for i, arg in enumerate(args):
arg_value = arg
arg_name = operation_args[i]
_write(arg_name, arg_value)
i = len(args)
# Only iterate over remaining arguments
for arg_name in operation_args[i:]:
if arg_name in kwargs: # Don't sort to maintain signature order
arg_value = kwargs[arg_name]
_write(arg_name, arg_value)
self.unindent()
self.feed('),')
return self.render(), imports
def indent(self):
self.indentation += 1
def unindent(self):
self.indentation -= 1
def feed(self, line):
self.buff.append(' ' * (self.indentation * 4) + line)
def render(self):
return '\n'.join(self.buff)
class MigrationWriter:
"""
Take a Migration instance and is able to produce the contents
of the migration file from it.
"""
def __init__(self, migration, include_header=True):
self.migration = migration
self.include_header = include_header
self.needs_manual_porting = False
def as_string(self):
"""Return a string of the file contents."""
items = {
"replaces_str": "",
"initial_str": "",
}
imports = set()
# Deconstruct operations
operations = []
for operation in self.migration.operations:
operation_string, operation_imports = OperationWriter(operation).serialize()
imports.update(operation_imports)
operations.append(operation_string)
items["operations"] = "\n".join(operations) + "\n" if operations else ""
# Format dependencies and write out swappable dependencies right
dependencies = []
for dependency in self.migration.dependencies:
if dependency[0] == "__setting__":
dependencies.append(" migrations.swappable_dependency(settings.%s)," % dependency[1])
imports.add("from django.conf import settings")
else:
dependencies.append(" %s," % self.serialize(dependency)[0])
items["dependencies"] = "\n".join(dependencies) + "\n" if dependencies else ""
# Format imports nicely, swapping imports of functions from migration files
# for comments
migration_imports = set()
for line in list(imports):
if re.match(r"^import (.*)\.\d+[^\s]*$", line):
migration_imports.add(line.split("import")[1].strip())
imports.remove(line)
self.needs_manual_porting = True
# django.db.migrations is always used, but models import may not be.
# If models import exists, merge it with migrations import.
if "from django.db import models" in imports:
imports.discard("from django.db import models")
imports.add("from django.db import migrations, models")
else:
imports.add("from django.db import migrations")
# Sort imports by the package / module to be imported (the part after
# "from" in "from ... import ..." or after "import" in "import ...").
sorted_imports = sorted(imports, key=lambda i: i.split()[1])
items["imports"] = "\n".join(sorted_imports) + "\n" if imports else ""
if migration_imports:
items["imports"] += (
"\n\n# Functions from the following migrations need manual "
"copying.\n# Move them and any dependencies into this file, "
"then update the\n# RunPython operations to refer to the local "
"versions:\n# %s"
) % "\n# ".join(sorted(migration_imports))
# If there's a replaces, make a string for it
if self.migration.replaces:
items['replaces_str'] = "\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
# Hinting that goes into comment
if self.include_header:
items['migration_header'] = MIGRATION_HEADER_TEMPLATE % {
'version': get_version(),
'timestamp': now().strftime("%Y-%m-%d %H:%M"),
}
else:
items['migration_header'] = ""
if self.migration.initial:
items['initial_str'] = "\n initial = True\n"
return MIGRATION_TEMPLATE % items
@property
def basedir(self):
migrations_package_name, _ = MigrationLoader.migrations_module(self.migration.app_label)
if migrations_package_name is None:
raise ValueError(
"Django can't create migrations for app '%s' because "
"migrations have been disabled via the MIGRATION_MODULES "
"setting." % self.migration.app_label
)
# See if we can import the migrations module directly
try:
migrations_module = import_module(migrations_package_name)
except ImportError:
pass
else:
try:
return module_dir(migrations_module)
except ValueError:
pass
# Alright, see if it's a direct submodule of the app
app_config = apps.get_app_config(self.migration.app_label)
maybe_app_name, _, migrations_package_basename = migrations_package_name.rpartition(".")
if app_config.name == maybe_app_name:
return os.path.join(app_config.path, migrations_package_basename)
# In case of using MIGRATION_MODULES setting and the custom package
# doesn't exist, create one, starting from an existing package
existing_dirs, missing_dirs = migrations_package_name.split("."), []
while existing_dirs:
missing_dirs.insert(0, existing_dirs.pop(-1))
try:
base_module = import_module(".".join(existing_dirs))
except (ImportError, ValueError):
continue
else:
try:
base_dir = module_dir(base_module)
except ValueError:
continue
else:
break
else:
raise ValueError(
"Could not locate an appropriate location to create "
"migrations package %s. Make sure the toplevel "
"package exists and can be imported." %
migrations_package_name)
final_dir = os.path.join(base_dir, *missing_dirs)
os.makedirs(final_dir, exist_ok=True)
for missing_dir in missing_dirs:
base_dir = os.path.join(base_dir, missing_dir)
with open(os.path.join(base_dir, "__init__.py"), "w"):
pass
return final_dir
@property
def filename(self):
return "%s.py" % self.migration.name
@property
def path(self):
return os.path.join(self.basedir, self.filename)
@classmethod
def serialize(cls, value):
return serializer_factory(value).serialize()
@classmethod
def register_serializer(cls, type_, serializer):
Serializer.register(type_, serializer)
@classmethod
def unregister_serializer(cls, type_):
Serializer.unregister(type_)
MIGRATION_HEADER_TEMPLATE = """\
# Generated by Django %(version)s on %(timestamp)s
"""
MIGRATION_TEMPLATE = """\
%(migration_header)s%(imports)s
class Migration(migrations.Migration):
%(replaces_str)s%(initial_str)s
dependencies = [
%(dependencies)s\
]
operations = [
%(operations)s\
]
"""

View File

@@ -0,0 +1,52 @@
from django.core.exceptions import ObjectDoesNotExist
from django.db.models import signals
from django.db.models.aggregates import * # NOQA
from django.db.models.aggregates import __all__ as aggregates_all
from django.db.models.constraints import * # NOQA
from django.db.models.constraints import __all__ as constraints_all
from django.db.models.deletion import (
CASCADE, DO_NOTHING, PROTECT, RESTRICT, SET, SET_DEFAULT, SET_NULL,
ProtectedError, RestrictedError,
)
from django.db.models.enums import * # NOQA
from django.db.models.enums import __all__ as enums_all
from django.db.models.expressions import (
Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func,
OrderBy, OuterRef, RowRange, Subquery, Value, ValueRange, When, Window,
WindowFrame,
)
from django.db.models.fields import * # NOQA
from django.db.models.fields import __all__ as fields_all
from django.db.models.fields.files import FileField, ImageField
from django.db.models.fields.json import JSONField
from django.db.models.fields.proxy import OrderWrt
from django.db.models.indexes import * # NOQA
from django.db.models.indexes import __all__ as indexes_all
from django.db.models.lookups import Lookup, Transform
from django.db.models.manager import Manager
from django.db.models.query import Prefetch, QuerySet, prefetch_related_objects
from django.db.models.query_utils import FilteredRelation, Q
# Imports that would create circular imports if sorted
from django.db.models.base import DEFERRED, Model # isort:skip
from django.db.models.fields.related import ( # isort:skip
ForeignKey, ForeignObject, OneToOneField, ManyToManyField,
ForeignObjectRel, ManyToOneRel, ManyToManyRel, OneToOneRel,
)
__all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all
__all__ += [
'ObjectDoesNotExist', 'signals',
'CASCADE', 'DO_NOTHING', 'PROTECT', 'RESTRICT', 'SET', 'SET_DEFAULT',
'SET_NULL', 'ProtectedError', 'RestrictedError',
'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F',
'Func', 'OrderBy', 'OuterRef', 'RowRange', 'Subquery', 'Value',
'ValueRange', 'When',
'Window', 'WindowFrame',
'FileField', 'ImageField', 'JSONField', 'OrderWrt', 'Lookup', 'Transform',
'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects',
'DEFERRED', 'Model', 'FilteredRelation',
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
'ForeignObjectRel', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel',
]

View File

@@ -0,0 +1,165 @@
"""
Classes to represent the definitions of aggregate functions.
"""
from django.core.exceptions import FieldError
from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.mixins import (
FixDurationInputMixin, NumericOutputFieldMixin,
)
__all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
]
class Aggregate(Func):
template = '%(function)s(%(distinct)s%(expressions)s)'
contains_aggregate = True
name = None
filter_template = '%s FILTER (WHERE %%(filter)s)'
window_compatible = True
allow_distinct = False
empty_result_set_value = None
def __init__(self, *expressions, distinct=False, filter=None, default=None, **extra):
if distinct and not self.allow_distinct:
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
if default is not None and self.empty_result_set_value is not None:
raise TypeError(f'{self.__class__.__name__} does not allow default.')
self.distinct = distinct
self.filter = filter
self.default = default
super().__init__(*expressions, **extra)
def get_source_fields(self):
# Don't return the filter expression since it's not a source field.
return [e._output_field_or_none for e in super().get_source_expressions()]
def get_source_expressions(self):
source_expressions = super().get_source_expressions()
if self.filter:
return source_expressions + [self.filter]
return source_expressions
def set_source_expressions(self, exprs):
self.filter = self.filter and exprs.pop()
return super().set_source_expressions(exprs)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
# Aggregates are not allowed in UPDATE queries, so ignore for_save
c = super().resolve_expression(query, allow_joins, reuse, summarize)
c.filter = c.filter and c.filter.resolve_expression(query, allow_joins, reuse, summarize)
if not summarize:
# Call Aggregate.get_source_expressions() to avoid
# returning self.filter and including that in this loop.
expressions = super(Aggregate, c).get_source_expressions()
for index, expr in enumerate(expressions):
if expr.contains_aggregate:
before_resolved = self.get_source_expressions()[index]
name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
if (default := c.default) is None:
return c
if hasattr(default, 'resolve_expression'):
default = default.resolve_expression(query, allow_joins, reuse, summarize)
c.default = None # Reset the default argument before wrapping.
return Coalesce(c, default, output_field=c._output_field_or_none)
@property
def default_alias(self):
expressions = self.get_source_expressions()
if len(expressions) == 1 and hasattr(expressions[0], 'name'):
return '%s__%s' % (expressions[0].name, self.name.lower())
raise TypeError("Complex expressions require an alias")
def get_group_by_cols(self, alias=None):
return []
def as_sql(self, compiler, connection, **extra_context):
extra_context['distinct'] = 'DISTINCT ' if self.distinct else ''
if self.filter:
if connection.features.supports_aggregate_filter_clause:
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
template = self.filter_template % extra_context.get('template', self.template)
sql, params = super().as_sql(
compiler, connection, template=template, filter=filter_sql,
**extra_context
)
return sql, params + filter_params
else:
copy = self.copy()
copy.filter = None
source_expressions = copy.get_source_expressions()
condition = When(self.filter, then=source_expressions[0])
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
return super().as_sql(compiler, connection, **extra_context)
def _get_repr_options(self):
options = super()._get_repr_options()
if self.distinct:
options['distinct'] = self.distinct
if self.filter:
options['filter'] = self.filter
return options
class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
function = 'AVG'
name = 'Avg'
allow_distinct = True
class Count(Aggregate):
function = 'COUNT'
name = 'Count'
output_field = IntegerField()
allow_distinct = True
empty_result_set_value = 0
def __init__(self, expression, filter=None, **extra):
if expression == '*':
expression = Star()
if isinstance(expression, Star) and filter is not None:
raise ValueError('Star cannot be used with filter. Please specify a field.')
super().__init__(expression, filter=filter, **extra)
class Max(Aggregate):
function = 'MAX'
name = 'Max'
class Min(Aggregate):
function = 'MIN'
name = 'Min'
class StdDev(NumericOutputFieldMixin, Aggregate):
name = 'StdDev'
def __init__(self, expression, sample=False, **extra):
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
super().__init__(expression, **extra)
def _get_repr_options(self):
return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'}
class Sum(FixDurationInputMixin, Aggregate):
function = 'SUM'
name = 'Sum'
allow_distinct = True
class Variance(NumericOutputFieldMixin, Aggregate):
name = 'Variance'
def __init__(self, expression, sample=False, **extra):
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
super().__init__(expression, **extra)
def _get_repr_options(self):
return {**super()._get_repr_options(), 'sample': self.function == 'VAR_SAMP'}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
"""
Constants used across the ORM in general.
"""
# Separator used to split filter strings apart.
LOOKUP_SEP = '__'

View File

@@ -0,0 +1,255 @@
from enum import Enum
from django.db.models.expressions import ExpressionList, F
from django.db.models.indexes import IndexExpression
from django.db.models.query_utils import Q
from django.db.models.sql.query import Query
__all__ = ['CheckConstraint', 'Deferrable', 'UniqueConstraint']
class BaseConstraint:
def __init__(self, name):
self.name = name
@property
def contains_expressions(self):
return False
def constraint_sql(self, model, schema_editor):
raise NotImplementedError('This method must be implemented by a subclass.')
def create_sql(self, model, schema_editor):
raise NotImplementedError('This method must be implemented by a subclass.')
def remove_sql(self, model, schema_editor):
raise NotImplementedError('This method must be implemented by a subclass.')
def deconstruct(self):
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
path = path.replace('django.db.models.constraints', 'django.db.models')
return (path, (), {'name': self.name})
def clone(self):
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)
class CheckConstraint(BaseConstraint):
def __init__(self, *, check, name):
self.check = check
if not getattr(check, 'conditional', False):
raise TypeError(
'CheckConstraint.check must be a Q instance or boolean '
'expression.'
)
super().__init__(name)
def _get_check_sql(self, model, schema_editor):
query = Query(model=model, alias_cols=False)
where = query.build_where(self.check)
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def constraint_sql(self, model, schema_editor):
check = self._get_check_sql(model, schema_editor)
return schema_editor._check_sql(self.name, check)
def create_sql(self, model, schema_editor):
check = self._get_check_sql(model, schema_editor)
return schema_editor._create_check_sql(model, self.name, check)
def remove_sql(self, model, schema_editor):
return schema_editor._delete_check_sql(model, self.name)
def __repr__(self):
return '<%s: check=%s name=%s>' % (
self.__class__.__qualname__,
self.check,
repr(self.name),
)
def __eq__(self, other):
if isinstance(other, CheckConstraint):
return self.name == other.name and self.check == other.check
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
kwargs['check'] = self.check
return path, args, kwargs
class Deferrable(Enum):
DEFERRED = 'deferred'
IMMEDIATE = 'immediate'
# A similar format was proposed for Python 3.10.
def __repr__(self):
return f'{self.__class__.__qualname__}.{self._name_}'
class UniqueConstraint(BaseConstraint):
def __init__(
self,
*expressions,
fields=(),
name=None,
condition=None,
deferrable=None,
include=None,
opclasses=(),
):
if not name:
raise ValueError('A unique constraint must be named.')
if not expressions and not fields:
raise ValueError(
'At least one field or expression is required to define a '
'unique constraint.'
)
if expressions and fields:
raise ValueError(
'UniqueConstraint.fields and expressions are mutually exclusive.'
)
if not isinstance(condition, (type(None), Q)):
raise ValueError('UniqueConstraint.condition must be a Q instance.')
if condition and deferrable:
raise ValueError(
'UniqueConstraint with conditions cannot be deferred.'
)
if include and deferrable:
raise ValueError(
'UniqueConstraint with include fields cannot be deferred.'
)
if opclasses and deferrable:
raise ValueError(
'UniqueConstraint with opclasses cannot be deferred.'
)
if expressions and deferrable:
raise ValueError(
'UniqueConstraint with expressions cannot be deferred.'
)
if expressions and opclasses:
raise ValueError(
'UniqueConstraint.opclasses cannot be used with expressions. '
'Use django.contrib.postgres.indexes.OpClass() instead.'
)
if not isinstance(deferrable, (type(None), Deferrable)):
raise ValueError(
'UniqueConstraint.deferrable must be a Deferrable instance.'
)
if not isinstance(include, (type(None), list, tuple)):
raise ValueError('UniqueConstraint.include must be a list or tuple.')
if not isinstance(opclasses, (list, tuple)):
raise ValueError('UniqueConstraint.opclasses must be a list or tuple.')
if opclasses and len(fields) != len(opclasses):
raise ValueError(
'UniqueConstraint.fields and UniqueConstraint.opclasses must '
'have the same number of elements.'
)
self.fields = tuple(fields)
self.condition = condition
self.deferrable = deferrable
self.include = tuple(include) if include else ()
self.opclasses = opclasses
self.expressions = tuple(
F(expression) if isinstance(expression, str) else expression
for expression in expressions
)
super().__init__(name)
@property
def contains_expressions(self):
return bool(self.expressions)
def _get_condition_sql(self, model, schema_editor):
if self.condition is None:
return None
query = Query(model=model, alias_cols=False)
where = query.build_where(self.condition)
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def _get_index_expressions(self, model, schema_editor):
if not self.expressions:
return None
index_expressions = []
for expression in self.expressions:
index_expression = IndexExpression(expression)
index_expression.set_wrapper_classes(schema_editor.connection)
index_expressions.append(index_expression)
return ExpressionList(*index_expressions).resolve_expression(
Query(model, alias_cols=False),
)
def constraint_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
include = [model._meta.get_field(field_name).column for field_name in self.include]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._unique_sql(
model, fields, self.name, condition=condition,
deferrable=self.deferrable, include=include,
opclasses=self.opclasses, expressions=expressions,
)
def create_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
include = [model._meta.get_field(field_name).column for field_name in self.include]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._create_unique_sql(
model, fields, self.name, condition=condition,
deferrable=self.deferrable, include=include,
opclasses=self.opclasses, expressions=expressions,
)
def remove_sql(self, model, schema_editor):
condition = self._get_condition_sql(model, schema_editor)
include = [model._meta.get_field(field_name).column for field_name in self.include]
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._delete_unique_sql(
model, self.name, condition=condition, deferrable=self.deferrable,
include=include, opclasses=self.opclasses, expressions=expressions,
)
def __repr__(self):
return '<%s:%s%s%s%s%s%s%s>' % (
self.__class__.__qualname__,
'' if not self.fields else ' fields=%s' % repr(self.fields),
'' if not self.expressions else ' expressions=%s' % repr(self.expressions),
' name=%s' % repr(self.name),
'' if self.condition is None else ' condition=%s' % self.condition,
'' if self.deferrable is None else ' deferrable=%r' % self.deferrable,
'' if not self.include else ' include=%s' % repr(self.include),
'' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses),
)
def __eq__(self, other):
if isinstance(other, UniqueConstraint):
return (
self.name == other.name and
self.fields == other.fields and
self.condition == other.condition and
self.deferrable == other.deferrable and
self.include == other.include and
self.opclasses == other.opclasses and
self.expressions == other.expressions
)
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.fields:
kwargs['fields'] = self.fields
if self.condition:
kwargs['condition'] = self.condition
if self.deferrable:
kwargs['deferrable'] = self.deferrable
if self.include:
kwargs['include'] = self.include
if self.opclasses:
kwargs['opclasses'] = self.opclasses
return path, self.expressions, kwargs

View File

@@ -0,0 +1,449 @@
from collections import Counter, defaultdict
from functools import partial
from itertools import chain
from operator import attrgetter
from django.db import IntegrityError, connections, transaction
from django.db.models import query_utils, signals, sql
class ProtectedError(IntegrityError):
def __init__(self, msg, protected_objects):
self.protected_objects = protected_objects
super().__init__(msg, protected_objects)
class RestrictedError(IntegrityError):
def __init__(self, msg, restricted_objects):
self.restricted_objects = restricted_objects
super().__init__(msg, restricted_objects)
def CASCADE(collector, field, sub_objs, using):
collector.collect(
sub_objs, source=field.remote_field.model, source_attr=field.name,
nullable=field.null, fail_on_restricted=False,
)
if field.null and not connections[using].features.can_defer_constraint_checks:
collector.add_field_update(field, None, sub_objs)
def PROTECT(collector, field, sub_objs, using):
raise ProtectedError(
"Cannot delete some instances of model '%s' because they are "
"referenced through a protected foreign key: '%s.%s'" % (
field.remote_field.model.__name__, sub_objs[0].__class__.__name__, field.name
),
sub_objs
)
def RESTRICT(collector, field, sub_objs, using):
collector.add_restricted_objects(field, sub_objs)
collector.add_dependency(field.remote_field.model, field.model)
def SET(value):
if callable(value):
def set_on_delete(collector, field, sub_objs, using):
collector.add_field_update(field, value(), sub_objs)
else:
def set_on_delete(collector, field, sub_objs, using):
collector.add_field_update(field, value, sub_objs)
set_on_delete.deconstruct = lambda: ('django.db.models.SET', (value,), {})
return set_on_delete
def SET_NULL(collector, field, sub_objs, using):
collector.add_field_update(field, None, sub_objs)
def SET_DEFAULT(collector, field, sub_objs, using):
collector.add_field_update(field, field.get_default(), sub_objs)
def DO_NOTHING(collector, field, sub_objs, using):
pass
def get_candidate_relations_to_delete(opts):
# The candidate relations are the ones that come from N-1 and 1-1 relations.
# N-N (i.e., many-to-many) relations aren't candidates for deletion.
return (
f for f in opts.get_fields(include_hidden=True)
if f.auto_created and not f.concrete and (f.one_to_one or f.one_to_many)
)
class Collector:
def __init__(self, using):
self.using = using
# Initially, {model: {instances}}, later values become lists.
self.data = defaultdict(set)
# {model: {(field, value): {instances}}}
self.field_updates = defaultdict(partial(defaultdict, set))
# {model: {field: {instances}}}
self.restricted_objects = defaultdict(partial(defaultdict, set))
# fast_deletes is a list of queryset-likes that can be deleted without
# fetching the objects into memory.
self.fast_deletes = []
# Tracks deletion-order dependency for databases without transactions
# or ability to defer constraint checks. Only concrete model classes
# should be included, as the dependencies exist only between actual
# database tables; proxy models are represented here by their concrete
# parent.
self.dependencies = defaultdict(set) # {model: {models}}
def add(self, objs, source=None, nullable=False, reverse_dependency=False):
"""
Add 'objs' to the collection of objects to be deleted. If the call is
the result of a cascade, 'source' should be the model that caused it,
and 'nullable' should be set to True if the relation can be null.
Return a list of all objects that were not already collected.
"""
if not objs:
return []
new_objs = []
model = objs[0].__class__
instances = self.data[model]
for obj in objs:
if obj not in instances:
new_objs.append(obj)
instances.update(new_objs)
# Nullable relationships can be ignored -- they are nulled out before
# deleting, and therefore do not affect the order in which objects have
# to be deleted.
if source is not None and not nullable:
self.add_dependency(source, model, reverse_dependency=reverse_dependency)
return new_objs
def add_dependency(self, model, dependency, reverse_dependency=False):
if reverse_dependency:
model, dependency = dependency, model
self.dependencies[model._meta.concrete_model].add(dependency._meta.concrete_model)
self.data.setdefault(dependency, self.data.default_factory())
def add_field_update(self, field, value, objs):
"""
Schedule a field update. 'objs' must be a homogeneous iterable
collection of model instances (e.g. a QuerySet).
"""
if not objs:
return
model = objs[0].__class__
self.field_updates[model][field, value].update(objs)
def add_restricted_objects(self, field, objs):
if objs:
model = objs[0].__class__
self.restricted_objects[model][field].update(objs)
def clear_restricted_objects_from_set(self, model, objs):
if model in self.restricted_objects:
self.restricted_objects[model] = {
field: items - objs
for field, items in self.restricted_objects[model].items()
}
def clear_restricted_objects_from_queryset(self, model, qs):
if model in self.restricted_objects:
objs = set(qs.filter(pk__in=[
obj.pk
for objs in self.restricted_objects[model].values() for obj in objs
]))
self.clear_restricted_objects_from_set(model, objs)
def _has_signal_listeners(self, model):
return (
signals.pre_delete.has_listeners(model) or
signals.post_delete.has_listeners(model)
)
def can_fast_delete(self, objs, from_field=None):
"""
Determine if the objects in the given queryset-like or single object
can be fast-deleted. This can be done if there are no cascades, no
parents and no signal listeners for the object class.
The 'from_field' tells where we are coming from - we need this to
determine if the objects are in fact to be deleted. Allow also
skipping parent -> child -> parent chain preventing fast delete of
the child.
"""
if from_field and from_field.remote_field.on_delete is not CASCADE:
return False
if hasattr(objs, '_meta'):
model = objs._meta.model
elif hasattr(objs, 'model') and hasattr(objs, '_raw_delete'):
model = objs.model
else:
return False
if self._has_signal_listeners(model):
return False
# The use of from_field comes from the need to avoid cascade back to
# parent when parent delete is cascading to child.
opts = model._meta
return (
all(link == from_field for link in opts.concrete_model._meta.parents.values()) and
# Foreign keys pointing to this model.
all(
related.field.remote_field.on_delete is DO_NOTHING
for related in get_candidate_relations_to_delete(opts)
) and (
# Something like generic foreign key.
not any(hasattr(field, 'bulk_related_objects') for field in opts.private_fields)
)
)
def get_del_batches(self, objs, fields):
"""
Return the objs in suitably sized batches for the used connection.
"""
field_names = [field.name for field in fields]
conn_batch_size = max(
connections[self.using].ops.bulk_batch_size(field_names, objs), 1)
if len(objs) > conn_batch_size:
return [objs[i:i + conn_batch_size]
for i in range(0, len(objs), conn_batch_size)]
else:
return [objs]
def collect(self, objs, source=None, nullable=False, collect_related=True,
source_attr=None, reverse_dependency=False, keep_parents=False,
fail_on_restricted=True):
"""
Add 'objs' to the collection of objects to be deleted as well as all
parent instances. 'objs' must be a homogeneous iterable collection of
model instances (e.g. a QuerySet). If 'collect_related' is True,
related objects will be handled by their respective on_delete handler.
If the call is the result of a cascade, 'source' should be the model
that caused it and 'nullable' should be set to True, if the relation
can be null.
If 'reverse_dependency' is True, 'source' will be deleted before the
current model, rather than after. (Needed for cascading to parent
models, the one case in which the cascade follows the forwards
direction of an FK rather than the reverse direction.)
If 'keep_parents' is True, data of parent model's will be not deleted.
If 'fail_on_restricted' is False, error won't be raised even if it's
prohibited to delete such objects due to RESTRICT, that defers
restricted object checking in recursive calls where the top-level call
may need to collect more objects to determine whether restricted ones
can be deleted.
"""
if self.can_fast_delete(objs):
self.fast_deletes.append(objs)
return
new_objs = self.add(objs, source, nullable,
reverse_dependency=reverse_dependency)
if not new_objs:
return
model = new_objs[0].__class__
if not keep_parents:
# Recursively collect concrete model's parent models, but not their
# related objects. These will be found by meta.get_fields()
concrete_model = model._meta.concrete_model
for ptr in concrete_model._meta.parents.values():
if ptr:
parent_objs = [getattr(obj, ptr.name) for obj in new_objs]
self.collect(parent_objs, source=model,
source_attr=ptr.remote_field.related_name,
collect_related=False,
reverse_dependency=True,
fail_on_restricted=False)
if not collect_related:
return
if keep_parents:
parents = set(model._meta.get_parent_list())
model_fast_deletes = defaultdict(list)
protected_objects = defaultdict(list)
for related in get_candidate_relations_to_delete(model._meta):
# Preserve parent reverse relationships if keep_parents=True.
if keep_parents and related.model in parents:
continue
field = related.field
if field.remote_field.on_delete == DO_NOTHING:
continue
related_model = related.related_model
if self.can_fast_delete(related_model, from_field=field):
model_fast_deletes[related_model].append(field)
continue
batches = self.get_del_batches(new_objs, [field])
for batch in batches:
sub_objs = self.related_objects(related_model, [field], batch)
# Non-referenced fields can be deferred if no signal receivers
# are connected for the related model as they'll never be
# exposed to the user. Skip field deferring when some
# relationships are select_related as interactions between both
# features are hard to get right. This should only happen in
# the rare cases where .related_objects is overridden anyway.
if not (sub_objs.query.select_related or self._has_signal_listeners(related_model)):
referenced_fields = set(chain.from_iterable(
(rf.attname for rf in rel.field.foreign_related_fields)
for rel in get_candidate_relations_to_delete(related_model._meta)
))
sub_objs = sub_objs.only(*tuple(referenced_fields))
if sub_objs:
try:
field.remote_field.on_delete(self, field, sub_objs, self.using)
except ProtectedError as error:
key = "'%s.%s'" % (field.model.__name__, field.name)
protected_objects[key] += error.protected_objects
if protected_objects:
raise ProtectedError(
'Cannot delete some instances of model %r because they are '
'referenced through protected foreign keys: %s.' % (
model.__name__,
', '.join(protected_objects),
),
set(chain.from_iterable(protected_objects.values())),
)
for related_model, related_fields in model_fast_deletes.items():
batches = self.get_del_batches(new_objs, related_fields)
for batch in batches:
sub_objs = self.related_objects(related_model, related_fields, batch)
self.fast_deletes.append(sub_objs)
for field in model._meta.private_fields:
if hasattr(field, 'bulk_related_objects'):
# It's something like generic foreign key.
sub_objs = field.bulk_related_objects(new_objs, self.using)
self.collect(sub_objs, source=model, nullable=True, fail_on_restricted=False)
if fail_on_restricted:
# Raise an error if collected restricted objects (RESTRICT) aren't
# candidates for deletion also collected via CASCADE.
for related_model, instances in self.data.items():
self.clear_restricted_objects_from_set(related_model, instances)
for qs in self.fast_deletes:
self.clear_restricted_objects_from_queryset(qs.model, qs)
if self.restricted_objects.values():
restricted_objects = defaultdict(list)
for related_model, fields in self.restricted_objects.items():
for field, objs in fields.items():
if objs:
key = "'%s.%s'" % (related_model.__name__, field.name)
restricted_objects[key] += objs
if restricted_objects:
raise RestrictedError(
'Cannot delete some instances of model %r because '
'they are referenced through restricted foreign keys: '
'%s.' % (
model.__name__,
', '.join(restricted_objects),
),
set(chain.from_iterable(restricted_objects.values())),
)
def related_objects(self, related_model, related_fields, objs):
"""
Get a QuerySet of the related model to objs via related fields.
"""
predicate = query_utils.Q(
*(
(f'{related_field.name}__in', objs)
for related_field in related_fields
),
_connector=query_utils.Q.OR,
)
return related_model._base_manager.using(self.using).filter(predicate)
def instances_with_model(self):
for model, instances in self.data.items():
for obj in instances:
yield model, obj
def sort(self):
sorted_models = []
concrete_models = set()
models = list(self.data)
while len(sorted_models) < len(models):
found = False
for model in models:
if model in sorted_models:
continue
dependencies = self.dependencies.get(model._meta.concrete_model)
if not (dependencies and dependencies.difference(concrete_models)):
sorted_models.append(model)
concrete_models.add(model._meta.concrete_model)
found = True
if not found:
return
self.data = {model: self.data[model] for model in sorted_models}
def delete(self):
# sort instance collections
for model, instances in self.data.items():
self.data[model] = sorted(instances, key=attrgetter("pk"))
# if possible, bring the models in an order suitable for databases that
# don't support transactions or cannot defer constraint checks until the
# end of a transaction.
self.sort()
# number of objects deleted for each model label
deleted_counter = Counter()
# Optimize for the case with a single obj and no dependencies
if len(self.data) == 1 and len(instances) == 1:
instance = list(instances)[0]
if self.can_fast_delete(instance):
with transaction.mark_for_rollback_on_error(self.using):
count = sql.DeleteQuery(model).delete_batch([instance.pk], self.using)
setattr(instance, model._meta.pk.attname, None)
return count, {model._meta.label: count}
with transaction.atomic(using=self.using, savepoint=False):
# send pre_delete signals
for model, obj in self.instances_with_model():
if not model._meta.auto_created:
signals.pre_delete.send(
sender=model, instance=obj, using=self.using
)
# fast deletes
for qs in self.fast_deletes:
count = qs._raw_delete(using=self.using)
if count:
deleted_counter[qs.model._meta.label] += count
# update fields
for model, instances_for_fieldvalues in self.field_updates.items():
for (field, value), instances in instances_for_fieldvalues.items():
query = sql.UpdateQuery(model)
query.update_batch([obj.pk for obj in instances],
{field.name: value}, self.using)
# reverse instance collections
for instances in self.data.values():
instances.reverse()
# delete instances
for model, instances in self.data.items():
query = sql.DeleteQuery(model)
pk_list = [obj.pk for obj in instances]
count = query.delete_batch(pk_list, self.using)
if count:
deleted_counter[model._meta.label] += count
if not model._meta.auto_created:
for obj in instances:
signals.post_delete.send(
sender=model, instance=obj, using=self.using
)
# update collected instances
for instances_for_fieldvalues in self.field_updates.values():
for (field, value), instances in instances_for_fieldvalues.items():
for obj in instances:
setattr(obj, field.attname, value)
for model, instances in self.data.items():
for instance in instances:
setattr(instance, model._meta.pk.attname, None)
return sum(deleted_counter.values()), dict(deleted_counter)

View File

@@ -0,0 +1,91 @@
import enum
from types import DynamicClassAttribute
from django.utils.functional import Promise
__all__ = ['Choices', 'IntegerChoices', 'TextChoices']
class ChoicesMeta(enum.EnumMeta):
"""A metaclass for creating a enum choices."""
def __new__(metacls, classname, bases, classdict, **kwds):
labels = []
for key in classdict._member_names:
value = classdict[key]
if (
isinstance(value, (list, tuple)) and
len(value) > 1 and
isinstance(value[-1], (Promise, str))
):
*value, label = value
value = tuple(value)
else:
label = key.replace('_', ' ').title()
labels.append(label)
# Use dict.__setitem__() to suppress defenses against double
# assignment in enum's classdict.
dict.__setitem__(classdict, key, value)
cls = super().__new__(metacls, classname, bases, classdict, **kwds)
for member, label in zip(cls.__members__.values(), labels):
member._label_ = label
return enum.unique(cls)
def __contains__(cls, member):
if not isinstance(member, enum.Enum):
# Allow non-enums to match against member values.
return any(x.value == member for x in cls)
return super().__contains__(member)
@property
def names(cls):
empty = ['__empty__'] if hasattr(cls, '__empty__') else []
return empty + [member.name for member in cls]
@property
def choices(cls):
empty = [(None, cls.__empty__)] if hasattr(cls, '__empty__') else []
return empty + [(member.value, member.label) for member in cls]
@property
def labels(cls):
return [label for _, label in cls.choices]
@property
def values(cls):
return [value for value, _ in cls.choices]
class Choices(enum.Enum, metaclass=ChoicesMeta):
"""Class for creating enumerated choices."""
@DynamicClassAttribute
def label(self):
return self._label_
@property
def do_not_call_in_templates(self):
return True
def __str__(self):
"""
Use value when cast to str, so that Choices set as model instance
attributes are rendered as expected in templates and similar contexts.
"""
return str(self.value)
# A similar format was proposed for Python 3.10.
def __repr__(self):
return f'{self.__class__.__qualname__}.{self._name_}'
class IntegerChoices(int, Choices):
"""Class for creating enumerated integer choices."""
pass
class TextChoices(str, Choices):
"""Class for creating enumerated string choices."""
def _generate_next_value_(name, start, count, last_values):
return name

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,481 @@
import datetime
import posixpath
from django import forms
from django.core import checks
from django.core.files.base import File
from django.core.files.images import ImageFile
from django.core.files.storage import Storage, default_storage
from django.core.files.utils import validate_file_name
from django.db.models import signals
from django.db.models.fields import Field
from django.db.models.query_utils import DeferredAttribute
from django.utils.translation import gettext_lazy as _
class FieldFile(File):
def __init__(self, instance, field, name):
super().__init__(None, name)
self.instance = instance
self.field = field
self.storage = field.storage
self._committed = True
def __eq__(self, other):
# Older code may be expecting FileField values to be simple strings.
# By overriding the == operator, it can remain backwards compatibility.
if hasattr(other, 'name'):
return self.name == other.name
return self.name == other
def __hash__(self):
return hash(self.name)
# The standard File contains most of the necessary properties, but
# FieldFiles can be instantiated without a name, so that needs to
# be checked for here.
def _require_file(self):
if not self:
raise ValueError("The '%s' attribute has no file associated with it." % self.field.name)
def _get_file(self):
self._require_file()
if getattr(self, '_file', None) is None:
self._file = self.storage.open(self.name, 'rb')
return self._file
def _set_file(self, file):
self._file = file
def _del_file(self):
del self._file
file = property(_get_file, _set_file, _del_file)
@property
def path(self):
self._require_file()
return self.storage.path(self.name)
@property
def url(self):
self._require_file()
return self.storage.url(self.name)
@property
def size(self):
self._require_file()
if not self._committed:
return self.file.size
return self.storage.size(self.name)
def open(self, mode='rb'):
self._require_file()
if getattr(self, '_file', None) is None:
self.file = self.storage.open(self.name, mode)
else:
self.file.open(mode)
return self
# open() doesn't alter the file's contents, but it does reset the pointer
open.alters_data = True
# In addition to the standard File API, FieldFiles have extra methods
# to further manipulate the underlying file, as well as update the
# associated model instance.
def save(self, name, content, save=True):
name = self.field.generate_filename(self.instance, name)
self.name = self.storage.save(name, content, max_length=self.field.max_length)
setattr(self.instance, self.field.attname, self.name)
self._committed = True
# Save the object because it has changed, unless save is False
if save:
self.instance.save()
save.alters_data = True
def delete(self, save=True):
if not self:
return
# Only close the file if it's already open, which we know by the
# presence of self._file
if hasattr(self, '_file'):
self.close()
del self.file
self.storage.delete(self.name)
self.name = None
setattr(self.instance, self.field.attname, self.name)
self._committed = False
if save:
self.instance.save()
delete.alters_data = True
@property
def closed(self):
file = getattr(self, '_file', None)
return file is None or file.closed
def close(self):
file = getattr(self, '_file', None)
if file is not None:
file.close()
def __getstate__(self):
# FieldFile needs access to its associated model field, an instance and
# the file's name. Everything else will be restored later, by
# FileDescriptor below.
return {
'name': self.name,
'closed': False,
'_committed': True,
'_file': None,
'instance': self.instance,
'field': self.field,
}
def __setstate__(self, state):
self.__dict__.update(state)
self.storage = self.field.storage
class FileDescriptor(DeferredAttribute):
"""
The descriptor for the file attribute on the model instance. Return a
FieldFile when accessed so you can write code like::
>>> from myapp.models import MyModel
>>> instance = MyModel.objects.get(pk=1)
>>> instance.file.size
Assign a file object on assignment so you can do::
>>> with open('/path/to/hello.world') as f:
... instance.file = File(f)
"""
def __get__(self, instance, cls=None):
if instance is None:
return self
# This is slightly complicated, so worth an explanation.
# instance.file`needs to ultimately return some instance of `File`,
# probably a subclass. Additionally, this returned object needs to have
# the FieldFile API so that users can easily do things like
# instance.file.path and have that delegated to the file storage engine.
# Easy enough if we're strict about assignment in __set__, but if you
# peek below you can see that we're not. So depending on the current
# value of the field we have to dynamically construct some sort of
# "thing" to return.
# The instance dict contains whatever was originally assigned
# in __set__.
file = super().__get__(instance, cls)
# If this value is a string (instance.file = "path/to/file") or None
# then we simply wrap it with the appropriate attribute class according
# to the file field. [This is FieldFile for FileFields and
# ImageFieldFile for ImageFields; it's also conceivable that user
# subclasses might also want to subclass the attribute class]. This
# object understands how to convert a path to a file, and also how to
# handle None.
if isinstance(file, str) or file is None:
attr = self.field.attr_class(instance, self.field, file)
instance.__dict__[self.field.attname] = attr
# Other types of files may be assigned as well, but they need to have
# the FieldFile interface added to them. Thus, we wrap any other type of
# File inside a FieldFile (well, the field's attr_class, which is
# usually FieldFile).
elif isinstance(file, File) and not isinstance(file, FieldFile):
file_copy = self.field.attr_class(instance, self.field, file.name)
file_copy.file = file
file_copy._committed = False
instance.__dict__[self.field.attname] = file_copy
# Finally, because of the (some would say boneheaded) way pickle works,
# the underlying FieldFile might not actually itself have an associated
# file. So we need to reset the details of the FieldFile in those cases.
elif isinstance(file, FieldFile) and not hasattr(file, 'field'):
file.instance = instance
file.field = self.field
file.storage = self.field.storage
# Make sure that the instance is correct.
elif isinstance(file, FieldFile) and instance is not file.instance:
file.instance = instance
# That was fun, wasn't it?
return instance.__dict__[self.field.attname]
def __set__(self, instance, value):
instance.__dict__[self.field.attname] = value
class FileField(Field):
# The class to wrap instance attributes in. Accessing the file object off
# the instance will always return an instance of attr_class.
attr_class = FieldFile
# The descriptor to use for accessing the attribute off of the class.
descriptor_class = FileDescriptor
description = _("File")
def __init__(self, verbose_name=None, name=None, upload_to='', storage=None, **kwargs):
self._primary_key_set_explicitly = 'primary_key' in kwargs
self.storage = storage or default_storage
if callable(self.storage):
# Hold a reference to the callable for deconstruct().
self._storage_callable = self.storage
self.storage = self.storage()
if not isinstance(self.storage, Storage):
raise TypeError(
"%s.storage must be a subclass/instance of %s.%s"
% (self.__class__.__qualname__, Storage.__module__, Storage.__qualname__)
)
self.upload_to = upload_to
kwargs.setdefault('max_length', 100)
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
return [
*super().check(**kwargs),
*self._check_primary_key(),
*self._check_upload_to(),
]
def _check_primary_key(self):
if self._primary_key_set_explicitly:
return [
checks.Error(
"'primary_key' is not a valid argument for a %s." % self.__class__.__name__,
obj=self,
id='fields.E201',
)
]
else:
return []
def _check_upload_to(self):
if isinstance(self.upload_to, str) and self.upload_to.startswith('/'):
return [
checks.Error(
"%s's 'upload_to' argument must be a relative path, not an "
"absolute path." % self.__class__.__name__,
obj=self,
id='fields.E202',
hint='Remove the leading slash.',
)
]
else:
return []
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if kwargs.get("max_length") == 100:
del kwargs["max_length"]
kwargs['upload_to'] = self.upload_to
if self.storage is not default_storage:
kwargs['storage'] = getattr(self, '_storage_callable', self.storage)
return name, path, args, kwargs
def get_internal_type(self):
return "FileField"
def get_prep_value(self, value):
value = super().get_prep_value(value)
# Need to convert File objects provided via a form to string for database insertion
if value is None:
return None
return str(value)
def pre_save(self, model_instance, add):
file = super().pre_save(model_instance, add)
if file and not file._committed:
# Commit the file to storage prior to saving the model
file.save(file.name, file.file, save=False)
return file
def contribute_to_class(self, cls, name, **kwargs):
super().contribute_to_class(cls, name, **kwargs)
setattr(cls, self.attname, self.descriptor_class(self))
def generate_filename(self, instance, filename):
"""
Apply (if callable) or prepend (if a string) upload_to to the filename,
then delegate further processing of the name to the storage backend.
Until the storage layer, all file paths are expected to be Unix style
(with forward slashes).
"""
if callable(self.upload_to):
filename = self.upload_to(instance, filename)
else:
dirname = datetime.datetime.now().strftime(str(self.upload_to))
filename = posixpath.join(dirname, filename)
filename = validate_file_name(filename, allow_relative_path=True)
return self.storage.generate_filename(filename)
def save_form_data(self, instance, data):
# Important: None means "no change", other false value means "clear"
# This subtle distinction (rather than a more explicit marker) is
# needed because we need to consume values that are also sane for a
# regular (non Model-) Form to find in its cleaned_data dictionary.
if data is not None:
# This value will be converted to str and stored in the
# database, so leaving False as-is is not acceptable.
setattr(instance, self.name, data or '')
def formfield(self, **kwargs):
return super().formfield(**{
'form_class': forms.FileField,
'max_length': self.max_length,
**kwargs,
})
class ImageFileDescriptor(FileDescriptor):
"""
Just like the FileDescriptor, but for ImageFields. The only difference is
assigning the width/height to the width_field/height_field, if appropriate.
"""
def __set__(self, instance, value):
previous_file = instance.__dict__.get(self.field.attname)
super().__set__(instance, value)
# To prevent recalculating image dimensions when we are instantiating
# an object from the database (bug #11084), only update dimensions if
# the field had a value before this assignment. Since the default
# value for FileField subclasses is an instance of field.attr_class,
# previous_file will only be None when we are called from
# Model.__init__(). The ImageField.update_dimension_fields method
# hooked up to the post_init signal handles the Model.__init__() cases.
# Assignment happening outside of Model.__init__() will trigger the
# update right here.
if previous_file is not None:
self.field.update_dimension_fields(instance, force=True)
class ImageFieldFile(ImageFile, FieldFile):
def delete(self, save=True):
# Clear the image dimensions cache
if hasattr(self, '_dimensions_cache'):
del self._dimensions_cache
super().delete(save)
class ImageField(FileField):
attr_class = ImageFieldFile
descriptor_class = ImageFileDescriptor
description = _("Image")
def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs):
self.width_field, self.height_field = width_field, height_field
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
return [
*super().check(**kwargs),
*self._check_image_library_installed(),
]
def _check_image_library_installed(self):
try:
from PIL import Image # NOQA
except ImportError:
return [
checks.Error(
'Cannot use ImageField because Pillow is not installed.',
hint=('Get Pillow at https://pypi.org/project/Pillow/ '
'or run command "python -m pip install Pillow".'),
obj=self,
id='fields.E210',
)
]
else:
return []
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.width_field:
kwargs['width_field'] = self.width_field
if self.height_field:
kwargs['height_field'] = self.height_field
return name, path, args, kwargs
def contribute_to_class(self, cls, name, **kwargs):
super().contribute_to_class(cls, name, **kwargs)
# Attach update_dimension_fields so that dimension fields declared
# after their corresponding image field don't stay cleared by
# Model.__init__, see bug #11196.
# Only run post-initialization dimension update on non-abstract models
if not cls._meta.abstract:
signals.post_init.connect(self.update_dimension_fields, sender=cls)
def update_dimension_fields(self, instance, force=False, *args, **kwargs):
"""
Update field's width and height fields, if defined.
This method is hooked up to model's post_init signal to update
dimensions after instantiating a model instance. However, dimensions
won't be updated if the dimensions fields are already populated. This
avoids unnecessary recalculation when loading an object from the
database.
Dimensions can be forced to update with force=True, which is how
ImageFileDescriptor.__set__ calls this method.
"""
# Nothing to update if the field doesn't have dimension fields or if
# the field is deferred.
has_dimension_fields = self.width_field or self.height_field
if not has_dimension_fields or self.attname not in instance.__dict__:
return
# getattr will call the ImageFileDescriptor's __get__ method, which
# coerces the assigned value into an instance of self.attr_class
# (ImageFieldFile in this case).
file = getattr(instance, self.attname)
# Nothing to update if we have no file and not being forced to update.
if not file and not force:
return
dimension_fields_filled = not(
(self.width_field and not getattr(instance, self.width_field)) or
(self.height_field and not getattr(instance, self.height_field))
)
# When both dimension fields have values, we are most likely loading
# data from the database or updating an image field that already had
# an image stored. In the first case, we don't want to update the
# dimension fields because we are already getting their values from the
# database. In the second case, we do want to update the dimensions
# fields and will skip this return because force will be True since we
# were called from ImageFileDescriptor.__set__.
if dimension_fields_filled and not force:
return
# file should be an instance of ImageFieldFile or should be None.
if file:
width = file.width
height = file.height
else:
# No file, so clear dimensions fields.
width = None
height = None
# Update the width and height fields.
if self.width_field:
setattr(instance, self.width_field, width)
if self.height_field:
setattr(instance, self.height_field, height)
def formfield(self, **kwargs):
return super().formfield(**{
'form_class': forms.ImageField,
**kwargs,
})

View File

@@ -0,0 +1,538 @@
import json
from django import forms
from django.core import checks, exceptions
from django.db import NotSupportedError, connections, router
from django.db.models import lookups
from django.db.models.lookups import PostgresOperatorLookup, Transform
from django.utils.translation import gettext_lazy as _
from . import Field
from .mixins import CheckFieldDefaultMixin
__all__ = ['JSONField']
class JSONField(CheckFieldDefaultMixin, Field):
empty_strings_allowed = False
description = _('A JSON object')
default_error_messages = {
'invalid': _('Value must be valid JSON.'),
}
_default_hint = ('dict', '{}')
def __init__(
self, verbose_name=None, name=None, encoder=None, decoder=None,
**kwargs,
):
if encoder and not callable(encoder):
raise ValueError('The encoder parameter must be a callable object.')
if decoder and not callable(decoder):
raise ValueError('The decoder parameter must be a callable object.')
self.encoder = encoder
self.decoder = decoder
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
errors = super().check(**kwargs)
databases = kwargs.get('databases') or []
errors.extend(self._check_supported(databases))
return errors
def _check_supported(self, databases):
errors = []
for db in databases:
if not router.allow_migrate_model(db, self.model):
continue
connection = connections[db]
if (
self.model._meta.required_db_vendor and
self.model._meta.required_db_vendor != connection.vendor
):
continue
if not (
'supports_json_field' in self.model._meta.required_db_features or
connection.features.supports_json_field
):
errors.append(
checks.Error(
'%s does not support JSONFields.'
% connection.display_name,
obj=self.model,
id='fields.E180',
)
)
return errors
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.encoder is not None:
kwargs['encoder'] = self.encoder
if self.decoder is not None:
kwargs['decoder'] = self.decoder
return name, path, args, kwargs
def from_db_value(self, value, expression, connection):
if value is None:
return value
# Some backends (SQLite at least) extract non-string values in their
# SQL datatypes.
if isinstance(expression, KeyTransform) and not isinstance(value, str):
return value
try:
return json.loads(value, cls=self.decoder)
except json.JSONDecodeError:
return value
def get_internal_type(self):
return 'JSONField'
def get_prep_value(self, value):
if value is None:
return value
return json.dumps(value, cls=self.encoder)
def get_transform(self, name):
transform = super().get_transform(name)
if transform:
return transform
return KeyTransformFactory(name)
def validate(self, value, model_instance):
super().validate(value, model_instance)
try:
json.dumps(value, cls=self.encoder)
except TypeError:
raise exceptions.ValidationError(
self.error_messages['invalid'],
code='invalid',
params={'value': value},
)
def value_to_string(self, obj):
return self.value_from_object(obj)
def formfield(self, **kwargs):
return super().formfield(**{
'form_class': forms.JSONField,
'encoder': self.encoder,
'decoder': self.decoder,
**kwargs,
})
def compile_json_path(key_transforms, include_root=True):
path = ['$'] if include_root else []
for key_transform in key_transforms:
try:
num = int(key_transform)
except ValueError: # non-integer
path.append('.')
path.append(json.dumps(key_transform))
else:
path.append('[%s]' % num)
return ''.join(path)
class DataContains(PostgresOperatorLookup):
lookup_name = 'contains'
postgres_operator = '@>'
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
'contains lookup is not supported on this database backend.'
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params
class ContainedBy(PostgresOperatorLookup):
lookup_name = 'contained_by'
postgres_operator = '<@'
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
'contained_by lookup is not supported on this database backend.'
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(rhs_params) + tuple(lhs_params)
return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params
class HasKeyLookup(PostgresOperatorLookup):
logical_operator = None
def as_sql(self, compiler, connection, template=None):
# Process JSON path from the left-hand side.
if isinstance(self.lhs, KeyTransform):
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection)
lhs_json_path = compile_json_path(lhs_key_transforms)
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
lhs_json_path = '$'
sql = template % lhs
# Process JSON path from the right-hand side.
rhs = self.rhs
rhs_params = []
if not isinstance(rhs, (list, tuple)):
rhs = [rhs]
for key in rhs:
if isinstance(key, KeyTransform):
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
else:
rhs_key_transforms = [key]
rhs_params.append('%s%s' % (
lhs_json_path,
compile_json_path(rhs_key_transforms, include_root=False),
))
# Add condition for each key.
if self.logical_operator:
sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params))
return sql, tuple(lhs_params) + tuple(rhs_params)
def as_mysql(self, compiler, connection):
return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)")
def as_oracle(self, compiler, connection):
sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')")
# Add paths directly into SQL because path expressions cannot be passed
# as bind variables on Oracle.
return sql % tuple(params), []
def as_postgresql(self, compiler, connection):
if isinstance(self.rhs, KeyTransform):
*_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
for key in rhs_key_transforms[:-1]:
self.lhs = KeyTransform(key, self.lhs)
self.rhs = rhs_key_transforms[-1]
return super().as_postgresql(compiler, connection)
def as_sqlite(self, compiler, connection):
return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL')
class HasKey(HasKeyLookup):
lookup_name = 'has_key'
postgres_operator = '?'
prepare_rhs = False
class HasKeys(HasKeyLookup):
lookup_name = 'has_keys'
postgres_operator = '?&'
logical_operator = ' AND '
def get_prep_lookup(self):
return [str(item) for item in self.rhs]
class HasAnyKeys(HasKeys):
lookup_name = 'has_any_keys'
postgres_operator = '?|'
logical_operator = ' OR '
class CaseInsensitiveMixin:
"""
Mixin to allow case-insensitive comparison of JSON values on MySQL.
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
Because utf8mb4_bin is a binary collation, comparison of JSON values is
case-sensitive.
"""
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == 'mysql':
return 'LOWER(%s)' % lhs, lhs_params
return lhs, lhs_params
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == 'mysql':
return 'LOWER(%s)' % rhs, rhs_params
return rhs, rhs_params
class JSONExact(lookups.Exact):
can_use_none_as_rhs = True
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
# Treat None lookup values as null.
if rhs == '%s' and rhs_params == [None]:
rhs_params = ['null']
if connection.vendor == 'mysql':
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
rhs = rhs % tuple(func)
return rhs, rhs_params
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
pass
JSONField.register_lookup(DataContains)
JSONField.register_lookup(ContainedBy)
JSONField.register_lookup(HasKey)
JSONField.register_lookup(HasKeys)
JSONField.register_lookup(HasAnyKeys)
JSONField.register_lookup(JSONExact)
JSONField.register_lookup(JSONIContains)
class KeyTransform(Transform):
postgres_operator = '->'
postgres_nested_operator = '#>'
def __init__(self, key_name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.key_name = str(key_name)
def preprocess_lhs(self, compiler, connection):
key_transforms = [self.key_name]
previous = self.lhs
while isinstance(previous, KeyTransform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
lhs, params = compiler.compile(previous)
if connection.vendor == 'oracle':
# Escape string-formatting.
key_transforms = [key.replace('%', '%%') for key in key_transforms]
return lhs, params, key_transforms
def as_mysql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
def as_oracle(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return (
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" %
((lhs, json_path) * 2)
), tuple(params) * 2
def as_postgresql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
if len(key_transforms) > 1:
sql = '(%s %s %%s)' % (lhs, self.postgres_nested_operator)
return sql, tuple(params) + (key_transforms,)
try:
lookup = int(self.key_name)
except ValueError:
lookup = self.key_name
return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,)
def as_sqlite(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
datatype_values = ','.join([
repr(datatype) for datatype in connection.ops.jsonfield_datatype_values
])
return (
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
class KeyTextTransform(KeyTransform):
postgres_operator = '->>'
postgres_nested_operator = '#>>'
class KeyTransformTextLookupMixin:
"""
Mixin for combining with a lookup expecting a text lhs from a JSONField
key lookup. On PostgreSQL, make use of the ->> operator instead of casting
key values to text and performing the lookup on the resulting
representation.
"""
def __init__(self, key_transform, *args, **kwargs):
if not isinstance(key_transform, KeyTransform):
raise TypeError(
'Transform should be an instance of KeyTransform in order to '
'use this lookup.'
)
key_text_transform = KeyTextTransform(
key_transform.key_name, *key_transform.source_expressions,
**key_transform.extra,
)
super().__init__(key_text_transform, *args, **kwargs)
class KeyTransformIsNull(lookups.IsNull):
# key__isnull=False is the same as has_key='key'
def as_oracle(self, compiler, connection):
sql, params = HasKey(
self.lhs.lhs,
self.lhs.key_name,
).as_oracle(compiler, connection)
if not self.rhs:
return sql, params
# Column doesn't have a key or IS NULL.
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
return '(NOT %s OR %s IS NULL)' % (sql, lhs), tuple(params) + tuple(lhs_params)
def as_sqlite(self, compiler, connection):
template = 'JSON_TYPE(%s, %%s) IS NULL'
if not self.rhs:
template = 'JSON_TYPE(%s, %%s) IS NOT NULL'
return HasKey(self.lhs.lhs, self.lhs.key_name).as_sql(
compiler,
connection,
template=template,
)
class KeyTransformIn(lookups.In):
def resolve_expression_parameter(self, compiler, connection, sql, param):
sql, params = super().resolve_expression_parameter(
compiler, connection, sql, param,
)
if (
not hasattr(param, 'as_sql') and
not connection.features.has_native_json_field
):
if connection.vendor == 'oracle':
value = json.loads(param)
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
if isinstance(value, (list, dict)):
sql = sql % 'JSON_QUERY'
else:
sql = sql % 'JSON_VALUE'
elif connection.vendor == 'mysql' or (
connection.vendor == 'sqlite' and
params[0] not in connection.ops.jsonfield_datatype_values
):
sql = "JSON_EXTRACT(%s, '$')"
if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
sql = 'JSON_UNQUOTE(%s)' % sql
return sql, params
class KeyTransformExact(JSONExact):
def process_rhs(self, compiler, connection):
if isinstance(self.rhs, KeyTransform):
return super(lookups.Exact, self).process_rhs(compiler, connection)
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == 'oracle':
func = []
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
for value in rhs_params:
value = json.loads(value)
if isinstance(value, (list, dict)):
func.append(sql % 'JSON_QUERY')
else:
func.append(sql % 'JSON_VALUE')
rhs = rhs % tuple(func)
elif connection.vendor == 'sqlite':
func = []
for value in rhs_params:
if value in connection.ops.jsonfield_datatype_values:
func.append('%s')
else:
func.append("JSON_EXTRACT(%s, '$')")
rhs = rhs % tuple(func)
return rhs, rhs_params
def as_oracle(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs_params == ['null']:
# Field has key and it's NULL.
has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name)
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True)
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
return (
'%s AND %s' % (has_key_sql, is_null_sql),
tuple(has_key_params) + tuple(is_null_params),
)
return super().as_sql(compiler, connection)
class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact):
pass
class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains):
pass
class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
pass
class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith):
pass
class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
pass
class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith):
pass
class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
pass
class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex):
pass
class KeyTransformNumericLookupMixin:
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if not connection.features.has_native_json_field:
rhs_params = [json.loads(value) for value in rhs_params]
return rhs, rhs_params
class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
pass
class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
pass
class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
pass
class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
pass
KeyTransform.register_lookup(KeyTransformIn)
KeyTransform.register_lookup(KeyTransformExact)
KeyTransform.register_lookup(KeyTransformIExact)
KeyTransform.register_lookup(KeyTransformIsNull)
KeyTransform.register_lookup(KeyTransformIContains)
KeyTransform.register_lookup(KeyTransformStartsWith)
KeyTransform.register_lookup(KeyTransformIStartsWith)
KeyTransform.register_lookup(KeyTransformEndsWith)
KeyTransform.register_lookup(KeyTransformIEndsWith)
KeyTransform.register_lookup(KeyTransformRegex)
KeyTransform.register_lookup(KeyTransformIRegex)
KeyTransform.register_lookup(KeyTransformLt)
KeyTransform.register_lookup(KeyTransformLte)
KeyTransform.register_lookup(KeyTransformGt)
KeyTransform.register_lookup(KeyTransformGte)
class KeyTransformFactory:
def __init__(self, key_name):
self.key_name = key_name
def __call__(self, *args, **kwargs):
return KeyTransform(self.key_name, *args, **kwargs)

View File

@@ -0,0 +1,56 @@
from django.core import checks
NOT_PROVIDED = object()
class FieldCacheMixin:
"""Provide an API for working with the model's fields value cache."""
def get_cache_name(self):
raise NotImplementedError
def get_cached_value(self, instance, default=NOT_PROVIDED):
cache_name = self.get_cache_name()
try:
return instance._state.fields_cache[cache_name]
except KeyError:
if default is NOT_PROVIDED:
raise
return default
def is_cached(self, instance):
return self.get_cache_name() in instance._state.fields_cache
def set_cached_value(self, instance, value):
instance._state.fields_cache[self.get_cache_name()] = value
def delete_cached_value(self, instance):
del instance._state.fields_cache[self.get_cache_name()]
class CheckFieldDefaultMixin:
_default_hint = ('<valid default>', '<invalid default>')
def _check_default(self):
if self.has_default() and self.default is not None and not callable(self.default):
return [
checks.Warning(
"%s default should be a callable instead of an instance "
"so that it's not shared between all field instances." % (
self.__class__.__name__,
),
hint=(
'Use a callable instead, e.g., use `%s` instead of '
'`%s`.' % self._default_hint
),
obj=self,
id='fields.E010',
)
]
else:
return []
def check(self, **kwargs):
errors = super().check(**kwargs)
errors.extend(self._check_default())
return errors

View File

@@ -0,0 +1,18 @@
"""
Field-like classes that aren't really fields. It's easier to use objects that
have the same attributes as fields sometimes (avoids a lot of special casing).
"""
from django.db.models import fields
class OrderWrt(fields.IntegerField):
"""
A proxy for the _order database field that is used when
Meta.order_with_respect_to is specified.
"""
def __init__(self, *args, **kwargs):
kwargs['name'] = '_order'
kwargs['editable'] = False
super().__init__(*args, **kwargs)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,159 @@
from django.db.models.lookups import (
Exact, GreaterThan, GreaterThanOrEqual, In, IsNull, LessThan,
LessThanOrEqual,
)
class MultiColSource:
contains_aggregate = False
def __init__(self, alias, targets, sources, field):
self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
self.output_field = self.field
def __repr__(self):
return "{}({}, {})".format(
self.__class__.__name__, self.alias, self.field)
def relabeled_clone(self, relabels):
return self.__class__(relabels.get(self.alias, self.alias),
self.targets, self.sources, self.field)
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
def resolve_expression(self, *args, **kwargs):
return self
def get_normalized_value(value, lhs):
from django.db.models import Model
if isinstance(value, Model):
value_list = []
sources = lhs.output_field.get_path_info()[-1].target_fields
for source in sources:
while not isinstance(value, source.model) and source.remote_field:
source = source.remote_field.model._meta.get_field(source.remote_field.field_name)
try:
value_list.append(getattr(value, source.attname))
except AttributeError:
# A case like Restaurant.objects.filter(place=restaurant_instance),
# where place is a OneToOneField and the primary key of Restaurant.
return (value.pk,)
return tuple(value_list)
if not isinstance(value, tuple):
return (value,)
return value
class RelatedIn(In):
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
# If we get here, we are dealing with single-column relations.
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
# We need to run the related field's get_prep_value(). Consider case
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
# doesn't have validation for non-integers, so we must run validation
# using the target field.
if hasattr(self.lhs.output_field, 'get_path_info'):
# Run the target field's get_prep_value. We can safely assume there is
# only one as we don't get to the direct value branch otherwise.
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
# For multicolumn lookups we need to build a multicolumn where clause.
# This clause is either a SubqueryConstraint (for values that need to be compiled to
# SQL) or an OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
from django.db.models.sql.where import (
AND, OR, SubqueryConstraint, WhereNode,
)
root_constraint = WhereNode(connector=OR)
if self.rhs_is_direct_value():
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
for value in values:
value_constraint = WhereNode()
for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
lookup_class = target.get_lookup('exact')
lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
value_constraint.add(lookup, AND)
root_constraint.add(value_constraint, OR)
else:
root_constraint.add(
SubqueryConstraint(
self.lhs.alias, [target.column for target in self.lhs.targets],
[source.name for source in self.lhs.sources], self.rhs),
AND)
return root_constraint.as_sql(compiler, connection)
else:
if (not getattr(self.rhs, 'has_select_fields', True) and
not getattr(self.lhs.field.target_field, 'primary_key', False)):
self.rhs.clear_select_clause()
if (getattr(self.lhs.output_field, 'primary_key', False) and
self.lhs.output_field.model == self.rhs.model):
# A case like Restaurant.objects.filter(place__in=restaurant_qs),
# where place is a OneToOneField and the primary key of
# Restaurant.
target_field = self.lhs.field.name
else:
target_field = self.lhs.field.target_field.name
self.rhs.add_fields([target_field], True)
return super().as_sql(compiler, connection)
class RelatedLookupMixin:
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource) and not hasattr(self.rhs, 'resolve_expression'):
# If we get here, we are dealing with single-column relations.
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
# We need to run the related field's get_prep_value(). Consider case
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
# doesn't have validation for non-integers, so we must run validation
# using the target field.
if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_path_info'):
# Get the target field. We can safely assume there is only one
# as we don't get to the direct value branch otherwise.
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
self.rhs = target_field.get_prep_value(self.rhs)
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
assert self.rhs_is_direct_value()
self.rhs = get_normalized_value(self.rhs, self.lhs)
from django.db.models.sql.where import AND, WhereNode
root_constraint = WhereNode()
for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
lookup_class = target.get_lookup(self.lookup_name)
root_constraint.add(
lookup_class(target.get_col(self.lhs.alias, source), val), AND)
return root_constraint.as_sql(compiler, connection)
return super().as_sql(compiler, connection)
class RelatedExact(RelatedLookupMixin, Exact):
pass
class RelatedLessThan(RelatedLookupMixin, LessThan):
pass
class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
pass
class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
pass
class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
pass
class RelatedIsNull(RelatedLookupMixin, IsNull):
pass

View File

@@ -0,0 +1,330 @@
"""
"Rel objects" for related fields.
"Rel objects" (for lack of a better name) carry information about the relation
modeled by a related field and provide some utility functions. They're stored
in the ``remote_field`` attribute of the field.
They also act as reverse fields for the purposes of the Meta API because
they're the closest concept currently available.
"""
from django.core import exceptions
from django.utils.functional import cached_property
from django.utils.hashable import make_hashable
from . import BLANK_CHOICE_DASH
from .mixins import FieldCacheMixin
class ForeignObjectRel(FieldCacheMixin):
"""
Used by ForeignObject to store information about the relation.
``_meta.get_fields()`` returns this class to provide access to the field
flags for the reverse relation.
"""
# Field flags
auto_created = True
concrete = False
editable = False
is_relation = True
# Reverse relations are always nullable (Django can't enforce that a
# foreign key on the related model points to this model).
null = True
empty_strings_allowed = False
def __init__(self, field, to, related_name=None, related_query_name=None,
limit_choices_to=None, parent_link=False, on_delete=None):
self.field = field
self.model = to
self.related_name = related_name
self.related_query_name = related_query_name
self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
self.parent_link = parent_link
self.on_delete = on_delete
self.symmetrical = False
self.multiple = True
# Some of the following cached_properties can't be initialized in
# __init__ as the field doesn't have its model yet. Calling these methods
# before field.contribute_to_class() has been called will result in
# AttributeError
@cached_property
def hidden(self):
return self.is_hidden()
@cached_property
def name(self):
return self.field.related_query_name()
@property
def remote_field(self):
return self.field
@property
def target_field(self):
"""
When filtering against this relation, return the field on the remote
model against which the filtering should happen.
"""
target_fields = self.get_path_info()[-1].target_fields
if len(target_fields) > 1:
raise exceptions.FieldError("Can't use target_field for multicolumn relations.")
return target_fields[0]
@cached_property
def related_model(self):
if not self.field.model:
raise AttributeError(
"This property can't be accessed before self.field.contribute_to_class has been called.")
return self.field.model
@cached_property
def many_to_many(self):
return self.field.many_to_many
@cached_property
def many_to_one(self):
return self.field.one_to_many
@cached_property
def one_to_many(self):
return self.field.many_to_one
@cached_property
def one_to_one(self):
return self.field.one_to_one
def get_lookup(self, lookup_name):
return self.field.get_lookup(lookup_name)
def get_internal_type(self):
return self.field.get_internal_type()
@property
def db_type(self):
return self.field.db_type
def __repr__(self):
return '<%s: %s.%s>' % (
type(self).__name__,
self.related_model._meta.app_label,
self.related_model._meta.model_name,
)
@property
def identity(self):
return (
self.field,
self.model,
self.related_name,
self.related_query_name,
make_hashable(self.limit_choices_to),
self.parent_link,
self.on_delete,
self.symmetrical,
self.multiple,
)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return self.identity == other.identity
def __hash__(self):
return hash(self.identity)
def get_choices(
self, include_blank=True, blank_choice=BLANK_CHOICE_DASH,
limit_choices_to=None, ordering=(),
):
"""
Return choices with a default blank choices included, for use
as <select> choices for this field.
Analog of django.db.models.fields.Field.get_choices(), provided
initially for utilization by RelatedFieldListFilter.
"""
limit_choices_to = limit_choices_to or self.limit_choices_to
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
if ordering:
qs = qs.order_by(*ordering)
return (blank_choice if include_blank else []) + [
(x.pk, str(x)) for x in qs
]
def is_hidden(self):
"""Should the related object be hidden?"""
return bool(self.related_name) and self.related_name[-1] == '+'
def get_joining_columns(self):
return self.field.get_reverse_joining_columns()
def get_extra_restriction(self, alias, related_alias):
return self.field.get_extra_restriction(related_alias, alias)
def set_field_name(self):
"""
Set the related field's name, this is not available until later stages
of app loading, so set_field_name is called from
set_attributes_from_rel()
"""
# By default foreign object doesn't relate to any remote field (for
# example custom multicolumn joins currently have no remote field).
self.field_name = None
def get_accessor_name(self, model=None):
# This method encapsulates the logic that decides what name to give an
# accessor descriptor that retrieves related many-to-one or
# many-to-many objects. It uses the lowercased object_name + "_set",
# but this can be overridden with the "related_name" option. Due to
# backwards compatibility ModelForms need to be able to provide an
# alternate model. See BaseInlineFormSet.get_default_prefix().
opts = model._meta if model else self.related_model._meta
model = model or self.related_model
if self.multiple:
# If this is a symmetrical m2m relation on self, there is no reverse accessor.
if self.symmetrical and model == self.model:
return None
if self.related_name:
return self.related_name
return opts.model_name + ('_set' if self.multiple else '')
def get_path_info(self, filtered_relation=None):
return self.field.get_reverse_path_info(filtered_relation)
def get_cache_name(self):
"""
Return the name of the cache key to use for storing an instance of the
forward model on the reverse model.
"""
return self.get_accessor_name()
class ManyToOneRel(ForeignObjectRel):
"""
Used by the ForeignKey field to store information about the relation.
``_meta.get_fields()`` returns this class to provide access to the field
flags for the reverse relation.
Note: Because we somewhat abuse the Rel objects by using them as reverse
fields we get the funny situation where
``ManyToOneRel.many_to_one == False`` and
``ManyToOneRel.one_to_many == True``. This is unfortunate but the actual
ManyToOneRel class is a private API and there is work underway to turn
reverse relations into actual fields.
"""
def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
limit_choices_to=None, parent_link=False, on_delete=None):
super().__init__(
field, to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
parent_link=parent_link,
on_delete=on_delete,
)
self.field_name = field_name
def __getstate__(self):
state = self.__dict__.copy()
state.pop('related_model', None)
return state
@property
def identity(self):
return super().identity + (self.field_name,)
def get_related_field(self):
"""
Return the Field in the 'to' object to which this relationship is tied.
"""
field = self.model._meta.get_field(self.field_name)
if not field.concrete:
raise exceptions.FieldDoesNotExist("No related field named '%s'" % self.field_name)
return field
def set_field_name(self):
self.field_name = self.field_name or self.model._meta.pk.name
class OneToOneRel(ManyToOneRel):
"""
Used by OneToOneField to store information about the relation.
``_meta.get_fields()`` returns this class to provide access to the field
flags for the reverse relation.
"""
def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
limit_choices_to=None, parent_link=False, on_delete=None):
super().__init__(
field, to, field_name,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
parent_link=parent_link,
on_delete=on_delete,
)
self.multiple = False
class ManyToManyRel(ForeignObjectRel):
"""
Used by ManyToManyField to store information about the relation.
``_meta.get_fields()`` returns this class to provide access to the field
flags for the reverse relation.
"""
def __init__(self, field, to, related_name=None, related_query_name=None,
limit_choices_to=None, symmetrical=True, through=None,
through_fields=None, db_constraint=True):
super().__init__(
field, to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
)
if through and not db_constraint:
raise ValueError("Can't supply a through model and db_constraint=False")
self.through = through
if through_fields and not through:
raise ValueError("Cannot specify through_fields without a through model")
self.through_fields = through_fields
self.symmetrical = symmetrical
self.db_constraint = db_constraint
@property
def identity(self):
return super().identity + (
self.through,
make_hashable(self.through_fields),
self.db_constraint,
)
def get_related_field(self):
"""
Return the field in the 'to' object to which this relationship is tied.
Provided for symmetry with ManyToOneRel.
"""
opts = self.through._meta
if self.through_fields:
field = opts.get_field(self.through_fields[0])
else:
for field in opts.fields:
rel = getattr(field, 'remote_field', None)
if rel and rel.model == self.model:
break
return field.foreign_related_fields[0]

View File

@@ -0,0 +1,46 @@
from .comparison import (
Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf,
)
from .datetime import (
Extract, ExtractDay, ExtractHour, ExtractIsoWeekDay, ExtractIsoYear,
ExtractMinute, ExtractMonth, ExtractQuarter, ExtractSecond, ExtractWeek,
ExtractWeekDay, ExtractYear, Now, Trunc, TruncDate, TruncDay, TruncHour,
TruncMinute, TruncMonth, TruncQuarter, TruncSecond, TruncTime, TruncWeek,
TruncYear,
)
from .math import (
Abs, ACos, ASin, ATan, ATan2, Ceil, Cos, Cot, Degrees, Exp, Floor, Ln, Log,
Mod, Pi, Power, Radians, Random, Round, Sign, Sin, Sqrt, Tan,
)
from .text import (
MD5, SHA1, SHA224, SHA256, SHA384, SHA512, Chr, Concat, ConcatPair, Left,
Length, Lower, LPad, LTrim, Ord, Repeat, Replace, Reverse, Right, RPad,
RTrim, StrIndex, Substr, Trim, Upper,
)
from .window import (
CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
PercentRank, Rank, RowNumber,
)
__all__ = [
# comparison and conversion
'Cast', 'Coalesce', 'Collate', 'Greatest', 'JSONObject', 'Least', 'NullIf',
# datetime
'Extract', 'ExtractDay', 'ExtractHour', 'ExtractMinute', 'ExtractMonth',
'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractIsoWeekDay',
'ExtractWeekDay', 'ExtractIsoYear', 'ExtractYear', 'Now', 'Trunc',
'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'TruncMonth',
'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncWeek', 'TruncYear',
# math
'Abs', 'ACos', 'ASin', 'ATan', 'ATan2', 'Ceil', 'Cos', 'Cot', 'Degrees',
'Exp', 'Floor', 'Ln', 'Log', 'Mod', 'Pi', 'Power', 'Radians', 'Random',
'Round', 'Sign', 'Sin', 'Sqrt', 'Tan',
# text
'MD5', 'SHA1', 'SHA224', 'SHA256', 'SHA384', 'SHA512', 'Chr', 'Concat',
'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim', 'Ord', 'Repeat',
'Replace', 'Reverse', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr',
'Trim', 'Upper',
# window
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
]

View File

@@ -0,0 +1,193 @@
"""Database functions that do comparisons or type conversions."""
from django.db import NotSupportedError
from django.db.models.expressions import Func, Value
from django.db.models.fields.json import JSONField
from django.utils.regex_helper import _lazy_re_compile
class Cast(Func):
"""Coerce an expression to a new field type."""
function = 'CAST'
template = '%(function)s(%(expressions)s AS %(db_type)s)'
def __init__(self, expression, output_field):
super().__init__(expression, output_field=output_field)
def as_sql(self, compiler, connection, **extra_context):
extra_context['db_type'] = self.output_field.cast_db_type(connection)
return super().as_sql(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
db_type = self.output_field.db_type(connection)
if db_type in {'datetime', 'time'}:
# Use strftime as datetime/time don't keep fractional seconds.
template = 'strftime(%%s, %(expressions)s)'
sql, params = super().as_sql(compiler, connection, template=template, **extra_context)
format_string = '%H:%M:%f' if db_type == 'time' else '%Y-%m-%d %H:%M:%f'
params.insert(0, format_string)
return sql, params
elif db_type == 'date':
template = 'date(%(expressions)s)'
return super().as_sql(compiler, connection, template=template, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection, **extra_context):
template = None
output_type = self.output_field.get_internal_type()
# MySQL doesn't support explicit cast to float.
if output_type == 'FloatField':
template = '(%(expressions)s + 0.0)'
# MariaDB doesn't support explicit cast to JSON.
elif output_type == 'JSONField' and connection.mysql_is_mariadb:
template = "JSON_EXTRACT(%(expressions)s, '$')"
return self.as_sql(compiler, connection, template=template, **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
# CAST would be valid too, but the :: shortcut syntax is more readable.
# 'expressions' is wrapped in parentheses in case it's a complex
# expression.
return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'JSONField':
# Oracle doesn't support explicit cast to JSON.
template = "JSON_QUERY(%(expressions)s, '$')"
return super().as_sql(compiler, connection, template=template, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
class Coalesce(Func):
"""Return, from left to right, the first non-null expression."""
function = 'COALESCE'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Coalesce must take at least two expressions')
super().__init__(*expressions, **extra)
@property
def empty_result_set_value(self):
for expression in self.get_source_expressions():
result = expression.empty_result_set_value
if result is NotImplemented or result is not None:
return result
return None
def as_oracle(self, compiler, connection, **extra_context):
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
# so convert all fields to NCLOB when that type is expected.
if self.output_field.get_internal_type() == 'TextField':
clone = self.copy()
clone.set_source_expressions([
Func(expression, function='TO_NCLOB') for expression in self.get_source_expressions()
])
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
class Collate(Func):
function = 'COLLATE'
template = '%(expressions)s %(function)s %(collation)s'
# Inspired from https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
collation_re = _lazy_re_compile(r'^[\w\-]+$')
def __init__(self, expression, collation):
if not (collation and self.collation_re.match(collation)):
raise ValueError('Invalid collation name: %r.' % collation)
self.collation = collation
super().__init__(expression)
def as_sql(self, compiler, connection, **extra_context):
extra_context.setdefault('collation', connection.ops.quote_name(self.collation))
return super().as_sql(compiler, connection, **extra_context)
class Greatest(Func):
"""
Return the maximum expression.
If any expression is null the return value is database-specific:
On PostgreSQL, the maximum not-null expression is returned.
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
"""
function = 'GREATEST'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Greatest must take at least two expressions')
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MAX function on SQLite."""
return super().as_sqlite(compiler, connection, function='MAX', **extra_context)
class JSONObject(Func):
function = 'JSON_OBJECT'
output_field = JSONField()
def __init__(self, **fields):
expressions = []
for key, value in fields.items():
expressions.extend((Value(key), value))
super().__init__(*expressions)
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.has_json_object_function:
raise NotSupportedError(
'JSONObject() is not supported on this database backend.'
)
return super().as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
return self.as_sql(
compiler,
connection,
function='JSONB_BUILD_OBJECT',
**extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
class ArgJoiner:
def join(self, args):
args = [' VALUE '.join(arg) for arg in zip(args[::2], args[1::2])]
return ', '.join(args)
return self.as_sql(
compiler,
connection,
arg_joiner=ArgJoiner(),
template='%(function)s(%(expressions)s RETURNING CLOB)',
**extra_context,
)
class Least(Func):
"""
Return the minimum expression.
If any expression is null the return value is database-specific:
On PostgreSQL, return the minimum not-null expression.
On MySQL, Oracle, and SQLite, if any expression is null, return null.
"""
function = 'LEAST'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Least must take at least two expressions')
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MIN function on SQLite."""
return super().as_sqlite(compiler, connection, function='MIN', **extra_context)
class NullIf(Func):
function = 'NULLIF'
arity = 2
def as_oracle(self, compiler, connection, **extra_context):
expression1 = self.get_source_expressions()[0]
if isinstance(expression1, Value) and expression1.value is None:
raise ValueError('Oracle does not allow Value(None) for expression1.')
return super().as_sql(compiler, connection, **extra_context)

View File

@@ -0,0 +1,339 @@
from datetime import datetime
from django.conf import settings
from django.db.models.expressions import Func
from django.db.models.fields import (
DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
)
from django.db.models.lookups import (
Transform, YearExact, YearGt, YearGte, YearLt, YearLte,
)
from django.utils import timezone
class TimezoneMixin:
tzinfo = None
def get_tzname(self):
# Timezone conversions must happen to the input datetime *before*
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
# database as 2016-01-01 01:00:00 +00:00. Any results should be
# based on the input datetime not the stored datetime.
tzname = None
if settings.USE_TZ:
if self.tzinfo is None:
tzname = timezone.get_current_timezone_name()
else:
tzname = timezone._get_timezone_name(self.tzinfo)
return tzname
class Extract(TimezoneMixin, Transform):
lookup_name = None
output_field = IntegerField()
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
if self.lookup_name is None:
self.lookup_name = lookup_name
if self.lookup_name is None:
raise ValueError('lookup_name must be provided')
self.tzinfo = tzinfo
super().__init__(expression, **extra)
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField):
tzname = self.get_tzname()
sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
elif self.tzinfo is not None:
raise ValueError('tzinfo can only be used with DateTimeField.')
elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, TimeField):
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, DurationField):
if not connection.features.has_native_duration_field:
raise ValueError('Extract requires native DurationField database support.')
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
else:
# resolve_expression has already validated the output_field so this
# assert should never be hit.
assert False, "Tried to Extract from an invalid type."
return sql, params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
field = getattr(copy.lhs, 'output_field', None)
if field is None:
return copy
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
raise ValueError(
'Extract input expression must be DateField, DateTimeField, '
'TimeField, or DurationField.'
)
# Passing dates to functions expecting datetimes is most likely a mistake.
if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
raise ValueError(
"Cannot extract time component '%s' from DateField '%s'." % (copy.lookup_name, field.name)
)
if (
isinstance(field, DurationField) and
copy.lookup_name in ('year', 'iso_year', 'month', 'week', 'week_day', 'iso_week_day', 'quarter')
):
raise ValueError(
"Cannot extract component '%s' from DurationField '%s'."
% (copy.lookup_name, field.name)
)
return copy
class ExtractYear(Extract):
lookup_name = 'year'
class ExtractIsoYear(Extract):
"""Return the ISO-8601 week-numbering year."""
lookup_name = 'iso_year'
class ExtractMonth(Extract):
lookup_name = 'month'
class ExtractDay(Extract):
lookup_name = 'day'
class ExtractWeek(Extract):
"""
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
week.
"""
lookup_name = 'week'
class ExtractWeekDay(Extract):
"""
Return Sunday=1 through Saturday=7.
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
"""
lookup_name = 'week_day'
class ExtractIsoWeekDay(Extract):
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
lookup_name = 'iso_week_day'
class ExtractQuarter(Extract):
lookup_name = 'quarter'
class ExtractHour(Extract):
lookup_name = 'hour'
class ExtractMinute(Extract):
lookup_name = 'minute'
class ExtractSecond(Extract):
lookup_name = 'second'
DateField.register_lookup(ExtractYear)
DateField.register_lookup(ExtractMonth)
DateField.register_lookup(ExtractDay)
DateField.register_lookup(ExtractWeekDay)
DateField.register_lookup(ExtractIsoWeekDay)
DateField.register_lookup(ExtractWeek)
DateField.register_lookup(ExtractIsoYear)
DateField.register_lookup(ExtractQuarter)
TimeField.register_lookup(ExtractHour)
TimeField.register_lookup(ExtractMinute)
TimeField.register_lookup(ExtractSecond)
DateTimeField.register_lookup(ExtractHour)
DateTimeField.register_lookup(ExtractMinute)
DateTimeField.register_lookup(ExtractSecond)
ExtractYear.register_lookup(YearExact)
ExtractYear.register_lookup(YearGt)
ExtractYear.register_lookup(YearGte)
ExtractYear.register_lookup(YearLt)
ExtractYear.register_lookup(YearLte)
ExtractIsoYear.register_lookup(YearExact)
ExtractIsoYear.register_lookup(YearGt)
ExtractIsoYear.register_lookup(YearGte)
ExtractIsoYear.register_lookup(YearLt)
ExtractIsoYear.register_lookup(YearLte)
class Now(Func):
template = 'CURRENT_TIMESTAMP'
output_field = DateTimeField()
def as_postgresql(self, compiler, connection, **extra_context):
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
# other databases.
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
class TruncBase(TimezoneMixin, Transform):
kind = None
tzinfo = None
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
def __init__(self, expression, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra):
self.tzinfo = tzinfo
self.is_dst = is_dst
super().__init__(expression, output_field=output_field, **extra)
def as_sql(self, compiler, connection):
inner_sql, inner_params = compiler.compile(self.lhs)
tzname = None
if isinstance(self.lhs.output_field, DateTimeField):
tzname = self.get_tzname()
elif self.tzinfo is not None:
raise ValueError('tzinfo can only be used with DateTimeField.')
if isinstance(self.output_field, DateTimeField):
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, DateField):
sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, TimeField):
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
else:
raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
return sql, inner_params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
field = copy.lhs.output_field
# DateTimeField is a subclass of DateField so this works for both.
if not isinstance(field, (DateField, TimeField)):
raise TypeError(
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
)
# If self.output_field was None, then accessing the field will trigger
# the resolver to assign it to self.lhs.output_field.
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
# Passing dates or times to functions expecting datetimes is most
# likely a mistake.
class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
output_field = class_output_field or copy.output_field
has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
if type(field) == DateField and (
isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
raise ValueError("Cannot truncate DateField '%s' to %s." % (
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
))
elif isinstance(field, TimeField) and (
isinstance(output_field, DateTimeField) or
copy.kind in ('year', 'quarter', 'month', 'week', 'day', 'date')):
raise ValueError("Cannot truncate TimeField '%s' to %s." % (
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
))
return copy
def convert_value(self, value, expression, connection):
if isinstance(self.output_field, DateTimeField):
if not settings.USE_TZ:
pass
elif value is not None:
value = value.replace(tzinfo=None)
value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
elif not connection.features.has_zoneinfo_database:
raise ValueError(
'Database returned an invalid datetime value. Are time '
'zone definitions for your database installed?'
)
elif isinstance(value, datetime):
if value is None:
pass
elif isinstance(self.output_field, DateField):
value = value.date()
elif isinstance(self.output_field, TimeField):
value = value.time()
return value
class Trunc(TruncBase):
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
def __init__(self, expression, kind, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra):
self.kind = kind
super().__init__(
expression, output_field=output_field, tzinfo=tzinfo,
is_dst=is_dst, **extra
)
class TruncYear(TruncBase):
kind = 'year'
class TruncQuarter(TruncBase):
kind = 'quarter'
class TruncMonth(TruncBase):
kind = 'month'
class TruncWeek(TruncBase):
"""Truncate to midnight on the Monday of the week."""
kind = 'week'
class TruncDay(TruncBase):
kind = 'day'
class TruncDate(TruncBase):
kind = 'date'
lookup_name = 'date'
output_field = DateField()
def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date.
lhs, lhs_params = compiler.compile(self.lhs)
tzname = self.get_tzname()
sql = connection.ops.datetime_cast_date_sql(lhs, tzname)
return sql, lhs_params
class TruncTime(TruncBase):
kind = 'time'
lookup_name = 'time'
output_field = TimeField()
def as_sql(self, compiler, connection):
# Cast to time rather than truncate to time.
lhs, lhs_params = compiler.compile(self.lhs)
tzname = self.get_tzname()
sql = connection.ops.datetime_cast_time_sql(lhs, tzname)
return sql, lhs_params
class TruncHour(TruncBase):
kind = 'hour'
class TruncMinute(TruncBase):
kind = 'minute'
class TruncSecond(TruncBase):
kind = 'second'
DateTimeField.register_lookup(TruncDate)
DateTimeField.register_lookup(TruncTime)

View File

@@ -0,0 +1,197 @@
import math
from django.db.models.expressions import Func, Value
from django.db.models.fields import FloatField, IntegerField
from django.db.models.functions import Cast
from django.db.models.functions.mixins import (
FixDecimalInputMixin, NumericOutputFieldMixin,
)
from django.db.models.lookups import Transform
class Abs(Transform):
function = 'ABS'
lookup_name = 'abs'
class ACos(NumericOutputFieldMixin, Transform):
function = 'ACOS'
lookup_name = 'acos'
class ASin(NumericOutputFieldMixin, Transform):
function = 'ASIN'
lookup_name = 'asin'
class ATan(NumericOutputFieldMixin, Transform):
function = 'ATAN'
lookup_name = 'atan'
class ATan2(NumericOutputFieldMixin, Func):
function = 'ATAN2'
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version >= (5, 0, 0):
return self.as_sql(compiler, connection)
# This function is usually ATan2(y, x), returning the inverse tangent
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
# Cast integers to float to avoid inconsistent/buggy behavior if the
# arguments are mixed between integer and float or decimal.
# https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
clone = self.copy()
clone.set_source_expressions([
Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField)
else expression for expression in self.get_source_expressions()[::-1]
])
return clone.as_sql(compiler, connection, **extra_context)
class Ceil(Transform):
function = 'CEILING'
lookup_name = 'ceil'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='CEIL', **extra_context)
class Cos(NumericOutputFieldMixin, Transform):
function = 'COS'
lookup_name = 'cos'
class Cot(NumericOutputFieldMixin, Transform):
function = 'COT'
lookup_name = 'cot'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
class Degrees(NumericOutputFieldMixin, Transform):
function = 'DEGREES'
lookup_name = 'degrees'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection,
template='((%%(expressions)s) * 180 / %s)' % math.pi,
**extra_context
)
class Exp(NumericOutputFieldMixin, Transform):
function = 'EXP'
lookup_name = 'exp'
class Floor(Transform):
function = 'FLOOR'
lookup_name = 'floor'
class Ln(NumericOutputFieldMixin, Transform):
function = 'LN'
lookup_name = 'ln'
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = 'LOG'
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(connection.ops, 'spatialite', False):
return self.as_sql(compiler, connection)
# This function is usually Log(b, x) returning the logarithm of x to
# the base b, but on SpatiaLite it's Log(x, b).
clone = self.copy()
clone.set_source_expressions(self.get_source_expressions()[::-1])
return clone.as_sql(compiler, connection, **extra_context)
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = 'MOD'
arity = 2
class Pi(NumericOutputFieldMixin, Func):
function = 'PI'
arity = 0
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
class Power(NumericOutputFieldMixin, Func):
function = 'POWER'
arity = 2
class Radians(NumericOutputFieldMixin, Transform):
function = 'RADIANS'
lookup_name = 'radians'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection,
template='((%%(expressions)s) * %s / 180)' % math.pi,
**extra_context
)
class Random(NumericOutputFieldMixin, Func):
function = 'RANDOM'
arity = 0
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='RAND', **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='DBMS_RANDOM.VALUE', **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='RAND', **extra_context)
def get_group_by_cols(self, alias=None):
return []
class Round(FixDecimalInputMixin, Transform):
function = 'ROUND'
lookup_name = 'round'
arity = None # Override Transform's arity=1 to enable passing precision.
def __init__(self, expression, precision=0, **extra):
super().__init__(expression, precision, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
precision = self.get_source_expressions()[1]
if isinstance(precision, Value) and precision.value < 0:
raise ValueError('SQLite does not support negative precision.')
return super().as_sqlite(compiler, connection, **extra_context)
def _resolve_output_field(self):
source = self.get_source_expressions()[0]
return source.output_field
class Sign(Transform):
function = 'SIGN'
lookup_name = 'sign'
class Sin(NumericOutputFieldMixin, Transform):
function = 'SIN'
lookup_name = 'sin'
class Sqrt(NumericOutputFieldMixin, Transform):
function = 'SQRT'
lookup_name = 'sqrt'
class Tan(NumericOutputFieldMixin, Transform):
function = 'TAN'
lookup_name = 'tan'

View File

@@ -0,0 +1,52 @@
import sys
from django.db.models.fields import DecimalField, FloatField, IntegerField
from django.db.models.functions import Cast
class FixDecimalInputMixin:
def as_postgresql(self, compiler, connection, **extra_context):
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
# following function signatures:
# - LOG(double, double)
# - MOD(double, double)
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
clone = self.copy()
clone.set_source_expressions([
Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
else expression for expression in self.get_source_expressions()
])
return clone.as_sql(compiler, connection, **extra_context)
class FixDurationInputMixin:
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s AS SIGNED)' % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
options = self._get_repr_options()
from django.db.backends.oracle.functions import (
IntervalToSeconds, SecondsToInterval,
)
return compiler.compile(
SecondsToInterval(self.__class__(IntervalToSeconds(expression), **options))
)
return super().as_sql(compiler, connection, **extra_context)
class NumericOutputFieldMixin:
def _resolve_output_field(self):
source_fields = self.get_source_fields()
if any(isinstance(s, DecimalField) for s in source_fields):
return DecimalField()
if any(isinstance(s, IntegerField) for s in source_fields):
return FloatField()
return super()._resolve_output_field() if source_fields else FloatField()

View File

@@ -0,0 +1,323 @@
from django.db import NotSupportedError
from django.db.models.expressions import Func, Value
from django.db.models.fields import CharField, IntegerField
from django.db.models.functions import Coalesce
from django.db.models.lookups import Transform
class MySQLSHA2Mixin:
def as_mysql(self, compiler, connection, **extra_content):
return super().as_sql(
compiler,
connection,
template='SHA2(%%(expressions)s, %s)' % self.function[3:],
**extra_content,
)
class OracleHashMixin:
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template=(
"LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW("
"%(expressions)s, 'AL32UTF8'), '%(function)s')))"
),
**extra_context,
)
class PostgreSQLSHAMixin:
def as_postgresql(self, compiler, connection, **extra_content):
return super().as_sql(
compiler,
connection,
template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
function=self.function.lower(),
**extra_content,
)
class Chr(Transform):
function = 'CHR'
lookup_name = 'chr'
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, function='CHAR',
template='%(function)s(%(expressions)s USING utf16)',
**extra_context
)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection,
template='%(function)s(%(expressions)s USING NCHAR_CS)',
**extra_context
)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='CHAR', **extra_context)
class ConcatPair(Func):
"""
Concatenate two arguments together. This is used by `Concat` because not
all backend databases support more than two arguments.
"""
function = 'CONCAT'
def as_sqlite(self, compiler, connection, **extra_context):
coalesced = self.coalesce()
return super(ConcatPair, coalesced).as_sql(
compiler, connection, template='%(expressions)s', arg_joiner=' || ',
**extra_context
)
def as_mysql(self, compiler, connection, **extra_context):
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
return super().as_sql(
compiler, connection, function='CONCAT_WS',
template="%(function)s('', %(expressions)s)",
**extra_context
)
def coalesce(self):
# null on either side results in null for expression, wrap with coalesce
c = self.copy()
c.set_source_expressions([
Coalesce(expression, Value('')) for expression in c.get_source_expressions()
])
return c
class Concat(Func):
"""
Concatenate text fields together. Backends that result in an entire
null expression when any arguments are null will wrap each argument in
coalesce functions to ensure a non-null result.
"""
function = None
template = "%(expressions)s"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Concat must take at least two expressions')
paired = self._paired(expressions)
super().__init__(paired, **extra)
def _paired(self, expressions):
# wrap pairs of expressions in successive concat functions
# exp = [a, b, c, d]
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
if len(expressions) == 2:
return ConcatPair(*expressions)
return ConcatPair(expressions[0], self._paired(expressions[1:]))
class Left(Func):
function = 'LEFT'
arity = 2
output_field = CharField()
def __init__(self, expression, length, **extra):
"""
expression: the name of a field, or an expression returning a string
length: the number of characters to return from the start of the string
"""
if not hasattr(length, 'resolve_expression'):
if length < 1:
raise ValueError("'length' must be greater than 0.")
super().__init__(expression, length, **extra)
def get_substr(self):
return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
def as_oracle(self, compiler, connection, **extra_context):
return self.get_substr().as_oracle(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return self.get_substr().as_sqlite(compiler, connection, **extra_context)
class Length(Transform):
"""Return the number of characters in the expression."""
function = 'LENGTH'
lookup_name = 'length'
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context)
class Lower(Transform):
function = 'LOWER'
lookup_name = 'lower'
class LPad(Func):
function = 'LPAD'
output_field = CharField()
def __init__(self, expression, length, fill_text=Value(' '), **extra):
if not hasattr(length, 'resolve_expression') and length is not None and length < 0:
raise ValueError("'length' must be greater or equal to 0.")
super().__init__(expression, length, fill_text, **extra)
class LTrim(Transform):
function = 'LTRIM'
lookup_name = 'ltrim'
class MD5(OracleHashMixin, Transform):
function = 'MD5'
lookup_name = 'md5'
class Ord(Transform):
function = 'ASCII'
lookup_name = 'ord'
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='ORD', **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='UNICODE', **extra_context)
class Repeat(Func):
function = 'REPEAT'
output_field = CharField()
def __init__(self, expression, number, **extra):
if not hasattr(number, 'resolve_expression') and number is not None and number < 0:
raise ValueError("'number' must be greater or equal to 0.")
super().__init__(expression, number, **extra)
def as_oracle(self, compiler, connection, **extra_context):
expression, number = self.source_expressions
length = None if number is None else Length(expression) * number
rpad = RPad(expression, length, expression)
return rpad.as_sql(compiler, connection, **extra_context)
class Replace(Func):
function = 'REPLACE'
def __init__(self, expression, text, replacement=Value(''), **extra):
super().__init__(expression, text, replacement, **extra)
class Reverse(Transform):
function = 'REVERSE'
lookup_name = 'reverse'
def as_oracle(self, compiler, connection, **extra_context):
# REVERSE in Oracle is undocumented and doesn't support multi-byte
# strings. Use a special subquery instead.
return super().as_sql(
compiler, connection,
template=(
'(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM '
'(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s '
'FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) '
'GROUP BY %(expressions)s)'
),
**extra_context
)
class Right(Left):
function = 'RIGHT'
def get_substr(self):
return Substr(self.source_expressions[0], self.source_expressions[1] * Value(-1))
class RPad(LPad):
function = 'RPAD'
class RTrim(Transform):
function = 'RTRIM'
lookup_name = 'rtrim'
class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = 'SHA1'
lookup_name = 'sha1'
class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
function = 'SHA224'
lookup_name = 'sha224'
def as_oracle(self, compiler, connection, **extra_context):
raise NotSupportedError('SHA224 is not supported on Oracle.')
class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = 'SHA256'
lookup_name = 'sha256'
class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = 'SHA384'
lookup_name = 'sha384'
class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = 'SHA512'
lookup_name = 'sha512'
class StrIndex(Func):
"""
Return a positive integer corresponding to the 1-indexed position of the
first occurrence of a substring inside another string, or 0 if the
substring is not found.
"""
function = 'INSTR'
arity = 2
output_field = IntegerField()
def as_postgresql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='STRPOS', **extra_context)
class Substr(Func):
function = 'SUBSTRING'
output_field = CharField()
def __init__(self, expression, pos, length=None, **extra):
"""
expression: the name of a field, or an expression returning a string
pos: an integer > 0, or an expression returning an integer
length: an optional number of characters to return
"""
if not hasattr(pos, 'resolve_expression'):
if pos < 1:
raise ValueError("'pos' must be greater than 0")
expressions = [expression, pos]
if length is not None:
expressions.append(length)
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
class Trim(Transform):
function = 'TRIM'
lookup_name = 'trim'
class Upper(Transform):
function = 'UPPER'
lookup_name = 'upper'

View File

@@ -0,0 +1,108 @@
from django.db.models.expressions import Func
from django.db.models.fields import FloatField, IntegerField
__all__ = [
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
]
class CumeDist(Func):
function = 'CUME_DIST'
output_field = FloatField()
window_compatible = True
class DenseRank(Func):
function = 'DENSE_RANK'
output_field = IntegerField()
window_compatible = True
class FirstValue(Func):
arity = 1
function = 'FIRST_VALUE'
window_compatible = True
class LagLeadFunction(Func):
window_compatible = True
def __init__(self, expression, offset=1, default=None, **extra):
if expression is None:
raise ValueError(
'%s requires a non-null source expression.' %
self.__class__.__name__
)
if offset is None or offset <= 0:
raise ValueError(
'%s requires a positive integer for the offset.' %
self.__class__.__name__
)
args = (expression, offset)
if default is not None:
args += (default,)
super().__init__(*args, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Lag(LagLeadFunction):
function = 'LAG'
class LastValue(Func):
arity = 1
function = 'LAST_VALUE'
window_compatible = True
class Lead(LagLeadFunction):
function = 'LEAD'
class NthValue(Func):
function = 'NTH_VALUE'
window_compatible = True
def __init__(self, expression, nth=1, **extra):
if expression is None:
raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__)
if nth is None or nth <= 0:
raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__)
super().__init__(expression, nth, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Ntile(Func):
function = 'NTILE'
output_field = IntegerField()
window_compatible = True
def __init__(self, num_buckets=1, **extra):
if num_buckets <= 0:
raise ValueError('num_buckets must be greater than 0.')
super().__init__(num_buckets, **extra)
class PercentRank(Func):
function = 'PERCENT_RANK'
output_field = FloatField()
window_compatible = True
class Rank(Func):
function = 'RANK'
output_field = IntegerField()
window_compatible = True
class RowNumber(Func):
function = 'ROW_NUMBER'
output_field = IntegerField()
window_compatible = True

View File

@@ -0,0 +1,270 @@
from django.db.backends.utils import names_digest, split_identifier
from django.db.models.expressions import Col, ExpressionList, F, Func, OrderBy
from django.db.models.functions import Collate
from django.db.models.query_utils import Q
from django.db.models.sql import Query
from django.utils.functional import partition
__all__ = ['Index']
class Index:
suffix = 'idx'
# The max length of the name of the index (restricted to 30 for
# cross-database compatibility with Oracle)
max_name_length = 30
def __init__(
self,
*expressions,
fields=(),
name=None,
db_tablespace=None,
opclasses=(),
condition=None,
include=None,
):
if opclasses and not name:
raise ValueError('An index must be named to use opclasses.')
if not isinstance(condition, (type(None), Q)):
raise ValueError('Index.condition must be a Q instance.')
if condition and not name:
raise ValueError('An index must be named to use condition.')
if not isinstance(fields, (list, tuple)):
raise ValueError('Index.fields must be a list or tuple.')
if not isinstance(opclasses, (list, tuple)):
raise ValueError('Index.opclasses must be a list or tuple.')
if not expressions and not fields:
raise ValueError(
'At least one field or expression is required to define an '
'index.'
)
if expressions and fields:
raise ValueError(
'Index.fields and expressions are mutually exclusive.',
)
if expressions and not name:
raise ValueError('An index must be named to use expressions.')
if expressions and opclasses:
raise ValueError(
'Index.opclasses cannot be used with expressions. Use '
'django.contrib.postgres.indexes.OpClass() instead.'
)
if opclasses and len(fields) != len(opclasses):
raise ValueError('Index.fields and Index.opclasses must have the same number of elements.')
if fields and not all(isinstance(field, str) for field in fields):
raise ValueError('Index.fields must contain only strings with field names.')
if include and not name:
raise ValueError('A covering index must be named.')
if not isinstance(include, (type(None), list, tuple)):
raise ValueError('Index.include must be a list or tuple.')
self.fields = list(fields)
# A list of 2-tuple with the field name and ordering ('' or 'DESC').
self.fields_orders = [
(field_name[1:], 'DESC') if field_name.startswith('-') else (field_name, '')
for field_name in self.fields
]
self.name = name or ''
self.db_tablespace = db_tablespace
self.opclasses = opclasses
self.condition = condition
self.include = tuple(include) if include else ()
self.expressions = tuple(
F(expression) if isinstance(expression, str) else expression
for expression in expressions
)
@property
def contains_expressions(self):
return bool(self.expressions)
def _get_condition_sql(self, model, schema_editor):
if self.condition is None:
return None
query = Query(model=model, alias_cols=False)
where = query.build_where(self.condition)
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def create_sql(self, model, schema_editor, using='', **kwargs):
include = [model._meta.get_field(field_name).column for field_name in self.include]
condition = self._get_condition_sql(model, schema_editor)
if self.expressions:
index_expressions = []
for expression in self.expressions:
index_expression = IndexExpression(expression)
index_expression.set_wrapper_classes(schema_editor.connection)
index_expressions.append(index_expression)
expressions = ExpressionList(*index_expressions).resolve_expression(
Query(model, alias_cols=False),
)
fields = None
col_suffixes = None
else:
fields = [
model._meta.get_field(field_name)
for field_name, _ in self.fields_orders
]
col_suffixes = [order[1] for order in self.fields_orders]
expressions = None
return schema_editor._create_index_sql(
model, fields=fields, name=self.name, using=using,
db_tablespace=self.db_tablespace, col_suffixes=col_suffixes,
opclasses=self.opclasses, condition=condition, include=include,
expressions=expressions, **kwargs,
)
def remove_sql(self, model, schema_editor, **kwargs):
return schema_editor._delete_index_sql(model, self.name, **kwargs)
def deconstruct(self):
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
path = path.replace('django.db.models.indexes', 'django.db.models')
kwargs = {'name': self.name}
if self.fields:
kwargs['fields'] = self.fields
if self.db_tablespace is not None:
kwargs['db_tablespace'] = self.db_tablespace
if self.opclasses:
kwargs['opclasses'] = self.opclasses
if self.condition:
kwargs['condition'] = self.condition
if self.include:
kwargs['include'] = self.include
return (path, self.expressions, kwargs)
def clone(self):
"""Create a copy of this Index."""
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)
def set_name_with_model(self, model):
"""
Generate a unique name for the index.
The name is divided into 3 parts - table name (12 chars), field name
(8 chars) and unique hash + suffix (10 chars). Each part is made to
fit its size by truncating the excess length.
"""
_, table_name = split_identifier(model._meta.db_table)
column_names = [model._meta.get_field(field_name).column for field_name, order in self.fields_orders]
column_names_with_order = [
(('-%s' if order else '%s') % column_name)
for column_name, (field_name, order) in zip(column_names, self.fields_orders)
]
# The length of the parts of the name is based on the default max
# length of 30 characters.
hash_data = [table_name] + column_names_with_order + [self.suffix]
self.name = '%s_%s_%s' % (
table_name[:11],
column_names[0][:7],
'%s_%s' % (names_digest(*hash_data, length=6), self.suffix),
)
if len(self.name) > self.max_name_length:
raise ValueError(
'Index too long for multiple database support. Is self.suffix '
'longer than 3 characters?'
)
if self.name[0] == '_' or self.name[0].isdigit():
self.name = 'D%s' % self.name[1:]
def __repr__(self):
return '<%s:%s%s%s%s%s%s%s>' % (
self.__class__.__qualname__,
'' if not self.fields else ' fields=%s' % repr(self.fields),
'' if not self.expressions else ' expressions=%s' % repr(self.expressions),
'' if not self.name else ' name=%s' % repr(self.name),
''
if self.db_tablespace is None
else ' db_tablespace=%s' % repr(self.db_tablespace),
'' if self.condition is None else ' condition=%s' % self.condition,
'' if not self.include else ' include=%s' % repr(self.include),
'' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses),
)
def __eq__(self, other):
if self.__class__ == other.__class__:
return self.deconstruct() == other.deconstruct()
return NotImplemented
class IndexExpression(Func):
"""Order and wrap expressions for CREATE INDEX statements."""
template = '%(expressions)s'
wrapper_classes = (OrderBy, Collate)
def set_wrapper_classes(self, connection=None):
# Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
if connection and connection.features.collate_as_index_expression:
self.wrapper_classes = tuple([
wrapper_cls
for wrapper_cls in self.wrapper_classes
if wrapper_cls is not Collate
])
@classmethod
def register_wrappers(cls, *wrapper_classes):
cls.wrapper_classes = wrapper_classes
def resolve_expression(
self,
query=None,
allow_joins=True,
reuse=None,
summarize=False,
for_save=False,
):
expressions = list(self.flatten())
# Split expressions and wrappers.
index_expressions, wrappers = partition(
lambda e: isinstance(e, self.wrapper_classes),
expressions,
)
wrapper_types = [type(wrapper) for wrapper in wrappers]
if len(wrapper_types) != len(set(wrapper_types)):
raise ValueError(
"Multiple references to %s can't be used in an indexed "
"expression." % ', '.join([
wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
])
)
if expressions[1:len(wrappers) + 1] != wrappers:
raise ValueError(
'%s must be topmost expressions in an indexed expression.'
% ', '.join([
wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
])
)
# Wrap expressions in parentheses if they are not column references.
root_expression = index_expressions[1]
resolve_root_expression = root_expression.resolve_expression(
query,
allow_joins,
reuse,
summarize,
for_save,
)
if not isinstance(resolve_root_expression, Col):
root_expression = Func(root_expression, template='(%(expressions)s)')
if wrappers:
# Order wrappers and set their expressions.
wrappers = sorted(
wrappers,
key=lambda w: self.wrapper_classes.index(type(w)),
)
wrappers = [wrapper.copy() for wrapper in wrappers]
for i, wrapper in enumerate(wrappers[:-1]):
wrapper.set_source_expressions([wrappers[i + 1]])
# Set the root expression on the deepest wrapper.
wrappers[-1].set_source_expressions([root_expression])
self.set_source_expressions([wrappers[0]])
else:
# Use the root expression, if there are no wrappers.
self.set_source_expressions([root_expression])
return super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
def as_sqlite(self, compiler, connection, **extra_context):
# Casting to numeric is unnecessary.
return self.as_sql(compiler, connection, **extra_context)

View File

@@ -0,0 +1,687 @@
import itertools
import math
from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Case, Expression, Func, Value, When
from django.db.models.fields import (
BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField,
)
from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet
from django.utils.functional import cached_property
from django.utils.hashable import make_hashable
class Lookup(Expression):
lookup_name = None
prepare_rhs = True
can_use_none_as_rhs = False
def __init__(self, lhs, rhs):
self.lhs, self.rhs = lhs, rhs
self.rhs = self.get_prep_lookup()
self.lhs = self.get_prep_lhs()
if hasattr(self.lhs, 'get_bilateral_transforms'):
bilateral_transforms = self.lhs.get_bilateral_transforms()
else:
bilateral_transforms = []
if bilateral_transforms:
# Warn the user as soon as possible if they are trying to apply
# a bilateral transformation on a nested QuerySet: that won't work.
from django.db.models.sql.query import ( # avoid circular import
Query,
)
if isinstance(rhs, Query):
raise NotImplementedError("Bilateral transformations on nested querysets are not implemented.")
self.bilateral_transforms = bilateral_transforms
def apply_bilateral_transforms(self, value):
for transform in self.bilateral_transforms:
value = transform(value)
return value
def __repr__(self):
return f'{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})'
def batch_process_rhs(self, compiler, connection, rhs=None):
if rhs is None:
rhs = self.rhs
if self.bilateral_transforms:
sqls, sqls_params = [], []
for p in rhs:
value = Value(p, output_field=self.lhs.output_field)
value = self.apply_bilateral_transforms(value)
value = value.resolve_expression(compiler.query)
sql, sql_params = compiler.compile(value)
sqls.append(sql)
sqls_params.extend(sql_params)
else:
_, params = self.get_db_prep_lookup(rhs, connection)
sqls, sqls_params = ['%s'] * len(params), params
return sqls, sqls_params
def get_source_expressions(self):
if self.rhs_is_direct_value():
return [self.lhs]
return [self.lhs, self.rhs]
def set_source_expressions(self, new_exprs):
if len(new_exprs) == 1:
self.lhs = new_exprs[0]
else:
self.lhs, self.rhs = new_exprs
def get_prep_lookup(self):
if not self.prepare_rhs or hasattr(self.rhs, 'resolve_expression'):
return self.rhs
if hasattr(self.lhs, 'output_field'):
if hasattr(self.lhs.output_field, 'get_prep_value'):
return self.lhs.output_field.get_prep_value(self.rhs)
elif self.rhs_is_direct_value():
return Value(self.rhs)
return self.rhs
def get_prep_lhs(self):
if hasattr(self.lhs, 'resolve_expression'):
return self.lhs
return Value(self.lhs)
def get_db_prep_lookup(self, value, connection):
return ('%s', [value])
def process_lhs(self, compiler, connection, lhs=None):
lhs = lhs or self.lhs
if hasattr(lhs, 'resolve_expression'):
lhs = lhs.resolve_expression(compiler.query)
sql, params = compiler.compile(lhs)
if isinstance(lhs, Lookup):
# Wrapped in parentheses to respect operator precedence.
sql = f'({sql})'
return sql, params
def process_rhs(self, compiler, connection):
value = self.rhs
if self.bilateral_transforms:
if self.rhs_is_direct_value():
# Do not call get_db_prep_lookup here as the value will be
# transformed before being used for lookup
value = Value(value, output_field=self.lhs.output_field)
value = self.apply_bilateral_transforms(value)
value = value.resolve_expression(compiler.query)
if hasattr(value, 'as_sql'):
sql, params = compiler.compile(value)
# Ensure expression is wrapped in parentheses to respect operator
# precedence but avoid double wrapping as it can be misinterpreted
# on some backends (e.g. subqueries on SQLite).
if sql and sql[0] != '(':
sql = '(%s)' % sql
return sql, params
else:
return self.get_db_prep_lookup(value, connection)
def rhs_is_direct_value(self):
return not hasattr(self.rhs, 'as_sql')
def get_group_by_cols(self, alias=None):
cols = []
for source in self.get_source_expressions():
cols.extend(source.get_group_by_cols())
return cols
def as_oracle(self, compiler, connection):
# Oracle doesn't allow EXISTS() and filters to be compared to another
# expression unless they're wrapped in a CASE WHEN.
wrapped = False
exprs = []
for expr in (self.lhs, self.rhs):
if connection.ops.conditional_expression_supported_in_where_clause(expr):
expr = Case(When(expr, then=True), default=False)
wrapped = True
exprs.append(expr)
lookup = type(self)(*exprs) if wrapped else self
return lookup.as_sql(compiler, connection)
@cached_property
def output_field(self):
return BooleanField()
@property
def identity(self):
return self.__class__, self.lhs, self.rhs
def __eq__(self, other):
if not isinstance(other, Lookup):
return NotImplemented
return self.identity == other.identity
def __hash__(self):
return hash(make_hashable(self.identity))
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = self.copy()
c.is_summary = summarize
c.lhs = self.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
c.rhs = self.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return c
def select_format(self, compiler, sql, params):
# Wrap filters with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END'
return sql, params
class Transform(RegisterLookupMixin, Func):
"""
RegisterLookupMixin() is first so that get_lookup() and get_transform()
first examine self and then check output_field.
"""
bilateral = False
arity = 1
@property
def lhs(self):
return self.get_source_expressions()[0]
def get_bilateral_transforms(self):
if hasattr(self.lhs, 'get_bilateral_transforms'):
bilateral_transforms = self.lhs.get_bilateral_transforms()
else:
bilateral_transforms = []
if self.bilateral:
bilateral_transforms.append(self.__class__)
return bilateral_transforms
class BuiltinLookup(Lookup):
def process_lhs(self, compiler, connection, lhs=None):
lhs_sql, params = super().process_lhs(compiler, connection, lhs)
field_internal_type = self.lhs.output_field.get_internal_type()
db_type = self.lhs.output_field.db_type(connection=connection)
lhs_sql = connection.ops.field_cast_sql(
db_type, field_internal_type) % lhs_sql
lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
return lhs_sql, list(params)
def as_sql(self, compiler, connection):
lhs_sql, params = self.process_lhs(compiler, connection)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
rhs_sql = self.get_rhs_op(connection, rhs_sql)
return '%s %s' % (lhs_sql, rhs_sql), params
def get_rhs_op(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
class FieldGetDbPrepValueMixin:
"""
Some lookups require Field.get_db_prep_value() to be called on their
inputs.
"""
get_db_prep_lookup_value_is_iterable = False
def get_db_prep_lookup(self, value, connection):
# For relational fields, use the 'target_field' attribute of the
# output_field.
field = getattr(self.lhs.output_field, 'target_field', None)
get_db_prep_value = getattr(field, 'get_db_prep_value', None) or self.lhs.output_field.get_db_prep_value
return (
'%s',
[get_db_prep_value(v, connection, prepared=True) for v in value]
if self.get_db_prep_lookup_value_is_iterable else
[get_db_prep_value(value, connection, prepared=True)]
)
class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
"""
Some lookups require Field.get_db_prep_value() to be called on each value
in an iterable.
"""
get_db_prep_lookup_value_is_iterable = True
def get_prep_lookup(self):
if hasattr(self.rhs, 'resolve_expression'):
return self.rhs
prepared_values = []
for rhs_value in self.rhs:
if hasattr(rhs_value, 'resolve_expression'):
# An expression will be handled by the database but can coexist
# alongside real values.
pass
elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
prepared_values.append(rhs_value)
return prepared_values
def process_rhs(self, compiler, connection):
if self.rhs_is_direct_value():
# rhs should be an iterable of values. Use batch_process_rhs()
# to prepare/transform those values.
return self.batch_process_rhs(compiler, connection)
else:
return super().process_rhs(compiler, connection)
def resolve_expression_parameter(self, compiler, connection, sql, param):
params = [param]
if hasattr(param, 'resolve_expression'):
param = param.resolve_expression(compiler.query)
if hasattr(param, 'as_sql'):
sql, params = compiler.compile(param)
return sql, params
def batch_process_rhs(self, compiler, connection, rhs=None):
pre_processed = super().batch_process_rhs(compiler, connection, rhs)
# The params list may contain expressions which compile to a
# sql/param pair. Zip them to get sql and param pairs that refer to the
# same argument and attempt to replace them with the result of
# compiling the param step.
sql, params = zip(*(
self.resolve_expression_parameter(compiler, connection, sql, param)
for sql, param in zip(*pre_processed)
))
params = itertools.chain.from_iterable(params)
return sql, tuple(params)
class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup):
"""Lookup defined by operators on PostgreSQL."""
postgres_operator = None
def as_postgresql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
return '%s %s %s' % (lhs, self.postgres_operator, rhs), params
@Field.register_lookup
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'exact'
def process_rhs(self, compiler, connection):
from django.db.models.sql.query import Query
if isinstance(self.rhs, Query):
if self.rhs.has_limit_one():
if not self.rhs.has_select_fields:
self.rhs.clear_select_clause()
self.rhs.add_fields(['pk'])
else:
raise ValueError(
'The QuerySet value for an exact lookup must be limited to '
'one result using slicing.'
)
return super().process_rhs(compiler, connection)
def as_sql(self, compiler, connection):
# Avoid comparison against direct rhs if lhs is a boolean value. That
# turns "boolfield__exact=True" into "WHERE boolean_field" instead of
# "WHERE boolean_field = True" when allowed.
if (
isinstance(self.rhs, bool) and
getattr(self.lhs, 'conditional', False) and
connection.ops.conditional_expression_supported_in_where_clause(self.lhs)
):
lhs_sql, params = self.process_lhs(compiler, connection)
template = '%s' if self.rhs else 'NOT %s'
return template % lhs_sql, params
return super().as_sql(compiler, connection)
@Field.register_lookup
class IExact(BuiltinLookup):
lookup_name = 'iexact'
prepare_rhs = False
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if params:
params[0] = connection.ops.prep_for_iexact_query(params[0])
return rhs, params
@Field.register_lookup
class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'gt'
@Field.register_lookup
class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'gte'
@Field.register_lookup
class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'lt'
@Field.register_lookup
class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'lte'
class IntegerFieldFloatRounding:
"""
Allow floats to work as query values for IntegerField. Without this, the
decimal portion of the float would always be discarded.
"""
def get_prep_lookup(self):
if isinstance(self.rhs, float):
self.rhs = math.ceil(self.rhs)
return super().get_prep_lookup()
@IntegerField.register_lookup
class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual):
pass
@IntegerField.register_lookup
class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
pass
@Field.register_lookup
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = 'in'
def process_rhs(self, compiler, connection):
db_rhs = getattr(self.rhs, '_db', None)
if db_rhs is not None and db_rhs != connection.alias:
raise ValueError(
"Subqueries aren't allowed across different databases. Force "
"the inner query to be evaluated using `list(inner_query)`."
)
if self.rhs_is_direct_value():
# Remove None from the list as NULL is never equal to anything.
try:
rhs = OrderedSet(self.rhs)
rhs.discard(None)
except TypeError: # Unhashable items in self.rhs
rhs = [r for r in self.rhs if r is not None]
if not rhs:
raise EmptyResultSet
# rhs should be an iterable; use batch_process_rhs() to
# prepare/transform those values.
sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
placeholder = '(' + ', '.join(sqls) + ')'
return (placeholder, sqls_params)
else:
from django.db.models.sql.query import ( # avoid circular import
Query,
)
if isinstance(self.rhs, Query):
query = self.rhs
query.clear_ordering(clear_default=True)
if not query.has_select_fields:
query.clear_select_clause()
query.add_fields(['pk'])
return super().process_rhs(compiler, connection)
def get_group_by_cols(self, alias=None):
cols = self.lhs.get_group_by_cols()
if hasattr(self.rhs, 'get_group_by_cols'):
if not getattr(self.rhs, 'has_select_fields', True):
self.rhs.clear_select_clause()
self.rhs.add_fields(['pk'])
cols.extend(self.rhs.get_group_by_cols())
return cols
def get_rhs_op(self, connection, rhs):
return 'IN %s' % rhs
def as_sql(self, compiler, connection):
max_in_list_size = connection.ops.max_in_list_size()
if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
return self.split_parameter_list_as_sql(compiler, connection)
return super().as_sql(compiler, connection)
def split_parameter_list_as_sql(self, compiler, connection):
# This is a special case for databases which limit the number of
# elements which can appear in an 'IN' clause.
max_in_list_size = connection.ops.max_in_list_size()
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.batch_process_rhs(compiler, connection)
in_clause_elements = ['(']
params = []
for offset in range(0, len(rhs_params), max_in_list_size):
if offset > 0:
in_clause_elements.append(' OR ')
in_clause_elements.append('%s IN (' % lhs)
params.extend(lhs_params)
sqls = rhs[offset: offset + max_in_list_size]
sqls_params = rhs_params[offset: offset + max_in_list_size]
param_group = ', '.join(sqls)
in_clause_elements.append(param_group)
in_clause_elements.append(')')
params.extend(sqls_params)
in_clause_elements.append(')')
return ''.join(in_clause_elements), params
class PatternLookup(BuiltinLookup):
param_pattern = '%%%s%%'
prepare_rhs = False
def get_rhs_op(self, connection, rhs):
# Assume we are in startswith. We need to produce SQL like:
# col LIKE %s, ['thevalue%']
# For python values we can (and should) do that directly in Python,
# but if the value is for example reference to other column, then
# we need to add the % pattern match to the lookup by something like
# col LIKE othercol || '%%'
# So, for Python values we don't need any special pattern, but for
# SQL reference values or SQL transformations we need the correct
# pattern added.
if hasattr(self.rhs, 'as_sql') or self.bilateral_transforms:
pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
return pattern.format(rhs)
else:
return super().get_rhs_op(connection, rhs)
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
params[0] = self.param_pattern % connection.ops.prep_for_like_query(params[0])
return rhs, params
@Field.register_lookup
class Contains(PatternLookup):
lookup_name = 'contains'
@Field.register_lookup
class IContains(Contains):
lookup_name = 'icontains'
@Field.register_lookup
class StartsWith(PatternLookup):
lookup_name = 'startswith'
param_pattern = '%s%%'
@Field.register_lookup
class IStartsWith(StartsWith):
lookup_name = 'istartswith'
@Field.register_lookup
class EndsWith(PatternLookup):
lookup_name = 'endswith'
param_pattern = '%%%s'
@Field.register_lookup
class IEndsWith(EndsWith):
lookup_name = 'iendswith'
@Field.register_lookup
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = 'range'
def get_rhs_op(self, connection, rhs):
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
@Field.register_lookup
class IsNull(BuiltinLookup):
lookup_name = 'isnull'
prepare_rhs = False
def as_sql(self, compiler, connection):
if not isinstance(self.rhs, bool):
raise ValueError(
'The QuerySet value for an isnull lookup must be True or '
'False.'
)
sql, params = compiler.compile(self.lhs)
if self.rhs:
return "%s IS NULL" % sql, params
else:
return "%s IS NOT NULL" % sql, params
@Field.register_lookup
class Regex(BuiltinLookup):
lookup_name = 'regex'
prepare_rhs = False
def as_sql(self, compiler, connection):
if self.lookup_name in connection.operators:
return super().as_sql(compiler, connection)
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
sql_template = connection.ops.regex_lookup(self.lookup_name)
return sql_template % (lhs, rhs), lhs_params + rhs_params
@Field.register_lookup
class IRegex(Regex):
lookup_name = 'iregex'
class YearLookup(Lookup):
def year_lookup_bounds(self, connection, year):
from django.db.models.functions import ExtractIsoYear
iso_year = isinstance(self.lhs, ExtractIsoYear)
output_field = self.lhs.lhs.output_field
if isinstance(output_field, DateTimeField):
bounds = connection.ops.year_lookup_bounds_for_datetime_field(
year, iso_year=iso_year,
)
else:
bounds = connection.ops.year_lookup_bounds_for_date_field(
year, iso_year=iso_year,
)
return bounds
def as_sql(self, compiler, connection):
# Avoid the extract operation if the rhs is a direct value to allow
# indexes to be used.
if self.rhs_is_direct_value():
# Skip the extract part by directly using the originating field,
# that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, _ = self.process_rhs(compiler, connection)
rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
start, finish = self.year_lookup_bounds(connection, self.rhs)
params.extend(self.get_bound_params(start, finish))
return '%s %s' % (lhs_sql, rhs_sql), params
return super().as_sql(compiler, connection)
def get_direct_rhs_sql(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
def get_bound_params(self, start, finish):
raise NotImplementedError(
'subclasses of YearLookup must provide a get_bound_params() method'
)
class YearExact(YearLookup, Exact):
def get_direct_rhs_sql(self, connection, rhs):
return 'BETWEEN %s AND %s'
def get_bound_params(self, start, finish):
return (start, finish)
class YearGt(YearLookup, GreaterThan):
def get_bound_params(self, start, finish):
return (finish,)
class YearGte(YearLookup, GreaterThanOrEqual):
def get_bound_params(self, start, finish):
return (start,)
class YearLt(YearLookup, LessThan):
def get_bound_params(self, start, finish):
return (start,)
class YearLte(YearLookup, LessThanOrEqual):
def get_bound_params(self, start, finish):
return (finish,)
class UUIDTextMixin:
"""
Strip hyphens from a value when filtering a UUIDField on backends without
a native datatype for UUID.
"""
def process_rhs(self, qn, connection):
if not connection.features.has_native_uuid_field:
from django.db.models.functions import Replace
if self.rhs_is_direct_value():
self.rhs = Value(self.rhs)
self.rhs = Replace(self.rhs, Value('-'), Value(''), output_field=CharField())
rhs, params = super().process_rhs(qn, connection)
return rhs, params
@UUIDField.register_lookup
class UUIDIExact(UUIDTextMixin, IExact):
pass
@UUIDField.register_lookup
class UUIDContains(UUIDTextMixin, Contains):
pass
@UUIDField.register_lookup
class UUIDIContains(UUIDTextMixin, IContains):
pass
@UUIDField.register_lookup
class UUIDStartsWith(UUIDTextMixin, StartsWith):
pass
@UUIDField.register_lookup
class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
pass
@UUIDField.register_lookup
class UUIDEndsWith(UUIDTextMixin, EndsWith):
pass
@UUIDField.register_lookup
class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
pass

View File

@@ -0,0 +1,203 @@
import copy
import inspect
from importlib import import_module
from django.db import router
from django.db.models.query import QuerySet
class BaseManager:
# To retain order, track each time a Manager instance is created.
creation_counter = 0
# Set to True for the 'objects' managers that are automatically created.
auto_created = False
#: If set to True the manager will be serialized into migrations and will
#: thus be available in e.g. RunPython operations.
use_in_migrations = False
def __new__(cls, *args, **kwargs):
# Capture the arguments to make returning them trivial.
obj = super().__new__(cls)
obj._constructor_args = (args, kwargs)
return obj
def __init__(self):
super().__init__()
self._set_creation_counter()
self.model = None
self.name = None
self._db = None
self._hints = {}
def __str__(self):
"""Return "app_label.model_label.manager_name"."""
return '%s.%s' % (self.model._meta.label, self.name)
def __class_getitem__(cls, *args, **kwargs):
return cls
def deconstruct(self):
"""
Return a 5-tuple of the form (as_manager (True), manager_class,
queryset_class, args, kwargs).
Raise a ValueError if the manager is dynamically generated.
"""
qs_class = self._queryset_class
if getattr(self, '_built_with_as_manager', False):
# using MyQuerySet.as_manager()
return (
True, # as_manager
None, # manager_class
'%s.%s' % (qs_class.__module__, qs_class.__name__), # qs_class
None, # args
None, # kwargs
)
else:
module_name = self.__module__
name = self.__class__.__name__
# Make sure it's actually there and not an inner class
module = import_module(module_name)
if not hasattr(module, name):
raise ValueError(
"Could not find manager %s in %s.\n"
"Please note that you need to inherit from managers you "
"dynamically generated with 'from_queryset()'."
% (name, module_name)
)
return (
False, # as_manager
'%s.%s' % (module_name, name), # manager_class
None, # qs_class
self._constructor_args[0], # args
self._constructor_args[1], # kwargs
)
def check(self, **kwargs):
return []
@classmethod
def _get_queryset_methods(cls, queryset_class):
def create_method(name, method):
def manager_method(self, *args, **kwargs):
return getattr(self.get_queryset(), name)(*args, **kwargs)
manager_method.__name__ = method.__name__
manager_method.__doc__ = method.__doc__
return manager_method
new_methods = {}
for name, method in inspect.getmembers(queryset_class, predicate=inspect.isfunction):
# Only copy missing methods.
if hasattr(cls, name):
continue
# Only copy public methods or methods with the attribute `queryset_only=False`.
queryset_only = getattr(method, 'queryset_only', None)
if queryset_only or (queryset_only is None and name.startswith('_')):
continue
# Copy the method onto the manager.
new_methods[name] = create_method(name, method)
return new_methods
@classmethod
def from_queryset(cls, queryset_class, class_name=None):
if class_name is None:
class_name = '%sFrom%s' % (cls.__name__, queryset_class.__name__)
return type(class_name, (cls,), {
'_queryset_class': queryset_class,
**cls._get_queryset_methods(queryset_class),
})
def contribute_to_class(self, cls, name):
self.name = self.name or name
self.model = cls
setattr(cls, name, ManagerDescriptor(self))
cls._meta.add_manager(self)
def _set_creation_counter(self):
"""
Set the creation counter value for this instance and increment the
class-level copy.
"""
self.creation_counter = BaseManager.creation_counter
BaseManager.creation_counter += 1
def db_manager(self, using=None, hints=None):
obj = copy.copy(self)
obj._db = using or self._db
obj._hints = hints or self._hints
return obj
@property
def db(self):
return self._db or router.db_for_read(self.model, **self._hints)
#######################
# PROXIES TO QUERYSET #
#######################
def get_queryset(self):
"""
Return a new QuerySet object. Subclasses can override this method to
customize the behavior of the Manager.
"""
return self._queryset_class(model=self.model, using=self._db, hints=self._hints)
def all(self):
# We can't proxy this method through the `QuerySet` like we do for the
# rest of the `QuerySet` methods. This is because `QuerySet.all()`
# works by creating a "copy" of the current queryset and in making said
# copy, all the cached `prefetch_related` lookups are lost. See the
# implementation of `RelatedManager.get_queryset()` for a better
# understanding of how this comes into play.
return self.get_queryset()
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
self._constructor_args == other._constructor_args
)
def __hash__(self):
return id(self)
class Manager(BaseManager.from_queryset(QuerySet)):
pass
class ManagerDescriptor:
def __init__(self, manager):
self.manager = manager
def __get__(self, instance, cls=None):
if instance is not None:
raise AttributeError("Manager isn't accessible via %s instances" % cls.__name__)
if cls._meta.abstract:
raise AttributeError("Manager isn't available; %s is abstract" % (
cls._meta.object_name,
))
if cls._meta.swapped:
raise AttributeError(
"Manager isn't available; '%s' has been swapped for '%s'" % (
cls._meta.label,
cls._meta.swapped,
)
)
return cls._meta.managers_map[self.manager.name]
class EmptyManager(Manager):
def __init__(self, model):
super().__init__()
self.model = model
def get_queryset(self):
return super().get_queryset().none()

Some files were not shown because too many files have changed in this diff Show More