Ajoutez des fichiers projet.
This commit is contained in:
42
venv/Lib/site-packages/django/db/__init__.py
Normal file
42
venv/Lib/site-packages/django/db/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from django.core import signals
|
||||
from django.db.utils import (
|
||||
DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY, ConnectionHandler,
|
||||
ConnectionRouter, DatabaseError, DataError, Error, IntegrityError,
|
||||
InterfaceError, InternalError, NotSupportedError, OperationalError,
|
||||
ProgrammingError,
|
||||
)
|
||||
from django.utils.connection import ConnectionProxy
|
||||
|
||||
__all__ = [
|
||||
'connection', 'connections', 'router', 'DatabaseError', 'IntegrityError',
|
||||
'InternalError', 'ProgrammingError', 'DataError', 'NotSupportedError',
|
||||
'Error', 'InterfaceError', 'OperationalError', 'DEFAULT_DB_ALIAS',
|
||||
'DJANGO_VERSION_PICKLE_KEY',
|
||||
]
|
||||
|
||||
connections = ConnectionHandler()
|
||||
|
||||
router = ConnectionRouter()
|
||||
|
||||
# For backwards compatibility. Prefer connections['default'] instead.
|
||||
connection = ConnectionProxy(connections, DEFAULT_DB_ALIAS)
|
||||
|
||||
|
||||
# Register an event to reset saved queries when a Django request is started.
|
||||
def reset_queries(**kwargs):
|
||||
for conn in connections.all():
|
||||
conn.queries_log.clear()
|
||||
|
||||
|
||||
signals.request_started.connect(reset_queries)
|
||||
|
||||
|
||||
# Register an event to reset transaction state and close connections past
|
||||
# their lifetime.
|
||||
def close_old_connections(**kwargs):
|
||||
for conn in connections.all():
|
||||
conn.close_if_unusable_or_obsolete()
|
||||
|
||||
|
||||
signals.request_started.connect(close_old_connections)
|
||||
signals.request_finished.connect(close_old_connections)
|
||||
687
venv/Lib/site-packages/django/db/backends/base/base.py
Normal file
687
venv/Lib/site-packages/django/db/backends/base/base.py
Normal file
@@ -0,0 +1,687 @@
|
||||
import _thread
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
|
||||
try:
|
||||
import zoneinfo
|
||||
except ImportError:
|
||||
from backports import zoneinfo
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError
|
||||
from django.db.backends import utils
|
||||
from django.db.backends.base.validation import BaseDatabaseValidation
|
||||
from django.db.backends.signals import connection_created
|
||||
from django.db.transaction import TransactionManagementError
|
||||
from django.db.utils import DatabaseErrorWrapper
|
||||
from django.utils import timezone
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
NO_DB_ALIAS = '__no_db__'
|
||||
|
||||
|
||||
# RemovedInDjango50Warning
|
||||
def timezone_constructor(tzname):
|
||||
if settings.USE_DEPRECATED_PYTZ:
|
||||
import pytz
|
||||
return pytz.timezone(tzname)
|
||||
return zoneinfo.ZoneInfo(tzname)
|
||||
|
||||
|
||||
class BaseDatabaseWrapper:
|
||||
"""Represent a database connection."""
|
||||
# Mapping of Field objects to their column types.
|
||||
data_types = {}
|
||||
# Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
|
||||
data_types_suffix = {}
|
||||
# Mapping of Field objects to their SQL for CHECK constraints.
|
||||
data_type_check_constraints = {}
|
||||
ops = None
|
||||
vendor = 'unknown'
|
||||
display_name = 'unknown'
|
||||
SchemaEditorClass = None
|
||||
# Classes instantiated in __init__().
|
||||
client_class = None
|
||||
creation_class = None
|
||||
features_class = None
|
||||
introspection_class = None
|
||||
ops_class = None
|
||||
validation_class = BaseDatabaseValidation
|
||||
|
||||
queries_limit = 9000
|
||||
|
||||
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
|
||||
# Connection related attributes.
|
||||
# The underlying database connection.
|
||||
self.connection = None
|
||||
# `settings_dict` should be a dictionary containing keys such as
|
||||
# NAME, USER, etc. It's called `settings_dict` instead of `settings`
|
||||
# to disambiguate it from Django settings modules.
|
||||
self.settings_dict = settings_dict
|
||||
self.alias = alias
|
||||
# Query logging in debug mode or when explicitly enabled.
|
||||
self.queries_log = deque(maxlen=self.queries_limit)
|
||||
self.force_debug_cursor = False
|
||||
|
||||
# Transaction related attributes.
|
||||
# Tracks if the connection is in autocommit mode. Per PEP 249, by
|
||||
# default, it isn't.
|
||||
self.autocommit = False
|
||||
# Tracks if the connection is in a transaction managed by 'atomic'.
|
||||
self.in_atomic_block = False
|
||||
# Increment to generate unique savepoint ids.
|
||||
self.savepoint_state = 0
|
||||
# List of savepoints created by 'atomic'.
|
||||
self.savepoint_ids = []
|
||||
# Tracks if the outermost 'atomic' block should commit on exit,
|
||||
# ie. if autocommit was active on entry.
|
||||
self.commit_on_exit = True
|
||||
# Tracks if the transaction should be rolled back to the next
|
||||
# available savepoint because of an exception in an inner block.
|
||||
self.needs_rollback = False
|
||||
|
||||
# Connection termination related attributes.
|
||||
self.close_at = None
|
||||
self.closed_in_transaction = False
|
||||
self.errors_occurred = False
|
||||
|
||||
# Thread-safety related attributes.
|
||||
self._thread_sharing_lock = threading.Lock()
|
||||
self._thread_sharing_count = 0
|
||||
self._thread_ident = _thread.get_ident()
|
||||
|
||||
# A list of no-argument functions to run when the transaction commits.
|
||||
# Each entry is an (sids, func) tuple, where sids is a set of the
|
||||
# active savepoint IDs when this function was registered.
|
||||
self.run_on_commit = []
|
||||
|
||||
# Should we run the on-commit hooks the next time set_autocommit(True)
|
||||
# is called?
|
||||
self.run_commit_hooks_on_set_autocommit_on = False
|
||||
|
||||
# A stack of wrappers to be invoked around execute()/executemany()
|
||||
# calls. Each entry is a function taking five arguments: execute, sql,
|
||||
# params, many, and context. It's the function's responsibility to
|
||||
# call execute(sql, params, many, context).
|
||||
self.execute_wrappers = []
|
||||
|
||||
self.client = self.client_class(self)
|
||||
self.creation = self.creation_class(self)
|
||||
self.features = self.features_class(self)
|
||||
self.introspection = self.introspection_class(self)
|
||||
self.ops = self.ops_class(self)
|
||||
self.validation = self.validation_class(self)
|
||||
|
||||
def ensure_timezone(self):
|
||||
"""
|
||||
Ensure the connection's timezone is set to `self.timezone_name` and
|
||||
return whether it changed or not.
|
||||
"""
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def timezone(self):
|
||||
"""
|
||||
Return a tzinfo of the database connection time zone.
|
||||
|
||||
This is only used when time zone support is enabled. When a datetime is
|
||||
read from the database, it is always returned in this time zone.
|
||||
|
||||
When the database backend supports time zones, it doesn't matter which
|
||||
time zone Django uses, as long as aware datetimes are used everywhere.
|
||||
Other users connecting to the database can choose their own time zone.
|
||||
|
||||
When the database backend doesn't support time zones, the time zone
|
||||
Django uses may be constrained by the requirements of other users of
|
||||
the database.
|
||||
"""
|
||||
if not settings.USE_TZ:
|
||||
return None
|
||||
elif self.settings_dict['TIME_ZONE'] is None:
|
||||
return timezone.utc
|
||||
else:
|
||||
return timezone_constructor(self.settings_dict['TIME_ZONE'])
|
||||
|
||||
@cached_property
|
||||
def timezone_name(self):
|
||||
"""
|
||||
Name of the time zone of the database connection.
|
||||
"""
|
||||
if not settings.USE_TZ:
|
||||
return settings.TIME_ZONE
|
||||
elif self.settings_dict['TIME_ZONE'] is None:
|
||||
return 'UTC'
|
||||
else:
|
||||
return self.settings_dict['TIME_ZONE']
|
||||
|
||||
@property
|
||||
def queries_logged(self):
|
||||
return self.force_debug_cursor or settings.DEBUG
|
||||
|
||||
@property
|
||||
def queries(self):
|
||||
if len(self.queries_log) == self.queries_log.maxlen:
|
||||
warnings.warn(
|
||||
"Limit for query logging exceeded, only the last {} queries "
|
||||
"will be returned.".format(self.queries_log.maxlen))
|
||||
return list(self.queries_log)
|
||||
|
||||
# ##### Backend-specific methods for creating connections and cursors #####
|
||||
|
||||
def get_connection_params(self):
|
||||
"""Return a dict of parameters suitable for get_new_connection."""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_connection_params() method')
|
||||
|
||||
def get_new_connection(self, conn_params):
|
||||
"""Open a connection to the database."""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_new_connection() method')
|
||||
|
||||
def init_connection_state(self):
|
||||
"""Initialize the database connection settings."""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require an init_connection_state() method')
|
||||
|
||||
def create_cursor(self, name=None):
|
||||
"""Create a cursor. Assume that a connection is established."""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a create_cursor() method')
|
||||
|
||||
# ##### Backend-specific methods for creating connections #####
|
||||
|
||||
@async_unsafe
|
||||
def connect(self):
|
||||
"""Connect to the database. Assume that the connection is closed."""
|
||||
# Check for invalid configurations.
|
||||
self.check_settings()
|
||||
# In case the previous connection was closed while in an atomic block
|
||||
self.in_atomic_block = False
|
||||
self.savepoint_ids = []
|
||||
self.needs_rollback = False
|
||||
# Reset parameters defining when to close the connection
|
||||
max_age = self.settings_dict['CONN_MAX_AGE']
|
||||
self.close_at = None if max_age is None else time.monotonic() + max_age
|
||||
self.closed_in_transaction = False
|
||||
self.errors_occurred = False
|
||||
# Establish the connection
|
||||
conn_params = self.get_connection_params()
|
||||
self.connection = self.get_new_connection(conn_params)
|
||||
self.set_autocommit(self.settings_dict['AUTOCOMMIT'])
|
||||
self.init_connection_state()
|
||||
connection_created.send(sender=self.__class__, connection=self)
|
||||
|
||||
self.run_on_commit = []
|
||||
|
||||
def check_settings(self):
|
||||
if self.settings_dict['TIME_ZONE'] is not None and not settings.USE_TZ:
|
||||
raise ImproperlyConfigured(
|
||||
"Connection '%s' cannot set TIME_ZONE because USE_TZ is False."
|
||||
% self.alias
|
||||
)
|
||||
|
||||
@async_unsafe
|
||||
def ensure_connection(self):
|
||||
"""Guarantee that a connection to the database is established."""
|
||||
if self.connection is None:
|
||||
with self.wrap_database_errors:
|
||||
self.connect()
|
||||
|
||||
# ##### Backend-specific wrappers for PEP-249 connection methods #####
|
||||
|
||||
def _prepare_cursor(self, cursor):
|
||||
"""
|
||||
Validate the connection is usable and perform database cursor wrapping.
|
||||
"""
|
||||
self.validate_thread_sharing()
|
||||
if self.queries_logged:
|
||||
wrapped_cursor = self.make_debug_cursor(cursor)
|
||||
else:
|
||||
wrapped_cursor = self.make_cursor(cursor)
|
||||
return wrapped_cursor
|
||||
|
||||
def _cursor(self, name=None):
|
||||
self.ensure_connection()
|
||||
with self.wrap_database_errors:
|
||||
return self._prepare_cursor(self.create_cursor(name))
|
||||
|
||||
def _commit(self):
|
||||
if self.connection is not None:
|
||||
with self.wrap_database_errors:
|
||||
return self.connection.commit()
|
||||
|
||||
def _rollback(self):
|
||||
if self.connection is not None:
|
||||
with self.wrap_database_errors:
|
||||
return self.connection.rollback()
|
||||
|
||||
def _close(self):
|
||||
if self.connection is not None:
|
||||
with self.wrap_database_errors:
|
||||
return self.connection.close()
|
||||
|
||||
# ##### Generic wrappers for PEP-249 connection methods #####
|
||||
|
||||
@async_unsafe
|
||||
def cursor(self):
|
||||
"""Create a cursor, opening a connection if necessary."""
|
||||
return self._cursor()
|
||||
|
||||
@async_unsafe
|
||||
def commit(self):
|
||||
"""Commit a transaction and reset the dirty flag."""
|
||||
self.validate_thread_sharing()
|
||||
self.validate_no_atomic_block()
|
||||
self._commit()
|
||||
# A successful commit means that the database connection works.
|
||||
self.errors_occurred = False
|
||||
self.run_commit_hooks_on_set_autocommit_on = True
|
||||
|
||||
@async_unsafe
|
||||
def rollback(self):
|
||||
"""Roll back a transaction and reset the dirty flag."""
|
||||
self.validate_thread_sharing()
|
||||
self.validate_no_atomic_block()
|
||||
self._rollback()
|
||||
# A successful rollback means that the database connection works.
|
||||
self.errors_occurred = False
|
||||
self.needs_rollback = False
|
||||
self.run_on_commit = []
|
||||
|
||||
@async_unsafe
|
||||
def close(self):
|
||||
"""Close the connection to the database."""
|
||||
self.validate_thread_sharing()
|
||||
self.run_on_commit = []
|
||||
|
||||
# Don't call validate_no_atomic_block() to avoid making it difficult
|
||||
# to get rid of a connection in an invalid state. The next connect()
|
||||
# will reset the transaction state anyway.
|
||||
if self.closed_in_transaction or self.connection is None:
|
||||
return
|
||||
try:
|
||||
self._close()
|
||||
finally:
|
||||
if self.in_atomic_block:
|
||||
self.closed_in_transaction = True
|
||||
self.needs_rollback = True
|
||||
else:
|
||||
self.connection = None
|
||||
|
||||
# ##### Backend-specific savepoint management methods #####
|
||||
|
||||
def _savepoint(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_create_sql(sid))
|
||||
|
||||
def _savepoint_rollback(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_rollback_sql(sid))
|
||||
|
||||
def _savepoint_commit(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_commit_sql(sid))
|
||||
|
||||
def _savepoint_allowed(self):
|
||||
# Savepoints cannot be created outside a transaction
|
||||
return self.features.uses_savepoints and not self.get_autocommit()
|
||||
|
||||
# ##### Generic savepoint management methods #####
|
||||
|
||||
@async_unsafe
|
||||
def savepoint(self):
|
||||
"""
|
||||
Create a savepoint inside the current transaction. Return an
|
||||
identifier for the savepoint that will be used for the subsequent
|
||||
rollback or commit. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not self._savepoint_allowed():
|
||||
return
|
||||
|
||||
thread_ident = _thread.get_ident()
|
||||
tid = str(thread_ident).replace('-', '')
|
||||
|
||||
self.savepoint_state += 1
|
||||
sid = "s%s_x%d" % (tid, self.savepoint_state)
|
||||
|
||||
self.validate_thread_sharing()
|
||||
self._savepoint(sid)
|
||||
|
||||
return sid
|
||||
|
||||
@async_unsafe
|
||||
def savepoint_rollback(self, sid):
|
||||
"""
|
||||
Roll back to a savepoint. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not self._savepoint_allowed():
|
||||
return
|
||||
|
||||
self.validate_thread_sharing()
|
||||
self._savepoint_rollback(sid)
|
||||
|
||||
# Remove any callbacks registered while this savepoint was active.
|
||||
self.run_on_commit = [
|
||||
(sids, func) for (sids, func) in self.run_on_commit if sid not in sids
|
||||
]
|
||||
|
||||
@async_unsafe
|
||||
def savepoint_commit(self, sid):
|
||||
"""
|
||||
Release a savepoint. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not self._savepoint_allowed():
|
||||
return
|
||||
|
||||
self.validate_thread_sharing()
|
||||
self._savepoint_commit(sid)
|
||||
|
||||
@async_unsafe
|
||||
def clean_savepoints(self):
|
||||
"""
|
||||
Reset the counter used to generate unique savepoint ids in this thread.
|
||||
"""
|
||||
self.savepoint_state = 0
|
||||
|
||||
# ##### Backend-specific transaction management methods #####
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
"""
|
||||
Backend-specific implementation to enable or disable autocommit.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a _set_autocommit() method')
|
||||
|
||||
# ##### Generic transaction management methods #####
|
||||
|
||||
def get_autocommit(self):
|
||||
"""Get the autocommit state."""
|
||||
self.ensure_connection()
|
||||
return self.autocommit
|
||||
|
||||
def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
|
||||
"""
|
||||
Enable or disable autocommit.
|
||||
|
||||
The usual way to start a transaction is to turn autocommit off.
|
||||
SQLite does not properly start a transaction when disabling
|
||||
autocommit. To avoid this buggy behavior and to actually enter a new
|
||||
transaction, an explicit BEGIN is required. Using
|
||||
force_begin_transaction_with_broken_autocommit=True will issue an
|
||||
explicit BEGIN with SQLite. This option will be ignored for other
|
||||
backends.
|
||||
"""
|
||||
self.validate_no_atomic_block()
|
||||
self.ensure_connection()
|
||||
|
||||
start_transaction_under_autocommit = (
|
||||
force_begin_transaction_with_broken_autocommit and not autocommit and
|
||||
hasattr(self, '_start_transaction_under_autocommit')
|
||||
)
|
||||
|
||||
if start_transaction_under_autocommit:
|
||||
self._start_transaction_under_autocommit()
|
||||
else:
|
||||
self._set_autocommit(autocommit)
|
||||
|
||||
self.autocommit = autocommit
|
||||
|
||||
if autocommit and self.run_commit_hooks_on_set_autocommit_on:
|
||||
self.run_and_clear_commit_hooks()
|
||||
self.run_commit_hooks_on_set_autocommit_on = False
|
||||
|
||||
def get_rollback(self):
|
||||
"""Get the "needs rollback" flag -- for *advanced use* only."""
|
||||
if not self.in_atomic_block:
|
||||
raise TransactionManagementError(
|
||||
"The rollback flag doesn't work outside of an 'atomic' block.")
|
||||
return self.needs_rollback
|
||||
|
||||
def set_rollback(self, rollback):
|
||||
"""
|
||||
Set or unset the "needs rollback" flag -- for *advanced use* only.
|
||||
"""
|
||||
if not self.in_atomic_block:
|
||||
raise TransactionManagementError(
|
||||
"The rollback flag doesn't work outside of an 'atomic' block.")
|
||||
self.needs_rollback = rollback
|
||||
|
||||
def validate_no_atomic_block(self):
|
||||
"""Raise an error if an atomic block is active."""
|
||||
if self.in_atomic_block:
|
||||
raise TransactionManagementError(
|
||||
"This is forbidden when an 'atomic' block is active.")
|
||||
|
||||
def validate_no_broken_transaction(self):
|
||||
if self.needs_rollback:
|
||||
raise TransactionManagementError(
|
||||
"An error occurred in the current transaction. You can't "
|
||||
"execute queries until the end of the 'atomic' block.")
|
||||
|
||||
# ##### Foreign key constraints checks handling #####
|
||||
|
||||
@contextmanager
|
||||
def constraint_checks_disabled(self):
|
||||
"""
|
||||
Disable foreign key constraint checking.
|
||||
"""
|
||||
disabled = self.disable_constraint_checking()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if disabled:
|
||||
self.enable_constraint_checking()
|
||||
|
||||
def disable_constraint_checking(self):
|
||||
"""
|
||||
Backends can implement as needed to temporarily disable foreign key
|
||||
constraint checking. Should return True if the constraints were
|
||||
disabled and will need to be reenabled.
|
||||
"""
|
||||
return False
|
||||
|
||||
def enable_constraint_checking(self):
|
||||
"""
|
||||
Backends can implement as needed to re-enable foreign key constraint
|
||||
checking.
|
||||
"""
|
||||
pass
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Backends can override this method if they can apply constraint
|
||||
checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
|
||||
IntegrityError if any invalid foreign key references are encountered.
|
||||
"""
|
||||
pass
|
||||
|
||||
# ##### Connection termination handling #####
|
||||
|
||||
def is_usable(self):
|
||||
"""
|
||||
Test if the database connection is usable.
|
||||
|
||||
This method may assume that self.connection is not None.
|
||||
|
||||
Actual implementations should take care not to raise exceptions
|
||||
as that may prevent Django from recycling unusable connections.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require an is_usable() method")
|
||||
|
||||
def close_if_unusable_or_obsolete(self):
|
||||
"""
|
||||
Close the current connection if unrecoverable errors have occurred
|
||||
or if it outlived its maximum age.
|
||||
"""
|
||||
if self.connection is not None:
|
||||
# If the application didn't restore the original autocommit setting,
|
||||
# don't take chances, drop the connection.
|
||||
if self.get_autocommit() != self.settings_dict['AUTOCOMMIT']:
|
||||
self.close()
|
||||
return
|
||||
|
||||
# If an exception other than DataError or IntegrityError occurred
|
||||
# since the last commit / rollback, check if the connection works.
|
||||
if self.errors_occurred:
|
||||
if self.is_usable():
|
||||
self.errors_occurred = False
|
||||
else:
|
||||
self.close()
|
||||
return
|
||||
|
||||
if self.close_at is not None and time.monotonic() >= self.close_at:
|
||||
self.close()
|
||||
return
|
||||
|
||||
# ##### Thread safety handling #####
|
||||
|
||||
@property
|
||||
def allow_thread_sharing(self):
|
||||
with self._thread_sharing_lock:
|
||||
return self._thread_sharing_count > 0
|
||||
|
||||
def inc_thread_sharing(self):
|
||||
with self._thread_sharing_lock:
|
||||
self._thread_sharing_count += 1
|
||||
|
||||
def dec_thread_sharing(self):
|
||||
with self._thread_sharing_lock:
|
||||
if self._thread_sharing_count <= 0:
|
||||
raise RuntimeError('Cannot decrement the thread sharing count below zero.')
|
||||
self._thread_sharing_count -= 1
|
||||
|
||||
def validate_thread_sharing(self):
|
||||
"""
|
||||
Validate that the connection isn't accessed by another thread than the
|
||||
one which originally created it, unless the connection was explicitly
|
||||
authorized to be shared between threads (via the `inc_thread_sharing()`
|
||||
method). Raise an exception if the validation fails.
|
||||
"""
|
||||
if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
|
||||
raise DatabaseError(
|
||||
"DatabaseWrapper objects created in a "
|
||||
"thread can only be used in that same thread. The object "
|
||||
"with alias '%s' was created in thread id %s and this is "
|
||||
"thread id %s."
|
||||
% (self.alias, self._thread_ident, _thread.get_ident())
|
||||
)
|
||||
|
||||
# ##### Miscellaneous #####
|
||||
|
||||
def prepare_database(self):
|
||||
"""
|
||||
Hook to do any database check or preparation, generally called before
|
||||
migrating a project or an app.
|
||||
"""
|
||||
pass
|
||||
|
||||
@cached_property
|
||||
def wrap_database_errors(self):
|
||||
"""
|
||||
Context manager and decorator that re-throws backend-specific database
|
||||
exceptions using Django's common wrappers.
|
||||
"""
|
||||
return DatabaseErrorWrapper(self)
|
||||
|
||||
def chunked_cursor(self):
|
||||
"""
|
||||
Return a cursor that tries to avoid caching in the database (if
|
||||
supported by the database), otherwise return a regular cursor.
|
||||
"""
|
||||
return self.cursor()
|
||||
|
||||
def make_debug_cursor(self, cursor):
|
||||
"""Create a cursor that logs all queries in self.queries_log."""
|
||||
return utils.CursorDebugWrapper(cursor, self)
|
||||
|
||||
def make_cursor(self, cursor):
|
||||
"""Create a cursor without debug logging."""
|
||||
return utils.CursorWrapper(cursor, self)
|
||||
|
||||
@contextmanager
|
||||
def temporary_connection(self):
|
||||
"""
|
||||
Context manager that ensures that a connection is established, and
|
||||
if it opened one, closes it to avoid leaving a dangling connection.
|
||||
This is useful for operations outside of the request-response cycle.
|
||||
|
||||
Provide a cursor: with self.temporary_connection() as cursor: ...
|
||||
"""
|
||||
must_close = self.connection is None
|
||||
try:
|
||||
with self.cursor() as cursor:
|
||||
yield cursor
|
||||
finally:
|
||||
if must_close:
|
||||
self.close()
|
||||
|
||||
@contextmanager
|
||||
def _nodb_cursor(self):
|
||||
"""
|
||||
Return a cursor from an alternative connection to be used when there is
|
||||
no need to access the main database, specifically for test db
|
||||
creation/deletion. This also prevents the production database from
|
||||
being exposed to potential child threads while (or after) the test
|
||||
database is destroyed. Refs #10868, #17786, #16969.
|
||||
"""
|
||||
conn = self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
yield cursor
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def schema_editor(self, *args, **kwargs):
|
||||
"""
|
||||
Return a new instance of this backend's SchemaEditor.
|
||||
"""
|
||||
if self.SchemaEditorClass is None:
|
||||
raise NotImplementedError(
|
||||
'The SchemaEditorClass attribute of this database wrapper is still None')
|
||||
return self.SchemaEditorClass(self, *args, **kwargs)
|
||||
|
||||
def on_commit(self, func):
|
||||
if not callable(func):
|
||||
raise TypeError("on_commit()'s callback must be a callable.")
|
||||
if self.in_atomic_block:
|
||||
# Transaction in progress; save for execution on commit.
|
||||
self.run_on_commit.append((set(self.savepoint_ids), func))
|
||||
elif not self.get_autocommit():
|
||||
raise TransactionManagementError('on_commit() cannot be used in manual transaction management')
|
||||
else:
|
||||
# No transaction in progress and in autocommit mode; execute
|
||||
# immediately.
|
||||
func()
|
||||
|
||||
def run_and_clear_commit_hooks(self):
|
||||
self.validate_no_atomic_block()
|
||||
current_run_on_commit = self.run_on_commit
|
||||
self.run_on_commit = []
|
||||
while current_run_on_commit:
|
||||
sids, func = current_run_on_commit.pop(0)
|
||||
func()
|
||||
|
||||
@contextmanager
|
||||
def execute_wrapper(self, wrapper):
|
||||
"""
|
||||
Return a context manager under which the wrapper is applied to suitable
|
||||
database query executions.
|
||||
"""
|
||||
self.execute_wrappers.append(wrapper)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.execute_wrappers.pop()
|
||||
|
||||
def copy(self, alias=None):
|
||||
"""
|
||||
Return a copy of this connection.
|
||||
|
||||
For tests that require two connections to the same database.
|
||||
"""
|
||||
settings_dict = copy.deepcopy(self.settings_dict)
|
||||
if alias is None:
|
||||
alias = self.alias
|
||||
return type(self)(settings_dict, alias)
|
||||
25
venv/Lib/site-packages/django/db/backends/base/client.py
Normal file
25
venv/Lib/site-packages/django/db/backends/base/client.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
class BaseDatabaseClient:
|
||||
"""Encapsulate backend-specific methods for opening a client shell."""
|
||||
# This should be a string representing the name of the executable
|
||||
# (e.g., "psql"). Subclasses must override this.
|
||||
executable_name = None
|
||||
|
||||
def __init__(self, connection):
|
||||
# connection is an instance of BaseDatabaseWrapper.
|
||||
self.connection = connection
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
raise NotImplementedError(
|
||||
'subclasses of BaseDatabaseClient must provide a '
|
||||
'settings_to_cmd_args_env() method or override a runshell().'
|
||||
)
|
||||
|
||||
def runshell(self, parameters):
|
||||
args, env = self.settings_to_cmd_args_env(self.connection.settings_dict, parameters)
|
||||
env = {**os.environ, **env} if env else None
|
||||
subprocess.run(args, env=env, check=True)
|
||||
342
venv/Lib/site-packages/django/db/backends/base/creation.py
Normal file
342
venv/Lib/site-packages/django/db/backends/base/creation.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import os
|
||||
import sys
|
||||
from io import StringIO
|
||||
from unittest import expectedFailure, skip
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.core import serializers
|
||||
from django.db import router
|
||||
from django.db.transaction import atomic
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
# The prefix to put on the default database name when creating
|
||||
# the test database.
|
||||
TEST_DATABASE_PREFIX = 'test_'
|
||||
|
||||
|
||||
class BaseDatabaseCreation:
|
||||
"""
|
||||
Encapsulate backend-specific differences pertaining to creation and
|
||||
destruction of the test database.
|
||||
"""
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def _nodb_cursor(self):
|
||||
return self.connection._nodb_cursor()
|
||||
|
||||
def log(self, msg):
|
||||
sys.stderr.write(msg + os.linesep)
|
||||
|
||||
def create_test_db(self, verbosity=1, autoclobber=False, serialize=True, keepdb=False):
|
||||
"""
|
||||
Create a test database, prompting the user for confirmation if the
|
||||
database already exists. Return the name of the test database created.
|
||||
"""
|
||||
# Don't import django.core.management if it isn't needed.
|
||||
from django.core.management import call_command
|
||||
|
||||
test_database_name = self._get_test_db_name()
|
||||
|
||||
if verbosity >= 1:
|
||||
action = 'Creating'
|
||||
if keepdb:
|
||||
action = "Using existing"
|
||||
|
||||
self.log('%s test database for alias %s...' % (
|
||||
action,
|
||||
self._get_database_display_str(verbosity, test_database_name),
|
||||
))
|
||||
|
||||
# We could skip this call if keepdb is True, but we instead
|
||||
# give it the keepdb param. This is to handle the case
|
||||
# where the test DB doesn't exist, in which case we need to
|
||||
# create it, then just not destroy it. If we instead skip
|
||||
# this, we will get an exception.
|
||||
self._create_test_db(verbosity, autoclobber, keepdb)
|
||||
|
||||
self.connection.close()
|
||||
settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
|
||||
self.connection.settings_dict["NAME"] = test_database_name
|
||||
|
||||
try:
|
||||
if self.connection.settings_dict['TEST']['MIGRATE'] is False:
|
||||
# Disable migrations for all apps.
|
||||
old_migration_modules = settings.MIGRATION_MODULES
|
||||
settings.MIGRATION_MODULES = {
|
||||
app.label: None
|
||||
for app in apps.get_app_configs()
|
||||
}
|
||||
# We report migrate messages at one level lower than that
|
||||
# requested. This ensures we don't get flooded with messages during
|
||||
# testing (unless you really ask to be flooded).
|
||||
call_command(
|
||||
'migrate',
|
||||
verbosity=max(verbosity - 1, 0),
|
||||
interactive=False,
|
||||
database=self.connection.alias,
|
||||
run_syncdb=True,
|
||||
)
|
||||
finally:
|
||||
if self.connection.settings_dict['TEST']['MIGRATE'] is False:
|
||||
settings.MIGRATION_MODULES = old_migration_modules
|
||||
|
||||
# We then serialize the current state of the database into a string
|
||||
# and store it on the connection. This slightly horrific process is so people
|
||||
# who are testing on databases without transactions or who are using
|
||||
# a TransactionTestCase still get a clean database on every test run.
|
||||
if serialize:
|
||||
self.connection._test_serialized_contents = self.serialize_db_to_string()
|
||||
|
||||
call_command('createcachetable', database=self.connection.alias)
|
||||
|
||||
# Ensure a connection for the side effect of initializing the test database.
|
||||
self.connection.ensure_connection()
|
||||
|
||||
if os.environ.get('RUNNING_DJANGOS_TEST_SUITE') == 'true':
|
||||
self.mark_expected_failures_and_skips()
|
||||
|
||||
return test_database_name
|
||||
|
||||
def set_as_test_mirror(self, primary_settings_dict):
|
||||
"""
|
||||
Set this database up to be used in testing as a mirror of a primary
|
||||
database whose settings are given.
|
||||
"""
|
||||
self.connection.settings_dict['NAME'] = primary_settings_dict['NAME']
|
||||
|
||||
def serialize_db_to_string(self):
|
||||
"""
|
||||
Serialize all data in the database into a JSON string.
|
||||
Designed only for test runner usage; will not handle large
|
||||
amounts of data.
|
||||
"""
|
||||
# Iteratively return every object for all models to serialize.
|
||||
def get_objects():
|
||||
from django.db.migrations.loader import MigrationLoader
|
||||
loader = MigrationLoader(self.connection)
|
||||
for app_config in apps.get_app_configs():
|
||||
if (
|
||||
app_config.models_module is not None and
|
||||
app_config.label in loader.migrated_apps and
|
||||
app_config.name not in settings.TEST_NON_SERIALIZED_APPS
|
||||
):
|
||||
for model in app_config.get_models():
|
||||
if (
|
||||
model._meta.can_migrate(self.connection) and
|
||||
router.allow_migrate_model(self.connection.alias, model)
|
||||
):
|
||||
queryset = model._base_manager.using(
|
||||
self.connection.alias,
|
||||
).order_by(model._meta.pk.name)
|
||||
yield from queryset.iterator()
|
||||
# Serialize to a string
|
||||
out = StringIO()
|
||||
serializers.serialize("json", get_objects(), indent=None, stream=out)
|
||||
return out.getvalue()
|
||||
|
||||
def deserialize_db_from_string(self, data):
|
||||
"""
|
||||
Reload the database with data from a string generated by
|
||||
the serialize_db_to_string() method.
|
||||
"""
|
||||
data = StringIO(data)
|
||||
table_names = set()
|
||||
# Load data in a transaction to handle forward references and cycles.
|
||||
with atomic(using=self.connection.alias):
|
||||
# Disable constraint checks, because some databases (MySQL) doesn't
|
||||
# support deferred checks.
|
||||
with self.connection.constraint_checks_disabled():
|
||||
for obj in serializers.deserialize('json', data, using=self.connection.alias):
|
||||
obj.save()
|
||||
table_names.add(obj.object.__class__._meta.db_table)
|
||||
# Manually check for any invalid keys that might have been added,
|
||||
# because constraint checks were disabled.
|
||||
self.connection.check_constraints(table_names=table_names)
|
||||
|
||||
def _get_database_display_str(self, verbosity, database_name):
|
||||
"""
|
||||
Return display string for a database for use in various actions.
|
||||
"""
|
||||
return "'%s'%s" % (
|
||||
self.connection.alias,
|
||||
(" ('%s')" % database_name) if verbosity >= 2 else '',
|
||||
)
|
||||
|
||||
def _get_test_db_name(self):
|
||||
"""
|
||||
Internal implementation - return the name of the test DB that will be
|
||||
created. Only useful when called from create_test_db() and
|
||||
_create_test_db() and when no external munging is done with the 'NAME'
|
||||
settings.
|
||||
"""
|
||||
if self.connection.settings_dict['TEST']['NAME']:
|
||||
return self.connection.settings_dict['TEST']['NAME']
|
||||
return TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME']
|
||||
|
||||
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
|
||||
cursor.execute('CREATE DATABASE %(dbname)s %(suffix)s' % parameters)
|
||||
|
||||
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
|
||||
"""
|
||||
Internal implementation - create the test db tables.
|
||||
"""
|
||||
test_database_name = self._get_test_db_name()
|
||||
test_db_params = {
|
||||
'dbname': self.connection.ops.quote_name(test_database_name),
|
||||
'suffix': self.sql_table_creation_suffix(),
|
||||
}
|
||||
# Create the test database and connect to it.
|
||||
with self._nodb_cursor() as cursor:
|
||||
try:
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
# if we want to keep the db, then no need to do any of the below,
|
||||
# just return and skip it all.
|
||||
if keepdb:
|
||||
return test_database_name
|
||||
|
||||
self.log('Got an error creating the test database: %s' % e)
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"Type 'yes' if you would like to try deleting the test "
|
||||
"database '%s', or 'no' to cancel: " % test_database_name)
|
||||
if autoclobber or confirm == 'yes':
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log('Destroying old test database for alias %s...' % (
|
||||
self._get_database_display_str(verbosity, test_database_name),
|
||||
))
|
||||
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
self.log('Got an error recreating the test database: %s' % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log('Tests cancelled.')
|
||||
sys.exit(1)
|
||||
|
||||
return test_database_name
|
||||
|
||||
def clone_test_db(self, suffix, verbosity=1, autoclobber=False, keepdb=False):
|
||||
"""
|
||||
Clone a test database.
|
||||
"""
|
||||
source_database_name = self.connection.settings_dict['NAME']
|
||||
|
||||
if verbosity >= 1:
|
||||
action = 'Cloning test database'
|
||||
if keepdb:
|
||||
action = 'Using existing clone'
|
||||
self.log('%s for alias %s...' % (
|
||||
action,
|
||||
self._get_database_display_str(verbosity, source_database_name),
|
||||
))
|
||||
|
||||
# We could skip this call if keepdb is True, but we instead
|
||||
# give it the keepdb param. See create_test_db for details.
|
||||
self._clone_test_db(suffix, verbosity, keepdb)
|
||||
|
||||
def get_test_db_clone_settings(self, suffix):
|
||||
"""
|
||||
Return a modified connection settings dict for the n-th clone of a DB.
|
||||
"""
|
||||
# When this function is called, the test database has been created
|
||||
# already and its name has been copied to settings_dict['NAME'] so
|
||||
# we don't need to call _get_test_db_name.
|
||||
orig_settings_dict = self.connection.settings_dict
|
||||
return {**orig_settings_dict, 'NAME': '{}_{}'.format(orig_settings_dict['NAME'], suffix)}
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
"""
|
||||
Internal implementation - duplicate the test db tables.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The database backend doesn't support cloning databases. "
|
||||
"Disable the option to run tests in parallel processes.")
|
||||
|
||||
def destroy_test_db(self, old_database_name=None, verbosity=1, keepdb=False, suffix=None):
|
||||
"""
|
||||
Destroy a test database, prompting the user for confirmation if the
|
||||
database already exists.
|
||||
"""
|
||||
self.connection.close()
|
||||
if suffix is None:
|
||||
test_database_name = self.connection.settings_dict['NAME']
|
||||
else:
|
||||
test_database_name = self.get_test_db_clone_settings(suffix)['NAME']
|
||||
|
||||
if verbosity >= 1:
|
||||
action = 'Destroying'
|
||||
if keepdb:
|
||||
action = 'Preserving'
|
||||
self.log('%s test database for alias %s...' % (
|
||||
action,
|
||||
self._get_database_display_str(verbosity, test_database_name),
|
||||
))
|
||||
|
||||
# if we want to preserve the database
|
||||
# skip the actual destroying piece.
|
||||
if not keepdb:
|
||||
self._destroy_test_db(test_database_name, verbosity)
|
||||
|
||||
# Restore the original database name
|
||||
if old_database_name is not None:
|
||||
settings.DATABASES[self.connection.alias]["NAME"] = old_database_name
|
||||
self.connection.settings_dict["NAME"] = old_database_name
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity):
|
||||
"""
|
||||
Internal implementation - remove the test db tables.
|
||||
"""
|
||||
# Remove the test database to clean up after
|
||||
# ourselves. Connect to the previous database (not the test database)
|
||||
# to do so, because it's not allowed to delete a database while being
|
||||
# connected to it.
|
||||
with self._nodb_cursor() as cursor:
|
||||
cursor.execute("DROP DATABASE %s"
|
||||
% self.connection.ops.quote_name(test_database_name))
|
||||
|
||||
def mark_expected_failures_and_skips(self):
|
||||
"""
|
||||
Mark tests in Django's test suite which are expected failures on this
|
||||
database and test which should be skipped on this database.
|
||||
"""
|
||||
for test_name in self.connection.features.django_test_expected_failures:
|
||||
test_case_name, _, test_method_name = test_name.rpartition('.')
|
||||
test_app = test_name.split('.')[0]
|
||||
# Importing a test app that isn't installed raises RuntimeError.
|
||||
if test_app in settings.INSTALLED_APPS:
|
||||
test_case = import_string(test_case_name)
|
||||
test_method = getattr(test_case, test_method_name)
|
||||
setattr(test_case, test_method_name, expectedFailure(test_method))
|
||||
for reason, tests in self.connection.features.django_test_skips.items():
|
||||
for test_name in tests:
|
||||
test_case_name, _, test_method_name = test_name.rpartition('.')
|
||||
test_app = test_name.split('.')[0]
|
||||
# Importing a test app that isn't installed raises RuntimeError.
|
||||
if test_app in settings.INSTALLED_APPS:
|
||||
test_case = import_string(test_case_name)
|
||||
test_method = getattr(test_case, test_method_name)
|
||||
setattr(test_case, test_method_name, skip(reason)(test_method))
|
||||
|
||||
def sql_table_creation_suffix(self):
|
||||
"""
|
||||
SQL to append to the end of the test table creation statements.
|
||||
"""
|
||||
return ''
|
||||
|
||||
def test_db_signature(self):
|
||||
"""
|
||||
Return a tuple with elements of self.connection.settings_dict (a
|
||||
DATABASES setting value) that uniquely identify a database
|
||||
accordingly to the RDBMS particularities.
|
||||
"""
|
||||
settings_dict = self.connection.settings_dict
|
||||
return (
|
||||
settings_dict['HOST'],
|
||||
settings_dict['PORT'],
|
||||
settings_dict['ENGINE'],
|
||||
self._get_test_db_name(),
|
||||
)
|
||||
367
venv/Lib/site-packages/django/db/backends/base/features.py
Normal file
367
venv/Lib/site-packages/django/db/backends/base/features.py
Normal file
@@ -0,0 +1,367 @@
|
||||
from django.db import ProgrammingError
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class BaseDatabaseFeatures:
|
||||
gis_enabled = False
|
||||
# Oracle can't group by LOB (large object) data types.
|
||||
allows_group_by_lob = True
|
||||
allows_group_by_pk = False
|
||||
allows_group_by_selected_pks = False
|
||||
empty_fetchmany_value = []
|
||||
update_can_self_select = True
|
||||
|
||||
# Does the backend distinguish between '' and None?
|
||||
interprets_empty_strings_as_nulls = False
|
||||
|
||||
# Does the backend allow inserting duplicate NULL rows in a nullable
|
||||
# unique field? All core backends implement this correctly, but other
|
||||
# databases such as SQL Server do not.
|
||||
supports_nullable_unique_constraints = True
|
||||
|
||||
# Does the backend allow inserting duplicate rows when a unique_together
|
||||
# constraint exists and some fields are nullable but not all of them?
|
||||
supports_partially_nullable_unique_constraints = True
|
||||
# Does the backend support initially deferrable unique constraints?
|
||||
supports_deferrable_unique_constraints = False
|
||||
|
||||
can_use_chunked_reads = True
|
||||
can_return_columns_from_insert = False
|
||||
can_return_rows_from_bulk_insert = False
|
||||
has_bulk_insert = True
|
||||
uses_savepoints = True
|
||||
can_release_savepoints = False
|
||||
|
||||
# If True, don't use integer foreign keys referring to, e.g., positive
|
||||
# integer primary keys.
|
||||
related_fields_match_type = False
|
||||
allow_sliced_subqueries_with_in = True
|
||||
has_select_for_update = False
|
||||
has_select_for_update_nowait = False
|
||||
has_select_for_update_skip_locked = False
|
||||
has_select_for_update_of = False
|
||||
has_select_for_no_key_update = False
|
||||
# Does the database's SELECT FOR UPDATE OF syntax require a column rather
|
||||
# than a table?
|
||||
select_for_update_of_column = False
|
||||
|
||||
# Does the default test database allow multiple connections?
|
||||
# Usually an indication that the test database is in-memory
|
||||
test_db_allows_multiple_connections = True
|
||||
|
||||
# Can an object be saved without an explicit primary key?
|
||||
supports_unspecified_pk = False
|
||||
|
||||
# Can a fixture contain forward references? i.e., are
|
||||
# FK constraints checked at the end of transaction, or
|
||||
# at the end of each save operation?
|
||||
supports_forward_references = True
|
||||
|
||||
# Does the backend truncate names properly when they are too long?
|
||||
truncates_names = False
|
||||
|
||||
# Is there a REAL datatype in addition to floats/doubles?
|
||||
has_real_datatype = False
|
||||
supports_subqueries_in_group_by = True
|
||||
|
||||
# Does the backend ignore unnecessary ORDER BY clauses in subqueries?
|
||||
ignores_unnecessary_order_by_in_subqueries = True
|
||||
|
||||
# Is there a true datatype for uuid?
|
||||
has_native_uuid_field = False
|
||||
|
||||
# Is there a true datatype for timedeltas?
|
||||
has_native_duration_field = False
|
||||
|
||||
# Does the database driver supports same type temporal data subtraction
|
||||
# by returning the type used to store duration field?
|
||||
supports_temporal_subtraction = False
|
||||
|
||||
# Does the __regex lookup support backreferencing and grouping?
|
||||
supports_regex_backreferencing = True
|
||||
|
||||
# Can date/datetime lookups be performed using a string?
|
||||
supports_date_lookup_using_string = True
|
||||
|
||||
# Can datetimes with timezones be used?
|
||||
supports_timezones = True
|
||||
|
||||
# Does the database have a copy of the zoneinfo database?
|
||||
has_zoneinfo_database = True
|
||||
|
||||
# When performing a GROUP BY, is an ORDER BY NULL required
|
||||
# to remove any ordering?
|
||||
requires_explicit_null_ordering_when_grouping = False
|
||||
|
||||
# Does the backend order NULL values as largest or smallest?
|
||||
nulls_order_largest = False
|
||||
|
||||
# Does the backend support NULLS FIRST and NULLS LAST in ORDER BY?
|
||||
supports_order_by_nulls_modifier = True
|
||||
|
||||
# Does the backend orders NULLS FIRST by default?
|
||||
order_by_nulls_first = False
|
||||
|
||||
# The database's limit on the number of query parameters.
|
||||
max_query_params = None
|
||||
|
||||
# Can an object have an autoincrement primary key of 0?
|
||||
allows_auto_pk_0 = True
|
||||
|
||||
# Do we need to NULL a ForeignKey out, or can the constraint check be
|
||||
# deferred
|
||||
can_defer_constraint_checks = False
|
||||
|
||||
# date_interval_sql can properly handle mixed Date/DateTime fields and timedeltas
|
||||
supports_mixed_date_datetime_comparisons = True
|
||||
|
||||
# Does the backend support tablespaces? Default to False because it isn't
|
||||
# in the SQL standard.
|
||||
supports_tablespaces = False
|
||||
|
||||
# Does the backend reset sequences between tests?
|
||||
supports_sequence_reset = True
|
||||
|
||||
# Can the backend introspect the default value of a column?
|
||||
can_introspect_default = True
|
||||
|
||||
# Confirm support for introspected foreign keys
|
||||
# Every database can do this reliably, except MySQL,
|
||||
# which can't do it for MyISAM tables
|
||||
can_introspect_foreign_keys = True
|
||||
|
||||
# Map fields which some backends may not be able to differentiate to the
|
||||
# field it's introspected as.
|
||||
introspected_field_types = {
|
||||
'AutoField': 'AutoField',
|
||||
'BigAutoField': 'BigAutoField',
|
||||
'BigIntegerField': 'BigIntegerField',
|
||||
'BinaryField': 'BinaryField',
|
||||
'BooleanField': 'BooleanField',
|
||||
'CharField': 'CharField',
|
||||
'DurationField': 'DurationField',
|
||||
'GenericIPAddressField': 'GenericIPAddressField',
|
||||
'IntegerField': 'IntegerField',
|
||||
'PositiveBigIntegerField': 'PositiveBigIntegerField',
|
||||
'PositiveIntegerField': 'PositiveIntegerField',
|
||||
'PositiveSmallIntegerField': 'PositiveSmallIntegerField',
|
||||
'SmallAutoField': 'SmallAutoField',
|
||||
'SmallIntegerField': 'SmallIntegerField',
|
||||
'TimeField': 'TimeField',
|
||||
}
|
||||
|
||||
# Can the backend introspect the column order (ASC/DESC) for indexes?
|
||||
supports_index_column_ordering = True
|
||||
|
||||
# Does the backend support introspection of materialized views?
|
||||
can_introspect_materialized_views = False
|
||||
|
||||
# Support for the DISTINCT ON clause
|
||||
can_distinct_on_fields = False
|
||||
|
||||
# Does the backend prevent running SQL queries in broken transactions?
|
||||
atomic_transactions = True
|
||||
|
||||
# Can we roll back DDL in a transaction?
|
||||
can_rollback_ddl = False
|
||||
|
||||
# Does it support operations requiring references rename in a transaction?
|
||||
supports_atomic_references_rename = True
|
||||
|
||||
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
|
||||
supports_combined_alters = False
|
||||
|
||||
# Does it support foreign keys?
|
||||
supports_foreign_keys = True
|
||||
|
||||
# Can it create foreign key constraints inline when adding columns?
|
||||
can_create_inline_fk = True
|
||||
|
||||
# Does it automatically index foreign keys?
|
||||
indexes_foreign_keys = True
|
||||
|
||||
# Does it support CHECK constraints?
|
||||
supports_column_check_constraints = True
|
||||
supports_table_check_constraints = True
|
||||
# Does the backend support introspection of CHECK constraints?
|
||||
can_introspect_check_constraints = True
|
||||
|
||||
# Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value})
|
||||
# parameter passing? Note this can be provided by the backend even if not
|
||||
# supported by the Python driver
|
||||
supports_paramstyle_pyformat = True
|
||||
|
||||
# Does the backend require literal defaults, rather than parameterized ones?
|
||||
requires_literal_defaults = False
|
||||
|
||||
# Does the backend require a connection reset after each material schema change?
|
||||
connection_persists_old_columns = False
|
||||
|
||||
# What kind of error does the backend throw when accessing closed cursor?
|
||||
closed_cursor_error_class = ProgrammingError
|
||||
|
||||
# Does 'a' LIKE 'A' match?
|
||||
has_case_insensitive_like = True
|
||||
|
||||
# Suffix for backends that don't support "SELECT xxx;" queries.
|
||||
bare_select_suffix = ''
|
||||
|
||||
# If NULL is implied on columns without needing to be explicitly specified
|
||||
implied_column_null = False
|
||||
|
||||
# Does the backend support "select for update" queries with limit (and offset)?
|
||||
supports_select_for_update_with_limit = True
|
||||
|
||||
# Does the backend ignore null expressions in GREATEST and LEAST queries unless
|
||||
# every expression is null?
|
||||
greatest_least_ignores_nulls = False
|
||||
|
||||
# Can the backend clone databases for parallel test execution?
|
||||
# Defaults to False to allow third-party backends to opt-in.
|
||||
can_clone_databases = False
|
||||
|
||||
# Does the backend consider table names with different casing to
|
||||
# be equal?
|
||||
ignores_table_name_case = False
|
||||
|
||||
# Place FOR UPDATE right after FROM clause. Used on MSSQL.
|
||||
for_update_after_from = False
|
||||
|
||||
# Combinatorial flags
|
||||
supports_select_union = True
|
||||
supports_select_intersection = True
|
||||
supports_select_difference = True
|
||||
supports_slicing_ordering_in_compound = False
|
||||
supports_parentheses_in_compound = True
|
||||
|
||||
# Does the database support SQL 2003 FILTER (WHERE ...) in aggregate
|
||||
# expressions?
|
||||
supports_aggregate_filter_clause = False
|
||||
|
||||
# Does the backend support indexing a TextField?
|
||||
supports_index_on_text_field = True
|
||||
|
||||
# Does the backend support window expressions (expression OVER (...))?
|
||||
supports_over_clause = False
|
||||
supports_frame_range_fixed_distance = False
|
||||
only_supports_unbounded_with_preceding_and_following = False
|
||||
|
||||
# Does the backend support CAST with precision?
|
||||
supports_cast_with_precision = True
|
||||
|
||||
# How many second decimals does the database return when casting a value to
|
||||
# a type with time?
|
||||
time_cast_precision = 6
|
||||
|
||||
# SQL to create a procedure for use by the Django test suite. The
|
||||
# functionality of the procedure isn't important.
|
||||
create_test_procedure_without_params_sql = None
|
||||
create_test_procedure_with_int_param_sql = None
|
||||
|
||||
# Does the backend support keyword parameters for cursor.callproc()?
|
||||
supports_callproc_kwargs = False
|
||||
|
||||
# What formats does the backend EXPLAIN syntax support?
|
||||
supported_explain_formats = set()
|
||||
|
||||
# Does DatabaseOperations.explain_query_prefix() raise ValueError if
|
||||
# unknown kwargs are passed to QuerySet.explain()?
|
||||
validates_explain_options = True
|
||||
|
||||
# Does the backend support the default parameter in lead() and lag()?
|
||||
supports_default_in_lead_lag = True
|
||||
|
||||
# Does the backend support ignoring constraint or uniqueness errors during
|
||||
# INSERT?
|
||||
supports_ignore_conflicts = True
|
||||
|
||||
# Does this backend require casting the results of CASE expressions used
|
||||
# in UPDATE statements to ensure the expression has the correct type?
|
||||
requires_casted_case_in_updates = False
|
||||
|
||||
# Does the backend support partial indexes (CREATE INDEX ... WHERE ...)?
|
||||
supports_partial_indexes = True
|
||||
supports_functions_in_partial_indexes = True
|
||||
# Does the backend support covering indexes (CREATE INDEX ... INCLUDE ...)?
|
||||
supports_covering_indexes = False
|
||||
# Does the backend support indexes on expressions?
|
||||
supports_expression_indexes = True
|
||||
# Does the backend treat COLLATE as an indexed expression?
|
||||
collate_as_index_expression = False
|
||||
|
||||
# Does the database allow more than one constraint or index on the same
|
||||
# field(s)?
|
||||
allows_multiple_constraints_on_same_fields = True
|
||||
|
||||
# Does the backend support boolean expressions in SELECT and GROUP BY
|
||||
# clauses?
|
||||
supports_boolean_expr_in_select_clause = True
|
||||
|
||||
# Does the backend support JSONField?
|
||||
supports_json_field = True
|
||||
# Can the backend introspect a JSONField?
|
||||
can_introspect_json_field = True
|
||||
# Does the backend support primitives in JSONField?
|
||||
supports_primitives_in_json_field = True
|
||||
# Is there a true datatype for JSON?
|
||||
has_native_json_field = False
|
||||
# Does the backend use PostgreSQL-style JSON operators like '->'?
|
||||
has_json_operators = False
|
||||
# Does the backend support __contains and __contained_by lookups for
|
||||
# a JSONField?
|
||||
supports_json_field_contains = True
|
||||
# Does value__d__contains={'f': 'g'} (without a list around the dict) match
|
||||
# {'d': [{'f': 'g'}]}?
|
||||
json_key_contains_list_matching_requires_list = False
|
||||
# Does the backend support JSONObject() database function?
|
||||
has_json_object_function = True
|
||||
|
||||
# Does the backend support column collations?
|
||||
supports_collation_on_charfield = True
|
||||
supports_collation_on_textfield = True
|
||||
# Does the backend support non-deterministic collations?
|
||||
supports_non_deterministic_collations = True
|
||||
|
||||
# Collation names for use by the Django test suite.
|
||||
test_collations = {
|
||||
'ci': None, # Case-insensitive.
|
||||
'cs': None, # Case-sensitive.
|
||||
'non_default': None, # Non-default.
|
||||
'swedish_ci': None # Swedish case-insensitive.
|
||||
}
|
||||
# SQL template override for tests.aggregation.tests.NowUTC
|
||||
test_now_utc_template = None
|
||||
|
||||
# A set of dotted paths to tests in Django's test suite that are expected
|
||||
# to fail on this database.
|
||||
django_test_expected_failures = set()
|
||||
# A map of reasons to sets of dotted paths to tests in Django's test suite
|
||||
# that should be skipped for this database.
|
||||
django_test_skips = {}
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
@cached_property
|
||||
def supports_explaining_query_execution(self):
|
||||
"""Does this backend support explaining query execution?"""
|
||||
return self.connection.ops.explain_prefix is not None
|
||||
|
||||
@cached_property
|
||||
def supports_transactions(self):
|
||||
"""Confirm support for transactions."""
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
|
||||
self.connection.set_autocommit(False)
|
||||
cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
|
||||
self.connection.rollback()
|
||||
self.connection.set_autocommit(True)
|
||||
cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
|
||||
count, = cursor.fetchone()
|
||||
cursor.execute('DROP TABLE ROLLBACK_TEST')
|
||||
return count == 0
|
||||
|
||||
def allows_group_by_selected_pks_on_model(self, model):
|
||||
if not self.allows_group_by_selected_pks:
|
||||
return False
|
||||
return model._meta.managed
|
||||
194
venv/Lib/site-packages/django/db/backends/base/introspection.py
Normal file
194
venv/Lib/site-packages/django/db/backends/base/introspection.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from collections import namedtuple
|
||||
|
||||
# Structure returned by DatabaseIntrospection.get_table_list()
|
||||
TableInfo = namedtuple('TableInfo', ['name', 'type'])
|
||||
|
||||
# Structure returned by the DB-API cursor.description interface (PEP 249)
|
||||
FieldInfo = namedtuple(
|
||||
'FieldInfo',
|
||||
'name type_code display_size internal_size precision scale null_ok '
|
||||
'default collation'
|
||||
)
|
||||
|
||||
|
||||
class BaseDatabaseIntrospection:
|
||||
"""Encapsulate backend-specific introspection utilities."""
|
||||
data_types_reverse = {}
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
"""
|
||||
Hook for a database backend to use the cursor description to
|
||||
match a Django field type to a database column.
|
||||
|
||||
For Oracle, the column data_type on its own is insufficient to
|
||||
distinguish between a FloatField and IntegerField, for example.
|
||||
"""
|
||||
return self.data_types_reverse[data_type]
|
||||
|
||||
def identifier_converter(self, name):
|
||||
"""
|
||||
Apply a conversion to the identifier for the purposes of comparison.
|
||||
|
||||
The default identifier converter is for case sensitive comparison.
|
||||
"""
|
||||
return name
|
||||
|
||||
def table_names(self, cursor=None, include_views=False):
|
||||
"""
|
||||
Return a list of names of all tables that exist in the database.
|
||||
Sort the returned table list by Python's default sorting. Do NOT use
|
||||
the database's ORDER BY here to avoid subtle differences in sorting
|
||||
order between databases.
|
||||
"""
|
||||
def get_names(cursor):
|
||||
return sorted(ti.name for ti in self.get_table_list(cursor)
|
||||
if include_views or ti.type == 't')
|
||||
if cursor is None:
|
||||
with self.connection.cursor() as cursor:
|
||||
return get_names(cursor)
|
||||
return get_names(cursor)
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""
|
||||
Return an unsorted list of TableInfo named tuples of all tables and
|
||||
views that exist in the database.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_table_list() method')
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
'subclasses of BaseDatabaseIntrospection may require a '
|
||||
'get_table_description() method.'
|
||||
)
|
||||
|
||||
def get_migratable_models(self):
|
||||
from django.apps import apps
|
||||
from django.db import router
|
||||
return (
|
||||
model
|
||||
for app_config in apps.get_app_configs()
|
||||
for model in router.get_migratable_models(app_config, self.connection.alias)
|
||||
if model._meta.can_migrate(self.connection)
|
||||
)
|
||||
|
||||
def django_table_names(self, only_existing=False, include_views=True):
|
||||
"""
|
||||
Return a list of all table names that have associated Django models and
|
||||
are in INSTALLED_APPS.
|
||||
|
||||
If only_existing is True, include only the tables in the database.
|
||||
"""
|
||||
tables = set()
|
||||
for model in self.get_migratable_models():
|
||||
if not model._meta.managed:
|
||||
continue
|
||||
tables.add(model._meta.db_table)
|
||||
tables.update(
|
||||
f.m2m_db_table() for f in model._meta.local_many_to_many
|
||||
if f.remote_field.through._meta.managed
|
||||
)
|
||||
tables = list(tables)
|
||||
if only_existing:
|
||||
existing_tables = set(self.table_names(include_views=include_views))
|
||||
tables = [
|
||||
t
|
||||
for t in tables
|
||||
if self.identifier_converter(t) in existing_tables
|
||||
]
|
||||
return tables
|
||||
|
||||
def installed_models(self, tables):
|
||||
"""
|
||||
Return a set of all models represented by the provided list of table
|
||||
names.
|
||||
"""
|
||||
tables = set(map(self.identifier_converter, tables))
|
||||
return {
|
||||
m for m in self.get_migratable_models()
|
||||
if self.identifier_converter(m._meta.db_table) in tables
|
||||
}
|
||||
|
||||
def sequence_list(self):
|
||||
"""
|
||||
Return a list of information about all DB sequences for all models in
|
||||
all apps.
|
||||
"""
|
||||
sequence_list = []
|
||||
with self.connection.cursor() as cursor:
|
||||
for model in self.get_migratable_models():
|
||||
if not model._meta.managed:
|
||||
continue
|
||||
if model._meta.swapped:
|
||||
continue
|
||||
sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
|
||||
for f in model._meta.local_many_to_many:
|
||||
# If this is an m2m using an intermediate table,
|
||||
# we don't need to reset the sequence.
|
||||
if f.remote_field.through._meta.auto_created:
|
||||
sequence = self.get_sequences(cursor, f.m2m_db_table())
|
||||
sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
|
||||
return sequence_list
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
"""
|
||||
Return a list of introspected sequences for table_name. Each sequence
|
||||
is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
|
||||
'name' key can be added if the backend supports named sequences.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_sequences() method')
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of
|
||||
{field_name: (field_name_other_table, other_table)} representing all
|
||||
relationships to the given table.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
'subclasses of BaseDatabaseIntrospection may require a '
|
||||
'get_relations() method.'
|
||||
)
|
||||
|
||||
def get_key_columns(self, cursor, table_name):
|
||||
"""
|
||||
Backends can override this to return a list of:
|
||||
(column_name, referenced_table_name, referenced_column_name)
|
||||
for all key columns in given table.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_key_columns() method')
|
||||
|
||||
def get_primary_key_column(self, cursor, table_name):
|
||||
"""
|
||||
Return the name of the primary key column for the given table.
|
||||
"""
|
||||
for constraint in self.get_constraints(cursor, table_name).values():
|
||||
if constraint['primary_key']:
|
||||
return constraint['columns'][0]
|
||||
return None
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index)
|
||||
across one or more columns.
|
||||
|
||||
Return a dict mapping constraint names to their attributes,
|
||||
where attributes is a dict with keys:
|
||||
* columns: List of columns this covers
|
||||
* primary_key: True if primary key, False otherwise
|
||||
* unique: True if this is a unique constraint, False otherwise
|
||||
* foreign_key: (table, column) of target, or None
|
||||
* check: True if check constraint, False otherwise
|
||||
* index: True if index, False otherwise.
|
||||
* orders: The order (ASC/DESC) defined for the columns of indexes
|
||||
* type: The type of the index (btree, hash, etc.)
|
||||
|
||||
Some backends may return special constraint names that don't exist
|
||||
if they don't name constraints of a certain type (e.g. SQLite)
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_constraints() method')
|
||||
709
venv/Lib/site-packages/django/db/backends/base/operations.py
Normal file
709
venv/Lib/site-packages/django/db/backends/base/operations.py
Normal file
@@ -0,0 +1,709 @@
|
||||
import datetime
|
||||
import decimal
|
||||
from importlib import import_module
|
||||
|
||||
import sqlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import NotSupportedError, transaction
|
||||
from django.db.backends import utils
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
|
||||
class BaseDatabaseOperations:
|
||||
"""
|
||||
Encapsulate backend-specific differences, such as the way a backend
|
||||
performs ordering or calculates the ID of a recently-inserted row.
|
||||
"""
|
||||
compiler_module = "django.db.models.sql.compiler"
|
||||
|
||||
# Integer field safe ranges by `internal_type` as documented
|
||||
# in docs/ref/models/fields.txt.
|
||||
integer_field_ranges = {
|
||||
'SmallIntegerField': (-32768, 32767),
|
||||
'IntegerField': (-2147483648, 2147483647),
|
||||
'BigIntegerField': (-9223372036854775808, 9223372036854775807),
|
||||
'PositiveBigIntegerField': (0, 9223372036854775807),
|
||||
'PositiveSmallIntegerField': (0, 32767),
|
||||
'PositiveIntegerField': (0, 2147483647),
|
||||
'SmallAutoField': (-32768, 32767),
|
||||
'AutoField': (-2147483648, 2147483647),
|
||||
'BigAutoField': (-9223372036854775808, 9223372036854775807),
|
||||
}
|
||||
set_operators = {
|
||||
'union': 'UNION',
|
||||
'intersection': 'INTERSECT',
|
||||
'difference': 'EXCEPT',
|
||||
}
|
||||
# Mapping of Field.get_internal_type() (typically the model field's class
|
||||
# name) to the data type to use for the Cast() function, if different from
|
||||
# DatabaseWrapper.data_types.
|
||||
cast_data_types = {}
|
||||
# CharField data type if the max_length argument isn't provided.
|
||||
cast_char_field_without_max_length = None
|
||||
|
||||
# Start and end points for window expressions.
|
||||
PRECEDING = 'PRECEDING'
|
||||
FOLLOWING = 'FOLLOWING'
|
||||
UNBOUNDED_PRECEDING = 'UNBOUNDED ' + PRECEDING
|
||||
UNBOUNDED_FOLLOWING = 'UNBOUNDED ' + FOLLOWING
|
||||
CURRENT_ROW = 'CURRENT ROW'
|
||||
|
||||
# Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
|
||||
explain_prefix = None
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self._cache = None
|
||||
|
||||
def autoinc_sql(self, table, column):
|
||||
"""
|
||||
Return any SQL needed to support auto-incrementing primary keys, or
|
||||
None if no SQL is necessary.
|
||||
|
||||
This SQL is executed when a table is created.
|
||||
"""
|
||||
return None
|
||||
|
||||
def bulk_batch_size(self, fields, objs):
|
||||
"""
|
||||
Return the maximum allowed batch size for the backend. The fields
|
||||
are the fields going to be inserted in the batch, the objs contains
|
||||
all the objects to be inserted.
|
||||
"""
|
||||
return len(objs)
|
||||
|
||||
def cache_key_culling_sql(self):
|
||||
"""
|
||||
Return an SQL query that retrieves the first cache key greater than the
|
||||
n smallest.
|
||||
|
||||
This is used by the 'db' cache backend to determine where to start
|
||||
culling.
|
||||
"""
|
||||
return "SELECT cache_key FROM %s ORDER BY cache_key LIMIT 1 OFFSET %%s"
|
||||
|
||||
def unification_cast_sql(self, output_field):
|
||||
"""
|
||||
Given a field instance, return the SQL that casts the result of a union
|
||||
to that type. The resulting string should contain a '%s' placeholder
|
||||
for the expression being cast.
|
||||
"""
|
||||
return '%s'
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
|
||||
extracts a value from the given date field field_name.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_extract_sql() method')
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
|
||||
truncates the given date or datetime field field_name to a date object
|
||||
with only the given specificity.
|
||||
|
||||
If `tzname` is provided, the given value is truncated in a specific
|
||||
timezone.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_trunc_sql() method.')
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
"""
|
||||
Return the SQL to cast a datetime value to date value.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
'subclasses of BaseDatabaseOperations may require a '
|
||||
'datetime_cast_date_sql() method.'
|
||||
)
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
"""
|
||||
Return the SQL to cast a datetime value to time value.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_cast_time_sql() method')
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
|
||||
'second', return the SQL that extracts a value from the given
|
||||
datetime field field_name.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_extract_sql() method')
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
|
||||
'second', return the SQL that truncates the given datetime field
|
||||
field_name to a datetime object with only the given specificity.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method')
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
"""
|
||||
Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
|
||||
that truncates the given time or datetime field field_name to a time
|
||||
object with only the given specificity.
|
||||
|
||||
If `tzname` is provided, the given value is truncated in a specific
|
||||
timezone.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a time_trunc_sql() method')
|
||||
|
||||
def time_extract_sql(self, lookup_type, field_name):
|
||||
"""
|
||||
Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
|
||||
that extracts a value from the given time field field_name.
|
||||
"""
|
||||
return self.date_extract_sql(lookup_type, field_name)
|
||||
|
||||
def deferrable_sql(self):
|
||||
"""
|
||||
Return the SQL to make a constraint "initially deferred" during a
|
||||
CREATE TABLE statement.
|
||||
"""
|
||||
return ''
|
||||
|
||||
def distinct_sql(self, fields, params):
|
||||
"""
|
||||
Return an SQL DISTINCT clause which removes duplicate rows from the
|
||||
result set. If any fields are given, only check the given fields for
|
||||
duplicates.
|
||||
"""
|
||||
if fields:
|
||||
raise NotSupportedError('DISTINCT ON fields is not supported by this database backend')
|
||||
else:
|
||||
return ['DISTINCT'], []
|
||||
|
||||
def fetch_returned_insert_columns(self, cursor, returning_params):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the newly created data.
|
||||
"""
|
||||
return cursor.fetchone()
|
||||
|
||||
def field_cast_sql(self, db_type, internal_type):
|
||||
"""
|
||||
Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type
|
||||
(e.g. 'GenericIPAddressField'), return the SQL to cast it before using
|
||||
it in a WHERE statement. The resulting string should contain a '%s'
|
||||
placeholder for the column being searched against.
|
||||
"""
|
||||
return '%s'
|
||||
|
||||
def force_no_ordering(self):
|
||||
"""
|
||||
Return a list used in the "ORDER BY" clause to force no ordering at
|
||||
all. Return an empty list to include nothing in the ordering.
|
||||
"""
|
||||
return []
|
||||
|
||||
def for_update_sql(self, nowait=False, skip_locked=False, of=(), no_key=False):
|
||||
"""
|
||||
Return the FOR UPDATE SQL clause to lock rows for an update operation.
|
||||
"""
|
||||
return 'FOR%s UPDATE%s%s%s' % (
|
||||
' NO KEY' if no_key else '',
|
||||
' OF %s' % ', '.join(of) if of else '',
|
||||
' NOWAIT' if nowait else '',
|
||||
' SKIP LOCKED' if skip_locked else '',
|
||||
)
|
||||
|
||||
def _get_limit_offset_params(self, low_mark, high_mark):
|
||||
offset = low_mark or 0
|
||||
if high_mark is not None:
|
||||
return (high_mark - offset), offset
|
||||
elif offset:
|
||||
return self.connection.ops.no_limit_value(), offset
|
||||
return None, offset
|
||||
|
||||
def limit_offset_sql(self, low_mark, high_mark):
|
||||
"""Return LIMIT/OFFSET SQL clause."""
|
||||
limit, offset = self._get_limit_offset_params(low_mark, high_mark)
|
||||
return ' '.join(sql for sql in (
|
||||
('LIMIT %d' % limit) if limit else None,
|
||||
('OFFSET %d' % offset) if offset else None,
|
||||
) if sql)
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
"""
|
||||
Return a string of the query last executed by the given cursor, with
|
||||
placeholders replaced with actual values.
|
||||
|
||||
`sql` is the raw query containing placeholders and `params` is the
|
||||
sequence of parameters. These are used by default, but this method
|
||||
exists for database backends to provide a better implementation
|
||||
according to their own quoting schemes.
|
||||
"""
|
||||
# Convert params to contain string values.
|
||||
def to_string(s):
|
||||
return force_str(s, strings_only=True, errors='replace')
|
||||
if isinstance(params, (list, tuple)):
|
||||
u_params = tuple(to_string(val) for val in params)
|
||||
elif params is None:
|
||||
u_params = ()
|
||||
else:
|
||||
u_params = {to_string(k): to_string(v) for k, v in params.items()}
|
||||
|
||||
return "QUERY = %r - PARAMS = %r" % (sql, u_params)
|
||||
|
||||
def last_insert_id(self, cursor, table_name, pk_name):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT statement into
|
||||
a table that has an auto-incrementing ID, return the newly created ID.
|
||||
|
||||
`pk_name` is the name of the primary-key column.
|
||||
"""
|
||||
return cursor.lastrowid
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
"""
|
||||
Return the string to use in a query when performing lookups
|
||||
("contains", "like", etc.). It should contain a '%s' placeholder for
|
||||
the column being searched against.
|
||||
"""
|
||||
return "%s"
|
||||
|
||||
def max_in_list_size(self):
|
||||
"""
|
||||
Return the maximum number of items that can be passed in a single 'IN'
|
||||
list condition, or None if the backend does not impose a limit.
|
||||
"""
|
||||
return None
|
||||
|
||||
def max_name_length(self):
|
||||
"""
|
||||
Return the maximum length of table and column names, or None if there
|
||||
is no limit.
|
||||
"""
|
||||
return None
|
||||
|
||||
def no_limit_value(self):
|
||||
"""
|
||||
Return the value to use for the LIMIT when we are wanting "LIMIT
|
||||
infinity". Return None if the limit clause can be omitted in this case.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a no_limit_value() method')
|
||||
|
||||
def pk_default_value(self):
|
||||
"""
|
||||
Return the value to use during an INSERT statement to specify that
|
||||
the field should use its default value.
|
||||
"""
|
||||
return 'DEFAULT'
|
||||
|
||||
def prepare_sql_script(self, sql):
|
||||
"""
|
||||
Take an SQL script that may contain multiple lines and return a list
|
||||
of statements to feed to successive cursor.execute() calls.
|
||||
|
||||
Since few databases are able to process raw SQL scripts in a single
|
||||
cursor.execute() call and PEP 249 doesn't talk about this use case,
|
||||
the default implementation is conservative.
|
||||
"""
|
||||
return [
|
||||
sqlparse.format(statement, strip_comments=True)
|
||||
for statement in sqlparse.split(sql) if statement
|
||||
]
|
||||
|
||||
def process_clob(self, value):
|
||||
"""
|
||||
Return the value of a CLOB column, for backends that return a locator
|
||||
object that requires additional processing.
|
||||
"""
|
||||
return value
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
"""
|
||||
For backends that support returning columns as part of an insert query,
|
||||
return the SQL and params to append to the INSERT query. The returned
|
||||
fragment should contain a format string to hold the appropriate column.
|
||||
"""
|
||||
pass
|
||||
|
||||
def compiler(self, compiler_name):
|
||||
"""
|
||||
Return the SQLCompiler class corresponding to the given name,
|
||||
in the namespace corresponding to the `compiler_module` attribute
|
||||
on this backend.
|
||||
"""
|
||||
if self._cache is None:
|
||||
self._cache = import_module(self.compiler_module)
|
||||
return getattr(self._cache, compiler_name)
|
||||
|
||||
def quote_name(self, name):
|
||||
"""
|
||||
Return a quoted version of the given table, index, or column name. Do
|
||||
not quote the given name if it's already been quoted.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a quote_name() method')
|
||||
|
||||
def regex_lookup(self, lookup_type):
|
||||
"""
|
||||
Return the string to use in a query when performing regular expression
|
||||
lookups (using "regex" or "iregex"). It should contain a '%s'
|
||||
placeholder for the column being searched against.
|
||||
|
||||
If the feature is not supported (or part of it is not supported), raise
|
||||
NotImplementedError.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a regex_lookup() method')
|
||||
|
||||
def savepoint_create_sql(self, sid):
|
||||
"""
|
||||
Return the SQL for starting a new savepoint. Only required if the
|
||||
"uses_savepoints" feature is True. The "sid" parameter is a string
|
||||
for the savepoint id.
|
||||
"""
|
||||
return "SAVEPOINT %s" % self.quote_name(sid)
|
||||
|
||||
def savepoint_commit_sql(self, sid):
|
||||
"""
|
||||
Return the SQL for committing the given savepoint.
|
||||
"""
|
||||
return "RELEASE SAVEPOINT %s" % self.quote_name(sid)
|
||||
|
||||
def savepoint_rollback_sql(self, sid):
|
||||
"""
|
||||
Return the SQL for rolling back the given savepoint.
|
||||
"""
|
||||
return "ROLLBACK TO SAVEPOINT %s" % self.quote_name(sid)
|
||||
|
||||
def set_time_zone_sql(self):
|
||||
"""
|
||||
Return the SQL that will set the connection's time zone.
|
||||
|
||||
Return '' if the backend doesn't support time zones.
|
||||
"""
|
||||
return ''
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
"""
|
||||
Return a list of SQL statements required to remove all data from
|
||||
the given database tables (without actually removing the tables
|
||||
themselves).
|
||||
|
||||
The `style` argument is a Style object as returned by either
|
||||
color_style() or no_style() in django.core.management.color.
|
||||
|
||||
If `reset_sequences` is True, the list includes SQL statements required
|
||||
to reset the sequences.
|
||||
|
||||
The `allow_cascade` argument determines whether truncation may cascade
|
||||
to tables with foreign keys pointing the tables being truncated.
|
||||
PostgreSQL requires a cascade even if these tables are empty.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations must provide an sql_flush() method')
|
||||
|
||||
def execute_sql_flush(self, sql_list):
|
||||
"""Execute a list of SQL statements to flush the database."""
|
||||
with transaction.atomic(
|
||||
using=self.connection.alias,
|
||||
savepoint=self.connection.features.can_rollback_ddl,
|
||||
):
|
||||
with self.connection.cursor() as cursor:
|
||||
for sql in sql_list:
|
||||
cursor.execute(sql)
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
"""
|
||||
Return a list of the SQL statements required to reset sequences
|
||||
passed in `sequences`.
|
||||
|
||||
The `style` argument is a Style object as returned by either
|
||||
color_style() or no_style() in django.core.management.color.
|
||||
"""
|
||||
return []
|
||||
|
||||
def sequence_reset_sql(self, style, model_list):
|
||||
"""
|
||||
Return a list of the SQL statements required to reset sequences for
|
||||
the given models.
|
||||
|
||||
The `style` argument is a Style object as returned by either
|
||||
color_style() or no_style() in django.core.management.color.
|
||||
"""
|
||||
return [] # No sequence reset required by default.
|
||||
|
||||
def start_transaction_sql(self):
|
||||
"""Return the SQL statement required to start a transaction."""
|
||||
return "BEGIN;"
|
||||
|
||||
def end_transaction_sql(self, success=True):
|
||||
"""Return the SQL statement required to end a transaction."""
|
||||
if not success:
|
||||
return "ROLLBACK;"
|
||||
return "COMMIT;"
|
||||
|
||||
def tablespace_sql(self, tablespace, inline=False):
|
||||
"""
|
||||
Return the SQL that will be used in a query to define the tablespace.
|
||||
|
||||
Return '' if the backend doesn't support tablespaces.
|
||||
|
||||
If `inline` is True, append the SQL to a row; otherwise append it to
|
||||
the entire CREATE TABLE or CREATE INDEX statement.
|
||||
"""
|
||||
return ''
|
||||
|
||||
def prep_for_like_query(self, x):
|
||||
"""Prepare a value for use in a LIKE query."""
|
||||
return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
|
||||
|
||||
# Same as prep_for_like_query(), but called for "iexact" matches, which
|
||||
# need not necessarily be implemented using "LIKE" in the backend.
|
||||
prep_for_iexact_query = prep_for_like_query
|
||||
|
||||
def validate_autopk_value(self, value):
|
||||
"""
|
||||
Certain backends do not accept some values for "serial" fields
|
||||
(for example zero in MySQL). Raise a ValueError if the value is
|
||||
invalid, otherwise return the validated value.
|
||||
"""
|
||||
return value
|
||||
|
||||
def adapt_unknown_value(self, value):
|
||||
"""
|
||||
Transform a value to something compatible with the backend driver.
|
||||
|
||||
This method only depends on the type of the value. It's designed for
|
||||
cases where the target type isn't known, such as .raw() SQL queries.
|
||||
As a consequence it may not work perfectly in all circumstances.
|
||||
"""
|
||||
if isinstance(value, datetime.datetime): # must be before date
|
||||
return self.adapt_datetimefield_value(value)
|
||||
elif isinstance(value, datetime.date):
|
||||
return self.adapt_datefield_value(value)
|
||||
elif isinstance(value, datetime.time):
|
||||
return self.adapt_timefield_value(value)
|
||||
elif isinstance(value, decimal.Decimal):
|
||||
return self.adapt_decimalfield_value(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def adapt_datefield_value(self, value):
|
||||
"""
|
||||
Transform a date value to an object compatible with what is expected
|
||||
by the backend driver for date columns.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
"""
|
||||
Transform a datetime value to an object compatible with what is expected
|
||||
by the backend driver for datetime columns.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
"""
|
||||
Transform a time value to an object compatible with what is expected
|
||||
by the backend driver for time columns.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("Django does not support timezone-aware times.")
|
||||
return str(value)
|
||||
|
||||
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
|
||||
"""
|
||||
Transform a decimal.Decimal value to an object compatible with what is
|
||||
expected by the backend driver for decimal (numeric) columns.
|
||||
"""
|
||||
return utils.format_number(value, max_digits, decimal_places)
|
||||
|
||||
def adapt_ipaddressfield_value(self, value):
|
||||
"""
|
||||
Transform a string representation of an IP address into the expected
|
||||
type for the backend driver.
|
||||
"""
|
||||
return value or None
|
||||
|
||||
def year_lookup_bounds_for_date_field(self, value, iso_year=False):
|
||||
"""
|
||||
Return a two-elements list with the lower and upper bound to be used
|
||||
with a BETWEEN operator to query a DateField value using a year
|
||||
lookup.
|
||||
|
||||
`value` is an int, containing the looked-up year.
|
||||
If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
|
||||
"""
|
||||
if iso_year:
|
||||
first = datetime.date.fromisocalendar(value, 1, 1)
|
||||
second = (
|
||||
datetime.date.fromisocalendar(value + 1, 1, 1) -
|
||||
datetime.timedelta(days=1)
|
||||
)
|
||||
else:
|
||||
first = datetime.date(value, 1, 1)
|
||||
second = datetime.date(value, 12, 31)
|
||||
first = self.adapt_datefield_value(first)
|
||||
second = self.adapt_datefield_value(second)
|
||||
return [first, second]
|
||||
|
||||
def year_lookup_bounds_for_datetime_field(self, value, iso_year=False):
|
||||
"""
|
||||
Return a two-elements list with the lower and upper bound to be used
|
||||
with a BETWEEN operator to query a DateTimeField value using a year
|
||||
lookup.
|
||||
|
||||
`value` is an int, containing the looked-up year.
|
||||
If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
|
||||
"""
|
||||
if iso_year:
|
||||
first = datetime.datetime.fromisocalendar(value, 1, 1)
|
||||
second = (
|
||||
datetime.datetime.fromisocalendar(value + 1, 1, 1) -
|
||||
datetime.timedelta(microseconds=1)
|
||||
)
|
||||
else:
|
||||
first = datetime.datetime(value, 1, 1)
|
||||
second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
|
||||
if settings.USE_TZ:
|
||||
tz = timezone.get_current_timezone()
|
||||
first = timezone.make_aware(first, tz)
|
||||
second = timezone.make_aware(second, tz)
|
||||
first = self.adapt_datetimefield_value(first)
|
||||
second = self.adapt_datetimefield_value(second)
|
||||
return [first, second]
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
"""
|
||||
Return a list of functions needed to convert field data.
|
||||
|
||||
Some field types on some backends do not provide data in the correct
|
||||
format, this is the hook for converter functions.
|
||||
"""
|
||||
return []
|
||||
|
||||
def convert_durationfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
return datetime.timedelta(0, 0, value)
|
||||
|
||||
def check_expression_support(self, expression):
|
||||
"""
|
||||
Check that the backend supports the provided expression.
|
||||
|
||||
This is used on specific backends to rule out known expressions
|
||||
that have problematic or nonexistent implementations. If the
|
||||
expression has a known problem, the backend should raise
|
||||
NotSupportedError.
|
||||
"""
|
||||
pass
|
||||
|
||||
def conditional_expression_supported_in_where_clause(self, expression):
|
||||
"""
|
||||
Return True, if the conditional expression is supported in the WHERE
|
||||
clause.
|
||||
"""
|
||||
return True
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
"""
|
||||
Combine a list of subexpressions into a single expression, using
|
||||
the provided connecting operator. This is required because operators
|
||||
can vary between backends (e.g., Oracle with %% and &) and between
|
||||
subexpression types (e.g., date expressions).
|
||||
"""
|
||||
conn = ' %s ' % connector
|
||||
return conn.join(sub_expressions)
|
||||
|
||||
def combine_duration_expression(self, connector, sub_expressions):
|
||||
return self.combine_expression(connector, sub_expressions)
|
||||
|
||||
def binary_placeholder_sql(self, value):
|
||||
"""
|
||||
Some backends require special syntax to insert binary content (MySQL
|
||||
for example uses '_binary %s').
|
||||
"""
|
||||
return '%s'
|
||||
|
||||
def modify_insert_params(self, placeholder, params):
|
||||
"""
|
||||
Allow modification of insert parameters. Needed for Oracle Spatial
|
||||
backend due to #10888.
|
||||
"""
|
||||
return params
|
||||
|
||||
def integer_field_range(self, internal_type):
|
||||
"""
|
||||
Given an integer field internal type (e.g. 'PositiveIntegerField'),
|
||||
return a tuple of the (min_value, max_value) form representing the
|
||||
range of the column type bound to the field.
|
||||
"""
|
||||
return self.integer_field_ranges[internal_type]
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
if self.connection.features.supports_temporal_subtraction:
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
return '(%s - %s)' % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params)
|
||||
raise NotSupportedError("This backend does not support %s subtraction." % internal_type)
|
||||
|
||||
def window_frame_start(self, start):
|
||||
if isinstance(start, int):
|
||||
if start < 0:
|
||||
return '%d %s' % (abs(start), self.PRECEDING)
|
||||
elif start == 0:
|
||||
return self.CURRENT_ROW
|
||||
elif start is None:
|
||||
return self.UNBOUNDED_PRECEDING
|
||||
raise ValueError("start argument must be a negative integer, zero, or None, but got '%s'." % start)
|
||||
|
||||
def window_frame_end(self, end):
|
||||
if isinstance(end, int):
|
||||
if end == 0:
|
||||
return self.CURRENT_ROW
|
||||
elif end > 0:
|
||||
return '%d %s' % (end, self.FOLLOWING)
|
||||
elif end is None:
|
||||
return self.UNBOUNDED_FOLLOWING
|
||||
raise ValueError("end argument must be a positive integer, zero, or None, but got '%s'." % end)
|
||||
|
||||
def window_frame_rows_start_end(self, start=None, end=None):
|
||||
"""
|
||||
Return SQL for start and end points in an OVER clause window frame.
|
||||
"""
|
||||
if not self.connection.features.supports_over_clause:
|
||||
raise NotSupportedError('This backend does not support window expressions.')
|
||||
return self.window_frame_start(start), self.window_frame_end(end)
|
||||
|
||||
def window_frame_range_start_end(self, start=None, end=None):
|
||||
start_, end_ = self.window_frame_rows_start_end(start, end)
|
||||
if (
|
||||
self.connection.features.only_supports_unbounded_with_preceding_and_following and
|
||||
((start and start < 0) or (end and end > 0))
|
||||
):
|
||||
raise NotSupportedError(
|
||||
'%s only supports UNBOUNDED together with PRECEDING and '
|
||||
'FOLLOWING.' % self.connection.display_name
|
||||
)
|
||||
return start_, end_
|
||||
|
||||
def explain_query_prefix(self, format=None, **options):
|
||||
if not self.connection.features.supports_explaining_query_execution:
|
||||
raise NotSupportedError('This backend does not support explaining query execution.')
|
||||
if format:
|
||||
supported_formats = self.connection.features.supported_explain_formats
|
||||
normalized_format = format.upper()
|
||||
if normalized_format not in supported_formats:
|
||||
msg = '%s is not a recognized format.' % normalized_format
|
||||
if supported_formats:
|
||||
msg += ' Allowed formats: %s' % ', '.join(sorted(supported_formats))
|
||||
raise ValueError(msg)
|
||||
if options:
|
||||
raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys())))
|
||||
return self.explain_prefix
|
||||
|
||||
def insert_statement(self, ignore_conflicts=False):
|
||||
return 'INSERT INTO'
|
||||
|
||||
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
|
||||
return ''
|
||||
1393
venv/Lib/site-packages/django/db/backends/base/schema.py
Normal file
1393
venv/Lib/site-packages/django/db/backends/base/schema.py
Normal file
File diff suppressed because it is too large
Load Diff
25
venv/Lib/site-packages/django/db/backends/base/validation.py
Normal file
25
venv/Lib/site-packages/django/db/backends/base/validation.py
Normal file
@@ -0,0 +1,25 @@
|
||||
class BaseDatabaseValidation:
|
||||
"""Encapsulate backend-specific validation."""
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def check(self, **kwargs):
|
||||
return []
|
||||
|
||||
def check_field(self, field, **kwargs):
|
||||
errors = []
|
||||
# Backends may implement a check_field_type() method.
|
||||
if (hasattr(self, 'check_field_type') and
|
||||
# Ignore any related fields.
|
||||
not getattr(field, 'remote_field', None)):
|
||||
# Ignore fields with unsupported features.
|
||||
db_supports_all_required_features = all(
|
||||
getattr(self.connection.features, feature, False)
|
||||
for feature in field.model._meta.required_db_features
|
||||
)
|
||||
if db_supports_all_required_features:
|
||||
field_type = field.db_type(self.connection)
|
||||
# Ignore non-concrete fields.
|
||||
if field_type is not None:
|
||||
errors.extend(self.check_field_type(field, field_type))
|
||||
return errors
|
||||
232
venv/Lib/site-packages/django/db/backends/ddl_references.py
Normal file
232
venv/Lib/site-packages/django/db/backends/ddl_references.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Helpers to manipulate deferred DDL statements that might need to be adjusted or
|
||||
discarded within when executing a migration.
|
||||
"""
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class Reference:
|
||||
"""Base class that defines the reference interface."""
|
||||
|
||||
def references_table(self, table):
|
||||
"""
|
||||
Return whether or not this instance references the specified table.
|
||||
"""
|
||||
return False
|
||||
|
||||
def references_column(self, table, column):
|
||||
"""
|
||||
Return whether or not this instance references the specified column.
|
||||
"""
|
||||
return False
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
"""
|
||||
Rename all references to the old_name to the new_table.
|
||||
"""
|
||||
pass
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
"""
|
||||
Rename all references to the old_column to the new_column.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s %r>' % (self.__class__.__name__, str(self))
|
||||
|
||||
def __str__(self):
|
||||
raise NotImplementedError('Subclasses must define how they should be converted to string.')
|
||||
|
||||
|
||||
class Table(Reference):
|
||||
"""Hold a reference to a table."""
|
||||
|
||||
def __init__(self, table, quote_name):
|
||||
self.table = table
|
||||
self.quote_name = quote_name
|
||||
|
||||
def references_table(self, table):
|
||||
return self.table == table
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
if self.table == old_table:
|
||||
self.table = new_table
|
||||
|
||||
def __str__(self):
|
||||
return self.quote_name(self.table)
|
||||
|
||||
|
||||
class TableColumns(Table):
|
||||
"""Base class for references to multiple columns of a table."""
|
||||
|
||||
def __init__(self, table, columns):
|
||||
self.table = table
|
||||
self.columns = columns
|
||||
|
||||
def references_column(self, table, column):
|
||||
return self.table == table and column in self.columns
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
if self.table == table:
|
||||
for index, column in enumerate(self.columns):
|
||||
if column == old_column:
|
||||
self.columns[index] = new_column
|
||||
|
||||
|
||||
class Columns(TableColumns):
|
||||
"""Hold a reference to one or many columns."""
|
||||
|
||||
def __init__(self, table, columns, quote_name, col_suffixes=()):
|
||||
self.quote_name = quote_name
|
||||
self.col_suffixes = col_suffixes
|
||||
super().__init__(table, columns)
|
||||
|
||||
def __str__(self):
|
||||
def col_str(column, idx):
|
||||
col = self.quote_name(column)
|
||||
try:
|
||||
suffix = self.col_suffixes[idx]
|
||||
if suffix:
|
||||
col = '{} {}'.format(col, suffix)
|
||||
except IndexError:
|
||||
pass
|
||||
return col
|
||||
|
||||
return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
|
||||
|
||||
|
||||
class IndexName(TableColumns):
|
||||
"""Hold a reference to an index name."""
|
||||
|
||||
def __init__(self, table, columns, suffix, create_index_name):
|
||||
self.suffix = suffix
|
||||
self.create_index_name = create_index_name
|
||||
super().__init__(table, columns)
|
||||
|
||||
def __str__(self):
|
||||
return self.create_index_name(self.table, self.columns, self.suffix)
|
||||
|
||||
|
||||
class IndexColumns(Columns):
|
||||
def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()):
|
||||
self.opclasses = opclasses
|
||||
super().__init__(table, columns, quote_name, col_suffixes)
|
||||
|
||||
def __str__(self):
|
||||
def col_str(column, idx):
|
||||
# Index.__init__() guarantees that self.opclasses is the same
|
||||
# length as self.columns.
|
||||
col = '{} {}'.format(self.quote_name(column), self.opclasses[idx])
|
||||
try:
|
||||
suffix = self.col_suffixes[idx]
|
||||
if suffix:
|
||||
col = '{} {}'.format(col, suffix)
|
||||
except IndexError:
|
||||
pass
|
||||
return col
|
||||
|
||||
return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
|
||||
|
||||
|
||||
class ForeignKeyName(TableColumns):
|
||||
"""Hold a reference to a foreign key name."""
|
||||
|
||||
def __init__(self, from_table, from_columns, to_table, to_columns, suffix_template, create_fk_name):
|
||||
self.to_reference = TableColumns(to_table, to_columns)
|
||||
self.suffix_template = suffix_template
|
||||
self.create_fk_name = create_fk_name
|
||||
super().__init__(from_table, from_columns,)
|
||||
|
||||
def references_table(self, table):
|
||||
return super().references_table(table) or self.to_reference.references_table(table)
|
||||
|
||||
def references_column(self, table, column):
|
||||
return (
|
||||
super().references_column(table, column) or
|
||||
self.to_reference.references_column(table, column)
|
||||
)
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
super().rename_table_references(old_table, new_table)
|
||||
self.to_reference.rename_table_references(old_table, new_table)
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
super().rename_column_references(table, old_column, new_column)
|
||||
self.to_reference.rename_column_references(table, old_column, new_column)
|
||||
|
||||
def __str__(self):
|
||||
suffix = self.suffix_template % {
|
||||
'to_table': self.to_reference.table,
|
||||
'to_column': self.to_reference.columns[0],
|
||||
}
|
||||
return self.create_fk_name(self.table, self.columns, suffix)
|
||||
|
||||
|
||||
class Statement(Reference):
|
||||
"""
|
||||
Statement template and formatting parameters container.
|
||||
|
||||
Allows keeping a reference to a statement without interpolating identifiers
|
||||
that might have to be adjusted if they're referencing a table or column
|
||||
that is removed
|
||||
"""
|
||||
def __init__(self, template, **parts):
|
||||
self.template = template
|
||||
self.parts = parts
|
||||
|
||||
def references_table(self, table):
|
||||
return any(
|
||||
hasattr(part, 'references_table') and part.references_table(table)
|
||||
for part in self.parts.values()
|
||||
)
|
||||
|
||||
def references_column(self, table, column):
|
||||
return any(
|
||||
hasattr(part, 'references_column') and part.references_column(table, column)
|
||||
for part in self.parts.values()
|
||||
)
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
for part in self.parts.values():
|
||||
if hasattr(part, 'rename_table_references'):
|
||||
part.rename_table_references(old_table, new_table)
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
for part in self.parts.values():
|
||||
if hasattr(part, 'rename_column_references'):
|
||||
part.rename_column_references(table, old_column, new_column)
|
||||
|
||||
def __str__(self):
|
||||
return self.template % self.parts
|
||||
|
||||
|
||||
class Expressions(TableColumns):
|
||||
def __init__(self, table, expressions, compiler, quote_value):
|
||||
self.compiler = compiler
|
||||
self.expressions = expressions
|
||||
self.quote_value = quote_value
|
||||
columns = [col.target.column for col in self.compiler.query._gen_cols([self.expressions])]
|
||||
super().__init__(table, columns)
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
if self.table != old_table:
|
||||
return
|
||||
self.expressions = self.expressions.relabeled_clone({old_table: new_table})
|
||||
super().rename_table_references(old_table, new_table)
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
if self.table != table:
|
||||
return
|
||||
expressions = deepcopy(self.expressions)
|
||||
self.columns = []
|
||||
for col in self.compiler.query._gen_cols([expressions]):
|
||||
if col.target.column == old_column:
|
||||
col.target.column = new_column
|
||||
self.columns.append(col.target.column)
|
||||
self.expressions = expressions
|
||||
|
||||
def __str__(self):
|
||||
sql, params = self.compiler.compile(self.expressions)
|
||||
params = map(self.quote_value, params)
|
||||
return sql % tuple(params)
|
||||
73
venv/Lib/site-packages/django/db/backends/dummy/base.py
Normal file
73
venv/Lib/site-packages/django/db/backends/dummy/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Dummy database backend for Django.
|
||||
|
||||
Django uses this if the database ENGINE setting is empty (None or empty string).
|
||||
|
||||
Each of these API functions, except connection.close(), raise
|
||||
ImproperlyConfigured.
|
||||
"""
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.dummy.features import DummyDatabaseFeatures
|
||||
|
||||
|
||||
def complain(*args, **kwargs):
|
||||
raise ImproperlyConfigured("settings.DATABASES is improperly configured. "
|
||||
"Please supply the ENGINE value. Check "
|
||||
"settings documentation for more details.")
|
||||
|
||||
|
||||
def ignore(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
quote_name = complain
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
runshell = complain
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
create_test_db = ignore
|
||||
destroy_test_db = ignore
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
get_table_list = complain
|
||||
get_table_description = complain
|
||||
get_relations = complain
|
||||
get_indexes = complain
|
||||
get_key_columns = complain
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
operators = {}
|
||||
# Override the base class implementations with null
|
||||
# implementations. Anything that tries to actually
|
||||
# do something raises complain; anything that tries
|
||||
# to rollback or undo something raises ignore.
|
||||
_cursor = complain
|
||||
ensure_connection = complain
|
||||
_commit = complain
|
||||
_rollback = ignore
|
||||
_close = ignore
|
||||
_savepoint = ignore
|
||||
_savepoint_commit = complain
|
||||
_savepoint_rollback = ignore
|
||||
_set_autocommit = complain
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DummyDatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
|
||||
def is_usable(self):
|
||||
return True
|
||||
@@ -0,0 +1,6 @@
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
|
||||
|
||||
class DummyDatabaseFeatures(BaseDatabaseFeatures):
|
||||
supports_transactions = False
|
||||
uses_savepoints = False
|
||||
405
venv/Lib/site-packages/django/db/backends/mysql/base.py
Normal file
405
venv/Lib/site-packages/django/db/backends/mysql/base.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
MySQL database backend for Django.
|
||||
|
||||
Requires mysqlclient: https://pypi.org/project/mysqlclient/
|
||||
"""
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import IntegrityError
|
||||
from django.db.backends import utils as backend_utils
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
try:
|
||||
import MySQLdb as Database
|
||||
except ImportError as err:
|
||||
raise ImproperlyConfigured(
|
||||
'Error loading MySQLdb module.\n'
|
||||
'Did you install mysqlclient?'
|
||||
) from err
|
||||
|
||||
from MySQLdb.constants import CLIENT, FIELD_TYPE
|
||||
from MySQLdb.converters import conversions
|
||||
|
||||
# Some of these import MySQLdb, so import them after checking if it's installed.
|
||||
from .client import DatabaseClient
|
||||
from .creation import DatabaseCreation
|
||||
from .features import DatabaseFeatures
|
||||
from .introspection import DatabaseIntrospection
|
||||
from .operations import DatabaseOperations
|
||||
from .schema import DatabaseSchemaEditor
|
||||
from .validation import DatabaseValidation
|
||||
|
||||
version = Database.version_info
|
||||
if version < (1, 4, 0):
|
||||
raise ImproperlyConfigured('mysqlclient 1.4.0 or newer is required; you have %s.' % Database.__version__)
|
||||
|
||||
|
||||
# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
|
||||
# terms of actual behavior as they are signed and include days -- and Django
|
||||
# expects time.
|
||||
django_conversions = {
|
||||
**conversions,
|
||||
**{FIELD_TYPE.TIME: backend_utils.typecast_time},
|
||||
}
|
||||
|
||||
# This should match the numerical portion of the version numbers (we can treat
|
||||
# versions like 5.0.24 and 5.0.24a as the same).
|
||||
server_version_re = _lazy_re_compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
|
||||
|
||||
|
||||
class CursorWrapper:
|
||||
"""
|
||||
A thin wrapper around MySQLdb's normal cursor class that catches particular
|
||||
exception instances and reraises them with the correct types.
|
||||
|
||||
Implemented as a wrapper, rather than a subclass, so that it isn't stuck
|
||||
to the particular underlying representation returned by Connection.cursor().
|
||||
"""
|
||||
codes_for_integrityerror = (
|
||||
1048, # Column cannot be null
|
||||
1690, # BIGINT UNSIGNED value is out of range
|
||||
3819, # CHECK constraint is violated
|
||||
4025, # CHECK constraint failed
|
||||
)
|
||||
|
||||
def __init__(self, cursor):
|
||||
self.cursor = cursor
|
||||
|
||||
def execute(self, query, args=None):
|
||||
try:
|
||||
# args is None means no string interpolation
|
||||
return self.cursor.execute(query, args)
|
||||
except Database.OperationalError as e:
|
||||
# Map some error codes to IntegrityError, since they seem to be
|
||||
# misclassified and Django would prefer the more logical place.
|
||||
if e.args[0] in self.codes_for_integrityerror:
|
||||
raise IntegrityError(*tuple(e.args))
|
||||
raise
|
||||
|
||||
def executemany(self, query, args):
|
||||
try:
|
||||
return self.cursor.executemany(query, args)
|
||||
except Database.OperationalError as e:
|
||||
# Map some error codes to IntegrityError, since they seem to be
|
||||
# misclassified and Django would prefer the more logical place.
|
||||
if e.args[0] in self.codes_for_integrityerror:
|
||||
raise IntegrityError(*tuple(e.args))
|
||||
raise
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.cursor, attr)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = 'mysql'
|
||||
# This dictionary maps Field objects to their associated MySQL column
|
||||
# types, as strings. Column-type strings can contain format strings; they'll
|
||||
# be interpolated against the values of Field.__dict__ before being output.
|
||||
# If a column type is set to None, it won't be included in the output.
|
||||
data_types = {
|
||||
'AutoField': 'integer AUTO_INCREMENT',
|
||||
'BigAutoField': 'bigint AUTO_INCREMENT',
|
||||
'BinaryField': 'longblob',
|
||||
'BooleanField': 'bool',
|
||||
'CharField': 'varchar(%(max_length)s)',
|
||||
'DateField': 'date',
|
||||
'DateTimeField': 'datetime(6)',
|
||||
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
|
||||
'DurationField': 'bigint',
|
||||
'FileField': 'varchar(%(max_length)s)',
|
||||
'FilePathField': 'varchar(%(max_length)s)',
|
||||
'FloatField': 'double precision',
|
||||
'IntegerField': 'integer',
|
||||
'BigIntegerField': 'bigint',
|
||||
'IPAddressField': 'char(15)',
|
||||
'GenericIPAddressField': 'char(39)',
|
||||
'JSONField': 'json',
|
||||
'OneToOneField': 'integer',
|
||||
'PositiveBigIntegerField': 'bigint UNSIGNED',
|
||||
'PositiveIntegerField': 'integer UNSIGNED',
|
||||
'PositiveSmallIntegerField': 'smallint UNSIGNED',
|
||||
'SlugField': 'varchar(%(max_length)s)',
|
||||
'SmallAutoField': 'smallint AUTO_INCREMENT',
|
||||
'SmallIntegerField': 'smallint',
|
||||
'TextField': 'longtext',
|
||||
'TimeField': 'time(6)',
|
||||
'UUIDField': 'char(32)',
|
||||
}
|
||||
|
||||
# For these data types:
|
||||
# - MySQL < 8.0.13 and MariaDB < 10.2.1 don't accept default values and
|
||||
# implicitly treat them as nullable
|
||||
# - all versions of MySQL and MariaDB don't support full width database
|
||||
# indexes
|
||||
_limited_data_types = (
|
||||
'tinyblob', 'blob', 'mediumblob', 'longblob', 'tinytext', 'text',
|
||||
'mediumtext', 'longtext', 'json',
|
||||
)
|
||||
|
||||
operators = {
|
||||
'exact': '= %s',
|
||||
'iexact': 'LIKE %s',
|
||||
'contains': 'LIKE BINARY %s',
|
||||
'icontains': 'LIKE %s',
|
||||
'gt': '> %s',
|
||||
'gte': '>= %s',
|
||||
'lt': '< %s',
|
||||
'lte': '<= %s',
|
||||
'startswith': 'LIKE BINARY %s',
|
||||
'endswith': 'LIKE BINARY %s',
|
||||
'istartswith': 'LIKE %s',
|
||||
'iendswith': 'LIKE %s',
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
|
||||
# escaped on database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
|
||||
pattern_ops = {
|
||||
'contains': "LIKE BINARY CONCAT('%%', {}, '%%')",
|
||||
'icontains': "LIKE CONCAT('%%', {}, '%%')",
|
||||
'startswith': "LIKE BINARY CONCAT({}, '%%')",
|
||||
'istartswith': "LIKE CONCAT({}, '%%')",
|
||||
'endswith': "LIKE BINARY CONCAT('%%', {})",
|
||||
'iendswith': "LIKE CONCAT('%%', {})",
|
||||
}
|
||||
|
||||
isolation_levels = {
|
||||
'read uncommitted',
|
||||
'read committed',
|
||||
'repeatable read',
|
||||
'serializable',
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
validation_class = DatabaseValidation
|
||||
|
||||
def get_connection_params(self):
|
||||
kwargs = {
|
||||
'conv': django_conversions,
|
||||
'charset': 'utf8',
|
||||
}
|
||||
settings_dict = self.settings_dict
|
||||
if settings_dict['USER']:
|
||||
kwargs['user'] = settings_dict['USER']
|
||||
if settings_dict['NAME']:
|
||||
kwargs['database'] = settings_dict['NAME']
|
||||
if settings_dict['PASSWORD']:
|
||||
kwargs['password'] = settings_dict['PASSWORD']
|
||||
if settings_dict['HOST'].startswith('/'):
|
||||
kwargs['unix_socket'] = settings_dict['HOST']
|
||||
elif settings_dict['HOST']:
|
||||
kwargs['host'] = settings_dict['HOST']
|
||||
if settings_dict['PORT']:
|
||||
kwargs['port'] = int(settings_dict['PORT'])
|
||||
# We need the number of potentially affected rows after an
|
||||
# "UPDATE", not the number of changed rows.
|
||||
kwargs['client_flag'] = CLIENT.FOUND_ROWS
|
||||
# Validate the transaction isolation level, if specified.
|
||||
options = settings_dict['OPTIONS'].copy()
|
||||
isolation_level = options.pop('isolation_level', 'read committed')
|
||||
if isolation_level:
|
||||
isolation_level = isolation_level.lower()
|
||||
if isolation_level not in self.isolation_levels:
|
||||
raise ImproperlyConfigured(
|
||||
"Invalid transaction isolation level '%s' specified.\n"
|
||||
"Use one of %s, or None." % (
|
||||
isolation_level,
|
||||
', '.join("'%s'" % s for s in sorted(self.isolation_levels))
|
||||
))
|
||||
self.isolation_level = isolation_level
|
||||
kwargs.update(options)
|
||||
return kwargs
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
connection = Database.connect(**conn_params)
|
||||
# bytes encoder in mysqlclient doesn't work and was added only to
|
||||
# prevent KeyErrors in Django < 2.0. We can remove this workaround when
|
||||
# mysqlclient 2.1 becomes the minimal mysqlclient supported by Django.
|
||||
# See https://github.com/PyMySQL/mysqlclient/issues/489
|
||||
if connection.encoders.get(bytes) is bytes:
|
||||
connection.encoders.pop(bytes)
|
||||
return connection
|
||||
|
||||
def init_connection_state(self):
|
||||
assignments = []
|
||||
if self.features.is_sql_auto_is_null_enabled:
|
||||
# SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
|
||||
# a recently inserted row will return when the field is tested
|
||||
# for NULL. Disabling this brings this aspect of MySQL in line
|
||||
# with SQL standards.
|
||||
assignments.append('SET SQL_AUTO_IS_NULL = 0')
|
||||
|
||||
if self.isolation_level:
|
||||
assignments.append('SET SESSION TRANSACTION ISOLATION LEVEL %s' % self.isolation_level.upper())
|
||||
|
||||
if assignments:
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute('; '.join(assignments))
|
||||
|
||||
@async_unsafe
|
||||
def create_cursor(self, name=None):
|
||||
cursor = self.connection.cursor()
|
||||
return CursorWrapper(cursor)
|
||||
|
||||
def _rollback(self):
|
||||
try:
|
||||
BaseDatabaseWrapper._rollback(self)
|
||||
except Database.NotSupportedError:
|
||||
pass
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
with self.wrap_database_errors:
|
||||
self.connection.autocommit(autocommit)
|
||||
|
||||
def disable_constraint_checking(self):
|
||||
"""
|
||||
Disable foreign key checks, primarily for use in adding rows with
|
||||
forward references. Always return True to indicate constraint checks
|
||||
need to be re-enabled.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute('SET foreign_key_checks=0')
|
||||
return True
|
||||
|
||||
def enable_constraint_checking(self):
|
||||
"""
|
||||
Re-enable foreign key checks after they have been disabled.
|
||||
"""
|
||||
# Override needs_rollback in case constraint_checks_disabled is
|
||||
# nested inside transaction.atomic.
|
||||
self.needs_rollback, needs_rollback = False, self.needs_rollback
|
||||
try:
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute('SET foreign_key_checks=1')
|
||||
finally:
|
||||
self.needs_rollback = needs_rollback
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check each table name in `table_names` for rows with invalid foreign
|
||||
key references. This method is intended to be used in conjunction with
|
||||
`disable_constraint_checking()` and `enable_constraint_checking()`, to
|
||||
determine if rows with invalid references were entered while constraint
|
||||
checks were off.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
if table_names is None:
|
||||
table_names = self.introspection.table_names(cursor)
|
||||
for table_name in table_names:
|
||||
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
|
||||
if not primary_key_column_name:
|
||||
continue
|
||||
key_columns = self.introspection.get_key_columns(cursor, table_name)
|
||||
for column_name, referenced_table_name, referenced_column_name in key_columns:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
|
||||
LEFT JOIN `%s` as REFERRED
|
||||
ON (REFERRING.`%s` = REFERRED.`%s`)
|
||||
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
|
||||
""" % (
|
||||
primary_key_column_name, column_name, table_name,
|
||||
referenced_table_name, column_name, referenced_column_name,
|
||||
column_name, referenced_column_name,
|
||||
)
|
||||
)
|
||||
for bad_row in cursor.fetchall():
|
||||
raise IntegrityError(
|
||||
"The row in table '%s' with primary key '%s' has an invalid "
|
||||
"foreign key: %s.%s contains a value '%s' that does not "
|
||||
"have a corresponding value in %s.%s."
|
||||
% (
|
||||
table_name, bad_row[0], table_name, column_name,
|
||||
bad_row[1], referenced_table_name, referenced_column_name,
|
||||
)
|
||||
)
|
||||
|
||||
def is_usable(self):
|
||||
try:
|
||||
self.connection.ping()
|
||||
except Database.Error:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def display_name(self):
|
||||
return 'MariaDB' if self.mysql_is_mariadb else 'MySQL'
|
||||
|
||||
@cached_property
|
||||
def data_type_check_constraints(self):
|
||||
if self.features.supports_column_check_constraints:
|
||||
check_constraints = {
|
||||
'PositiveBigIntegerField': '`%(column)s` >= 0',
|
||||
'PositiveIntegerField': '`%(column)s` >= 0',
|
||||
'PositiveSmallIntegerField': '`%(column)s` >= 0',
|
||||
}
|
||||
if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
|
||||
# MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
|
||||
# a check constraint.
|
||||
check_constraints['JSONField'] = 'JSON_VALID(`%(column)s`)'
|
||||
return check_constraints
|
||||
return {}
|
||||
|
||||
@cached_property
|
||||
def mysql_server_data(self):
|
||||
with self.temporary_connection() as cursor:
|
||||
# Select some server variables and test if the time zone
|
||||
# definitions are installed. CONVERT_TZ returns NULL if 'UTC'
|
||||
# timezone isn't loaded into the mysql.time_zone table.
|
||||
cursor.execute("""
|
||||
SELECT VERSION(),
|
||||
@@sql_mode,
|
||||
@@default_storage_engine,
|
||||
@@sql_auto_is_null,
|
||||
@@lower_case_table_names,
|
||||
CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
return {
|
||||
'version': row[0],
|
||||
'sql_mode': row[1],
|
||||
'default_storage_engine': row[2],
|
||||
'sql_auto_is_null': bool(row[3]),
|
||||
'lower_case_table_names': bool(row[4]),
|
||||
'has_zoneinfo_database': bool(row[5]),
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def mysql_server_info(self):
|
||||
return self.mysql_server_data['version']
|
||||
|
||||
@cached_property
|
||||
def mysql_version(self):
|
||||
match = server_version_re.match(self.mysql_server_info)
|
||||
if not match:
|
||||
raise Exception('Unable to determine MySQL version from version string %r' % self.mysql_server_info)
|
||||
return tuple(int(x) for x in match.groups())
|
||||
|
||||
@cached_property
|
||||
def mysql_is_mariadb(self):
|
||||
return 'mariadb' in self.mysql_server_info.lower()
|
||||
|
||||
@cached_property
|
||||
def sql_mode(self):
|
||||
sql_mode = self.mysql_server_data['sql_mode']
|
||||
return set(sql_mode.split(',') if sql_mode else ())
|
||||
60
venv/Lib/site-packages/django/db/backends/mysql/client.py
Normal file
60
venv/Lib/site-packages/django/db/backends/mysql/client.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = 'mysql'
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name]
|
||||
env = None
|
||||
database = settings_dict['OPTIONS'].get(
|
||||
'database',
|
||||
settings_dict['OPTIONS'].get('db', settings_dict['NAME']),
|
||||
)
|
||||
user = settings_dict['OPTIONS'].get('user', settings_dict['USER'])
|
||||
password = settings_dict['OPTIONS'].get(
|
||||
'password',
|
||||
settings_dict['OPTIONS'].get('passwd', settings_dict['PASSWORD'])
|
||||
)
|
||||
host = settings_dict['OPTIONS'].get('host', settings_dict['HOST'])
|
||||
port = settings_dict['OPTIONS'].get('port', settings_dict['PORT'])
|
||||
server_ca = settings_dict['OPTIONS'].get('ssl', {}).get('ca')
|
||||
client_cert = settings_dict['OPTIONS'].get('ssl', {}).get('cert')
|
||||
client_key = settings_dict['OPTIONS'].get('ssl', {}).get('key')
|
||||
defaults_file = settings_dict['OPTIONS'].get('read_default_file')
|
||||
charset = settings_dict['OPTIONS'].get('charset')
|
||||
# Seems to be no good way to set sql_mode with CLI.
|
||||
|
||||
if defaults_file:
|
||||
args += ["--defaults-file=%s" % defaults_file]
|
||||
if user:
|
||||
args += ["--user=%s" % user]
|
||||
if password:
|
||||
# The MYSQL_PWD environment variable usage is discouraged per
|
||||
# MySQL's documentation due to the possibility of exposure through
|
||||
# `ps` on old Unix flavors but --password suffers from the same
|
||||
# flaw on even more systems. Usage of an environment variable also
|
||||
# prevents password exposure if the subprocess.run(check=True) call
|
||||
# raises a CalledProcessError since the string representation of
|
||||
# the latter includes all of the provided `args`.
|
||||
env = {'MYSQL_PWD': password}
|
||||
if host:
|
||||
if '/' in host:
|
||||
args += ["--socket=%s" % host]
|
||||
else:
|
||||
args += ["--host=%s" % host]
|
||||
if port:
|
||||
args += ["--port=%s" % port]
|
||||
if server_ca:
|
||||
args += ["--ssl-ca=%s" % server_ca]
|
||||
if client_cert:
|
||||
args += ["--ssl-cert=%s" % client_cert]
|
||||
if client_key:
|
||||
args += ["--ssl-key=%s" % client_key]
|
||||
if charset:
|
||||
args += ['--default-character-set=%s' % charset]
|
||||
if database:
|
||||
args += [database]
|
||||
args.extend(parameters)
|
||||
return args, env
|
||||
71
venv/Lib/site-packages/django/db/backends/mysql/compiler.py
Normal file
71
venv/Lib/site-packages/django/db/backends/mysql/compiler.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.expressions import Col
|
||||
from django.db.models.sql import compiler
|
||||
|
||||
|
||||
class SQLCompiler(compiler.SQLCompiler):
|
||||
def as_subquery_condition(self, alias, columns, compiler):
|
||||
qn = compiler.quote_name_unless_alias
|
||||
qn2 = self.connection.ops.quote_name
|
||||
sql, params = self.as_sql()
|
||||
return '(%s) IN (%s)' % (', '.join('%s.%s' % (qn(alias), qn2(column)) for column in columns), sql), params
|
||||
|
||||
|
||||
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
|
||||
def as_sql(self):
|
||||
# Prefer the non-standard DELETE FROM syntax over the SQL generated by
|
||||
# the SQLDeleteCompiler's default implementation when multiple tables
|
||||
# are involved since MySQL/MariaDB will generate a more efficient query
|
||||
# plan than when using a subquery.
|
||||
where, having = self.query.where.split_having()
|
||||
if self.single_alias or having:
|
||||
# DELETE FROM cannot be used when filtering against aggregates
|
||||
# since it doesn't allow for GROUP BY and HAVING clauses.
|
||||
return super().as_sql()
|
||||
result = [
|
||||
'DELETE %s FROM' % self.quote_name_unless_alias(
|
||||
self.query.get_initial_alias()
|
||||
)
|
||||
]
|
||||
from_sql, from_params = self.get_from_clause()
|
||||
result.extend(from_sql)
|
||||
where_sql, where_params = self.compile(where)
|
||||
if where_sql:
|
||||
result.append('WHERE %s' % where_sql)
|
||||
return ' '.join(result), tuple(from_params) + tuple(where_params)
|
||||
|
||||
|
||||
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
|
||||
def as_sql(self):
|
||||
update_query, update_params = super().as_sql()
|
||||
# MySQL and MariaDB support UPDATE ... ORDER BY syntax.
|
||||
if self.query.order_by:
|
||||
order_by_sql = []
|
||||
order_by_params = []
|
||||
db_table = self.query.get_meta().db_table
|
||||
try:
|
||||
for resolved, (sql, params, _) in self.get_order_by():
|
||||
if (
|
||||
isinstance(resolved.expression, Col) and
|
||||
resolved.expression.alias != db_table
|
||||
):
|
||||
# Ignore ordering if it contains joined fields, because
|
||||
# they cannot be used in the ORDER BY clause.
|
||||
raise FieldError
|
||||
order_by_sql.append(sql)
|
||||
order_by_params.extend(params)
|
||||
update_query += ' ORDER BY ' + ', '.join(order_by_sql)
|
||||
update_params += tuple(order_by_params)
|
||||
except FieldError:
|
||||
# Ignore ordering if it contains annotations, because they're
|
||||
# removed in .update() and cannot be resolved.
|
||||
pass
|
||||
return update_query, update_params
|
||||
|
||||
|
||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
|
||||
pass
|
||||
68
venv/Lib/site-packages/django/db/backends/mysql/creation.py
Normal file
68
venv/Lib/site-packages/django/db/backends/mysql/creation.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
|
||||
from .client import DatabaseClient
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
|
||||
def sql_table_creation_suffix(self):
|
||||
suffix = []
|
||||
test_settings = self.connection.settings_dict['TEST']
|
||||
if test_settings['CHARSET']:
|
||||
suffix.append('CHARACTER SET %s' % test_settings['CHARSET'])
|
||||
if test_settings['COLLATION']:
|
||||
suffix.append('COLLATE %s' % test_settings['COLLATION'])
|
||||
return ' '.join(suffix)
|
||||
|
||||
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
|
||||
try:
|
||||
super()._execute_create_test_db(cursor, parameters, keepdb)
|
||||
except Exception as e:
|
||||
if len(e.args) < 1 or e.args[0] != 1007:
|
||||
# All errors except "database exists" (1007) cancel tests.
|
||||
self.log('Got an error creating the test database: %s' % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
raise
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
source_database_name = self.connection.settings_dict['NAME']
|
||||
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
|
||||
test_db_params = {
|
||||
'dbname': self.connection.ops.quote_name(target_database_name),
|
||||
'suffix': self.sql_table_creation_suffix(),
|
||||
}
|
||||
with self._nodb_cursor() as cursor:
|
||||
try:
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception:
|
||||
if keepdb:
|
||||
# If the database should be kept, skip everything else.
|
||||
return
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log('Destroying old test database for alias %s...' % (
|
||||
self._get_database_display_str(verbosity, target_database_name),
|
||||
))
|
||||
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
self.log('Got an error recreating the test database: %s' % e)
|
||||
sys.exit(2)
|
||||
self._clone_db(source_database_name, target_database_name)
|
||||
|
||||
def _clone_db(self, source_database_name, target_database_name):
|
||||
cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(self.connection.settings_dict, [])
|
||||
dump_cmd = ['mysqldump', *cmd_args[1:-1], '--routines', '--events', source_database_name]
|
||||
dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None
|
||||
load_cmd = cmd_args
|
||||
load_cmd[-1] = target_database_name
|
||||
|
||||
with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE, env=dump_env) as dump_proc:
|
||||
with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL, env=load_env):
|
||||
# Allow dump_proc to receive a SIGPIPE if the load process exits.
|
||||
dump_proc.stdout.close()
|
||||
268
venv/Lib/site-packages/django/db/backends/mysql/features.py
Normal file
268
venv/Lib/site-packages/django/db/backends/mysql/features.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import operator
|
||||
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
empty_fetchmany_value = ()
|
||||
allows_group_by_pk = True
|
||||
related_fields_match_type = True
|
||||
# MySQL doesn't support sliced subqueries with IN/ALL/ANY/SOME.
|
||||
allow_sliced_subqueries_with_in = False
|
||||
has_select_for_update = True
|
||||
supports_forward_references = False
|
||||
supports_regex_backreferencing = False
|
||||
supports_date_lookup_using_string = False
|
||||
supports_timezones = False
|
||||
requires_explicit_null_ordering_when_grouping = True
|
||||
can_release_savepoints = True
|
||||
atomic_transactions = False
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
supports_select_intersection = False
|
||||
supports_select_difference = False
|
||||
supports_slicing_ordering_in_compound = True
|
||||
supports_index_on_text_field = False
|
||||
has_case_insensitive_like = False
|
||||
create_test_procedure_without_params_sql = """
|
||||
CREATE PROCEDURE test_procedure ()
|
||||
BEGIN
|
||||
DECLARE V_I INTEGER;
|
||||
SET V_I = 1;
|
||||
END;
|
||||
"""
|
||||
create_test_procedure_with_int_param_sql = """
|
||||
CREATE PROCEDURE test_procedure (P_I INTEGER)
|
||||
BEGIN
|
||||
DECLARE V_I INTEGER;
|
||||
SET V_I = P_I;
|
||||
END;
|
||||
"""
|
||||
# Neither MySQL nor MariaDB support partial indexes.
|
||||
supports_partial_indexes = False
|
||||
# COLLATE must be wrapped in parentheses because MySQL treats COLLATE as an
|
||||
# indexed expression.
|
||||
collate_as_index_expression = True
|
||||
|
||||
supports_order_by_nulls_modifier = False
|
||||
order_by_nulls_first = True
|
||||
|
||||
@cached_property
|
||||
def test_collations(self):
|
||||
charset = 'utf8'
|
||||
if self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 6):
|
||||
# utf8 is an alias for utf8mb3 in MariaDB 10.6+.
|
||||
charset = 'utf8mb3'
|
||||
return {
|
||||
'ci': f'{charset}_general_ci',
|
||||
'non_default': f'{charset}_esperanto_ci',
|
||||
'swedish_ci': f'{charset}_swedish_ci',
|
||||
}
|
||||
|
||||
test_now_utc_template = 'UTC_TIMESTAMP'
|
||||
|
||||
@cached_property
|
||||
def django_test_skips(self):
|
||||
skips = {
|
||||
"This doesn't work on MySQL.": {
|
||||
'db_functions.comparison.test_greatest.GreatestTests.test_coalesce_workaround',
|
||||
'db_functions.comparison.test_least.LeastTests.test_coalesce_workaround',
|
||||
},
|
||||
'Running on MySQL requires utf8mb4 encoding (#18392).': {
|
||||
'model_fields.test_textfield.TextFieldTests.test_emoji',
|
||||
'model_fields.test_charfield.TestCharField.test_emoji',
|
||||
},
|
||||
"MySQL doesn't support functional indexes on a function that "
|
||||
"returns JSON": {
|
||||
'schema.tests.SchemaTests.test_func_index_json_key_transform',
|
||||
},
|
||||
"MySQL supports multiplying and dividing DurationFields by a "
|
||||
"scalar value but it's not implemented (#25287).": {
|
||||
'expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide',
|
||||
},
|
||||
}
|
||||
if 'ONLY_FULL_GROUP_BY' in self.connection.sql_mode:
|
||||
skips.update({
|
||||
'GROUP BY optimization does not work properly when '
|
||||
'ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #31331.': {
|
||||
'aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation_multivalued',
|
||||
'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o',
|
||||
},
|
||||
})
|
||||
if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (8,):
|
||||
skips.update({
|
||||
'Casting to datetime/time is not supported by MySQL < 8.0. (#30224)': {
|
||||
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python',
|
||||
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python',
|
||||
},
|
||||
'MySQL < 8.0 returns string type instead of datetime/time. (#30224)': {
|
||||
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database',
|
||||
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database',
|
||||
},
|
||||
})
|
||||
if (
|
||||
self.connection.mysql_is_mariadb and
|
||||
(10, 4, 3) < self.connection.mysql_version < (10, 5, 2)
|
||||
):
|
||||
skips.update({
|
||||
'https://jira.mariadb.org/browse/MDEV-19598': {
|
||||
'schema.tests.SchemaTests.test_alter_not_unique_field_to_primary_key',
|
||||
},
|
||||
})
|
||||
if (
|
||||
self.connection.mysql_is_mariadb and
|
||||
(10, 4, 12) < self.connection.mysql_version < (10, 5)
|
||||
):
|
||||
skips.update({
|
||||
'https://jira.mariadb.org/browse/MDEV-22775': {
|
||||
'schema.tests.SchemaTests.test_alter_pk_with_self_referential_field',
|
||||
},
|
||||
})
|
||||
if not self.supports_explain_analyze:
|
||||
skips.update({
|
||||
'MariaDB and MySQL >= 8.0.18 specific.': {
|
||||
'queries.test_explain.ExplainTests.test_mysql_analyze',
|
||||
},
|
||||
})
|
||||
return skips
|
||||
|
||||
@cached_property
|
||||
def _mysql_storage_engine(self):
|
||||
"Internal method used in Django tests. Don't rely on this from your code"
|
||||
return self.connection.mysql_server_data['default_storage_engine']
|
||||
|
||||
@cached_property
|
||||
def allows_auto_pk_0(self):
|
||||
"""
|
||||
Autoincrement primary key can be set to 0 if it doesn't generate new
|
||||
autoincrement values.
|
||||
"""
|
||||
return 'NO_AUTO_VALUE_ON_ZERO' in self.connection.sql_mode
|
||||
|
||||
@cached_property
|
||||
def update_can_self_select(self):
|
||||
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 3, 2)
|
||||
|
||||
@cached_property
|
||||
def can_introspect_foreign_keys(self):
|
||||
"Confirm support for introspected foreign keys"
|
||||
return self._mysql_storage_engine != 'MyISAM'
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
'BinaryField': 'TextField',
|
||||
'BooleanField': 'IntegerField',
|
||||
'DurationField': 'BigIntegerField',
|
||||
'GenericIPAddressField': 'CharField',
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def can_return_columns_from_insert(self):
|
||||
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 5, 0)
|
||||
|
||||
can_return_rows_from_bulk_insert = property(operator.attrgetter('can_return_columns_from_insert'))
|
||||
|
||||
@cached_property
|
||||
def has_zoneinfo_database(self):
|
||||
return self.connection.mysql_server_data['has_zoneinfo_database']
|
||||
|
||||
@cached_property
|
||||
def is_sql_auto_is_null_enabled(self):
|
||||
return self.connection.mysql_server_data['sql_auto_is_null']
|
||||
|
||||
@cached_property
|
||||
def supports_over_clause(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return True
|
||||
return self.connection.mysql_version >= (8, 0, 2)
|
||||
|
||||
supports_frame_range_fixed_distance = property(operator.attrgetter('supports_over_clause'))
|
||||
|
||||
@cached_property
|
||||
def supports_column_check_constraints(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.connection.mysql_version >= (10, 2, 1)
|
||||
return self.connection.mysql_version >= (8, 0, 16)
|
||||
|
||||
supports_table_check_constraints = property(operator.attrgetter('supports_column_check_constraints'))
|
||||
|
||||
@cached_property
|
||||
def can_introspect_check_constraints(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
version = self.connection.mysql_version
|
||||
return (version >= (10, 2, 22) and version < (10, 3)) or version >= (10, 3, 10)
|
||||
return self.connection.mysql_version >= (8, 0, 16)
|
||||
|
||||
@cached_property
|
||||
def has_select_for_update_skip_locked(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.connection.mysql_version >= (10, 6)
|
||||
return self.connection.mysql_version >= (8, 0, 1)
|
||||
|
||||
@cached_property
|
||||
def has_select_for_update_nowait(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.connection.mysql_version >= (10, 3, 0)
|
||||
return self.connection.mysql_version >= (8, 0, 1)
|
||||
|
||||
@cached_property
|
||||
def has_select_for_update_of(self):
|
||||
return not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 1)
|
||||
|
||||
@cached_property
|
||||
def supports_explain_analyze(self):
|
||||
return self.connection.mysql_is_mariadb or self.connection.mysql_version >= (8, 0, 18)
|
||||
|
||||
@cached_property
|
||||
def supported_explain_formats(self):
|
||||
# Alias MySQL's TRADITIONAL to TEXT for consistency with other
|
||||
# backends.
|
||||
formats = {'JSON', 'TEXT', 'TRADITIONAL'}
|
||||
if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 16):
|
||||
formats.add('TREE')
|
||||
return formats
|
||||
|
||||
@cached_property
|
||||
def supports_transactions(self):
|
||||
"""
|
||||
All storage engines except MyISAM support transactions.
|
||||
"""
|
||||
return self._mysql_storage_engine != 'MyISAM'
|
||||
|
||||
@cached_property
|
||||
def ignores_table_name_case(self):
|
||||
return self.connection.mysql_server_data['lower_case_table_names']
|
||||
|
||||
@cached_property
|
||||
def supports_default_in_lead_lag(self):
|
||||
# To be added in https://jira.mariadb.org/browse/MDEV-12981.
|
||||
return not self.connection.mysql_is_mariadb
|
||||
|
||||
@cached_property
|
||||
def supports_json_field(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.connection.mysql_version >= (10, 2, 7)
|
||||
return self.connection.mysql_version >= (5, 7, 8)
|
||||
|
||||
@cached_property
|
||||
def can_introspect_json_field(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.supports_json_field and self.can_introspect_check_constraints
|
||||
return self.supports_json_field
|
||||
|
||||
@cached_property
|
||||
def supports_index_column_ordering(self):
|
||||
return (
|
||||
not self.connection.mysql_is_mariadb and
|
||||
self.connection.mysql_version >= (8, 0, 1)
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def supports_expression_indexes(self):
|
||||
return (
|
||||
not self.connection.mysql_is_mariadb and
|
||||
self.connection.mysql_version >= (8, 0, 13)
|
||||
)
|
||||
309
venv/Lib/site-packages/django/db/backends/mysql/introspection.py
Normal file
309
venv/Lib/site-packages/django/db/backends/mysql/introspection.py
Normal file
@@ -0,0 +1,309 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import sqlparse
|
||||
from MySQLdb.constants import FIELD_TYPE
|
||||
|
||||
from django.db.backends.base.introspection import (
|
||||
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
|
||||
)
|
||||
from django.db.models import Index
|
||||
from django.utils.datastructures import OrderedSet
|
||||
|
||||
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned', 'has_json_constraint'))
|
||||
InfoLine = namedtuple(
|
||||
'InfoLine',
|
||||
'col_name data_type max_len num_prec num_scale extra column_default '
|
||||
'collation is_unsigned'
|
||||
)
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
data_types_reverse = {
|
||||
FIELD_TYPE.BLOB: 'TextField',
|
||||
FIELD_TYPE.CHAR: 'CharField',
|
||||
FIELD_TYPE.DECIMAL: 'DecimalField',
|
||||
FIELD_TYPE.NEWDECIMAL: 'DecimalField',
|
||||
FIELD_TYPE.DATE: 'DateField',
|
||||
FIELD_TYPE.DATETIME: 'DateTimeField',
|
||||
FIELD_TYPE.DOUBLE: 'FloatField',
|
||||
FIELD_TYPE.FLOAT: 'FloatField',
|
||||
FIELD_TYPE.INT24: 'IntegerField',
|
||||
FIELD_TYPE.JSON: 'JSONField',
|
||||
FIELD_TYPE.LONG: 'IntegerField',
|
||||
FIELD_TYPE.LONGLONG: 'BigIntegerField',
|
||||
FIELD_TYPE.SHORT: 'SmallIntegerField',
|
||||
FIELD_TYPE.STRING: 'CharField',
|
||||
FIELD_TYPE.TIME: 'TimeField',
|
||||
FIELD_TYPE.TIMESTAMP: 'DateTimeField',
|
||||
FIELD_TYPE.TINY: 'IntegerField',
|
||||
FIELD_TYPE.TINY_BLOB: 'TextField',
|
||||
FIELD_TYPE.MEDIUM_BLOB: 'TextField',
|
||||
FIELD_TYPE.LONG_BLOB: 'TextField',
|
||||
FIELD_TYPE.VAR_STRING: 'CharField',
|
||||
}
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
field_type = super().get_field_type(data_type, description)
|
||||
if 'auto_increment' in description.extra:
|
||||
if field_type == 'IntegerField':
|
||||
return 'AutoField'
|
||||
elif field_type == 'BigIntegerField':
|
||||
return 'BigAutoField'
|
||||
elif field_type == 'SmallIntegerField':
|
||||
return 'SmallAutoField'
|
||||
if description.is_unsigned:
|
||||
if field_type == 'BigIntegerField':
|
||||
return 'PositiveBigIntegerField'
|
||||
elif field_type == 'IntegerField':
|
||||
return 'PositiveIntegerField'
|
||||
elif field_type == 'SmallIntegerField':
|
||||
return 'PositiveSmallIntegerField'
|
||||
# JSON data type is an alias for LONGTEXT in MariaDB, use check
|
||||
# constraints clauses to introspect JSONField.
|
||||
if description.has_json_constraint:
|
||||
return 'JSONField'
|
||||
return field_type
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
cursor.execute("SHOW FULL TABLES")
|
||||
return [TableInfo(row[0], {'BASE TABLE': 't', 'VIEW': 'v'}.get(row[1]))
|
||||
for row in cursor.fetchall()]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface."
|
||||
"""
|
||||
json_constraints = {}
|
||||
if self.connection.mysql_is_mariadb and self.connection.features.can_introspect_json_field:
|
||||
# JSON data type is an alias for LONGTEXT in MariaDB, select
|
||||
# JSON_VALID() constraints to introspect JSONField.
|
||||
cursor.execute("""
|
||||
SELECT c.constraint_name AS column_name
|
||||
FROM information_schema.check_constraints AS c
|
||||
WHERE
|
||||
c.table_name = %s AND
|
||||
LOWER(c.check_clause) = 'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
|
||||
c.constraint_schema = DATABASE()
|
||||
""", [table_name])
|
||||
json_constraints = {row[0] for row in cursor.fetchall()}
|
||||
# A default collation for the given table.
|
||||
cursor.execute("""
|
||||
SELECT table_collation
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
""", [table_name])
|
||||
row = cursor.fetchone()
|
||||
default_column_collation = row[0] if row else ''
|
||||
# information_schema database gives more accurate results for some figures:
|
||||
# - varchar length returned by cursor.description is an internal length,
|
||||
# not visible length (#5725)
|
||||
# - precision and scale (for decimal fields) (#5014)
|
||||
# - auto_increment is not available in cursor.description
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
column_name, data_type, character_maximum_length,
|
||||
numeric_precision, numeric_scale, extra, column_default,
|
||||
CASE
|
||||
WHEN collation_name = %s THEN NULL
|
||||
ELSE collation_name
|
||||
END AS collation_name,
|
||||
CASE
|
||||
WHEN column_type LIKE '%% unsigned' THEN 1
|
||||
ELSE 0
|
||||
END AS is_unsigned
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = %s AND table_schema = DATABASE()
|
||||
""", [default_column_collation, table_name])
|
||||
field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
|
||||
|
||||
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
|
||||
|
||||
def to_int(i):
|
||||
return int(i) if i is not None else i
|
||||
|
||||
fields = []
|
||||
for line in cursor.description:
|
||||
info = field_info[line[0]]
|
||||
fields.append(FieldInfo(
|
||||
*line[:3],
|
||||
to_int(info.max_len) or line[3],
|
||||
to_int(info.num_prec) or line[4],
|
||||
to_int(info.num_scale) or line[5],
|
||||
line[6],
|
||||
info.column_default,
|
||||
info.collation,
|
||||
info.extra,
|
||||
info.is_unsigned,
|
||||
line[0] in json_constraints,
|
||||
))
|
||||
return fields
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
for field_info in self.get_table_description(cursor, table_name):
|
||||
if 'auto_increment' in field_info.extra:
|
||||
# MySQL allows only one auto-increment column per table.
|
||||
return [{'table': table_name, 'column': field_info.name}]
|
||||
return []
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all relationships to the given table.
|
||||
"""
|
||||
constraints = self.get_key_columns(cursor, table_name)
|
||||
relations = {}
|
||||
for my_fieldname, other_table, other_field in constraints:
|
||||
relations[my_fieldname] = (other_field, other_table)
|
||||
return relations
|
||||
|
||||
def get_key_columns(self, cursor, table_name):
|
||||
"""
|
||||
Return a list of (column_name, referenced_table_name, referenced_column_name)
|
||||
for all key columns in the given table.
|
||||
"""
|
||||
key_columns = []
|
||||
cursor.execute("""
|
||||
SELECT column_name, referenced_table_name, referenced_column_name
|
||||
FROM information_schema.key_column_usage
|
||||
WHERE table_name = %s
|
||||
AND table_schema = DATABASE()
|
||||
AND referenced_table_name IS NOT NULL
|
||||
AND referenced_column_name IS NOT NULL""", [table_name])
|
||||
key_columns.extend(cursor.fetchall())
|
||||
return key_columns
|
||||
|
||||
def get_storage_engine(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve the storage engine for a given table. Return the default
|
||||
storage engine if the table doesn't exist.
|
||||
"""
|
||||
cursor.execute("""
|
||||
SELECT engine
|
||||
FROM information_schema.tables
|
||||
WHERE
|
||||
table_name = %s AND
|
||||
table_schema = DATABASE()
|
||||
""", [table_name])
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
return self.connection.features._mysql_storage_engine
|
||||
return result[0]
|
||||
|
||||
def _parse_constraint_columns(self, check_clause, columns):
|
||||
check_columns = OrderedSet()
|
||||
statement = sqlparse.parse(check_clause)[0]
|
||||
tokens = (token for token in statement.flatten() if not token.is_whitespace)
|
||||
for token in tokens:
|
||||
if (
|
||||
token.ttype == sqlparse.tokens.Name and
|
||||
self.connection.ops.quote_name(token.value) == token.value and
|
||||
token.value[1:-1] in columns
|
||||
):
|
||||
check_columns.add(token.value[1:-1])
|
||||
return check_columns
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns.
|
||||
"""
|
||||
constraints = {}
|
||||
# Get the actual constraint names and columns
|
||||
name_query = """
|
||||
SELECT kc.`constraint_name`, kc.`column_name`,
|
||||
kc.`referenced_table_name`, kc.`referenced_column_name`,
|
||||
c.`constraint_type`
|
||||
FROM
|
||||
information_schema.key_column_usage AS kc,
|
||||
information_schema.table_constraints AS c
|
||||
WHERE
|
||||
kc.table_schema = DATABASE() AND
|
||||
c.table_schema = kc.table_schema AND
|
||||
c.constraint_name = kc.constraint_name AND
|
||||
c.constraint_type != 'CHECK' AND
|
||||
kc.table_name = %s
|
||||
ORDER BY kc.`ordinal_position`
|
||||
"""
|
||||
cursor.execute(name_query, [table_name])
|
||||
for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
|
||||
if constraint not in constraints:
|
||||
constraints[constraint] = {
|
||||
'columns': OrderedSet(),
|
||||
'primary_key': kind == 'PRIMARY KEY',
|
||||
'unique': kind in {'PRIMARY KEY', 'UNIQUE'},
|
||||
'index': False,
|
||||
'check': False,
|
||||
'foreign_key': (ref_table, ref_column) if ref_column else None,
|
||||
}
|
||||
if self.connection.features.supports_index_column_ordering:
|
||||
constraints[constraint]['orders'] = []
|
||||
constraints[constraint]['columns'].add(column)
|
||||
# Add check constraints.
|
||||
if self.connection.features.can_introspect_check_constraints:
|
||||
unnamed_constraints_index = 0
|
||||
columns = {info.name for info in self.get_table_description(cursor, table_name)}
|
||||
if self.connection.mysql_is_mariadb:
|
||||
type_query = """
|
||||
SELECT c.constraint_name, c.check_clause
|
||||
FROM information_schema.check_constraints AS c
|
||||
WHERE
|
||||
c.constraint_schema = DATABASE() AND
|
||||
c.table_name = %s
|
||||
"""
|
||||
else:
|
||||
type_query = """
|
||||
SELECT cc.constraint_name, cc.check_clause
|
||||
FROM
|
||||
information_schema.check_constraints AS cc,
|
||||
information_schema.table_constraints AS tc
|
||||
WHERE
|
||||
cc.constraint_schema = DATABASE() AND
|
||||
tc.table_schema = cc.constraint_schema AND
|
||||
cc.constraint_name = tc.constraint_name AND
|
||||
tc.constraint_type = 'CHECK' AND
|
||||
tc.table_name = %s
|
||||
"""
|
||||
cursor.execute(type_query, [table_name])
|
||||
for constraint, check_clause in cursor.fetchall():
|
||||
constraint_columns = self._parse_constraint_columns(check_clause, columns)
|
||||
# Ensure uniqueness of unnamed constraints. Unnamed unique
|
||||
# and check columns constraints have the same name as
|
||||
# a column.
|
||||
if set(constraint_columns) == {constraint}:
|
||||
unnamed_constraints_index += 1
|
||||
constraint = '__unnamed_constraint_%s__' % unnamed_constraints_index
|
||||
constraints[constraint] = {
|
||||
'columns': constraint_columns,
|
||||
'primary_key': False,
|
||||
'unique': False,
|
||||
'index': False,
|
||||
'check': True,
|
||||
'foreign_key': None,
|
||||
}
|
||||
# Now add in the indexes
|
||||
cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name))
|
||||
for table, non_unique, index, colseq, column, order, type_ in [
|
||||
x[:6] + (x[10],) for x in cursor.fetchall()
|
||||
]:
|
||||
if index not in constraints:
|
||||
constraints[index] = {
|
||||
'columns': OrderedSet(),
|
||||
'primary_key': False,
|
||||
'unique': not non_unique,
|
||||
'check': False,
|
||||
'foreign_key': None,
|
||||
}
|
||||
if self.connection.features.supports_index_column_ordering:
|
||||
constraints[index]['orders'] = []
|
||||
constraints[index]['index'] = True
|
||||
constraints[index]['type'] = Index.suffix if type_ == 'BTREE' else type_.lower()
|
||||
constraints[index]['columns'].add(column)
|
||||
if self.connection.features.supports_index_column_ordering:
|
||||
constraints[index]['orders'].append('DESC' if order == 'D' else 'ASC')
|
||||
# Convert the sorted sets to lists
|
||||
for constraint in constraints.values():
|
||||
constraint['columns'] = list(constraint['columns'])
|
||||
return constraints
|
||||
378
venv/Lib/site-packages/django/db/backends/mysql/operations.py
Normal file
378
venv/Lib/site-packages/django/db/backends/mysql/operations.py
Normal file
@@ -0,0 +1,378 @@
|
||||
import uuid
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.utils import split_tzname_delta
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
compiler_module = "django.db.backends.mysql.compiler"
|
||||
|
||||
# MySQL stores positive fields as UNSIGNED ints.
|
||||
integer_field_ranges = {
|
||||
**BaseDatabaseOperations.integer_field_ranges,
|
||||
'PositiveSmallIntegerField': (0, 65535),
|
||||
'PositiveIntegerField': (0, 4294967295),
|
||||
'PositiveBigIntegerField': (0, 18446744073709551615),
|
||||
}
|
||||
cast_data_types = {
|
||||
'AutoField': 'signed integer',
|
||||
'BigAutoField': 'signed integer',
|
||||
'SmallAutoField': 'signed integer',
|
||||
'CharField': 'char(%(max_length)s)',
|
||||
'DecimalField': 'decimal(%(max_digits)s, %(decimal_places)s)',
|
||||
'TextField': 'char',
|
||||
'IntegerField': 'signed integer',
|
||||
'BigIntegerField': 'signed integer',
|
||||
'SmallIntegerField': 'signed integer',
|
||||
'PositiveBigIntegerField': 'unsigned integer',
|
||||
'PositiveIntegerField': 'unsigned integer',
|
||||
'PositiveSmallIntegerField': 'unsigned integer',
|
||||
'DurationField': 'signed integer',
|
||||
}
|
||||
cast_char_field_without_max_length = 'char'
|
||||
explain_prefix = 'EXPLAIN'
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
|
||||
if lookup_type == 'week_day':
|
||||
# DAYOFWEEK() returns an integer, 1-7, Sunday=1.
|
||||
return "DAYOFWEEK(%s)" % field_name
|
||||
elif lookup_type == 'iso_week_day':
|
||||
# WEEKDAY() returns an integer, 0-6, Monday=0.
|
||||
return "WEEKDAY(%s) + 1" % field_name
|
||||
elif lookup_type == 'week':
|
||||
# Override the value of default_week_format for consistency with
|
||||
# other database backends.
|
||||
# Mode 3: Monday, 1-53, with 4 or more days this year.
|
||||
return "WEEK(%s, 3)" % field_name
|
||||
elif lookup_type == 'iso_year':
|
||||
# Get the year part from the YEARWEEK function, which returns a
|
||||
# number as year * 100 + week.
|
||||
return "TRUNCATE(YEARWEEK(%s, 3), -2) / 100" % field_name
|
||||
else:
|
||||
# EXTRACT returns 1-53 based on ISO-8601 for the week number.
|
||||
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
fields = {
|
||||
'year': '%%Y-01-01',
|
||||
'month': '%%Y-%%m-01',
|
||||
} # Use double percents to escape.
|
||||
if lookup_type in fields:
|
||||
format_str = fields[lookup_type]
|
||||
return "CAST(DATE_FORMAT(%s, '%s') AS DATE)" % (field_name, format_str)
|
||||
elif lookup_type == 'quarter':
|
||||
return "MAKEDATE(YEAR(%s), 1) + INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER" % (
|
||||
field_name, field_name
|
||||
)
|
||||
elif lookup_type == 'week':
|
||||
return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (
|
||||
field_name, field_name
|
||||
)
|
||||
else:
|
||||
return "DATE(%s)" % (field_name)
|
||||
|
||||
def _prepare_tzname_delta(self, tzname):
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
return f'{sign}{offset}' if offset else tzname
|
||||
|
||||
def _convert_field_to_tz(self, field_name, tzname):
|
||||
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
|
||||
field_name = "CONVERT_TZ(%s, '%s', '%s')" % (
|
||||
field_name,
|
||||
self.connection.timezone_name,
|
||||
self._prepare_tzname_delta(tzname),
|
||||
)
|
||||
return field_name
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return "DATE(%s)" % field_name
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return "TIME(%s)" % field_name
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return self.date_extract_sql(lookup_type, field_name)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
fields = ['year', 'month', 'day', 'hour', 'minute', 'second']
|
||||
format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape.
|
||||
format_def = ('0000-', '01', '-01', ' 00:', '00', ':00')
|
||||
if lookup_type == 'quarter':
|
||||
return (
|
||||
"CAST(DATE_FORMAT(MAKEDATE(YEAR({field_name}), 1) + "
|
||||
"INTERVAL QUARTER({field_name}) QUARTER - " +
|
||||
"INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)"
|
||||
).format(field_name=field_name)
|
||||
if lookup_type == 'week':
|
||||
return (
|
||||
"CAST(DATE_FORMAT(DATE_SUB({field_name}, "
|
||||
"INTERVAL WEEKDAY({field_name}) DAY), "
|
||||
"'%%Y-%%m-%%d 00:00:00') AS DATETIME)"
|
||||
).format(field_name=field_name)
|
||||
try:
|
||||
i = fields.index(lookup_type) + 1
|
||||
except ValueError:
|
||||
sql = field_name
|
||||
else:
|
||||
format_str = ''.join(format[:i] + format_def[i:])
|
||||
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
|
||||
return sql
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
fields = {
|
||||
'hour': '%%H:00:00',
|
||||
'minute': '%%H:%%i:00',
|
||||
'second': '%%H:%%i:%%s',
|
||||
} # Use double percents to escape.
|
||||
if lookup_type in fields:
|
||||
format_str = fields[lookup_type]
|
||||
return "CAST(DATE_FORMAT(%s, '%s') AS TIME)" % (field_name, format_str)
|
||||
else:
|
||||
return "TIME(%s)" % (field_name)
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the tuple of returned data.
|
||||
"""
|
||||
return cursor.fetchall()
|
||||
|
||||
def format_for_duration_arithmetic(self, sql):
|
||||
return 'INTERVAL %s MICROSECOND' % sql
|
||||
|
||||
def force_no_ordering(self):
|
||||
"""
|
||||
"ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
|
||||
columns. If no ordering would otherwise be applied, we don't want any
|
||||
implicit sorting going on.
|
||||
"""
|
||||
return [(None, ("NULL", [], False))]
|
||||
|
||||
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
|
||||
return value
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# With MySQLdb, cursor objects have an (undocumented) "_executed"
|
||||
# attribute where the exact query sent to the database is saved.
|
||||
# See MySQLdb/cursors.py in the source distribution.
|
||||
# MySQLdb returns string, PyMySQL bytes.
|
||||
return force_str(getattr(cursor, '_executed', None), errors='replace')
|
||||
|
||||
def no_limit_value(self):
|
||||
# 2**64 - 1, as recommended by the MySQL documentation
|
||||
return 18446744073709551615
|
||||
|
||||
def quote_name(self, name):
|
||||
if name.startswith("`") and name.endswith("`"):
|
||||
return name # Quoting once is enough.
|
||||
return "`%s`" % name
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
# MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
|
||||
# statement.
|
||||
if not fields:
|
||||
return '', ()
|
||||
columns = [
|
||||
'%s.%s' % (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
) for field in fields
|
||||
]
|
||||
return 'RETURNING %s' % ', '.join(columns), ()
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if not tables:
|
||||
return []
|
||||
|
||||
sql = ['SET FOREIGN_KEY_CHECKS = 0;']
|
||||
if reset_sequences:
|
||||
# It's faster to TRUNCATE tables that require a sequence reset
|
||||
# since ALTER TABLE AUTO_INCREMENT is slower than TRUNCATE.
|
||||
sql.extend(
|
||||
'%s %s;' % (
|
||||
style.SQL_KEYWORD('TRUNCATE'),
|
||||
style.SQL_FIELD(self.quote_name(table_name)),
|
||||
) for table_name in tables
|
||||
)
|
||||
else:
|
||||
# Otherwise issue a simple DELETE since it's faster than TRUNCATE
|
||||
# and preserves sequences.
|
||||
sql.extend(
|
||||
'%s %s %s;' % (
|
||||
style.SQL_KEYWORD('DELETE'),
|
||||
style.SQL_KEYWORD('FROM'),
|
||||
style.SQL_FIELD(self.quote_name(table_name)),
|
||||
) for table_name in tables
|
||||
)
|
||||
sql.append('SET FOREIGN_KEY_CHECKS = 1;')
|
||||
return sql
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
return [
|
||||
'%s %s %s %s = 1;' % (
|
||||
style.SQL_KEYWORD('ALTER'),
|
||||
style.SQL_KEYWORD('TABLE'),
|
||||
style.SQL_FIELD(self.quote_name(sequence_info['table'])),
|
||||
style.SQL_FIELD('AUTO_INCREMENT'),
|
||||
) for sequence_info in sequences
|
||||
]
|
||||
|
||||
def validate_autopk_value(self, value):
|
||||
# Zero in AUTO_INCREMENT field does not work without the
|
||||
# NO_AUTO_VALUE_ON_ZERO SQL mode.
|
||||
if value == 0 and not self.connection.features.allows_auto_pk_0:
|
||||
raise ValueError('The database backend does not accept 0 as a '
|
||||
'value for AutoField.')
|
||||
return value
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# MySQL doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
value = timezone.make_naive(value, self.connection.timezone)
|
||||
else:
|
||||
raise ValueError("MySQL backend does not support timezone-aware datetimes when USE_TZ is False.")
|
||||
return str(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# MySQL doesn't support tz-aware times
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("MySQL backend does not support timezone-aware times.")
|
||||
|
||||
return value.isoformat(timespec='microseconds')
|
||||
|
||||
def max_name_length(self):
|
||||
return 64
|
||||
|
||||
def pk_default_value(self):
|
||||
return 'NULL'
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
|
||||
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
|
||||
return "VALUES " + values_sql
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
if connector == '^':
|
||||
return 'POW(%s)' % ','.join(sub_expressions)
|
||||
# Convert the result to a signed integer since MySQL's binary operators
|
||||
# return an unsigned integer.
|
||||
elif connector in ('&', '|', '<<', '#'):
|
||||
connector = '^' if connector == '#' else connector
|
||||
return 'CONVERT(%s, SIGNED)' % connector.join(sub_expressions)
|
||||
elif connector == '>>':
|
||||
lhs, rhs = sub_expressions
|
||||
return 'FLOOR(%(lhs)s / POW(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type == 'BooleanField':
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
elif internal_type == 'DateTimeField':
|
||||
if settings.USE_TZ:
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == 'UUIDField':
|
||||
converters.append(self.convert_uuidfield_value)
|
||||
return converters
|
||||
|
||||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
if value in (0, 1):
|
||||
value = bool(value)
|
||||
return value
|
||||
|
||||
def convert_datetimefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = timezone.make_aware(value, self.connection.timezone)
|
||||
return value
|
||||
|
||||
def convert_uuidfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
def binary_placeholder_sql(self, value):
|
||||
return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s'
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
if internal_type == 'TimeField':
|
||||
if self.connection.mysql_is_mariadb:
|
||||
# MariaDB includes the microsecond component in TIME_TO_SEC as
|
||||
# a decimal. MySQL returns an integer without microseconds.
|
||||
return 'CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) * 1000000 AS SIGNED)' % {
|
||||
'lhs': lhs_sql, 'rhs': rhs_sql
|
||||
}, (*lhs_params, *rhs_params)
|
||||
return (
|
||||
"((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -"
|
||||
" (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))"
|
||||
) % {'lhs': lhs_sql, 'rhs': rhs_sql}, tuple(lhs_params) * 2 + tuple(rhs_params) * 2
|
||||
params = (*rhs_params, *lhs_params)
|
||||
return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params
|
||||
|
||||
def explain_query_prefix(self, format=None, **options):
|
||||
# Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
|
||||
if format and format.upper() == 'TEXT':
|
||||
format = 'TRADITIONAL'
|
||||
elif not format and 'TREE' in self.connection.features.supported_explain_formats:
|
||||
# Use TREE by default (if supported) as it's more informative.
|
||||
format = 'TREE'
|
||||
analyze = options.pop('analyze', False)
|
||||
prefix = super().explain_query_prefix(format, **options)
|
||||
if analyze and self.connection.features.supports_explain_analyze:
|
||||
# MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
|
||||
prefix = 'ANALYZE' if self.connection.mysql_is_mariadb else prefix + ' ANALYZE'
|
||||
if format and not (analyze and not self.connection.mysql_is_mariadb):
|
||||
# Only MariaDB supports the analyze option with formats.
|
||||
prefix += ' FORMAT=%s' % format
|
||||
return prefix
|
||||
|
||||
def regex_lookup(self, lookup_type):
|
||||
# REGEXP BINARY doesn't work correctly in MySQL 8+ and REGEXP_LIKE
|
||||
# doesn't exist in MySQL 5.x or in MariaDB.
|
||||
if self.connection.mysql_version < (8, 0, 0) or self.connection.mysql_is_mariadb:
|
||||
if lookup_type == 'regex':
|
||||
return '%s REGEXP BINARY %s'
|
||||
return '%s REGEXP %s'
|
||||
|
||||
match_option = 'c' if lookup_type == 'regex' else 'i'
|
||||
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
|
||||
|
||||
def insert_statement(self, ignore_conflicts=False):
|
||||
return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
lookup = '%s'
|
||||
if internal_type == 'JSONField':
|
||||
if self.connection.mysql_is_mariadb or lookup_type in (
|
||||
'iexact', 'contains', 'icontains', 'startswith', 'istartswith',
|
||||
'endswith', 'iendswith', 'regex', 'iregex',
|
||||
):
|
||||
lookup = 'JSON_UNQUOTE(%s)'
|
||||
return lookup
|
||||
160
venv/Lib/site-packages/django/db/backends/mysql/schema.py
Normal file
160
venv/Lib/site-packages/django/db/backends/mysql/schema.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.models import NOT_PROVIDED
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
|
||||
sql_rename_table = "RENAME TABLE %(old_table)s TO %(new_table)s"
|
||||
|
||||
sql_alter_column_null = "MODIFY %(column)s %(type)s NULL"
|
||||
sql_alter_column_not_null = "MODIFY %(column)s %(type)s NOT NULL"
|
||||
sql_alter_column_type = "MODIFY %(column)s %(type)s"
|
||||
sql_alter_column_collate = "MODIFY %(column)s %(type)s%(collation)s"
|
||||
sql_alter_column_no_default_null = 'ALTER COLUMN %(column)s SET DEFAULT NULL'
|
||||
|
||||
# No 'CASCADE' which works as a no-op in MySQL but is undocumented
|
||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
|
||||
|
||||
sql_delete_unique = "ALTER TABLE %(table)s DROP INDEX %(name)s"
|
||||
sql_create_column_inline_fk = (
|
||||
', ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) '
|
||||
'REFERENCES %(to_table)s(%(to_column)s)'
|
||||
)
|
||||
sql_delete_fk = "ALTER TABLE %(table)s DROP FOREIGN KEY %(name)s"
|
||||
|
||||
sql_delete_index = "DROP INDEX %(name)s ON %(table)s"
|
||||
|
||||
sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
|
||||
sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY"
|
||||
|
||||
sql_create_index = 'CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s'
|
||||
|
||||
@property
|
||||
def sql_delete_check(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
# The name of the column check constraint is the same as the field
|
||||
# name on MariaDB. Adding IF EXISTS clause prevents migrations
|
||||
# crash. Constraint is removed during a "MODIFY" column statement.
|
||||
return 'ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s'
|
||||
return 'ALTER TABLE %(table)s DROP CHECK %(name)s'
|
||||
|
||||
@property
|
||||
def sql_rename_column(self):
|
||||
# MariaDB >= 10.5.2 and MySQL >= 8.0.4 support an
|
||||
# "ALTER TABLE ... RENAME COLUMN" statement.
|
||||
if self.connection.mysql_is_mariadb:
|
||||
if self.connection.mysql_version >= (10, 5, 2):
|
||||
return super().sql_rename_column
|
||||
elif self.connection.mysql_version >= (8, 0, 4):
|
||||
return super().sql_rename_column
|
||||
return 'ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s'
|
||||
|
||||
def quote_value(self, value):
|
||||
self.connection.ensure_connection()
|
||||
if isinstance(value, str):
|
||||
value = value.replace('%', '%%')
|
||||
# MySQLdb escapes to string, PyMySQL to bytes.
|
||||
quoted = self.connection.connection.escape(value, self.connection.connection.encoders)
|
||||
if isinstance(value, str) and isinstance(quoted, bytes):
|
||||
quoted = quoted.decode()
|
||||
return quoted
|
||||
|
||||
def _is_limited_data_type(self, field):
|
||||
db_type = field.db_type(self.connection)
|
||||
return db_type is not None and db_type.lower() in self.connection._limited_data_types
|
||||
|
||||
def skip_default(self, field):
|
||||
if not self._supports_limited_data_type_defaults:
|
||||
return self._is_limited_data_type(field)
|
||||
return False
|
||||
|
||||
def skip_default_on_alter(self, field):
|
||||
if self._is_limited_data_type(field) and not self.connection.mysql_is_mariadb:
|
||||
# MySQL doesn't support defaults for BLOB and TEXT in the
|
||||
# ALTER COLUMN statement.
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def _supports_limited_data_type_defaults(self):
|
||||
# MariaDB >= 10.2.1 and MySQL >= 8.0.13 supports defaults for BLOB
|
||||
# and TEXT.
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.connection.mysql_version >= (10, 2, 1)
|
||||
return self.connection.mysql_version >= (8, 0, 13)
|
||||
|
||||
def _column_default_sql(self, field):
|
||||
if (
|
||||
not self.connection.mysql_is_mariadb and
|
||||
self._supports_limited_data_type_defaults and
|
||||
self._is_limited_data_type(field)
|
||||
):
|
||||
# MySQL supports defaults for BLOB and TEXT columns only if the
|
||||
# default value is written as an expression i.e. in parentheses.
|
||||
return '(%s)'
|
||||
return super()._column_default_sql(field)
|
||||
|
||||
def add_field(self, model, field):
|
||||
super().add_field(model, field)
|
||||
|
||||
# Simulate the effect of a one-off default.
|
||||
# field.default may be unhashable, so a set isn't used for "in" check.
|
||||
if self.skip_default(field) and field.default not in (None, NOT_PROVIDED):
|
||||
effective_default = self.effective_default(field)
|
||||
self.execute('UPDATE %(table)s SET %(column)s = %%s' % {
|
||||
'table': self.quote_name(model._meta.db_table),
|
||||
'column': self.quote_name(field.column),
|
||||
}, [effective_default])
|
||||
|
||||
def _field_should_be_indexed(self, model, field):
|
||||
if not super()._field_should_be_indexed(model, field):
|
||||
return False
|
||||
|
||||
storage = self.connection.introspection.get_storage_engine(
|
||||
self.connection.cursor(), model._meta.db_table
|
||||
)
|
||||
# No need to create an index for ForeignKey fields except if
|
||||
# db_constraint=False because the index from that constraint won't be
|
||||
# created.
|
||||
if (storage == "InnoDB" and
|
||||
field.get_internal_type() == 'ForeignKey' and
|
||||
field.db_constraint):
|
||||
return False
|
||||
return not self._is_limited_data_type(field)
|
||||
|
||||
def _delete_composed_index(self, model, fields, *args):
|
||||
"""
|
||||
MySQL can remove an implicit FK index on a field when that field is
|
||||
covered by another index like a unique_together. "covered" here means
|
||||
that the more complex index starts like the simpler one.
|
||||
http://bugs.mysql.com/bug.php?id=37910 / Django ticket #24757
|
||||
We check here before removing the [unique|index]_together if we have to
|
||||
recreate a FK index.
|
||||
"""
|
||||
first_field = model._meta.get_field(fields[0])
|
||||
if first_field.get_internal_type() == 'ForeignKey':
|
||||
constraint_names = self._constraint_names(model, [first_field.column], index=True)
|
||||
if not constraint_names:
|
||||
self.execute(
|
||||
self._create_index_sql(model, fields=[first_field], suffix='')
|
||||
)
|
||||
return super()._delete_composed_index(model, fields, *args)
|
||||
|
||||
def _set_field_new_type_null_status(self, field, new_type):
|
||||
"""
|
||||
Keep the null property of the old field. If it has changed, it will be
|
||||
handled separately.
|
||||
"""
|
||||
if field.null:
|
||||
new_type += " NULL"
|
||||
else:
|
||||
new_type += " NOT NULL"
|
||||
return new_type
|
||||
|
||||
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
|
||||
new_type = self._set_field_new_type_null_status(old_field, new_type)
|
||||
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
|
||||
|
||||
def _rename_field_sql(self, table, old_field, new_field, new_type):
|
||||
new_type = self._set_field_new_type_null_status(old_field, new_type)
|
||||
return super()._rename_field_sql(table, old_field, new_field, new_type)
|
||||
@@ -0,0 +1,69 @@
|
||||
from django.core import checks
|
||||
from django.db.backends.base.validation import BaseDatabaseValidation
|
||||
from django.utils.version import get_docs_version
|
||||
|
||||
|
||||
class DatabaseValidation(BaseDatabaseValidation):
|
||||
def check(self, **kwargs):
|
||||
issues = super().check(**kwargs)
|
||||
issues.extend(self._check_sql_mode(**kwargs))
|
||||
return issues
|
||||
|
||||
def _check_sql_mode(self, **kwargs):
|
||||
if not (self.connection.sql_mode & {'STRICT_TRANS_TABLES', 'STRICT_ALL_TABLES'}):
|
||||
return [checks.Warning(
|
||||
"%s Strict Mode is not set for database connection '%s'"
|
||||
% (self.connection.display_name, self.connection.alias),
|
||||
hint=(
|
||||
"%s's Strict Mode fixes many data integrity problems in "
|
||||
"%s, such as data truncation upon insertion, by "
|
||||
"escalating warnings into errors. It is strongly "
|
||||
"recommended you activate it. See: "
|
||||
"https://docs.djangoproject.com/en/%s/ref/databases/#mysql-sql-mode"
|
||||
% (
|
||||
self.connection.display_name,
|
||||
self.connection.display_name,
|
||||
get_docs_version(),
|
||||
),
|
||||
),
|
||||
id='mysql.W002',
|
||||
)]
|
||||
return []
|
||||
|
||||
def check_field_type(self, field, field_type):
|
||||
"""
|
||||
MySQL has the following field length restriction:
|
||||
No character (varchar) fields can have a length exceeding 255
|
||||
characters if they have a unique index on them.
|
||||
MySQL doesn't support a database index on some data types.
|
||||
"""
|
||||
errors = []
|
||||
if (field_type.startswith('varchar') and field.unique and
|
||||
(field.max_length is None or int(field.max_length) > 255)):
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
'%s may not allow unique CharFields to have a max_length '
|
||||
'> 255.' % self.connection.display_name,
|
||||
obj=field,
|
||||
hint=(
|
||||
'See: https://docs.djangoproject.com/en/%s/ref/'
|
||||
'databases/#mysql-character-fields' % get_docs_version()
|
||||
),
|
||||
id='mysql.W003',
|
||||
)
|
||||
)
|
||||
|
||||
if field.db_index and field_type.lower() in self.connection._limited_data_types:
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
'%s does not support a database index on %s columns.'
|
||||
% (self.connection.display_name, field_type),
|
||||
hint=(
|
||||
"An index won't be created. Silence this warning if "
|
||||
"you don't care about it."
|
||||
),
|
||||
obj=field,
|
||||
id='fields.W162',
|
||||
)
|
||||
)
|
||||
return errors
|
||||
554
venv/Lib/site-packages/django/db/backends/oracle/base.py
Normal file
554
venv/Lib/site-packages/django/db/backends/oracle/base.py
Normal file
@@ -0,0 +1,554 @@
|
||||
"""
|
||||
Oracle database backend for Django.
|
||||
|
||||
Requires cx_Oracle: https://oracle.github.io/python-cx_Oracle/
|
||||
"""
|
||||
import datetime
|
||||
import decimal
|
||||
import os
|
||||
import platform
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import IntegrityError
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.encoding import force_bytes, force_str
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
def _setup_environment(environ):
|
||||
# Cygwin requires some special voodoo to set the environment variables
|
||||
# properly so that Oracle will see them.
|
||||
if platform.system().upper().startswith('CYGWIN'):
|
||||
try:
|
||||
import ctypes
|
||||
except ImportError as e:
|
||||
raise ImproperlyConfigured("Error loading ctypes: %s; "
|
||||
"the Oracle backend requires ctypes to "
|
||||
"operate correctly under Cygwin." % e)
|
||||
kernel32 = ctypes.CDLL('kernel32')
|
||||
for name, value in environ:
|
||||
kernel32.SetEnvironmentVariableA(name, value)
|
||||
else:
|
||||
os.environ.update(environ)
|
||||
|
||||
|
||||
_setup_environment([
|
||||
# Oracle takes client-side character set encoding from the environment.
|
||||
('NLS_LANG', '.AL32UTF8'),
|
||||
# This prevents Unicode from getting mangled by getting encoded into the
|
||||
# potentially non-Unicode database character set.
|
||||
('ORA_NCHAR_LITERAL_REPLACE', 'TRUE'),
|
||||
])
|
||||
|
||||
|
||||
try:
|
||||
import cx_Oracle as Database
|
||||
except ImportError as e:
|
||||
raise ImproperlyConfigured("Error loading cx_Oracle module: %s" % e)
|
||||
|
||||
# Some of these import cx_Oracle, so import them after checking if it's installed.
|
||||
from .client import DatabaseClient # NOQA
|
||||
from .creation import DatabaseCreation # NOQA
|
||||
from .features import DatabaseFeatures # NOQA
|
||||
from .introspection import DatabaseIntrospection # NOQA
|
||||
from .operations import DatabaseOperations # NOQA
|
||||
from .schema import DatabaseSchemaEditor # NOQA
|
||||
from .utils import Oracle_datetime, dsn # NOQA
|
||||
from .validation import DatabaseValidation # NOQA
|
||||
|
||||
|
||||
@contextmanager
|
||||
def wrap_oracle_errors():
|
||||
try:
|
||||
yield
|
||||
except Database.DatabaseError as e:
|
||||
# cx_Oracle raises a cx_Oracle.DatabaseError exception with the
|
||||
# following attributes and values:
|
||||
# code = 2091
|
||||
# message = 'ORA-02091: transaction rolled back
|
||||
# 'ORA-02291: integrity constraint (TEST_DJANGOTEST.SYS
|
||||
# _C00102056) violated - parent key not found'
|
||||
# or:
|
||||
# 'ORA-00001: unique constraint (DJANGOTEST.DEFERRABLE_
|
||||
# PINK_CONSTRAINT) violated
|
||||
# Convert that case to Django's IntegrityError exception.
|
||||
x = e.args[0]
|
||||
if (
|
||||
hasattr(x, 'code') and
|
||||
hasattr(x, 'message') and
|
||||
x.code == 2091 and
|
||||
('ORA-02291' in x.message or 'ORA-00001' in x.message)
|
||||
):
|
||||
raise IntegrityError(*tuple(e.args))
|
||||
raise
|
||||
|
||||
|
||||
class _UninitializedOperatorsDescriptor:
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
# If connection.operators is looked up before a connection has been
|
||||
# created, transparently initialize connection.operators to avert an
|
||||
# AttributeError.
|
||||
if instance is None:
|
||||
raise AttributeError("operators not available as class attribute")
|
||||
# Creating a cursor will initialize the operators.
|
||||
instance.cursor().close()
|
||||
return instance.__dict__['operators']
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = 'oracle'
|
||||
display_name = 'Oracle'
|
||||
# This dictionary maps Field objects to their associated Oracle column
|
||||
# types, as strings. Column-type strings can contain format strings; they'll
|
||||
# be interpolated against the values of Field.__dict__ before being output.
|
||||
# If a column type is set to None, it won't be included in the output.
|
||||
#
|
||||
# Any format strings starting with "qn_" are quoted before being used in the
|
||||
# output (the "qn_" prefix is stripped before the lookup is performed.
|
||||
data_types = {
|
||||
'AutoField': 'NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY',
|
||||
'BigAutoField': 'NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY',
|
||||
'BinaryField': 'BLOB',
|
||||
'BooleanField': 'NUMBER(1)',
|
||||
'CharField': 'NVARCHAR2(%(max_length)s)',
|
||||
'DateField': 'DATE',
|
||||
'DateTimeField': 'TIMESTAMP',
|
||||
'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)',
|
||||
'DurationField': 'INTERVAL DAY(9) TO SECOND(6)',
|
||||
'FileField': 'NVARCHAR2(%(max_length)s)',
|
||||
'FilePathField': 'NVARCHAR2(%(max_length)s)',
|
||||
'FloatField': 'DOUBLE PRECISION',
|
||||
'IntegerField': 'NUMBER(11)',
|
||||
'JSONField': 'NCLOB',
|
||||
'BigIntegerField': 'NUMBER(19)',
|
||||
'IPAddressField': 'VARCHAR2(15)',
|
||||
'GenericIPAddressField': 'VARCHAR2(39)',
|
||||
'OneToOneField': 'NUMBER(11)',
|
||||
'PositiveBigIntegerField': 'NUMBER(19)',
|
||||
'PositiveIntegerField': 'NUMBER(11)',
|
||||
'PositiveSmallIntegerField': 'NUMBER(11)',
|
||||
'SlugField': 'NVARCHAR2(%(max_length)s)',
|
||||
'SmallAutoField': 'NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY',
|
||||
'SmallIntegerField': 'NUMBER(11)',
|
||||
'TextField': 'NCLOB',
|
||||
'TimeField': 'TIMESTAMP',
|
||||
'URLField': 'VARCHAR2(%(max_length)s)',
|
||||
'UUIDField': 'VARCHAR2(32)',
|
||||
}
|
||||
data_type_check_constraints = {
|
||||
'BooleanField': '%(qn_column)s IN (0,1)',
|
||||
'JSONField': '%(qn_column)s IS JSON',
|
||||
'PositiveBigIntegerField': '%(qn_column)s >= 0',
|
||||
'PositiveIntegerField': '%(qn_column)s >= 0',
|
||||
'PositiveSmallIntegerField': '%(qn_column)s >= 0',
|
||||
}
|
||||
|
||||
# Oracle doesn't support a database index on these columns.
|
||||
_limited_data_types = ('clob', 'nclob', 'blob')
|
||||
|
||||
operators = _UninitializedOperatorsDescriptor()
|
||||
|
||||
_standard_operators = {
|
||||
'exact': '= %s',
|
||||
'iexact': '= UPPER(%s)',
|
||||
'contains': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
|
||||
'icontains': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
|
||||
'gt': '> %s',
|
||||
'gte': '>= %s',
|
||||
'lt': '< %s',
|
||||
'lte': '<= %s',
|
||||
'startswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
|
||||
'endswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
|
||||
'istartswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
|
||||
'iendswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
|
||||
}
|
||||
|
||||
_likec_operators = {
|
||||
**_standard_operators,
|
||||
'contains': "LIKEC %s ESCAPE '\\'",
|
||||
'icontains': "LIKEC UPPER(%s) ESCAPE '\\'",
|
||||
'startswith': "LIKEC %s ESCAPE '\\'",
|
||||
'endswith': "LIKEC %s ESCAPE '\\'",
|
||||
'istartswith': "LIKEC UPPER(%s) ESCAPE '\\'",
|
||||
'iendswith': "LIKEC UPPER(%s) ESCAPE '\\'",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, %, _)
|
||||
# should be escaped on the database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
|
||||
_pattern_ops = {
|
||||
'contains': "'%%' || {} || '%%'",
|
||||
'icontains': "'%%' || UPPER({}) || '%%'",
|
||||
'startswith': "{} || '%%'",
|
||||
'istartswith': "UPPER({}) || '%%'",
|
||||
'endswith': "'%%' || {}",
|
||||
'iendswith': "'%%' || UPPER({})",
|
||||
}
|
||||
|
||||
_standard_pattern_ops = {k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)"
|
||||
" ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
for k, v in _pattern_ops.items()}
|
||||
_likec_pattern_ops = {k: "LIKEC " + v + " ESCAPE '\\'"
|
||||
for k, v in _pattern_ops.items()}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
validation_class = DatabaseValidation
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True)
|
||||
self.features.can_return_columns_from_insert = use_returning_into
|
||||
|
||||
def get_connection_params(self):
|
||||
conn_params = self.settings_dict['OPTIONS'].copy()
|
||||
if 'use_returning_into' in conn_params:
|
||||
del conn_params['use_returning_into']
|
||||
return conn_params
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
return Database.connect(
|
||||
user=self.settings_dict['USER'],
|
||||
password=self.settings_dict['PASSWORD'],
|
||||
dsn=dsn(self.settings_dict),
|
||||
**conn_params,
|
||||
)
|
||||
|
||||
def init_connection_state(self):
|
||||
cursor = self.create_cursor()
|
||||
# Set the territory first. The territory overrides NLS_DATE_FORMAT
|
||||
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of
|
||||
# these are set in single statement it isn't clear what is supposed
|
||||
# to happen.
|
||||
cursor.execute("ALTER SESSION SET NLS_TERRITORY = 'AMERICA'")
|
||||
# Set Oracle date to ANSI date format. This only needs to execute
|
||||
# once when we create a new connection. We also set the Territory
|
||||
# to 'AMERICA' which forces Sunday to evaluate to a '1' in
|
||||
# TO_CHAR().
|
||||
cursor.execute(
|
||||
"ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'"
|
||||
" NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'" +
|
||||
(" TIME_ZONE = 'UTC'" if settings.USE_TZ else '')
|
||||
)
|
||||
cursor.close()
|
||||
if 'operators' not in self.__dict__:
|
||||
# Ticket #14149: Check whether our LIKE implementation will
|
||||
# work for this connection or we need to fall back on LIKEC.
|
||||
# This check is performed only once per DatabaseWrapper
|
||||
# instance per thread, since subsequent connections will use
|
||||
# the same settings.
|
||||
cursor = self.create_cursor()
|
||||
try:
|
||||
cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
|
||||
% self._standard_operators['contains'],
|
||||
['X'])
|
||||
except Database.DatabaseError:
|
||||
self.operators = self._likec_operators
|
||||
self.pattern_ops = self._likec_pattern_ops
|
||||
else:
|
||||
self.operators = self._standard_operators
|
||||
self.pattern_ops = self._standard_pattern_ops
|
||||
cursor.close()
|
||||
self.connection.stmtcachesize = 20
|
||||
# Ensure all changes are preserved even when AUTOCOMMIT is False.
|
||||
if not self.get_autocommit():
|
||||
self.commit()
|
||||
|
||||
@async_unsafe
|
||||
def create_cursor(self, name=None):
|
||||
return FormatStylePlaceholderCursor(self.connection)
|
||||
|
||||
def _commit(self):
|
||||
if self.connection is not None:
|
||||
with wrap_oracle_errors():
|
||||
return self.connection.commit()
|
||||
|
||||
# Oracle doesn't support releasing savepoints. But we fake them when query
|
||||
# logging is enabled to keep query counts consistent with other backends.
|
||||
def _savepoint_commit(self, sid):
|
||||
if self.queries_logged:
|
||||
self.queries_log.append({
|
||||
'sql': '-- RELEASE SAVEPOINT %s (faked)' % self.ops.quote_name(sid),
|
||||
'time': '0.000',
|
||||
})
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
with self.wrap_database_errors:
|
||||
self.connection.autocommit = autocommit
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check constraints by setting them to immediate. Return them to deferred
|
||||
afterward.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
|
||||
cursor.execute('SET CONSTRAINTS ALL DEFERRED')
|
||||
|
||||
def is_usable(self):
|
||||
try:
|
||||
self.connection.ping()
|
||||
except Database.Error:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def cx_oracle_version(self):
|
||||
return tuple(int(x) for x in Database.version.split('.'))
|
||||
|
||||
@cached_property
|
||||
def oracle_version(self):
|
||||
with self.temporary_connection():
|
||||
return tuple(int(x) for x in self.connection.version.split('.'))
|
||||
|
||||
|
||||
class OracleParam:
|
||||
"""
|
||||
Wrapper object for formatting parameters for Oracle. If the string
|
||||
representation of the value is large enough (greater than 4000 characters)
|
||||
the input size needs to be set as CLOB. Alternatively, if the parameter
|
||||
has an `input_size` attribute, then the value of the `input_size` attribute
|
||||
will be used instead. Otherwise, no input size will be set for the
|
||||
parameter when executing the query.
|
||||
"""
|
||||
|
||||
def __init__(self, param, cursor, strings_only=False):
|
||||
# With raw SQL queries, datetimes can reach this function
|
||||
# without being converted by DateTimeField.get_db_prep_value.
|
||||
if settings.USE_TZ and (isinstance(param, datetime.datetime) and
|
||||
not isinstance(param, Oracle_datetime)):
|
||||
param = Oracle_datetime.from_datetime(param)
|
||||
|
||||
string_size = 0
|
||||
# Oracle doesn't recognize True and False correctly.
|
||||
if param is True:
|
||||
param = 1
|
||||
elif param is False:
|
||||
param = 0
|
||||
if hasattr(param, 'bind_parameter'):
|
||||
self.force_bytes = param.bind_parameter(cursor)
|
||||
elif isinstance(param, (Database.Binary, datetime.timedelta)):
|
||||
self.force_bytes = param
|
||||
else:
|
||||
# To transmit to the database, we need Unicode if supported
|
||||
# To get size right, we must consider bytes.
|
||||
self.force_bytes = force_str(param, cursor.charset, strings_only)
|
||||
if isinstance(self.force_bytes, str):
|
||||
# We could optimize by only converting up to 4000 bytes here
|
||||
string_size = len(force_bytes(param, cursor.charset, strings_only))
|
||||
if hasattr(param, 'input_size'):
|
||||
# If parameter has `input_size` attribute, use that.
|
||||
self.input_size = param.input_size
|
||||
elif string_size > 4000:
|
||||
# Mark any string param greater than 4000 characters as a CLOB.
|
||||
self.input_size = Database.CLOB
|
||||
elif isinstance(param, datetime.datetime):
|
||||
self.input_size = Database.TIMESTAMP
|
||||
else:
|
||||
self.input_size = None
|
||||
|
||||
|
||||
class VariableWrapper:
|
||||
"""
|
||||
An adapter class for cursor variables that prevents the wrapped object
|
||||
from being converted into a string when used to instantiate an OracleParam.
|
||||
This can be used generally for any other object that should be passed into
|
||||
Cursor.execute as-is.
|
||||
"""
|
||||
|
||||
def __init__(self, var):
|
||||
self.var = var
|
||||
|
||||
def bind_parameter(self, cursor):
|
||||
return self.var
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.var, key)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == 'var':
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
setattr(self.var, key, value)
|
||||
|
||||
|
||||
class FormatStylePlaceholderCursor:
|
||||
"""
|
||||
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
|
||||
style. This fixes it -- but note that if you want to use a literal "%s" in
|
||||
a query, you'll need to use "%%s".
|
||||
"""
|
||||
charset = 'utf-8'
|
||||
|
||||
def __init__(self, connection):
|
||||
self.cursor = connection.cursor()
|
||||
self.cursor.outputtypehandler = self._output_type_handler
|
||||
|
||||
@staticmethod
|
||||
def _output_number_converter(value):
|
||||
return decimal.Decimal(value) if '.' in value else int(value)
|
||||
|
||||
@staticmethod
|
||||
def _get_decimal_converter(precision, scale):
|
||||
if scale == 0:
|
||||
return int
|
||||
context = decimal.Context(prec=precision)
|
||||
quantize_value = decimal.Decimal(1).scaleb(-scale)
|
||||
return lambda v: decimal.Decimal(v).quantize(quantize_value, context=context)
|
||||
|
||||
@staticmethod
|
||||
def _output_type_handler(cursor, name, defaultType, length, precision, scale):
|
||||
"""
|
||||
Called for each db column fetched from cursors. Return numbers as the
|
||||
appropriate Python type.
|
||||
"""
|
||||
if defaultType == Database.NUMBER:
|
||||
if scale == -127:
|
||||
if precision == 0:
|
||||
# NUMBER column: decimal-precision floating point.
|
||||
# This will normally be an integer from a sequence,
|
||||
# but it could be a decimal value.
|
||||
outconverter = FormatStylePlaceholderCursor._output_number_converter
|
||||
else:
|
||||
# FLOAT column: binary-precision floating point.
|
||||
# This comes from FloatField columns.
|
||||
outconverter = float
|
||||
elif precision > 0:
|
||||
# NUMBER(p,s) column: decimal-precision fixed point.
|
||||
# This comes from IntegerField and DecimalField columns.
|
||||
outconverter = FormatStylePlaceholderCursor._get_decimal_converter(precision, scale)
|
||||
else:
|
||||
# No type information. This normally comes from a
|
||||
# mathematical expression in the SELECT list. Guess int
|
||||
# or Decimal based on whether it has a decimal point.
|
||||
outconverter = FormatStylePlaceholderCursor._output_number_converter
|
||||
return cursor.var(
|
||||
Database.STRING,
|
||||
size=255,
|
||||
arraysize=cursor.arraysize,
|
||||
outconverter=outconverter,
|
||||
)
|
||||
|
||||
def _format_params(self, params):
|
||||
try:
|
||||
return {k: OracleParam(v, self, True) for k, v in params.items()}
|
||||
except AttributeError:
|
||||
return tuple(OracleParam(p, self, True) for p in params)
|
||||
|
||||
def _guess_input_sizes(self, params_list):
|
||||
# Try dict handling; if that fails, treat as sequence
|
||||
if hasattr(params_list[0], 'keys'):
|
||||
sizes = {}
|
||||
for params in params_list:
|
||||
for k, value in params.items():
|
||||
if value.input_size:
|
||||
sizes[k] = value.input_size
|
||||
if sizes:
|
||||
self.setinputsizes(**sizes)
|
||||
else:
|
||||
# It's not a list of dicts; it's a list of sequences
|
||||
sizes = [None] * len(params_list[0])
|
||||
for params in params_list:
|
||||
for i, value in enumerate(params):
|
||||
if value.input_size:
|
||||
sizes[i] = value.input_size
|
||||
if sizes:
|
||||
self.setinputsizes(*sizes)
|
||||
|
||||
def _param_generator(self, params):
|
||||
# Try dict handling; if that fails, treat as sequence
|
||||
if hasattr(params, 'items'):
|
||||
return {k: v.force_bytes for k, v in params.items()}
|
||||
else:
|
||||
return [p.force_bytes for p in params]
|
||||
|
||||
def _fix_for_params(self, query, params, unify_by_values=False):
|
||||
# cx_Oracle wants no trailing ';' for SQL statements. For PL/SQL, it
|
||||
# it does want a trailing ';' but not a trailing '/'. However, these
|
||||
# characters must be included in the original query in case the query
|
||||
# is being passed to SQL*Plus.
|
||||
if query.endswith(';') or query.endswith('/'):
|
||||
query = query[:-1]
|
||||
if params is None:
|
||||
params = []
|
||||
elif hasattr(params, 'keys'):
|
||||
# Handle params as dict
|
||||
args = {k: ":%s" % k for k in params}
|
||||
query = query % args
|
||||
elif unify_by_values and params:
|
||||
# Handle params as a dict with unified query parameters by their
|
||||
# values. It can be used only in single query execute() because
|
||||
# executemany() shares the formatted query with each of the params
|
||||
# list. e.g. for input params = [0.75, 2, 0.75, 'sth', 0.75]
|
||||
# params_dict = {0.75: ':arg0', 2: ':arg1', 'sth': ':arg2'}
|
||||
# args = [':arg0', ':arg1', ':arg0', ':arg2', ':arg0']
|
||||
# params = {':arg0': 0.75, ':arg1': 2, ':arg2': 'sth'}
|
||||
params_dict = {
|
||||
param: ':arg%d' % i
|
||||
for i, param in enumerate(dict.fromkeys(params))
|
||||
}
|
||||
args = [params_dict[param] for param in params]
|
||||
params = {value: key for key, value in params_dict.items()}
|
||||
query = query % tuple(args)
|
||||
else:
|
||||
# Handle params as sequence
|
||||
args = [(':arg%d' % i) for i in range(len(params))]
|
||||
query = query % tuple(args)
|
||||
return query, self._format_params(params)
|
||||
|
||||
def execute(self, query, params=None):
|
||||
query, params = self._fix_for_params(query, params, unify_by_values=True)
|
||||
self._guess_input_sizes([params])
|
||||
with wrap_oracle_errors():
|
||||
return self.cursor.execute(query, self._param_generator(params))
|
||||
|
||||
def executemany(self, query, params=None):
|
||||
if not params:
|
||||
# No params given, nothing to do
|
||||
return None
|
||||
# uniform treatment for sequences and iterables
|
||||
params_iter = iter(params)
|
||||
query, firstparams = self._fix_for_params(query, next(params_iter))
|
||||
# we build a list of formatted params; as we're going to traverse it
|
||||
# more than once, we can't make it lazy by using a generator
|
||||
formatted = [firstparams] + [self._format_params(p) for p in params_iter]
|
||||
self._guess_input_sizes(formatted)
|
||||
with wrap_oracle_errors():
|
||||
return self.cursor.executemany(query, [self._param_generator(p) for p in formatted])
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.cursor.close()
|
||||
except Database.InterfaceError:
|
||||
# already closed
|
||||
pass
|
||||
|
||||
def var(self, *args):
|
||||
return VariableWrapper(self.cursor.var(*args))
|
||||
|
||||
def arrayvar(self, *args):
|
||||
return VariableWrapper(self.cursor.arrayvar(*args))
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.cursor, attr)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
27
venv/Lib/site-packages/django/db/backends/oracle/client.py
Normal file
27
venv/Lib/site-packages/django/db/backends/oracle/client.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import shutil
|
||||
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = 'sqlplus'
|
||||
wrapper_name = 'rlwrap'
|
||||
|
||||
@staticmethod
|
||||
def connect_string(settings_dict):
|
||||
from django.db.backends.oracle.utils import dsn
|
||||
|
||||
return '%s/"%s"@%s' % (
|
||||
settings_dict['USER'],
|
||||
settings_dict['PASSWORD'],
|
||||
dsn(settings_dict),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name, '-L', cls.connect_string(settings_dict)]
|
||||
wrapper_path = shutil.which(cls.wrapper_name)
|
||||
if wrapper_path:
|
||||
args = [wrapper_path, *args]
|
||||
args.extend(parameters)
|
||||
return args, None
|
||||
400
venv/Lib/site-packages/django/db/backends/oracle/creation.py
Normal file
400
venv/Lib/site-packages/django/db/backends/oracle/creation.py
Normal file
@@ -0,0 +1,400 @@
|
||||
import sys
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import DatabaseError
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
from django.utils.crypto import get_random_string
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
TEST_DATABASE_PREFIX = 'test_'
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
|
||||
@cached_property
|
||||
def _maindb_connection(self):
|
||||
"""
|
||||
This is analogous to other backends' `_nodb_connection` property,
|
||||
which allows access to an "administrative" connection which can
|
||||
be used to manage the test databases.
|
||||
For Oracle, the only connection that can be used for that purpose
|
||||
is the main (non-test) connection.
|
||||
"""
|
||||
settings_dict = settings.DATABASES[self.connection.alias]
|
||||
user = settings_dict.get('SAVED_USER') or settings_dict['USER']
|
||||
password = settings_dict.get('SAVED_PASSWORD') or settings_dict['PASSWORD']
|
||||
settings_dict = {**settings_dict, 'USER': user, 'PASSWORD': password}
|
||||
DatabaseWrapper = type(self.connection)
|
||||
return DatabaseWrapper(settings_dict, alias=self.connection.alias)
|
||||
|
||||
def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
|
||||
parameters = self._get_test_db_params()
|
||||
with self._maindb_connection.cursor() as cursor:
|
||||
if self._test_database_create():
|
||||
try:
|
||||
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
|
||||
except Exception as e:
|
||||
if 'ORA-01543' not in str(e):
|
||||
# All errors except "tablespace already exists" cancel tests
|
||||
self.log('Got an error creating the test database: %s' % e)
|
||||
sys.exit(2)
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"It appears the test database, %s, already exists. "
|
||||
"Type 'yes' to delete it, or 'no' to cancel: " % parameters['user'])
|
||||
if autoclobber or confirm == 'yes':
|
||||
if verbosity >= 1:
|
||||
self.log("Destroying old test database for alias '%s'..." % self.connection.alias)
|
||||
try:
|
||||
self._execute_test_db_destruction(cursor, parameters, verbosity)
|
||||
except DatabaseError as e:
|
||||
if 'ORA-29857' in str(e):
|
||||
self._handle_objects_preventing_db_destruction(cursor, parameters,
|
||||
verbosity, autoclobber)
|
||||
else:
|
||||
# Ran into a database error that isn't about leftover objects in the tablespace
|
||||
self.log('Got an error destroying the old test database: %s' % e)
|
||||
sys.exit(2)
|
||||
except Exception as e:
|
||||
self.log('Got an error destroying the old test database: %s' % e)
|
||||
sys.exit(2)
|
||||
try:
|
||||
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
|
||||
except Exception as e:
|
||||
self.log('Got an error recreating the test database: %s' % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log('Tests cancelled.')
|
||||
sys.exit(1)
|
||||
|
||||
if self._test_user_create():
|
||||
if verbosity >= 1:
|
||||
self.log('Creating test user...')
|
||||
try:
|
||||
self._create_test_user(cursor, parameters, verbosity, keepdb)
|
||||
except Exception as e:
|
||||
if 'ORA-01920' not in str(e):
|
||||
# All errors except "user already exists" cancel tests
|
||||
self.log('Got an error creating the test user: %s' % e)
|
||||
sys.exit(2)
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"It appears the test user, %s, already exists. Type "
|
||||
"'yes' to delete it, or 'no' to cancel: " % parameters['user'])
|
||||
if autoclobber or confirm == 'yes':
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log('Destroying old test user...')
|
||||
self._destroy_test_user(cursor, parameters, verbosity)
|
||||
if verbosity >= 1:
|
||||
self.log('Creating test user...')
|
||||
self._create_test_user(cursor, parameters, verbosity, keepdb)
|
||||
except Exception as e:
|
||||
self.log('Got an error recreating the test user: %s' % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log('Tests cancelled.')
|
||||
sys.exit(1)
|
||||
self._maindb_connection.close() # done with main user -- test user and tablespaces created
|
||||
self._switch_to_test_user(parameters)
|
||||
return self.connection.settings_dict['NAME']
|
||||
|
||||
def _switch_to_test_user(self, parameters):
|
||||
"""
|
||||
Switch to the user that's used for creating the test database.
|
||||
|
||||
Oracle doesn't have the concept of separate databases under the same
|
||||
user, so a separate user is used; see _create_test_db(). The main user
|
||||
is also needed for cleanup when testing is completed, so save its
|
||||
credentials in the SAVED_USER/SAVED_PASSWORD key in the settings dict.
|
||||
"""
|
||||
real_settings = settings.DATABASES[self.connection.alias]
|
||||
real_settings['SAVED_USER'] = self.connection.settings_dict['SAVED_USER'] = \
|
||||
self.connection.settings_dict['USER']
|
||||
real_settings['SAVED_PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD'] = \
|
||||
self.connection.settings_dict['PASSWORD']
|
||||
real_test_settings = real_settings['TEST']
|
||||
test_settings = self.connection.settings_dict['TEST']
|
||||
real_test_settings['USER'] = real_settings['USER'] = test_settings['USER'] = \
|
||||
self.connection.settings_dict['USER'] = parameters['user']
|
||||
real_settings['PASSWORD'] = self.connection.settings_dict['PASSWORD'] = parameters['password']
|
||||
|
||||
def set_as_test_mirror(self, primary_settings_dict):
|
||||
"""
|
||||
Set this database up to be used in testing as a mirror of a primary
|
||||
database whose settings are given.
|
||||
"""
|
||||
self.connection.settings_dict['USER'] = primary_settings_dict['USER']
|
||||
self.connection.settings_dict['PASSWORD'] = primary_settings_dict['PASSWORD']
|
||||
|
||||
def _handle_objects_preventing_db_destruction(self, cursor, parameters, verbosity, autoclobber):
|
||||
# There are objects in the test tablespace which prevent dropping it
|
||||
# The easy fix is to drop the test user -- but are we allowed to do so?
|
||||
self.log(
|
||||
'There are objects in the old test database which prevent its destruction.\n'
|
||||
'If they belong to the test user, deleting the user will allow the test '
|
||||
'database to be recreated.\n'
|
||||
'Otherwise, you will need to find and remove each of these objects, '
|
||||
'or use a different tablespace.\n'
|
||||
)
|
||||
if self._test_user_create():
|
||||
if not autoclobber:
|
||||
confirm = input("Type 'yes' to delete user %s: " % parameters['user'])
|
||||
if autoclobber or confirm == 'yes':
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log('Destroying old test user...')
|
||||
self._destroy_test_user(cursor, parameters, verbosity)
|
||||
except Exception as e:
|
||||
self.log('Got an error destroying the test user: %s' % e)
|
||||
sys.exit(2)
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log("Destroying old test database for alias '%s'..." % self.connection.alias)
|
||||
self._execute_test_db_destruction(cursor, parameters, verbosity)
|
||||
except Exception as e:
|
||||
self.log('Got an error destroying the test database: %s' % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log('Tests cancelled -- test database cannot be recreated.')
|
||||
sys.exit(1)
|
||||
else:
|
||||
self.log("Django is configured to use pre-existing test user '%s',"
|
||||
" and will not attempt to delete it." % parameters['user'])
|
||||
self.log('Tests cancelled -- test database cannot be recreated.')
|
||||
sys.exit(1)
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity=1):
|
||||
"""
|
||||
Destroy a test database, prompting the user for confirmation if the
|
||||
database already exists. Return the name of the test database created.
|
||||
"""
|
||||
self.connection.settings_dict['USER'] = self.connection.settings_dict['SAVED_USER']
|
||||
self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD']
|
||||
self.connection.close()
|
||||
parameters = self._get_test_db_params()
|
||||
with self._maindb_connection.cursor() as cursor:
|
||||
if self._test_user_create():
|
||||
if verbosity >= 1:
|
||||
self.log('Destroying test user...')
|
||||
self._destroy_test_user(cursor, parameters, verbosity)
|
||||
if self._test_database_create():
|
||||
if verbosity >= 1:
|
||||
self.log('Destroying test database tables...')
|
||||
self._execute_test_db_destruction(cursor, parameters, verbosity)
|
||||
self._maindb_connection.close()
|
||||
|
||||
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):
|
||||
if verbosity >= 2:
|
||||
self.log('_create_test_db(): dbname = %s' % parameters['user'])
|
||||
if self._test_database_oracle_managed_files():
|
||||
statements = [
|
||||
"""
|
||||
CREATE TABLESPACE %(tblspace)s
|
||||
DATAFILE SIZE %(size)s
|
||||
AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s
|
||||
""",
|
||||
"""
|
||||
CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
|
||||
TEMPFILE SIZE %(size_tmp)s
|
||||
AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s
|
||||
""",
|
||||
]
|
||||
else:
|
||||
statements = [
|
||||
"""
|
||||
CREATE TABLESPACE %(tblspace)s
|
||||
DATAFILE '%(datafile)s' SIZE %(size)s REUSE
|
||||
AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s
|
||||
""",
|
||||
"""
|
||||
CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
|
||||
TEMPFILE '%(datafile_tmp)s' SIZE %(size_tmp)s REUSE
|
||||
AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s
|
||||
""",
|
||||
]
|
||||
# Ignore "tablespace already exists" error when keepdb is on.
|
||||
acceptable_ora_err = 'ORA-01543' if keepdb else None
|
||||
self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err)
|
||||
|
||||
def _create_test_user(self, cursor, parameters, verbosity, keepdb=False):
|
||||
if verbosity >= 2:
|
||||
self.log('_create_test_user(): username = %s' % parameters['user'])
|
||||
statements = [
|
||||
"""CREATE USER %(user)s
|
||||
IDENTIFIED BY "%(password)s"
|
||||
DEFAULT TABLESPACE %(tblspace)s
|
||||
TEMPORARY TABLESPACE %(tblspace_temp)s
|
||||
QUOTA UNLIMITED ON %(tblspace)s
|
||||
""",
|
||||
"""GRANT CREATE SESSION,
|
||||
CREATE TABLE,
|
||||
CREATE SEQUENCE,
|
||||
CREATE PROCEDURE,
|
||||
CREATE TRIGGER
|
||||
TO %(user)s""",
|
||||
]
|
||||
# Ignore "user already exists" error when keepdb is on
|
||||
acceptable_ora_err = 'ORA-01920' if keepdb else None
|
||||
success = self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err)
|
||||
# If the password was randomly generated, change the user accordingly.
|
||||
if not success and self._test_settings_get('PASSWORD') is None:
|
||||
set_password = 'ALTER USER %(user)s IDENTIFIED BY "%(password)s"'
|
||||
self._execute_statements(cursor, [set_password], parameters, verbosity)
|
||||
# Most test suites can be run without "create view" and
|
||||
# "create materialized view" privileges. But some need it.
|
||||
for object_type in ('VIEW', 'MATERIALIZED VIEW'):
|
||||
extra = 'GRANT CREATE %(object_type)s TO %(user)s'
|
||||
parameters['object_type'] = object_type
|
||||
success = self._execute_allow_fail_statements(cursor, [extra], parameters, verbosity, 'ORA-01031')
|
||||
if not success and verbosity >= 2:
|
||||
self.log('Failed to grant CREATE %s permission to test user. This may be ok.' % object_type)
|
||||
|
||||
def _execute_test_db_destruction(self, cursor, parameters, verbosity):
|
||||
if verbosity >= 2:
|
||||
self.log('_execute_test_db_destruction(): dbname=%s' % parameters['user'])
|
||||
statements = [
|
||||
'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
|
||||
'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
|
||||
]
|
||||
self._execute_statements(cursor, statements, parameters, verbosity)
|
||||
|
||||
def _destroy_test_user(self, cursor, parameters, verbosity):
|
||||
if verbosity >= 2:
|
||||
self.log('_destroy_test_user(): user=%s' % parameters['user'])
|
||||
self.log('Be patient. This can take some time...')
|
||||
statements = [
|
||||
'DROP USER %(user)s CASCADE',
|
||||
]
|
||||
self._execute_statements(cursor, statements, parameters, verbosity)
|
||||
|
||||
def _execute_statements(self, cursor, statements, parameters, verbosity, allow_quiet_fail=False):
|
||||
for template in statements:
|
||||
stmt = template % parameters
|
||||
if verbosity >= 2:
|
||||
print(stmt)
|
||||
try:
|
||||
cursor.execute(stmt)
|
||||
except Exception as err:
|
||||
if (not allow_quiet_fail) or verbosity >= 2:
|
||||
self.log('Failed (%s)' % (err))
|
||||
raise
|
||||
|
||||
def _execute_allow_fail_statements(self, cursor, statements, parameters, verbosity, acceptable_ora_err):
|
||||
"""
|
||||
Execute statements which are allowed to fail silently if the Oracle
|
||||
error code given by `acceptable_ora_err` is raised. Return True if the
|
||||
statements execute without an exception, or False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Statement can fail when acceptable_ora_err is not None
|
||||
allow_quiet_fail = acceptable_ora_err is not None and len(acceptable_ora_err) > 0
|
||||
self._execute_statements(cursor, statements, parameters, verbosity, allow_quiet_fail=allow_quiet_fail)
|
||||
return True
|
||||
except DatabaseError as err:
|
||||
description = str(err)
|
||||
if acceptable_ora_err is None or acceptable_ora_err not in description:
|
||||
raise
|
||||
return False
|
||||
|
||||
def _get_test_db_params(self):
|
||||
return {
|
||||
'dbname': self._test_database_name(),
|
||||
'user': self._test_database_user(),
|
||||
'password': self._test_database_passwd(),
|
||||
'tblspace': self._test_database_tblspace(),
|
||||
'tblspace_temp': self._test_database_tblspace_tmp(),
|
||||
'datafile': self._test_database_tblspace_datafile(),
|
||||
'datafile_tmp': self._test_database_tblspace_tmp_datafile(),
|
||||
'maxsize': self._test_database_tblspace_maxsize(),
|
||||
'maxsize_tmp': self._test_database_tblspace_tmp_maxsize(),
|
||||
'size': self._test_database_tblspace_size(),
|
||||
'size_tmp': self._test_database_tblspace_tmp_size(),
|
||||
'extsize': self._test_database_tblspace_extsize(),
|
||||
'extsize_tmp': self._test_database_tblspace_tmp_extsize(),
|
||||
}
|
||||
|
||||
def _test_settings_get(self, key, default=None, prefixed=None):
|
||||
"""
|
||||
Return a value from the test settings dict, or a given default, or a
|
||||
prefixed entry from the main settings dict.
|
||||
"""
|
||||
settings_dict = self.connection.settings_dict
|
||||
val = settings_dict['TEST'].get(key, default)
|
||||
if val is None and prefixed:
|
||||
val = TEST_DATABASE_PREFIX + settings_dict[prefixed]
|
||||
return val
|
||||
|
||||
def _test_database_name(self):
|
||||
return self._test_settings_get('NAME', prefixed='NAME')
|
||||
|
||||
def _test_database_create(self):
|
||||
return self._test_settings_get('CREATE_DB', default=True)
|
||||
|
||||
def _test_user_create(self):
|
||||
return self._test_settings_get('CREATE_USER', default=True)
|
||||
|
||||
def _test_database_user(self):
|
||||
return self._test_settings_get('USER', prefixed='USER')
|
||||
|
||||
def _test_database_passwd(self):
|
||||
password = self._test_settings_get('PASSWORD')
|
||||
if password is None and self._test_user_create():
|
||||
# Oracle passwords are limited to 30 chars and can't contain symbols.
|
||||
password = get_random_string(30)
|
||||
return password
|
||||
|
||||
def _test_database_tblspace(self):
|
||||
return self._test_settings_get('TBLSPACE', prefixed='USER')
|
||||
|
||||
def _test_database_tblspace_tmp(self):
|
||||
settings_dict = self.connection.settings_dict
|
||||
return settings_dict['TEST'].get('TBLSPACE_TMP',
|
||||
TEST_DATABASE_PREFIX + settings_dict['USER'] + '_temp')
|
||||
|
||||
def _test_database_tblspace_datafile(self):
|
||||
tblspace = '%s.dbf' % self._test_database_tblspace()
|
||||
return self._test_settings_get('DATAFILE', default=tblspace)
|
||||
|
||||
def _test_database_tblspace_tmp_datafile(self):
|
||||
tblspace = '%s.dbf' % self._test_database_tblspace_tmp()
|
||||
return self._test_settings_get('DATAFILE_TMP', default=tblspace)
|
||||
|
||||
def _test_database_tblspace_maxsize(self):
|
||||
return self._test_settings_get('DATAFILE_MAXSIZE', default='500M')
|
||||
|
||||
def _test_database_tblspace_tmp_maxsize(self):
|
||||
return self._test_settings_get('DATAFILE_TMP_MAXSIZE', default='500M')
|
||||
|
||||
def _test_database_tblspace_size(self):
|
||||
return self._test_settings_get('DATAFILE_SIZE', default='50M')
|
||||
|
||||
def _test_database_tblspace_tmp_size(self):
|
||||
return self._test_settings_get('DATAFILE_TMP_SIZE', default='50M')
|
||||
|
||||
def _test_database_tblspace_extsize(self):
|
||||
return self._test_settings_get('DATAFILE_EXTSIZE', default='25M')
|
||||
|
||||
def _test_database_tblspace_tmp_extsize(self):
|
||||
return self._test_settings_get('DATAFILE_TMP_EXTSIZE', default='25M')
|
||||
|
||||
def _test_database_oracle_managed_files(self):
|
||||
return self._test_settings_get('ORACLE_MANAGED_FILES', default=False)
|
||||
|
||||
def _get_test_db_name(self):
|
||||
"""
|
||||
Return the 'production' DB name to get the test DB creation machinery
|
||||
to work. This isn't a great deal in this case because DB names as
|
||||
handled by Django don't have real counterparts in Oracle.
|
||||
"""
|
||||
return self.connection.settings_dict['NAME']
|
||||
|
||||
def test_db_signature(self):
|
||||
settings_dict = self.connection.settings_dict
|
||||
return (
|
||||
settings_dict['HOST'],
|
||||
settings_dict['PORT'],
|
||||
settings_dict['ENGINE'],
|
||||
settings_dict['NAME'],
|
||||
self._test_database_user(),
|
||||
)
|
||||
120
venv/Lib/site-packages/django/db/backends/oracle/features.py
Normal file
120
venv/Lib/site-packages/django/db/backends/oracle/features.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from django.db import DatabaseError, InterfaceError
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
# Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got
|
||||
# BLOB" when grouping by LOBs (#24096).
|
||||
allows_group_by_lob = False
|
||||
interprets_empty_strings_as_nulls = True
|
||||
has_select_for_update = True
|
||||
has_select_for_update_nowait = True
|
||||
has_select_for_update_skip_locked = True
|
||||
has_select_for_update_of = True
|
||||
select_for_update_of_column = True
|
||||
can_return_columns_from_insert = True
|
||||
supports_subqueries_in_group_by = False
|
||||
ignores_unnecessary_order_by_in_subqueries = False
|
||||
supports_transactions = True
|
||||
supports_timezones = False
|
||||
has_native_duration_field = True
|
||||
can_defer_constraint_checks = True
|
||||
supports_partially_nullable_unique_constraints = False
|
||||
supports_deferrable_unique_constraints = True
|
||||
truncates_names = True
|
||||
supports_tablespaces = True
|
||||
supports_sequence_reset = False
|
||||
can_introspect_materialized_views = True
|
||||
atomic_transactions = False
|
||||
supports_combined_alters = False
|
||||
nulls_order_largest = True
|
||||
requires_literal_defaults = True
|
||||
closed_cursor_error_class = InterfaceError
|
||||
bare_select_suffix = " FROM DUAL"
|
||||
# select for update with limit can be achieved on Oracle, but not with the current backend.
|
||||
supports_select_for_update_with_limit = False
|
||||
supports_temporal_subtraction = True
|
||||
# Oracle doesn't ignore quoted identifiers case but the current backend
|
||||
# does by uppercasing all identifiers.
|
||||
ignores_table_name_case = True
|
||||
supports_index_on_text_field = False
|
||||
has_case_insensitive_like = False
|
||||
create_test_procedure_without_params_sql = """
|
||||
CREATE PROCEDURE "TEST_PROCEDURE" AS
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := 1;
|
||||
END;
|
||||
"""
|
||||
create_test_procedure_with_int_param_sql = """
|
||||
CREATE PROCEDURE "TEST_PROCEDURE" (P_I INTEGER) AS
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := P_I;
|
||||
END;
|
||||
"""
|
||||
supports_callproc_kwargs = True
|
||||
supports_over_clause = True
|
||||
supports_frame_range_fixed_distance = True
|
||||
supports_ignore_conflicts = False
|
||||
max_query_params = 2**16 - 1
|
||||
supports_partial_indexes = False
|
||||
supports_slicing_ordering_in_compound = True
|
||||
allows_multiple_constraints_on_same_fields = False
|
||||
supports_boolean_expr_in_select_clause = False
|
||||
supports_primitives_in_json_field = False
|
||||
supports_json_field_contains = False
|
||||
supports_collation_on_textfield = False
|
||||
test_collations = {
|
||||
'ci': 'BINARY_CI',
|
||||
'cs': 'BINARY',
|
||||
'non_default': 'SWEDISH_CI',
|
||||
'swedish_ci': 'SWEDISH_CI',
|
||||
}
|
||||
test_now_utc_template = "CURRENT_TIMESTAMP AT TIME ZONE 'UTC'"
|
||||
|
||||
django_test_skips = {
|
||||
"Oracle doesn't support SHA224.": {
|
||||
'db_functions.text.test_sha224.SHA224Tests.test_basic',
|
||||
'db_functions.text.test_sha224.SHA224Tests.test_transform',
|
||||
},
|
||||
"Oracle doesn't support bitwise XOR.": {
|
||||
'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor',
|
||||
'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_null',
|
||||
},
|
||||
"Oracle requires ORDER BY in row_number, ANSI:SQL doesn't.": {
|
||||
'expressions_window.tests.WindowFunctionTests.test_row_number_no_ordering',
|
||||
},
|
||||
'Raises ORA-00600: internal error code.': {
|
||||
'model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery',
|
||||
},
|
||||
}
|
||||
django_test_expected_failures = {
|
||||
# A bug in Django/cx_Oracle with respect to string handling (#23843).
|
||||
'annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions',
|
||||
'annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions_can_ref_other_functions',
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
'GenericIPAddressField': 'CharField',
|
||||
'PositiveBigIntegerField': 'BigIntegerField',
|
||||
'PositiveIntegerField': 'IntegerField',
|
||||
'PositiveSmallIntegerField': 'IntegerField',
|
||||
'SmallIntegerField': 'IntegerField',
|
||||
'TimeField': 'DateTimeField',
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def supports_collation_on_charfield(self):
|
||||
with self.connection.cursor() as cursor:
|
||||
try:
|
||||
cursor.execute("SELECT CAST('a' AS VARCHAR2(4001)) FROM dual")
|
||||
except DatabaseError as e:
|
||||
if e.args[0].code == 910:
|
||||
return False
|
||||
raise
|
||||
return True
|
||||
@@ -0,0 +1,22 @@
|
||||
from django.db.models import DecimalField, DurationField, Func
|
||||
|
||||
|
||||
class IntervalToSeconds(Func):
|
||||
function = ''
|
||||
template = """
|
||||
EXTRACT(day from %(expressions)s) * 86400 +
|
||||
EXTRACT(hour from %(expressions)s) * 3600 +
|
||||
EXTRACT(minute from %(expressions)s) * 60 +
|
||||
EXTRACT(second from %(expressions)s)
|
||||
"""
|
||||
|
||||
def __init__(self, expression, *, output_field=None, **extra):
|
||||
super().__init__(expression, output_field=output_field or DecimalField(), **extra)
|
||||
|
||||
|
||||
class SecondsToInterval(Func):
|
||||
function = 'NUMTODSINTERVAL'
|
||||
template = "%(function)s(%(expressions)s, 'SECOND')"
|
||||
|
||||
def __init__(self, expression, *, output_field=None, **extra):
|
||||
super().__init__(expression, output_field=output_field or DurationField(), **extra)
|
||||
@@ -0,0 +1,336 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import cx_Oracle
|
||||
|
||||
from django.db import models
|
||||
from django.db.backends.base.introspection import (
|
||||
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
|
||||
)
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('is_autofield', 'is_json'))
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
cache_bust_counter = 1
|
||||
|
||||
# Maps type objects to Django Field types.
|
||||
@cached_property
|
||||
def data_types_reverse(self):
|
||||
if self.connection.cx_oracle_version < (8,):
|
||||
return {
|
||||
cx_Oracle.BLOB: 'BinaryField',
|
||||
cx_Oracle.CLOB: 'TextField',
|
||||
cx_Oracle.DATETIME: 'DateField',
|
||||
cx_Oracle.FIXED_CHAR: 'CharField',
|
||||
cx_Oracle.FIXED_NCHAR: 'CharField',
|
||||
cx_Oracle.INTERVAL: 'DurationField',
|
||||
cx_Oracle.NATIVE_FLOAT: 'FloatField',
|
||||
cx_Oracle.NCHAR: 'CharField',
|
||||
cx_Oracle.NCLOB: 'TextField',
|
||||
cx_Oracle.NUMBER: 'DecimalField',
|
||||
cx_Oracle.STRING: 'CharField',
|
||||
cx_Oracle.TIMESTAMP: 'DateTimeField',
|
||||
}
|
||||
else:
|
||||
return {
|
||||
cx_Oracle.DB_TYPE_DATE: 'DateField',
|
||||
cx_Oracle.DB_TYPE_BINARY_DOUBLE: 'FloatField',
|
||||
cx_Oracle.DB_TYPE_BLOB: 'BinaryField',
|
||||
cx_Oracle.DB_TYPE_CHAR: 'CharField',
|
||||
cx_Oracle.DB_TYPE_CLOB: 'TextField',
|
||||
cx_Oracle.DB_TYPE_INTERVAL_DS: 'DurationField',
|
||||
cx_Oracle.DB_TYPE_NCHAR: 'CharField',
|
||||
cx_Oracle.DB_TYPE_NCLOB: 'TextField',
|
||||
cx_Oracle.DB_TYPE_NVARCHAR: 'CharField',
|
||||
cx_Oracle.DB_TYPE_NUMBER: 'DecimalField',
|
||||
cx_Oracle.DB_TYPE_TIMESTAMP: 'DateTimeField',
|
||||
cx_Oracle.DB_TYPE_VARCHAR: 'CharField',
|
||||
}
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
if data_type == cx_Oracle.NUMBER:
|
||||
precision, scale = description[4:6]
|
||||
if scale == 0:
|
||||
if precision > 11:
|
||||
return 'BigAutoField' if description.is_autofield else 'BigIntegerField'
|
||||
elif 1 < precision < 6 and description.is_autofield:
|
||||
return 'SmallAutoField'
|
||||
elif precision == 1:
|
||||
return 'BooleanField'
|
||||
elif description.is_autofield:
|
||||
return 'AutoField'
|
||||
else:
|
||||
return 'IntegerField'
|
||||
elif scale == -127:
|
||||
return 'FloatField'
|
||||
elif data_type == cx_Oracle.NCLOB and description.is_json:
|
||||
return 'JSONField'
|
||||
|
||||
return super().get_field_type(data_type, description)
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
cursor.execute("""
|
||||
SELECT table_name, 't'
|
||||
FROM user_tables
|
||||
WHERE
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM user_mviews
|
||||
WHERE user_mviews.mview_name = user_tables.table_name
|
||||
)
|
||||
UNION ALL
|
||||
SELECT view_name, 'v' FROM user_views
|
||||
UNION ALL
|
||||
SELECT mview_name, 'v' FROM user_mviews
|
||||
""")
|
||||
return [TableInfo(self.identifier_converter(row[0]), row[1]) for row in cursor.fetchall()]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
# user_tab_columns gives data default for columns
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
user_tab_cols.column_name,
|
||||
user_tab_cols.data_default,
|
||||
CASE
|
||||
WHEN user_tab_cols.collation = user_tables.default_collation
|
||||
THEN NULL
|
||||
ELSE user_tab_cols.collation
|
||||
END collation,
|
||||
CASE
|
||||
WHEN user_tab_cols.char_used IS NULL
|
||||
THEN user_tab_cols.data_length
|
||||
ELSE user_tab_cols.char_length
|
||||
END as internal_size,
|
||||
CASE
|
||||
WHEN user_tab_cols.identity_column = 'YES' THEN 1
|
||||
ELSE 0
|
||||
END as is_autofield,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1
|
||||
FROM user_json_columns
|
||||
WHERE
|
||||
user_json_columns.table_name = user_tab_cols.table_name AND
|
||||
user_json_columns.column_name = user_tab_cols.column_name
|
||||
)
|
||||
THEN 1
|
||||
ELSE 0
|
||||
END as is_json
|
||||
FROM user_tab_cols
|
||||
LEFT OUTER JOIN
|
||||
user_tables ON user_tables.table_name = user_tab_cols.table_name
|
||||
WHERE user_tab_cols.table_name = UPPER(%s)
|
||||
""", [table_name])
|
||||
field_map = {
|
||||
column: (internal_size, default if default != 'NULL' else None, collation, is_autofield, is_json)
|
||||
for column, default, collation, internal_size, is_autofield, is_json in cursor.fetchall()
|
||||
}
|
||||
self.cache_bust_counter += 1
|
||||
cursor.execute("SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format(
|
||||
self.connection.ops.quote_name(table_name),
|
||||
self.cache_bust_counter))
|
||||
description = []
|
||||
for desc in cursor.description:
|
||||
name = desc[0]
|
||||
internal_size, default, collation, is_autofield, is_json = field_map[name]
|
||||
name = name % {} # cx_Oracle, for some reason, doubles percent signs.
|
||||
description.append(FieldInfo(
|
||||
self.identifier_converter(name), *desc[1:3], internal_size, desc[4] or 0,
|
||||
desc[5] or 0, *desc[6:], default, collation, is_autofield, is_json,
|
||||
))
|
||||
return description
|
||||
|
||||
def identifier_converter(self, name):
|
||||
"""Identifier comparison is case insensitive under Oracle."""
|
||||
return name.lower()
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
user_tab_identity_cols.sequence_name,
|
||||
user_tab_identity_cols.column_name
|
||||
FROM
|
||||
user_tab_identity_cols,
|
||||
user_constraints,
|
||||
user_cons_columns cols
|
||||
WHERE
|
||||
user_constraints.constraint_name = cols.constraint_name
|
||||
AND user_constraints.table_name = user_tab_identity_cols.table_name
|
||||
AND cols.column_name = user_tab_identity_cols.column_name
|
||||
AND user_constraints.constraint_type = 'P'
|
||||
AND user_tab_identity_cols.table_name = UPPER(%s)
|
||||
""", [table_name])
|
||||
# Oracle allows only one identity column per table.
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return [{
|
||||
'name': self.identifier_converter(row[0]),
|
||||
'table': self.identifier_converter(table_name),
|
||||
'column': self.identifier_converter(row[1]),
|
||||
}]
|
||||
# To keep backward compatibility for AutoFields that aren't Oracle
|
||||
# identity columns.
|
||||
for f in table_fields:
|
||||
if isinstance(f, models.AutoField):
|
||||
return [{'table': table_name, 'column': f.column}]
|
||||
return []
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all relationships to the given table.
|
||||
"""
|
||||
table_name = table_name.upper()
|
||||
cursor.execute("""
|
||||
SELECT ca.column_name, cb.table_name, cb.column_name
|
||||
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb
|
||||
WHERE user_constraints.table_name = %s AND
|
||||
user_constraints.constraint_name = ca.constraint_name AND
|
||||
user_constraints.r_constraint_name = cb.constraint_name AND
|
||||
ca.position = cb.position""", [table_name])
|
||||
|
||||
return {
|
||||
self.identifier_converter(field_name): (
|
||||
self.identifier_converter(rel_field_name),
|
||||
self.identifier_converter(rel_table_name),
|
||||
) for field_name, rel_table_name, rel_field_name in cursor.fetchall()
|
||||
}
|
||||
|
||||
def get_key_columns(self, cursor, table_name):
|
||||
cursor.execute("""
|
||||
SELECT ccol.column_name, rcol.table_name AS referenced_table, rcol.column_name AS referenced_column
|
||||
FROM user_constraints c
|
||||
JOIN user_cons_columns ccol
|
||||
ON ccol.constraint_name = c.constraint_name
|
||||
JOIN user_cons_columns rcol
|
||||
ON rcol.constraint_name = c.r_constraint_name
|
||||
WHERE c.table_name = %s AND c.constraint_type = 'R'""", [table_name.upper()])
|
||||
return [
|
||||
tuple(self.identifier_converter(cell) for cell in row)
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_primary_key_column(self, cursor, table_name):
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
cols.column_name
|
||||
FROM
|
||||
user_constraints,
|
||||
user_cons_columns cols
|
||||
WHERE
|
||||
user_constraints.constraint_name = cols.constraint_name AND
|
||||
user_constraints.constraint_type = 'P' AND
|
||||
user_constraints.table_name = UPPER(%s) AND
|
||||
cols.position = 1
|
||||
""", [table_name])
|
||||
row = cursor.fetchone()
|
||||
return self.identifier_converter(row[0]) if row else None
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns.
|
||||
"""
|
||||
constraints = {}
|
||||
# Loop over the constraints, getting PKs, uniques, and checks
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
user_constraints.constraint_name,
|
||||
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.position),
|
||||
CASE user_constraints.constraint_type
|
||||
WHEN 'P' THEN 1
|
||||
ELSE 0
|
||||
END AS is_primary_key,
|
||||
CASE
|
||||
WHEN user_constraints.constraint_type IN ('P', 'U') THEN 1
|
||||
ELSE 0
|
||||
END AS is_unique,
|
||||
CASE user_constraints.constraint_type
|
||||
WHEN 'C' THEN 1
|
||||
ELSE 0
|
||||
END AS is_check_constraint
|
||||
FROM
|
||||
user_constraints
|
||||
LEFT OUTER JOIN
|
||||
user_cons_columns cols ON user_constraints.constraint_name = cols.constraint_name
|
||||
WHERE
|
||||
user_constraints.constraint_type = ANY('P', 'U', 'C')
|
||||
AND user_constraints.table_name = UPPER(%s)
|
||||
GROUP BY user_constraints.constraint_name, user_constraints.constraint_type
|
||||
""", [table_name])
|
||||
for constraint, columns, pk, unique, check in cursor.fetchall():
|
||||
constraint = self.identifier_converter(constraint)
|
||||
constraints[constraint] = {
|
||||
'columns': columns.split(','),
|
||||
'primary_key': pk,
|
||||
'unique': unique,
|
||||
'foreign_key': None,
|
||||
'check': check,
|
||||
'index': unique, # All uniques come with an index
|
||||
}
|
||||
# Foreign key constraints
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
cons.constraint_name,
|
||||
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.position),
|
||||
LOWER(rcols.table_name),
|
||||
LOWER(rcols.column_name)
|
||||
FROM
|
||||
user_constraints cons
|
||||
INNER JOIN
|
||||
user_cons_columns rcols ON rcols.constraint_name = cons.r_constraint_name AND rcols.position = 1
|
||||
LEFT OUTER JOIN
|
||||
user_cons_columns cols ON cons.constraint_name = cols.constraint_name
|
||||
WHERE
|
||||
cons.constraint_type = 'R' AND
|
||||
cons.table_name = UPPER(%s)
|
||||
GROUP BY cons.constraint_name, rcols.table_name, rcols.column_name
|
||||
""", [table_name])
|
||||
for constraint, columns, other_table, other_column in cursor.fetchall():
|
||||
constraint = self.identifier_converter(constraint)
|
||||
constraints[constraint] = {
|
||||
'primary_key': False,
|
||||
'unique': False,
|
||||
'foreign_key': (other_table, other_column),
|
||||
'check': False,
|
||||
'index': False,
|
||||
'columns': columns.split(','),
|
||||
}
|
||||
# Now get indexes
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
ind.index_name,
|
||||
LOWER(ind.index_type),
|
||||
LOWER(ind.uniqueness),
|
||||
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.column_position),
|
||||
LISTAGG(cols.descend, ',') WITHIN GROUP (ORDER BY cols.column_position)
|
||||
FROM
|
||||
user_ind_columns cols, user_indexes ind
|
||||
WHERE
|
||||
cols.table_name = UPPER(%s) AND
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM user_constraints cons
|
||||
WHERE ind.index_name = cons.index_name
|
||||
) AND cols.index_name = ind.index_name
|
||||
GROUP BY ind.index_name, ind.index_type, ind.uniqueness
|
||||
""", [table_name])
|
||||
for constraint, type_, unique, columns, orders in cursor.fetchall():
|
||||
constraint = self.identifier_converter(constraint)
|
||||
constraints[constraint] = {
|
||||
'primary_key': False,
|
||||
'unique': unique == 'unique',
|
||||
'foreign_key': None,
|
||||
'check': False,
|
||||
'index': True,
|
||||
'type': 'idx' if type_ == 'normal' else type_,
|
||||
'columns': columns.split(','),
|
||||
'orders': orders.split(','),
|
||||
}
|
||||
return constraints
|
||||
647
venv/Lib/site-packages/django/db/backends/oracle/operations.py
Normal file
647
venv/Lib/site-packages/django/db/backends/oracle/operations.py
Normal file
@@ -0,0 +1,647 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import DatabaseError, NotSupportedError
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.utils import (
|
||||
split_tzname_delta, strip_quotes, truncate_name,
|
||||
)
|
||||
from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
|
||||
from django.db.models.expressions import RawSQL
|
||||
from django.db.models.sql.where import WhereNode
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_bytes, force_str
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
from .base import Database
|
||||
from .utils import BulkInsertMapper, InsertVar, Oracle_datetime
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
# Oracle uses NUMBER(5), NUMBER(11), and NUMBER(19) for integer fields.
|
||||
# SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by
|
||||
# SmallAutoField, to preserve backward compatibility.
|
||||
integer_field_ranges = {
|
||||
'SmallIntegerField': (-99999999999, 99999999999),
|
||||
'IntegerField': (-99999999999, 99999999999),
|
||||
'BigIntegerField': (-9999999999999999999, 9999999999999999999),
|
||||
'PositiveBigIntegerField': (0, 9999999999999999999),
|
||||
'PositiveSmallIntegerField': (0, 99999999999),
|
||||
'PositiveIntegerField': (0, 99999999999),
|
||||
'SmallAutoField': (-99999, 99999),
|
||||
'AutoField': (-99999999999, 99999999999),
|
||||
'BigAutoField': (-9999999999999999999, 9999999999999999999),
|
||||
}
|
||||
set_operators = {**BaseDatabaseOperations.set_operators, 'difference': 'MINUS'}
|
||||
|
||||
# TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
|
||||
_sequence_reset_sql = """
|
||||
DECLARE
|
||||
table_value integer;
|
||||
seq_value integer;
|
||||
seq_name user_tab_identity_cols.sequence_name%%TYPE;
|
||||
BEGIN
|
||||
BEGIN
|
||||
SELECT sequence_name INTO seq_name FROM user_tab_identity_cols
|
||||
WHERE table_name = '%(table_name)s' AND
|
||||
column_name = '%(column_name)s';
|
||||
EXCEPTION WHEN NO_DATA_FOUND THEN
|
||||
seq_name := '%(no_autofield_sequence_name)s';
|
||||
END;
|
||||
|
||||
SELECT NVL(MAX(%(column)s), 0) INTO table_value FROM %(table)s;
|
||||
SELECT NVL(last_number - cache_size, 0) INTO seq_value FROM user_sequences
|
||||
WHERE sequence_name = seq_name;
|
||||
WHILE table_value > seq_value LOOP
|
||||
EXECUTE IMMEDIATE 'SELECT "'||seq_name||'".nextval FROM DUAL'
|
||||
INTO seq_value;
|
||||
END LOOP;
|
||||
END;
|
||||
/"""
|
||||
|
||||
# Oracle doesn't support string without precision; use the max string size.
|
||||
cast_char_field_without_max_length = 'NVARCHAR2(2000)'
|
||||
cast_data_types = {
|
||||
'AutoField': 'NUMBER(11)',
|
||||
'BigAutoField': 'NUMBER(19)',
|
||||
'SmallAutoField': 'NUMBER(5)',
|
||||
'TextField': cast_char_field_without_max_length,
|
||||
}
|
||||
|
||||
def cache_key_culling_sql(self):
|
||||
return 'SELECT cache_key FROM %s ORDER BY cache_key OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY'
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
if lookup_type == 'week_day':
|
||||
# TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
|
||||
return "TO_CHAR(%s, 'D')" % field_name
|
||||
elif lookup_type == 'iso_week_day':
|
||||
return "TO_CHAR(%s - 1, 'D')" % field_name
|
||||
elif lookup_type == 'week':
|
||||
# IW = ISO week number
|
||||
return "TO_CHAR(%s, 'IW')" % field_name
|
||||
elif lookup_type == 'quarter':
|
||||
return "TO_CHAR(%s, 'Q')" % field_name
|
||||
elif lookup_type == 'iso_year':
|
||||
return "TO_CHAR(%s, 'IYYY')" % field_name
|
||||
else:
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/EXTRACT-datetime.html
|
||||
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
|
||||
if lookup_type in ('year', 'month'):
|
||||
return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
|
||||
elif lookup_type == 'quarter':
|
||||
return "TRUNC(%s, 'Q')" % field_name
|
||||
elif lookup_type == 'week':
|
||||
return "TRUNC(%s, 'IW')" % field_name
|
||||
else:
|
||||
return "TRUNC(%s)" % field_name
|
||||
|
||||
# Oracle crashes with "ORA-03113: end-of-file on communication channel"
|
||||
# if the time zone name is passed in parameter. Use interpolation instead.
|
||||
# https://groups.google.com/forum/#!msg/django-developers/zwQju7hbG78/9l934yelwfsJ
|
||||
# This regexp matches all time zone names from the zoneinfo database.
|
||||
_tzname_re = _lazy_re_compile(r'^[\w/:+-]+$')
|
||||
|
||||
def _prepare_tzname_delta(self, tzname):
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
return f'{sign}{offset}' if offset else tzname
|
||||
|
||||
def _convert_field_to_tz(self, field_name, tzname):
|
||||
if not (settings.USE_TZ and tzname):
|
||||
return field_name
|
||||
if not self._tzname_re.match(tzname):
|
||||
raise ValueError("Invalid time zone name: %s" % tzname)
|
||||
# Convert from connection timezone to the local time, returning
|
||||
# TIMESTAMP WITH TIME ZONE and cast it back to TIMESTAMP to strip the
|
||||
# TIME ZONE details.
|
||||
if self.connection.timezone_name != tzname:
|
||||
return "CAST((FROM_TZ(%s, '%s') AT TIME ZONE '%s') AS TIMESTAMP)" % (
|
||||
field_name,
|
||||
self.connection.timezone_name,
|
||||
self._prepare_tzname_delta(tzname),
|
||||
)
|
||||
return field_name
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return 'TRUNC(%s)' % field_name
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
# Since `TimeField` values are stored as TIMESTAMP change to the
|
||||
# default date and convert the field to the specified timezone.
|
||||
convert_datetime_sql = (
|
||||
"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR(%s, 'HH24:MI:SS.FF')), "
|
||||
"'YYYY-MM-DD HH24:MI:SS.FF')"
|
||||
) % self._convert_field_to_tz(field_name, tzname)
|
||||
return "CASE WHEN %s IS NOT NULL THEN %s ELSE NULL END" % (
|
||||
field_name, convert_datetime_sql,
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return self.date_extract_sql(lookup_type, field_name)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
|
||||
if lookup_type in ('year', 'month'):
|
||||
sql = "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
|
||||
elif lookup_type == 'quarter':
|
||||
sql = "TRUNC(%s, 'Q')" % field_name
|
||||
elif lookup_type == 'week':
|
||||
sql = "TRUNC(%s, 'IW')" % field_name
|
||||
elif lookup_type == 'day':
|
||||
sql = "TRUNC(%s)" % field_name
|
||||
elif lookup_type == 'hour':
|
||||
sql = "TRUNC(%s, 'HH24')" % field_name
|
||||
elif lookup_type == 'minute':
|
||||
sql = "TRUNC(%s, 'MI')" % field_name
|
||||
else:
|
||||
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
|
||||
return sql
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
# The implementation is similar to `datetime_trunc_sql` as both
|
||||
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where
|
||||
# the date part of the later is ignored.
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
if lookup_type == 'hour':
|
||||
sql = "TRUNC(%s, 'HH24')" % field_name
|
||||
elif lookup_type == 'minute':
|
||||
sql = "TRUNC(%s, 'MI')" % field_name
|
||||
elif lookup_type == 'second':
|
||||
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
|
||||
return sql
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type in ['JSONField', 'TextField']:
|
||||
converters.append(self.convert_textfield_value)
|
||||
elif internal_type == 'BinaryField':
|
||||
converters.append(self.convert_binaryfield_value)
|
||||
elif internal_type == 'BooleanField':
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
elif internal_type == 'DateTimeField':
|
||||
if settings.USE_TZ:
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == 'DateField':
|
||||
converters.append(self.convert_datefield_value)
|
||||
elif internal_type == 'TimeField':
|
||||
converters.append(self.convert_timefield_value)
|
||||
elif internal_type == 'UUIDField':
|
||||
converters.append(self.convert_uuidfield_value)
|
||||
# Oracle stores empty strings as null. If the field accepts the empty
|
||||
# string, undo this to adhere to the Django convention of using
|
||||
# the empty string instead of null.
|
||||
if expression.output_field.empty_strings_allowed:
|
||||
converters.append(
|
||||
self.convert_empty_bytes
|
||||
if internal_type == 'BinaryField' else
|
||||
self.convert_empty_string
|
||||
)
|
||||
return converters
|
||||
|
||||
def convert_textfield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.LOB):
|
||||
value = value.read()
|
||||
return value
|
||||
|
||||
def convert_binaryfield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.LOB):
|
||||
value = force_bytes(value.read())
|
||||
return value
|
||||
|
||||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
if value in (0, 1):
|
||||
value = bool(value)
|
||||
return value
|
||||
|
||||
# cx_Oracle always returns datetime.datetime objects for
|
||||
# DATE and TIMESTAMP columns, but Django wants to see a
|
||||
# python datetime.date, .time, or .datetime.
|
||||
|
||||
def convert_datetimefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = timezone.make_aware(value, self.connection.timezone)
|
||||
return value
|
||||
|
||||
def convert_datefield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.Timestamp):
|
||||
value = value.date()
|
||||
return value
|
||||
|
||||
def convert_timefield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.Timestamp):
|
||||
value = value.time()
|
||||
return value
|
||||
|
||||
def convert_uuidfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def convert_empty_string(value, expression, connection):
|
||||
return '' if value is None else value
|
||||
|
||||
@staticmethod
|
||||
def convert_empty_bytes(value, expression, connection):
|
||||
return b'' if value is None else value
|
||||
|
||||
def deferrable_sql(self):
|
||||
return " DEFERRABLE INITIALLY DEFERRED"
|
||||
|
||||
def fetch_returned_insert_columns(self, cursor, returning_params):
|
||||
columns = []
|
||||
for param in returning_params:
|
||||
value = param.get_value()
|
||||
if value == []:
|
||||
raise DatabaseError(
|
||||
'The database did not return a new row id. Probably '
|
||||
'"ORA-1403: no data found" was raised internally but was '
|
||||
'hidden by the Oracle OCI library (see '
|
||||
'https://code.djangoproject.com/ticket/28859).'
|
||||
)
|
||||
columns.append(value[0])
|
||||
return tuple(columns)
|
||||
|
||||
def field_cast_sql(self, db_type, internal_type):
|
||||
if db_type and db_type.endswith('LOB') and internal_type != 'JSONField':
|
||||
return "DBMS_LOB.SUBSTR(%s)"
|
||||
else:
|
||||
return "%s"
|
||||
|
||||
def no_limit_value(self):
|
||||
return None
|
||||
|
||||
def limit_offset_sql(self, low_mark, high_mark):
|
||||
fetch, offset = self._get_limit_offset_params(low_mark, high_mark)
|
||||
return ' '.join(sql for sql in (
|
||||
('OFFSET %d ROWS' % offset) if offset else None,
|
||||
('FETCH FIRST %d ROWS ONLY' % fetch) if fetch else None,
|
||||
) if sql)
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# https://cx-oracle.readthedocs.io/en/latest/cursor.html#Cursor.statement
|
||||
# The DB API definition does not define this attribute.
|
||||
statement = cursor.statement
|
||||
# Unlike Psycopg's `query` and MySQLdb`'s `_executed`, cx_Oracle's
|
||||
# `statement` doesn't contain the query parameters. Substitute
|
||||
# parameters manually.
|
||||
if isinstance(params, (tuple, list)):
|
||||
for i, param in enumerate(params):
|
||||
statement = statement.replace(':arg%d' % i, force_str(param, errors='replace'))
|
||||
elif isinstance(params, dict):
|
||||
for key, param in params.items():
|
||||
statement = statement.replace(':%s' % key, force_str(param, errors='replace'))
|
||||
return statement
|
||||
|
||||
def last_insert_id(self, cursor, table_name, pk_name):
|
||||
sq_name = self._get_sequence_name(cursor, strip_quotes(table_name), pk_name)
|
||||
cursor.execute('"%s".currval' % sq_name)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
|
||||
return "UPPER(%s)"
|
||||
if internal_type == 'JSONField' and lookup_type == 'exact':
|
||||
return 'DBMS_LOB.SUBSTR(%s)'
|
||||
return "%s"
|
||||
|
||||
def max_in_list_size(self):
|
||||
return 1000
|
||||
|
||||
def max_name_length(self):
|
||||
return 30
|
||||
|
||||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
||||
def prep_for_iexact_query(self, x):
|
||||
return x
|
||||
|
||||
def process_clob(self, value):
|
||||
if value is None:
|
||||
return ''
|
||||
return value.read()
|
||||
|
||||
def quote_name(self, name):
|
||||
# SQL92 requires delimited (quoted) names to be case-sensitive. When
|
||||
# not quoted, Oracle has case-insensitive behavior for identifiers, but
|
||||
# always defaults to uppercase.
|
||||
# We simplify things by making Oracle identifiers always uppercase.
|
||||
if not name.startswith('"') and not name.endswith('"'):
|
||||
name = '"%s"' % truncate_name(name, self.max_name_length())
|
||||
# Oracle puts the query text into a (query % args) construct, so % signs
|
||||
# in names need to be escaped. The '%%' will be collapsed back to '%' at
|
||||
# that stage so we aren't really making the name longer here.
|
||||
name = name.replace('%', '%%')
|
||||
return name.upper()
|
||||
|
||||
def regex_lookup(self, lookup_type):
|
||||
if lookup_type == 'regex':
|
||||
match_option = "'c'"
|
||||
else:
|
||||
match_option = "'i'"
|
||||
return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
if not fields:
|
||||
return '', ()
|
||||
field_names = []
|
||||
params = []
|
||||
for field in fields:
|
||||
field_names.append('%s.%s' % (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
))
|
||||
params.append(InsertVar(field))
|
||||
return 'RETURNING %s INTO %s' % (
|
||||
', '.join(field_names),
|
||||
', '.join(['%s'] * len(params)),
|
||||
), tuple(params)
|
||||
|
||||
def __foreign_key_constraints(self, table_name, recursive):
|
||||
with self.connection.cursor() as cursor:
|
||||
if recursive:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
user_tables.table_name, rcons.constraint_name
|
||||
FROM
|
||||
user_tables
|
||||
JOIN
|
||||
user_constraints cons
|
||||
ON (user_tables.table_name = cons.table_name AND cons.constraint_type = ANY('P', 'U'))
|
||||
LEFT JOIN
|
||||
user_constraints rcons
|
||||
ON (user_tables.table_name = rcons.table_name AND rcons.constraint_type = 'R')
|
||||
START WITH user_tables.table_name = UPPER(%s)
|
||||
CONNECT BY NOCYCLE PRIOR cons.constraint_name = rcons.r_constraint_name
|
||||
GROUP BY
|
||||
user_tables.table_name, rcons.constraint_name
|
||||
HAVING user_tables.table_name != UPPER(%s)
|
||||
ORDER BY MAX(level) DESC
|
||||
""", (table_name, table_name))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
cons.table_name, cons.constraint_name
|
||||
FROM
|
||||
user_constraints cons
|
||||
WHERE
|
||||
cons.constraint_type = 'R'
|
||||
AND cons.table_name = UPPER(%s)
|
||||
""", (table_name,))
|
||||
return cursor.fetchall()
|
||||
|
||||
@cached_property
|
||||
def _foreign_key_constraints(self):
|
||||
# 512 is large enough to fit the ~330 tables (as of this writing) in
|
||||
# Django's test suite.
|
||||
return lru_cache(maxsize=512)(self.__foreign_key_constraints)
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if not tables:
|
||||
return []
|
||||
|
||||
truncated_tables = {table.upper() for table in tables}
|
||||
constraints = set()
|
||||
# Oracle's TRUNCATE CASCADE only works with ON DELETE CASCADE foreign
|
||||
# keys which Django doesn't define. Emulate the PostgreSQL behavior
|
||||
# which truncates all dependent tables by manually retrieving all
|
||||
# foreign key constraints and resolving dependencies.
|
||||
for table in tables:
|
||||
for foreign_table, constraint in self._foreign_key_constraints(table, recursive=allow_cascade):
|
||||
if allow_cascade:
|
||||
truncated_tables.add(foreign_table)
|
||||
constraints.add((foreign_table, constraint))
|
||||
sql = [
|
||||
'%s %s %s %s %s %s %s %s;' % (
|
||||
style.SQL_KEYWORD('ALTER'),
|
||||
style.SQL_KEYWORD('TABLE'),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
style.SQL_KEYWORD('DISABLE'),
|
||||
style.SQL_KEYWORD('CONSTRAINT'),
|
||||
style.SQL_FIELD(self.quote_name(constraint)),
|
||||
style.SQL_KEYWORD('KEEP'),
|
||||
style.SQL_KEYWORD('INDEX'),
|
||||
) for table, constraint in constraints
|
||||
] + [
|
||||
'%s %s %s;' % (
|
||||
style.SQL_KEYWORD('TRUNCATE'),
|
||||
style.SQL_KEYWORD('TABLE'),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
) for table in truncated_tables
|
||||
] + [
|
||||
'%s %s %s %s %s %s;' % (
|
||||
style.SQL_KEYWORD('ALTER'),
|
||||
style.SQL_KEYWORD('TABLE'),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
style.SQL_KEYWORD('ENABLE'),
|
||||
style.SQL_KEYWORD('CONSTRAINT'),
|
||||
style.SQL_FIELD(self.quote_name(constraint)),
|
||||
) for table, constraint in constraints
|
||||
]
|
||||
if reset_sequences:
|
||||
sequences = [
|
||||
sequence
|
||||
for sequence in self.connection.introspection.sequence_list()
|
||||
if sequence['table'].upper() in truncated_tables
|
||||
]
|
||||
# Since we've just deleted all the rows, running our sequence ALTER
|
||||
# code will reset the sequence to 0.
|
||||
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
|
||||
return sql
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
sql = []
|
||||
for sequence_info in sequences:
|
||||
no_autofield_sequence_name = self._get_no_autofield_sequence_name(sequence_info['table'])
|
||||
table = self.quote_name(sequence_info['table'])
|
||||
column = self.quote_name(sequence_info['column'] or 'id')
|
||||
query = self._sequence_reset_sql % {
|
||||
'no_autofield_sequence_name': no_autofield_sequence_name,
|
||||
'table': table,
|
||||
'column': column,
|
||||
'table_name': strip_quotes(table),
|
||||
'column_name': strip_quotes(column),
|
||||
}
|
||||
sql.append(query)
|
||||
return sql
|
||||
|
||||
def sequence_reset_sql(self, style, model_list):
|
||||
output = []
|
||||
query = self._sequence_reset_sql
|
||||
for model in model_list:
|
||||
for f in model._meta.local_fields:
|
||||
if isinstance(f, AutoField):
|
||||
no_autofield_sequence_name = self._get_no_autofield_sequence_name(model._meta.db_table)
|
||||
table = self.quote_name(model._meta.db_table)
|
||||
column = self.quote_name(f.column)
|
||||
output.append(query % {
|
||||
'no_autofield_sequence_name': no_autofield_sequence_name,
|
||||
'table': table,
|
||||
'column': column,
|
||||
'table_name': strip_quotes(table),
|
||||
'column_name': strip_quotes(column),
|
||||
})
|
||||
# Only one AutoField is allowed per model, so don't
|
||||
# continue to loop
|
||||
break
|
||||
return output
|
||||
|
||||
def start_transaction_sql(self):
|
||||
return ''
|
||||
|
||||
def tablespace_sql(self, tablespace, inline=False):
|
||||
if inline:
|
||||
return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
|
||||
else:
|
||||
return "TABLESPACE %s" % self.quote_name(tablespace)
|
||||
|
||||
def adapt_datefield_value(self, value):
|
||||
"""
|
||||
Transform a date value to an object compatible with what is expected
|
||||
by the backend driver for date columns.
|
||||
The default implementation transforms the date to text, but that is not
|
||||
necessary for Oracle.
|
||||
"""
|
||||
return value
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
"""
|
||||
Transform a datetime value to an object compatible with what is expected
|
||||
by the backend driver for datetime columns.
|
||||
|
||||
If naive datetime is passed assumes that is in UTC. Normally Django
|
||||
models.DateTimeField makes sure that if USE_TZ is True passed datetime
|
||||
is timezone aware.
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# cx_Oracle doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
value = timezone.make_naive(value, self.connection.timezone)
|
||||
else:
|
||||
raise ValueError("Oracle backend does not support timezone-aware datetimes when USE_TZ is False.")
|
||||
|
||||
return Oracle_datetime.from_datetime(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
return datetime.datetime.strptime(value, '%H:%M:%S')
|
||||
|
||||
# Oracle doesn't support tz-aware times
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("Oracle backend does not support timezone-aware times.")
|
||||
|
||||
return Oracle_datetime(1900, 1, 1, value.hour, value.minute,
|
||||
value.second, value.microsecond)
|
||||
|
||||
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
|
||||
return value
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
lhs, rhs = sub_expressions
|
||||
if connector == '%%':
|
||||
return 'MOD(%s)' % ','.join(sub_expressions)
|
||||
elif connector == '&':
|
||||
return 'BITAND(%s)' % ','.join(sub_expressions)
|
||||
elif connector == '|':
|
||||
return 'BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s' % {'lhs': lhs, 'rhs': rhs}
|
||||
elif connector == '<<':
|
||||
return '(%(lhs)s * POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
|
||||
elif connector == '>>':
|
||||
return 'FLOOR(%(lhs)s / POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
|
||||
elif connector == '^':
|
||||
return 'POWER(%s)' % ','.join(sub_expressions)
|
||||
elif connector == '#':
|
||||
raise NotSupportedError('Bitwise XOR is not supported in Oracle.')
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
|
||||
def _get_no_autofield_sequence_name(self, table):
|
||||
"""
|
||||
Manually created sequence name to keep backward compatibility for
|
||||
AutoFields that aren't Oracle identity columns.
|
||||
"""
|
||||
name_length = self.max_name_length() - 3
|
||||
return '%s_SQ' % truncate_name(strip_quotes(table), name_length).upper()
|
||||
|
||||
def _get_sequence_name(self, cursor, table, pk_name):
|
||||
cursor.execute("""
|
||||
SELECT sequence_name
|
||||
FROM user_tab_identity_cols
|
||||
WHERE table_name = UPPER(%s)
|
||||
AND column_name = UPPER(%s)""", [table, pk_name])
|
||||
row = cursor.fetchone()
|
||||
return self._get_no_autofield_sequence_name(table) if row is None else row[0]
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
query = []
|
||||
for row in placeholder_rows:
|
||||
select = []
|
||||
for i, placeholder in enumerate(row):
|
||||
# A model without any fields has fields=[None].
|
||||
if fields[i]:
|
||||
internal_type = getattr(fields[i], 'target_field', fields[i]).get_internal_type()
|
||||
placeholder = BulkInsertMapper.types.get(internal_type, '%s') % placeholder
|
||||
# Add columns aliases to the first select to avoid "ORA-00918:
|
||||
# column ambiguously defined" when two or more columns in the
|
||||
# first select have the same value.
|
||||
if not query:
|
||||
placeholder = '%s col_%s' % (placeholder, i)
|
||||
select.append(placeholder)
|
||||
query.append('SELECT %s FROM DUAL' % ', '.join(select))
|
||||
# Bulk insert to tables with Oracle identity columns causes Oracle to
|
||||
# add sequence.nextval to it. Sequence.nextval cannot be used with the
|
||||
# UNION operator. To prevent incorrect SQL, move UNION to a subquery.
|
||||
return 'SELECT * FROM (%s)' % ' UNION ALL '.join(query)
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
if internal_type == 'DateField':
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
params = (*lhs_params, *rhs_params)
|
||||
return "NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql), params
|
||||
return super().subtract_temporals(internal_type, lhs, rhs)
|
||||
|
||||
def bulk_batch_size(self, fields, objs):
|
||||
"""Oracle restricts the number of parameters in a query."""
|
||||
if fields:
|
||||
return self.connection.features.max_query_params // len(fields)
|
||||
return len(objs)
|
||||
|
||||
def conditional_expression_supported_in_where_clause(self, expression):
|
||||
"""
|
||||
Oracle supports only EXISTS(...) or filters in the WHERE clause, others
|
||||
must be compared with True.
|
||||
"""
|
||||
if isinstance(expression, (Exists, Lookup, WhereNode)):
|
||||
return True
|
||||
if isinstance(expression, ExpressionWrapper) and expression.conditional:
|
||||
return self.conditional_expression_supported_in_where_clause(expression.expression)
|
||||
if isinstance(expression, RawSQL) and expression.conditional:
|
||||
return True
|
||||
return False
|
||||
211
venv/Lib/site-packages/django/db/backends/oracle/schema.py
Normal file
211
venv/Lib/site-packages/django/db/backends/oracle/schema.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import copy
|
||||
import datetime
|
||||
import re
|
||||
|
||||
from django.db import DatabaseError
|
||||
from django.db.backends.base.schema import (
|
||||
BaseDatabaseSchemaEditor, _related_non_m2m_objects,
|
||||
)
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
|
||||
sql_create_column = "ALTER TABLE %(table)s ADD %(column)s %(definition)s"
|
||||
sql_alter_column_type = "MODIFY %(column)s %(type)s"
|
||||
sql_alter_column_null = "MODIFY %(column)s NULL"
|
||||
sql_alter_column_not_null = "MODIFY %(column)s NOT NULL"
|
||||
sql_alter_column_default = "MODIFY %(column)s DEFAULT %(default)s"
|
||||
sql_alter_column_no_default = "MODIFY %(column)s DEFAULT NULL"
|
||||
sql_alter_column_no_default_null = sql_alter_column_no_default
|
||||
sql_alter_column_collate = "MODIFY %(column)s %(type)s%(collation)s"
|
||||
|
||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
|
||||
sql_create_column_inline_fk = 'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
|
||||
sql_delete_table = "DROP TABLE %(table)s CASCADE CONSTRAINTS"
|
||||
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
|
||||
|
||||
def quote_value(self, value):
|
||||
if isinstance(value, (datetime.date, datetime.time, datetime.datetime)):
|
||||
return "'%s'" % value
|
||||
elif isinstance(value, str):
|
||||
return "'%s'" % value.replace("\'", "\'\'").replace('%', '%%')
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return "'%s'" % value.hex()
|
||||
elif isinstance(value, bool):
|
||||
return "1" if value else "0"
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
def remove_field(self, model, field):
|
||||
# If the column is an identity column, drop the identity before
|
||||
# removing the field.
|
||||
if self._is_identity_column(model._meta.db_table, field.column):
|
||||
self._drop_identity(model._meta.db_table, field.column)
|
||||
super().remove_field(model, field)
|
||||
|
||||
def delete_model(self, model):
|
||||
# Run superclass action
|
||||
super().delete_model(model)
|
||||
# Clean up manually created sequence.
|
||||
self.execute("""
|
||||
DECLARE
|
||||
i INTEGER;
|
||||
BEGIN
|
||||
SELECT COUNT(1) INTO i FROM USER_SEQUENCES
|
||||
WHERE SEQUENCE_NAME = '%(sq_name)s';
|
||||
IF i = 1 THEN
|
||||
EXECUTE IMMEDIATE 'DROP SEQUENCE "%(sq_name)s"';
|
||||
END IF;
|
||||
END;
|
||||
/""" % {'sq_name': self.connection.ops._get_no_autofield_sequence_name(model._meta.db_table)})
|
||||
|
||||
def alter_field(self, model, old_field, new_field, strict=False):
|
||||
try:
|
||||
super().alter_field(model, old_field, new_field, strict)
|
||||
except DatabaseError as e:
|
||||
description = str(e)
|
||||
# If we're changing type to an unsupported type we need a
|
||||
# SQLite-ish workaround
|
||||
if 'ORA-22858' in description or 'ORA-22859' in description:
|
||||
self._alter_field_type_workaround(model, old_field, new_field)
|
||||
# If an identity column is changing to a non-numeric type, drop the
|
||||
# identity first.
|
||||
elif 'ORA-30675' in description:
|
||||
self._drop_identity(model._meta.db_table, old_field.column)
|
||||
self.alter_field(model, old_field, new_field, strict)
|
||||
# If a primary key column is changing to an identity column, drop
|
||||
# the primary key first.
|
||||
elif 'ORA-30673' in description and old_field.primary_key:
|
||||
self._delete_primary_key(model, strict=True)
|
||||
self._alter_field_type_workaround(model, old_field, new_field)
|
||||
else:
|
||||
raise
|
||||
|
||||
def _alter_field_type_workaround(self, model, old_field, new_field):
|
||||
"""
|
||||
Oracle refuses to change from some type to other type.
|
||||
What we need to do instead is:
|
||||
- Add a nullable version of the desired field with a temporary name. If
|
||||
the new column is an auto field, then the temporary column can't be
|
||||
nullable.
|
||||
- Update the table to transfer values from old to new
|
||||
- Drop old column
|
||||
- Rename the new column and possibly drop the nullable property
|
||||
"""
|
||||
# Make a new field that's like the new one but with a temporary
|
||||
# column name.
|
||||
new_temp_field = copy.deepcopy(new_field)
|
||||
new_temp_field.null = (new_field.get_internal_type() not in ('AutoField', 'BigAutoField', 'SmallAutoField'))
|
||||
new_temp_field.column = self._generate_temp_name(new_field.column)
|
||||
# Add it
|
||||
self.add_field(model, new_temp_field)
|
||||
# Explicit data type conversion
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf
|
||||
# /Data-Type-Comparison-Rules.html#GUID-D0C5A47E-6F93-4C2D-9E49-4F2B86B359DD
|
||||
new_value = self.quote_name(old_field.column)
|
||||
old_type = old_field.db_type(self.connection)
|
||||
if re.match('^N?CLOB', old_type):
|
||||
new_value = "TO_CHAR(%s)" % new_value
|
||||
old_type = 'VARCHAR2'
|
||||
if re.match('^N?VARCHAR2', old_type):
|
||||
new_internal_type = new_field.get_internal_type()
|
||||
if new_internal_type == 'DateField':
|
||||
new_value = "TO_DATE(%s, 'YYYY-MM-DD')" % new_value
|
||||
elif new_internal_type == 'DateTimeField':
|
||||
new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
|
||||
elif new_internal_type == 'TimeField':
|
||||
# TimeField are stored as TIMESTAMP with a 1900-01-01 date part.
|
||||
new_value = "TO_TIMESTAMP(CONCAT('1900-01-01 ', %s), 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
|
||||
# Transfer values across
|
||||
self.execute("UPDATE %s set %s=%s" % (
|
||||
self.quote_name(model._meta.db_table),
|
||||
self.quote_name(new_temp_field.column),
|
||||
new_value,
|
||||
))
|
||||
# Drop the old field
|
||||
self.remove_field(model, old_field)
|
||||
# Rename and possibly make the new field NOT NULL
|
||||
super().alter_field(model, new_temp_field, new_field)
|
||||
# Recreate foreign key (if necessary) because the old field is not
|
||||
# passed to the alter_field() and data types of new_temp_field and
|
||||
# new_field always match.
|
||||
new_type = new_field.db_type(self.connection)
|
||||
if (
|
||||
(old_field.primary_key and new_field.primary_key) or
|
||||
(old_field.unique and new_field.unique)
|
||||
) and old_type != new_type:
|
||||
for _, rel in _related_non_m2m_objects(new_temp_field, new_field):
|
||||
if rel.field.db_constraint:
|
||||
self.execute(self._create_fk_sql(rel.related_model, rel.field, '_fk'))
|
||||
|
||||
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
|
||||
auto_field_types = {'AutoField', 'BigAutoField', 'SmallAutoField'}
|
||||
# Drop the identity if migrating away from AutoField.
|
||||
if (
|
||||
old_field.get_internal_type() in auto_field_types and
|
||||
new_field.get_internal_type() not in auto_field_types and
|
||||
self._is_identity_column(model._meta.db_table, new_field.column)
|
||||
):
|
||||
self._drop_identity(model._meta.db_table, new_field.column)
|
||||
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
|
||||
|
||||
def normalize_name(self, name):
|
||||
"""
|
||||
Get the properly shortened and uppercased identifier as returned by
|
||||
quote_name() but without the quotes.
|
||||
"""
|
||||
nn = self.quote_name(name)
|
||||
if nn[0] == '"' and nn[-1] == '"':
|
||||
nn = nn[1:-1]
|
||||
return nn
|
||||
|
||||
def _generate_temp_name(self, for_name):
|
||||
"""Generate temporary names for workarounds that need temp columns."""
|
||||
suffix = hex(hash(for_name)).upper()[1:]
|
||||
return self.normalize_name(for_name + "_" + suffix)
|
||||
|
||||
def prepare_default(self, value):
|
||||
return self.quote_value(value)
|
||||
|
||||
def _field_should_be_indexed(self, model, field):
|
||||
create_index = super()._field_should_be_indexed(model, field)
|
||||
db_type = field.db_type(self.connection)
|
||||
if db_type is not None and db_type.lower() in self.connection._limited_data_types:
|
||||
return False
|
||||
return create_index
|
||||
|
||||
def _unique_should_be_added(self, old_field, new_field):
|
||||
return (
|
||||
super()._unique_should_be_added(old_field, new_field) and
|
||||
not self._field_became_primary_key(old_field, new_field)
|
||||
)
|
||||
|
||||
def _is_identity_column(self, table_name, column_name):
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
CASE WHEN identity_column = 'YES' THEN 1 ELSE 0 END
|
||||
FROM user_tab_cols
|
||||
WHERE table_name = %s AND
|
||||
column_name = %s
|
||||
""", [self.normalize_name(table_name), self.normalize_name(column_name)])
|
||||
row = cursor.fetchone()
|
||||
return row[0] if row else False
|
||||
|
||||
def _drop_identity(self, table_name, column_name):
|
||||
self.execute('ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY' % {
|
||||
'table': self.quote_name(table_name),
|
||||
'column': self.quote_name(column_name),
|
||||
})
|
||||
|
||||
def _get_default_collation(self, table_name):
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT default_collation FROM user_tables WHERE table_name = %s
|
||||
""", [self.normalize_name(table_name)])
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def _alter_column_collation_sql(self, model, new_field, new_type, new_collation):
|
||||
if new_collation is None:
|
||||
new_collation = self._get_default_collation(model._meta.db_table)
|
||||
return super()._alter_column_collation_sql(model, new_field, new_type, new_collation)
|
||||
90
venv/Lib/site-packages/django/db/backends/oracle/utils.py
Normal file
90
venv/Lib/site-packages/django/db/backends/oracle/utils.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import datetime
|
||||
|
||||
from .base import Database
|
||||
|
||||
|
||||
class InsertVar:
|
||||
"""
|
||||
A late-binding cursor variable that can be passed to Cursor.execute
|
||||
as a parameter, in order to receive the id of the row created by an
|
||||
insert statement.
|
||||
"""
|
||||
types = {
|
||||
'AutoField': int,
|
||||
'BigAutoField': int,
|
||||
'SmallAutoField': int,
|
||||
'IntegerField': int,
|
||||
'BigIntegerField': int,
|
||||
'SmallIntegerField': int,
|
||||
'PositiveBigIntegerField': int,
|
||||
'PositiveSmallIntegerField': int,
|
||||
'PositiveIntegerField': int,
|
||||
'FloatField': Database.NATIVE_FLOAT,
|
||||
'DateTimeField': Database.TIMESTAMP,
|
||||
'DateField': Database.Date,
|
||||
'DecimalField': Database.NUMBER,
|
||||
}
|
||||
|
||||
def __init__(self, field):
|
||||
internal_type = getattr(field, 'target_field', field).get_internal_type()
|
||||
self.db_type = self.types.get(internal_type, str)
|
||||
self.bound_param = None
|
||||
|
||||
def bind_parameter(self, cursor):
|
||||
self.bound_param = cursor.cursor.var(self.db_type)
|
||||
return self.bound_param
|
||||
|
||||
def get_value(self):
|
||||
return self.bound_param.getvalue()
|
||||
|
||||
|
||||
class Oracle_datetime(datetime.datetime):
|
||||
"""
|
||||
A datetime object, with an additional class attribute
|
||||
to tell cx_Oracle to save the microseconds too.
|
||||
"""
|
||||
input_size = Database.TIMESTAMP
|
||||
|
||||
@classmethod
|
||||
def from_datetime(cls, dt):
|
||||
return Oracle_datetime(
|
||||
dt.year, dt.month, dt.day,
|
||||
dt.hour, dt.minute, dt.second, dt.microsecond,
|
||||
)
|
||||
|
||||
|
||||
class BulkInsertMapper:
|
||||
BLOB = 'TO_BLOB(%s)'
|
||||
CLOB = 'TO_CLOB(%s)'
|
||||
DATE = 'TO_DATE(%s)'
|
||||
INTERVAL = 'CAST(%s as INTERVAL DAY(9) TO SECOND(6))'
|
||||
NUMBER = 'TO_NUMBER(%s)'
|
||||
TIMESTAMP = 'TO_TIMESTAMP(%s)'
|
||||
|
||||
types = {
|
||||
'AutoField': NUMBER,
|
||||
'BigAutoField': NUMBER,
|
||||
'BigIntegerField': NUMBER,
|
||||
'BinaryField': BLOB,
|
||||
'BooleanField': NUMBER,
|
||||
'DateField': DATE,
|
||||
'DateTimeField': TIMESTAMP,
|
||||
'DecimalField': NUMBER,
|
||||
'DurationField': INTERVAL,
|
||||
'FloatField': NUMBER,
|
||||
'IntegerField': NUMBER,
|
||||
'PositiveBigIntegerField': NUMBER,
|
||||
'PositiveIntegerField': NUMBER,
|
||||
'PositiveSmallIntegerField': NUMBER,
|
||||
'SmallAutoField': NUMBER,
|
||||
'SmallIntegerField': NUMBER,
|
||||
'TextField': CLOB,
|
||||
'TimeField': TIMESTAMP,
|
||||
}
|
||||
|
||||
|
||||
def dsn(settings_dict):
|
||||
if settings_dict['PORT']:
|
||||
host = settings_dict['HOST'].strip() or 'localhost'
|
||||
return Database.makedsn(host, int(settings_dict['PORT']), settings_dict['NAME'])
|
||||
return settings_dict['NAME']
|
||||
@@ -0,0 +1,22 @@
|
||||
from django.core import checks
|
||||
from django.db.backends.base.validation import BaseDatabaseValidation
|
||||
|
||||
|
||||
class DatabaseValidation(BaseDatabaseValidation):
|
||||
def check_field_type(self, field, field_type):
|
||||
"""Oracle doesn't support a database index on some data types."""
|
||||
errors = []
|
||||
if field.db_index and field_type.lower() in self.connection._limited_data_types:
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
'Oracle does not support a database index on %s columns.'
|
||||
% field_type,
|
||||
hint=(
|
||||
"An index won't be created. Silence this warning if "
|
||||
"you don't care about it."
|
||||
),
|
||||
obj=field,
|
||||
id='fields.W162',
|
||||
)
|
||||
)
|
||||
return errors
|
||||
353
venv/Lib/site-packages/django/db/backends/postgresql/base.py
Normal file
353
venv/Lib/site-packages/django/db/backends/postgresql/base.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
PostgreSQL database backend for Django.
|
||||
|
||||
Requires psycopg 2: https://www.psycopg.org/
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import DatabaseError as WrappedDatabaseError, connections
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.backends.utils import (
|
||||
CursorDebugWrapper as BaseCursorDebugWrapper,
|
||||
)
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.safestring import SafeString
|
||||
from django.utils.version import get_version_tuple
|
||||
|
||||
try:
|
||||
import psycopg2 as Database
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
except ImportError as e:
|
||||
raise ImproperlyConfigured("Error loading psycopg2 module: %s" % e)
|
||||
|
||||
|
||||
def psycopg2_version():
|
||||
version = psycopg2.__version__.split(' ', 1)[0]
|
||||
return get_version_tuple(version)
|
||||
|
||||
|
||||
PSYCOPG2_VERSION = psycopg2_version()
|
||||
|
||||
if PSYCOPG2_VERSION < (2, 5, 4):
|
||||
raise ImproperlyConfigured("psycopg2_version 2.5.4 or newer is required; you have %s" % psycopg2.__version__)
|
||||
|
||||
|
||||
# Some of these import psycopg2, so import them after checking if it's installed.
|
||||
from .client import DatabaseClient # NOQA
|
||||
from .creation import DatabaseCreation # NOQA
|
||||
from .features import DatabaseFeatures # NOQA
|
||||
from .introspection import DatabaseIntrospection # NOQA
|
||||
from .operations import DatabaseOperations # NOQA
|
||||
from .schema import DatabaseSchemaEditor # NOQA
|
||||
|
||||
psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
|
||||
psycopg2.extras.register_uuid()
|
||||
|
||||
# Register support for inet[] manually so we don't have to handle the Inet()
|
||||
# object on load all the time.
|
||||
INETARRAY_OID = 1041
|
||||
INETARRAY = psycopg2.extensions.new_array_type(
|
||||
(INETARRAY_OID,),
|
||||
'INETARRAY',
|
||||
psycopg2.extensions.UNICODE,
|
||||
)
|
||||
psycopg2.extensions.register_type(INETARRAY)
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = 'postgresql'
|
||||
display_name = 'PostgreSQL'
|
||||
# This dictionary maps Field objects to their associated PostgreSQL column
|
||||
# types, as strings. Column-type strings can contain format strings; they'll
|
||||
# be interpolated against the values of Field.__dict__ before being output.
|
||||
# If a column type is set to None, it won't be included in the output.
|
||||
data_types = {
|
||||
'AutoField': 'serial',
|
||||
'BigAutoField': 'bigserial',
|
||||
'BinaryField': 'bytea',
|
||||
'BooleanField': 'boolean',
|
||||
'CharField': 'varchar(%(max_length)s)',
|
||||
'DateField': 'date',
|
||||
'DateTimeField': 'timestamp with time zone',
|
||||
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
|
||||
'DurationField': 'interval',
|
||||
'FileField': 'varchar(%(max_length)s)',
|
||||
'FilePathField': 'varchar(%(max_length)s)',
|
||||
'FloatField': 'double precision',
|
||||
'IntegerField': 'integer',
|
||||
'BigIntegerField': 'bigint',
|
||||
'IPAddressField': 'inet',
|
||||
'GenericIPAddressField': 'inet',
|
||||
'JSONField': 'jsonb',
|
||||
'OneToOneField': 'integer',
|
||||
'PositiveBigIntegerField': 'bigint',
|
||||
'PositiveIntegerField': 'integer',
|
||||
'PositiveSmallIntegerField': 'smallint',
|
||||
'SlugField': 'varchar(%(max_length)s)',
|
||||
'SmallAutoField': 'smallserial',
|
||||
'SmallIntegerField': 'smallint',
|
||||
'TextField': 'text',
|
||||
'TimeField': 'time',
|
||||
'UUIDField': 'uuid',
|
||||
}
|
||||
data_type_check_constraints = {
|
||||
'PositiveBigIntegerField': '"%(column)s" >= 0',
|
||||
'PositiveIntegerField': '"%(column)s" >= 0',
|
||||
'PositiveSmallIntegerField': '"%(column)s" >= 0',
|
||||
}
|
||||
operators = {
|
||||
'exact': '= %s',
|
||||
'iexact': '= UPPER(%s)',
|
||||
'contains': 'LIKE %s',
|
||||
'icontains': 'LIKE UPPER(%s)',
|
||||
'regex': '~ %s',
|
||||
'iregex': '~* %s',
|
||||
'gt': '> %s',
|
||||
'gte': '>= %s',
|
||||
'lt': '< %s',
|
||||
'lte': '<= %s',
|
||||
'startswith': 'LIKE %s',
|
||||
'endswith': 'LIKE %s',
|
||||
'istartswith': 'LIKE UPPER(%s)',
|
||||
'iendswith': 'LIKE UPPER(%s)',
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
|
||||
# escaped on database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
|
||||
pattern_ops = {
|
||||
'contains': "LIKE '%%' || {} || '%%'",
|
||||
'icontains': "LIKE '%%' || UPPER({}) || '%%'",
|
||||
'startswith': "LIKE {} || '%%'",
|
||||
'istartswith': "LIKE UPPER({}) || '%%'",
|
||||
'endswith': "LIKE '%%' || {}",
|
||||
'iendswith': "LIKE '%%' || UPPER({})",
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
# PostgreSQL backend-specific attributes.
|
||||
_named_cursor_idx = 0
|
||||
|
||||
def get_connection_params(self):
|
||||
settings_dict = self.settings_dict
|
||||
# None may be used to connect to the default 'postgres' db
|
||||
if (
|
||||
settings_dict['NAME'] == '' and
|
||||
not settings_dict.get('OPTIONS', {}).get('service')
|
||||
):
|
||||
raise ImproperlyConfigured(
|
||||
"settings.DATABASES is improperly configured. "
|
||||
"Please supply the NAME or OPTIONS['service'] value."
|
||||
)
|
||||
if len(settings_dict['NAME'] or '') > self.ops.max_name_length():
|
||||
raise ImproperlyConfigured(
|
||||
"The database name '%s' (%d characters) is longer than "
|
||||
"PostgreSQL's limit of %d characters. Supply a shorter NAME "
|
||||
"in settings.DATABASES." % (
|
||||
settings_dict['NAME'],
|
||||
len(settings_dict['NAME']),
|
||||
self.ops.max_name_length(),
|
||||
)
|
||||
)
|
||||
conn_params = {}
|
||||
if settings_dict['NAME']:
|
||||
conn_params = {
|
||||
'database': settings_dict['NAME'],
|
||||
**settings_dict['OPTIONS'],
|
||||
}
|
||||
elif settings_dict['NAME'] is None:
|
||||
# Connect to the default 'postgres' db.
|
||||
settings_dict.get('OPTIONS', {}).pop('service', None)
|
||||
conn_params = {'database': 'postgres', **settings_dict['OPTIONS']}
|
||||
else:
|
||||
conn_params = {**settings_dict['OPTIONS']}
|
||||
|
||||
conn_params.pop('isolation_level', None)
|
||||
if settings_dict['USER']:
|
||||
conn_params['user'] = settings_dict['USER']
|
||||
if settings_dict['PASSWORD']:
|
||||
conn_params['password'] = settings_dict['PASSWORD']
|
||||
if settings_dict['HOST']:
|
||||
conn_params['host'] = settings_dict['HOST']
|
||||
if settings_dict['PORT']:
|
||||
conn_params['port'] = settings_dict['PORT']
|
||||
return conn_params
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
connection = Database.connect(**conn_params)
|
||||
|
||||
# self.isolation_level must be set:
|
||||
# - after connecting to the database in order to obtain the database's
|
||||
# default when no value is explicitly specified in options.
|
||||
# - before calling _set_autocommit() because if autocommit is on, that
|
||||
# will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
|
||||
options = self.settings_dict['OPTIONS']
|
||||
try:
|
||||
self.isolation_level = options['isolation_level']
|
||||
except KeyError:
|
||||
self.isolation_level = connection.isolation_level
|
||||
else:
|
||||
# Set the isolation level to the value from OPTIONS.
|
||||
if self.isolation_level != connection.isolation_level:
|
||||
connection.set_session(isolation_level=self.isolation_level)
|
||||
# Register dummy loads() to avoid a round trip from psycopg2's decode
|
||||
# to json.dumps() to json.loads(), when using a custom decoder in
|
||||
# JSONField.
|
||||
psycopg2.extras.register_default_jsonb(conn_or_curs=connection, loads=lambda x: x)
|
||||
return connection
|
||||
|
||||
def ensure_timezone(self):
|
||||
if self.connection is None:
|
||||
return False
|
||||
conn_timezone_name = self.connection.get_parameter_status('TimeZone')
|
||||
timezone_name = self.timezone_name
|
||||
if timezone_name and conn_timezone_name != timezone_name:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
|
||||
return True
|
||||
return False
|
||||
|
||||
def init_connection_state(self):
|
||||
self.connection.set_client_encoding('UTF8')
|
||||
|
||||
timezone_changed = self.ensure_timezone()
|
||||
if timezone_changed:
|
||||
# Commit after setting the time zone (see #17062)
|
||||
if not self.get_autocommit():
|
||||
self.connection.commit()
|
||||
|
||||
@async_unsafe
|
||||
def create_cursor(self, name=None):
|
||||
if name:
|
||||
# In autocommit mode, the cursor will be used outside of a
|
||||
# transaction, hence use a holdable cursor.
|
||||
cursor = self.connection.cursor(name, scrollable=False, withhold=self.connection.autocommit)
|
||||
else:
|
||||
cursor = self.connection.cursor()
|
||||
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
|
||||
return cursor
|
||||
|
||||
def tzinfo_factory(self, offset):
|
||||
return self.timezone
|
||||
|
||||
@async_unsafe
|
||||
def chunked_cursor(self):
|
||||
self._named_cursor_idx += 1
|
||||
# Get the current async task
|
||||
# Note that right now this is behind @async_unsafe, so this is
|
||||
# unreachable, but in future we'll start loosening this restriction.
|
||||
# For now, it's here so that every use of "threading" is
|
||||
# also async-compatible.
|
||||
try:
|
||||
current_task = asyncio.current_task()
|
||||
except RuntimeError:
|
||||
current_task = None
|
||||
# Current task can be none even if the current_task call didn't error
|
||||
if current_task:
|
||||
task_ident = str(id(current_task))
|
||||
else:
|
||||
task_ident = 'sync'
|
||||
# Use that and the thread ident to get a unique name
|
||||
return self._cursor(
|
||||
name='_django_curs_%d_%s_%d' % (
|
||||
# Avoid reusing name in other threads / tasks
|
||||
threading.current_thread().ident,
|
||||
task_ident,
|
||||
self._named_cursor_idx,
|
||||
)
|
||||
)
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
with self.wrap_database_errors:
|
||||
self.connection.autocommit = autocommit
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check constraints by setting them to immediate. Return them to deferred
|
||||
afterward.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
|
||||
cursor.execute('SET CONSTRAINTS ALL DEFERRED')
|
||||
|
||||
def is_usable(self):
|
||||
try:
|
||||
# Use a psycopg cursor directly, bypassing Django's utilities.
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute('SELECT 1')
|
||||
except Database.Error:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@contextmanager
|
||||
def _nodb_cursor(self):
|
||||
cursor = None
|
||||
try:
|
||||
with super()._nodb_cursor() as cursor:
|
||||
yield cursor
|
||||
except (Database.DatabaseError, WrappedDatabaseError):
|
||||
if cursor is not None:
|
||||
raise
|
||||
warnings.warn(
|
||||
"Normally Django will use a connection to the 'postgres' database "
|
||||
"to avoid running initialization queries against the production "
|
||||
"database when it's not needed (for example, when running tests). "
|
||||
"Django was unable to create a connection to the 'postgres' database "
|
||||
"and will use the first PostgreSQL database instead.",
|
||||
RuntimeWarning
|
||||
)
|
||||
for connection in connections.all():
|
||||
if connection.vendor == 'postgresql' and connection.settings_dict['NAME'] != 'postgres':
|
||||
conn = self.__class__(
|
||||
{**self.settings_dict, 'NAME': connection.settings_dict['NAME']},
|
||||
alias=self.alias,
|
||||
)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
yield cursor
|
||||
finally:
|
||||
conn.close()
|
||||
break
|
||||
else:
|
||||
raise
|
||||
|
||||
@cached_property
|
||||
def pg_version(self):
|
||||
with self.temporary_connection():
|
||||
return self.connection.server_version
|
||||
|
||||
def make_debug_cursor(self, cursor):
|
||||
return CursorDebugWrapper(cursor, self)
|
||||
|
||||
|
||||
class CursorDebugWrapper(BaseCursorDebugWrapper):
|
||||
def copy_expert(self, sql, file, *args):
|
||||
with self.debug_sql(sql):
|
||||
return self.cursor.copy_expert(sql, file, *args)
|
||||
|
||||
def copy_to(self, file, table, *args, **kwargs):
|
||||
with self.debug_sql(sql='COPY %s TO STDOUT' % table):
|
||||
return self.cursor.copy_to(file, table, *args, **kwargs)
|
||||
@@ -0,0 +1,64 @@
|
||||
import signal
|
||||
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = 'psql'
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name]
|
||||
options = settings_dict.get('OPTIONS', {})
|
||||
|
||||
host = settings_dict.get('HOST')
|
||||
port = settings_dict.get('PORT')
|
||||
dbname = settings_dict.get('NAME')
|
||||
user = settings_dict.get('USER')
|
||||
passwd = settings_dict.get('PASSWORD')
|
||||
passfile = options.get('passfile')
|
||||
service = options.get('service')
|
||||
sslmode = options.get('sslmode')
|
||||
sslrootcert = options.get('sslrootcert')
|
||||
sslcert = options.get('sslcert')
|
||||
sslkey = options.get('sslkey')
|
||||
|
||||
if not dbname and not service:
|
||||
# Connect to the default 'postgres' db.
|
||||
dbname = 'postgres'
|
||||
if user:
|
||||
args += ['-U', user]
|
||||
if host:
|
||||
args += ['-h', host]
|
||||
if port:
|
||||
args += ['-p', str(port)]
|
||||
if dbname:
|
||||
args += [dbname]
|
||||
args.extend(parameters)
|
||||
|
||||
env = {}
|
||||
if passwd:
|
||||
env['PGPASSWORD'] = str(passwd)
|
||||
if service:
|
||||
env['PGSERVICE'] = str(service)
|
||||
if sslmode:
|
||||
env['PGSSLMODE'] = str(sslmode)
|
||||
if sslrootcert:
|
||||
env['PGSSLROOTCERT'] = str(sslrootcert)
|
||||
if sslcert:
|
||||
env['PGSSLCERT'] = str(sslcert)
|
||||
if sslkey:
|
||||
env['PGSSLKEY'] = str(sslkey)
|
||||
if passfile:
|
||||
env['PGPASSFILE'] = str(passfile)
|
||||
return args, (env or None)
|
||||
|
||||
def runshell(self, parameters):
|
||||
sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
try:
|
||||
# Allow SIGINT to pass to psql to abort queries.
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
super().runshell(parameters)
|
||||
finally:
|
||||
# Restore the original SIGINT handler.
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
@@ -0,0 +1,80 @@
|
||||
import sys
|
||||
|
||||
from psycopg2 import errorcodes
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
from django.db.backends.utils import strip_quotes
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
|
||||
def _quote_name(self, name):
|
||||
return self.connection.ops.quote_name(name)
|
||||
|
||||
def _get_database_create_suffix(self, encoding=None, template=None):
|
||||
suffix = ""
|
||||
if encoding:
|
||||
suffix += " ENCODING '{}'".format(encoding)
|
||||
if template:
|
||||
suffix += " TEMPLATE {}".format(self._quote_name(template))
|
||||
return suffix and "WITH" + suffix
|
||||
|
||||
def sql_table_creation_suffix(self):
|
||||
test_settings = self.connection.settings_dict['TEST']
|
||||
if test_settings.get('COLLATION') is not None:
|
||||
raise ImproperlyConfigured(
|
||||
'PostgreSQL does not support collation setting at database '
|
||||
'creation time.'
|
||||
)
|
||||
return self._get_database_create_suffix(
|
||||
encoding=test_settings['CHARSET'],
|
||||
template=test_settings.get('TEMPLATE'),
|
||||
)
|
||||
|
||||
def _database_exists(self, cursor, database_name):
|
||||
cursor.execute('SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s', [strip_quotes(database_name)])
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
|
||||
try:
|
||||
if keepdb and self._database_exists(cursor, parameters['dbname']):
|
||||
# If the database should be kept and it already exists, don't
|
||||
# try to create a new one.
|
||||
return
|
||||
super()._execute_create_test_db(cursor, parameters, keepdb)
|
||||
except Exception as e:
|
||||
if getattr(e.__cause__, 'pgcode', '') != errorcodes.DUPLICATE_DATABASE:
|
||||
# All errors except "database already exists" cancel tests.
|
||||
self.log('Got an error creating the test database: %s' % e)
|
||||
sys.exit(2)
|
||||
elif not keepdb:
|
||||
# If the database should be kept, ignore "database already
|
||||
# exists".
|
||||
raise
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
# CREATE DATABASE ... WITH TEMPLATE ... requires closing connections
|
||||
# to the template database.
|
||||
self.connection.close()
|
||||
|
||||
source_database_name = self.connection.settings_dict['NAME']
|
||||
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
|
||||
test_db_params = {
|
||||
'dbname': self._quote_name(target_database_name),
|
||||
'suffix': self._get_database_create_suffix(template=source_database_name),
|
||||
}
|
||||
with self._nodb_cursor() as cursor:
|
||||
try:
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception:
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log('Destroying old test database for alias %s...' % (
|
||||
self._get_database_display_str(verbosity, target_database_name),
|
||||
))
|
||||
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
self.log('Got an error cloning the test database: %s' % e)
|
||||
sys.exit(2)
|
||||
@@ -0,0 +1,97 @@
|
||||
import operator
|
||||
|
||||
from django.db import InterfaceError
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
allows_group_by_selected_pks = True
|
||||
can_return_columns_from_insert = True
|
||||
can_return_rows_from_bulk_insert = True
|
||||
has_real_datatype = True
|
||||
has_native_uuid_field = True
|
||||
has_native_duration_field = True
|
||||
has_native_json_field = True
|
||||
can_defer_constraint_checks = True
|
||||
has_select_for_update = True
|
||||
has_select_for_update_nowait = True
|
||||
has_select_for_update_of = True
|
||||
has_select_for_update_skip_locked = True
|
||||
has_select_for_no_key_update = True
|
||||
can_release_savepoints = True
|
||||
supports_tablespaces = True
|
||||
supports_transactions = True
|
||||
can_introspect_materialized_views = True
|
||||
can_distinct_on_fields = True
|
||||
can_rollback_ddl = True
|
||||
supports_combined_alters = True
|
||||
nulls_order_largest = True
|
||||
closed_cursor_error_class = InterfaceError
|
||||
has_case_insensitive_like = False
|
||||
greatest_least_ignores_nulls = True
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
supports_slicing_ordering_in_compound = True
|
||||
create_test_procedure_without_params_sql = """
|
||||
CREATE FUNCTION test_procedure () RETURNS void AS $$
|
||||
DECLARE
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := 1;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;"""
|
||||
create_test_procedure_with_int_param_sql = """
|
||||
CREATE FUNCTION test_procedure (P_I INTEGER) RETURNS void AS $$
|
||||
DECLARE
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := P_I;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;"""
|
||||
requires_casted_case_in_updates = True
|
||||
supports_over_clause = True
|
||||
only_supports_unbounded_with_preceding_and_following = True
|
||||
supports_aggregate_filter_clause = True
|
||||
supported_explain_formats = {'JSON', 'TEXT', 'XML', 'YAML'}
|
||||
validates_explain_options = False # A query will error on invalid options.
|
||||
supports_deferrable_unique_constraints = True
|
||||
has_json_operators = True
|
||||
json_key_contains_list_matching_requires_list = True
|
||||
test_collations = {
|
||||
'non_default': 'sv-x-icu',
|
||||
'swedish_ci': 'sv-x-icu',
|
||||
}
|
||||
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
|
||||
|
||||
django_test_skips = {
|
||||
'opclasses are PostgreSQL only.': {
|
||||
'indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses',
|
||||
},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
'PositiveBigIntegerField': 'BigIntegerField',
|
||||
'PositiveIntegerField': 'IntegerField',
|
||||
'PositiveSmallIntegerField': 'SmallIntegerField',
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def is_postgresql_11(self):
|
||||
return self.connection.pg_version >= 110000
|
||||
|
||||
@cached_property
|
||||
def is_postgresql_12(self):
|
||||
return self.connection.pg_version >= 120000
|
||||
|
||||
@cached_property
|
||||
def is_postgresql_13(self):
|
||||
return self.connection.pg_version >= 130000
|
||||
|
||||
has_websearch_to_tsquery = property(operator.attrgetter('is_postgresql_11'))
|
||||
supports_covering_indexes = property(operator.attrgetter('is_postgresql_11'))
|
||||
supports_covering_gist_indexes = property(operator.attrgetter('is_postgresql_12'))
|
||||
supports_non_deterministic_collations = property(operator.attrgetter('is_postgresql_12'))
|
||||
@@ -0,0 +1,234 @@
|
||||
from django.db.backends.base.introspection import (
|
||||
BaseDatabaseIntrospection, FieldInfo, TableInfo,
|
||||
)
|
||||
from django.db.models import Index
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
# Maps type codes to Django Field types.
|
||||
data_types_reverse = {
|
||||
16: 'BooleanField',
|
||||
17: 'BinaryField',
|
||||
20: 'BigIntegerField',
|
||||
21: 'SmallIntegerField',
|
||||
23: 'IntegerField',
|
||||
25: 'TextField',
|
||||
700: 'FloatField',
|
||||
701: 'FloatField',
|
||||
869: 'GenericIPAddressField',
|
||||
1042: 'CharField', # blank-padded
|
||||
1043: 'CharField',
|
||||
1082: 'DateField',
|
||||
1083: 'TimeField',
|
||||
1114: 'DateTimeField',
|
||||
1184: 'DateTimeField',
|
||||
1186: 'DurationField',
|
||||
1266: 'TimeField',
|
||||
1700: 'DecimalField',
|
||||
2950: 'UUIDField',
|
||||
3802: 'JSONField',
|
||||
}
|
||||
# A hook for subclasses.
|
||||
index_default_access_method = 'btree'
|
||||
|
||||
ignored_tables = []
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
field_type = super().get_field_type(data_type, description)
|
||||
if description.default and 'nextval' in description.default:
|
||||
if field_type == 'IntegerField':
|
||||
return 'AutoField'
|
||||
elif field_type == 'BigIntegerField':
|
||||
return 'BigAutoField'
|
||||
elif field_type == 'SmallIntegerField':
|
||||
return 'SmallAutoField'
|
||||
return field_type
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
cursor.execute("""
|
||||
SELECT c.relname,
|
||||
CASE WHEN c.relispartition THEN 'p' WHEN c.relkind IN ('m', 'v') THEN 'v' ELSE 't' END
|
||||
FROM pg_catalog.pg_class c
|
||||
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
|
||||
AND pg_catalog.pg_table_is_visible(c.oid)
|
||||
""")
|
||||
return [TableInfo(*row) for row in cursor.fetchall() if row[0] not in self.ignored_tables]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
# Query the pg_catalog tables as cursor.description does not reliably
|
||||
# return the nullable property and information_schema.columns does not
|
||||
# contain details of materialized views.
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
a.attname AS column_name,
|
||||
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||||
CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation
|
||||
FROM pg_attribute a
|
||||
LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
|
||||
LEFT JOIN pg_collation co ON a.attcollation = co.oid
|
||||
JOIN pg_type t ON a.atttypid = t.oid
|
||||
JOIN pg_class c ON a.attrelid = c.oid
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
|
||||
AND c.relname = %s
|
||||
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
|
||||
AND pg_catalog.pg_table_is_visible(c.oid)
|
||||
""", [table_name])
|
||||
field_map = {line[0]: line[1:] for line in cursor.fetchall()}
|
||||
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
|
||||
return [
|
||||
FieldInfo(
|
||||
line.name,
|
||||
line.type_code,
|
||||
line.display_size,
|
||||
line.internal_size,
|
||||
line.precision,
|
||||
line.scale,
|
||||
*field_map[line.name],
|
||||
)
|
||||
for line in cursor.description
|
||||
]
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
cursor.execute("""
|
||||
SELECT s.relname as sequence_name, col.attname
|
||||
FROM pg_class s
|
||||
JOIN pg_namespace sn ON sn.oid = s.relnamespace
|
||||
JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid = 'pg_class'::regclass
|
||||
JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass
|
||||
JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum
|
||||
JOIN pg_class tbl ON tbl.oid = ad.adrelid
|
||||
WHERE s.relkind = 'S'
|
||||
AND d.deptype in ('a', 'n')
|
||||
AND pg_catalog.pg_table_is_visible(tbl.oid)
|
||||
AND tbl.relname = %s
|
||||
""", [table_name])
|
||||
return [
|
||||
{'name': row[0], 'table': table_name, 'column': row[1]}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all relationships to the given table.
|
||||
"""
|
||||
return {row[0]: (row[2], row[1]) for row in self.get_key_columns(cursor, table_name)}
|
||||
|
||||
def get_key_columns(self, cursor, table_name):
|
||||
cursor.execute("""
|
||||
SELECT a1.attname, c2.relname, a2.attname
|
||||
FROM pg_constraint con
|
||||
LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
|
||||
LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
|
||||
LEFT JOIN pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
|
||||
LEFT JOIN pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
|
||||
WHERE
|
||||
c1.relname = %s AND
|
||||
con.contype = 'f' AND
|
||||
c1.relnamespace = c2.relnamespace AND
|
||||
pg_catalog.pg_table_is_visible(c1.oid)
|
||||
""", [table_name])
|
||||
return cursor.fetchall()
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns. Also retrieve the definition of expression-based
|
||||
indexes.
|
||||
"""
|
||||
constraints = {}
|
||||
# Loop over the key table, collecting things as constraints. The column
|
||||
# array must return column names in the same order in which they were
|
||||
# created.
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
c.conname,
|
||||
array(
|
||||
SELECT attname
|
||||
FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
|
||||
JOIN pg_attribute AS ca ON cols.colid = ca.attnum
|
||||
WHERE ca.attrelid = c.conrelid
|
||||
ORDER BY cols.arridx
|
||||
),
|
||||
c.contype,
|
||||
(SELECT fkc.relname || '.' || fka.attname
|
||||
FROM pg_attribute AS fka
|
||||
JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
|
||||
WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
|
||||
cl.reloptions
|
||||
FROM pg_constraint AS c
|
||||
JOIN pg_class AS cl ON c.conrelid = cl.oid
|
||||
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
|
||||
""", [table_name])
|
||||
for constraint, columns, kind, used_cols, options in cursor.fetchall():
|
||||
constraints[constraint] = {
|
||||
"columns": columns,
|
||||
"primary_key": kind == "p",
|
||||
"unique": kind in ["p", "u"],
|
||||
"foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
|
||||
"check": kind == "c",
|
||||
"index": False,
|
||||
"definition": None,
|
||||
"options": options,
|
||||
}
|
||||
# Now get indexes
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
indexname, array_agg(attname ORDER BY arridx), indisunique, indisprimary,
|
||||
array_agg(ordering ORDER BY arridx), amname, exprdef, s2.attoptions
|
||||
FROM (
|
||||
SELECT
|
||||
c2.relname as indexname, idx.*, attr.attname, am.amname,
|
||||
CASE
|
||||
WHEN idx.indexprs IS NOT NULL THEN
|
||||
pg_get_indexdef(idx.indexrelid)
|
||||
END AS exprdef,
|
||||
CASE am.amname
|
||||
WHEN %s THEN
|
||||
CASE (option & 1)
|
||||
WHEN 1 THEN 'DESC' ELSE 'ASC'
|
||||
END
|
||||
END as ordering,
|
||||
c2.reloptions as attoptions
|
||||
FROM (
|
||||
SELECT *
|
||||
FROM pg_index i, unnest(i.indkey, i.indoption) WITH ORDINALITY koi(key, option, arridx)
|
||||
) idx
|
||||
LEFT JOIN pg_class c ON idx.indrelid = c.oid
|
||||
LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
|
||||
LEFT JOIN pg_am am ON c2.relam = am.oid
|
||||
LEFT JOIN pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
|
||||
WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
|
||||
) s2
|
||||
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
|
||||
""", [self.index_default_access_method, table_name])
|
||||
for index, columns, unique, primary, orders, type_, definition, options in cursor.fetchall():
|
||||
if index not in constraints:
|
||||
basic_index = (
|
||||
type_ == self.index_default_access_method and
|
||||
# '_btree' references
|
||||
# django.contrib.postgres.indexes.BTreeIndex.suffix.
|
||||
not index.endswith('_btree') and options is None
|
||||
)
|
||||
constraints[index] = {
|
||||
"columns": columns if columns != [None] else [],
|
||||
"orders": orders if orders != [None] else [],
|
||||
"primary_key": primary,
|
||||
"unique": unique,
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": True,
|
||||
"type": Index.suffix if basic_index else type_,
|
||||
"definition": definition,
|
||||
"options": options,
|
||||
}
|
||||
return constraints
|
||||
@@ -0,0 +1,276 @@
|
||||
from psycopg2.extras import Inet
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.utils import split_tzname_delta
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
cast_char_field_without_max_length = 'varchar'
|
||||
explain_prefix = 'EXPLAIN'
|
||||
cast_data_types = {
|
||||
'AutoField': 'integer',
|
||||
'BigAutoField': 'bigint',
|
||||
'SmallAutoField': 'smallint',
|
||||
}
|
||||
|
||||
def unification_cast_sql(self, output_field):
|
||||
internal_type = output_field.get_internal_type()
|
||||
if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"):
|
||||
# PostgreSQL will resolve a union as type 'text' if input types are
|
||||
# 'unknown'.
|
||||
# https://www.postgresql.org/docs/current/typeconv-union-case.html
|
||||
# These fields cannot be implicitly cast back in the default
|
||||
# PostgreSQL configuration so we need to explicitly cast them.
|
||||
# We must also remove components of the type within brackets:
|
||||
# varchar(255) -> varchar.
|
||||
return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0]
|
||||
return '%s'
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
|
||||
if lookup_type == 'week_day':
|
||||
# For consistency across backends, we return Sunday=1, Saturday=7.
|
||||
return "EXTRACT('dow' FROM %s) + 1" % field_name
|
||||
elif lookup_type == 'iso_week_day':
|
||||
return "EXTRACT('isodow' FROM %s)" % field_name
|
||||
elif lookup_type == 'iso_year':
|
||||
return "EXTRACT('isoyear' FROM %s)" % field_name
|
||||
else:
|
||||
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
|
||||
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
|
||||
|
||||
def _prepare_tzname_delta(self, tzname):
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
if offset:
|
||||
sign = '-' if sign == '+' else '+'
|
||||
return f'{tzname}{sign}{offset}'
|
||||
return tzname
|
||||
|
||||
def _convert_field_to_tz(self, field_name, tzname):
|
||||
if tzname and settings.USE_TZ:
|
||||
field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname))
|
||||
return field_name
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return '(%s)::date' % field_name
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return '(%s)::time' % field_name
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return self.date_extract_sql(lookup_type, field_name)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
|
||||
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name)
|
||||
|
||||
def deferrable_sql(self):
|
||||
return " DEFERRABLE INITIALLY DEFERRED"
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the tuple of returned data.
|
||||
"""
|
||||
return cursor.fetchall()
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
lookup = '%s'
|
||||
|
||||
# Cast text lookups to text to allow things like filter(x__contains=4)
|
||||
if lookup_type in ('iexact', 'contains', 'icontains', 'startswith',
|
||||
'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'):
|
||||
if internal_type in ('IPAddressField', 'GenericIPAddressField'):
|
||||
lookup = "HOST(%s)"
|
||||
elif internal_type in ('CICharField', 'CIEmailField', 'CITextField'):
|
||||
lookup = '%s::citext'
|
||||
else:
|
||||
lookup = "%s::text"
|
||||
|
||||
# Use UPPER(x) for case-insensitive lookups; it's faster.
|
||||
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
|
||||
lookup = 'UPPER(%s)' % lookup
|
||||
|
||||
return lookup
|
||||
|
||||
def no_limit_value(self):
|
||||
return None
|
||||
|
||||
def prepare_sql_script(self, sql):
|
||||
return [sql]
|
||||
|
||||
def quote_name(self, name):
|
||||
if name.startswith('"') and name.endswith('"'):
|
||||
return name # Quoting once is enough.
|
||||
return '"%s"' % name
|
||||
|
||||
def set_time_zone_sql(self):
|
||||
return "SET TIME ZONE %s"
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if not tables:
|
||||
return []
|
||||
|
||||
# Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us
|
||||
# to truncate tables referenced by a foreign key in any other table.
|
||||
sql_parts = [
|
||||
style.SQL_KEYWORD('TRUNCATE'),
|
||||
', '.join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
|
||||
]
|
||||
if reset_sequences:
|
||||
sql_parts.append(style.SQL_KEYWORD('RESTART IDENTITY'))
|
||||
if allow_cascade:
|
||||
sql_parts.append(style.SQL_KEYWORD('CASCADE'))
|
||||
return ['%s;' % ' '.join(sql_parts)]
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
# 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
|
||||
# to reset sequence indices
|
||||
sql = []
|
||||
for sequence_info in sequences:
|
||||
table_name = sequence_info['table']
|
||||
# 'id' will be the case if it's an m2m using an autogenerated
|
||||
# intermediate table (see BaseDatabaseIntrospection.sequence_list).
|
||||
column_name = sequence_info['column'] or 'id'
|
||||
sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % (
|
||||
style.SQL_KEYWORD('SELECT'),
|
||||
style.SQL_TABLE(self.quote_name(table_name)),
|
||||
style.SQL_FIELD(column_name),
|
||||
))
|
||||
return sql
|
||||
|
||||
def tablespace_sql(self, tablespace, inline=False):
|
||||
if inline:
|
||||
return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
|
||||
else:
|
||||
return "TABLESPACE %s" % self.quote_name(tablespace)
|
||||
|
||||
def sequence_reset_sql(self, style, model_list):
|
||||
from django.db import models
|
||||
output = []
|
||||
qn = self.quote_name
|
||||
for model in model_list:
|
||||
# Use `coalesce` to set the sequence for each model to the max pk value if there are records,
|
||||
# or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
|
||||
# if there are records (as the max pk value is already in use), otherwise set it to false.
|
||||
# Use pg_get_serial_sequence to get the underlying sequence name from the table name
|
||||
# and column name (available since PostgreSQL 8)
|
||||
|
||||
for f in model._meta.local_fields:
|
||||
if isinstance(f, models.AutoField):
|
||||
output.append(
|
||||
"%s setval(pg_get_serial_sequence('%s','%s'), "
|
||||
"coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (
|
||||
style.SQL_KEYWORD('SELECT'),
|
||||
style.SQL_TABLE(qn(model._meta.db_table)),
|
||||
style.SQL_FIELD(f.column),
|
||||
style.SQL_FIELD(qn(f.column)),
|
||||
style.SQL_FIELD(qn(f.column)),
|
||||
style.SQL_KEYWORD('IS NOT'),
|
||||
style.SQL_KEYWORD('FROM'),
|
||||
style.SQL_TABLE(qn(model._meta.db_table)),
|
||||
)
|
||||
)
|
||||
break # Only one AutoField is allowed per model, so don't bother continuing.
|
||||
return output
|
||||
|
||||
def prep_for_iexact_query(self, x):
|
||||
return x
|
||||
|
||||
def max_name_length(self):
|
||||
"""
|
||||
Return the maximum length of an identifier.
|
||||
|
||||
The maximum length of an identifier is 63 by default, but can be
|
||||
changed by recompiling PostgreSQL after editing the NAMEDATALEN
|
||||
macro in src/include/pg_config_manual.h.
|
||||
|
||||
This implementation returns 63, but can be overridden by a custom
|
||||
database backend that inherits most of its behavior from this one.
|
||||
"""
|
||||
return 63
|
||||
|
||||
def distinct_sql(self, fields, params):
|
||||
if fields:
|
||||
params = [param for param_list in params for param in param_list]
|
||||
return (['DISTINCT ON (%s)' % ', '.join(fields)], params)
|
||||
else:
|
||||
return ['DISTINCT'], []
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# https://www.psycopg.org/docs/cursor.html#cursor.query
|
||||
# The query attribute is a Psycopg extension to the DB API 2.0.
|
||||
if cursor.query is not None:
|
||||
return cursor.query.decode()
|
||||
return None
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
if not fields:
|
||||
return '', ()
|
||||
columns = [
|
||||
'%s.%s' % (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
) for field in fields
|
||||
]
|
||||
return 'RETURNING %s' % ', '.join(columns), ()
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
|
||||
values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
|
||||
return "VALUES " + values_sql
|
||||
|
||||
def adapt_datefield_value(self, value):
|
||||
return value
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
return value
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
return value
|
||||
|
||||
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
|
||||
return value
|
||||
|
||||
def adapt_ipaddressfield_value(self, value):
|
||||
if value:
|
||||
return Inet(value)
|
||||
return None
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
if internal_type == 'DateField':
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
params = (*lhs_params, *rhs_params)
|
||||
return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), params
|
||||
return super().subtract_temporals(internal_type, lhs, rhs)
|
||||
|
||||
def explain_query_prefix(self, format=None, **options):
|
||||
prefix = super().explain_query_prefix(format)
|
||||
extra = {}
|
||||
if format:
|
||||
extra['FORMAT'] = format
|
||||
if options:
|
||||
extra.update({
|
||||
name.upper(): 'true' if value else 'false'
|
||||
for name, value in options.items()
|
||||
})
|
||||
if extra:
|
||||
prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
|
||||
return prefix
|
||||
|
||||
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
|
||||
return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts)
|
||||
238
venv/Lib/site-packages/django/db/backends/postgresql/schema.py
Normal file
238
venv/Lib/site-packages/django/db/backends/postgresql/schema.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import psycopg2
|
||||
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.backends.ddl_references import IndexColumns
|
||||
from django.db.backends.utils import strip_quotes
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
|
||||
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
|
||||
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
|
||||
sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
|
||||
sql_set_sequence_owner = 'ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s'
|
||||
|
||||
sql_create_index = (
|
||||
'CREATE INDEX %(name)s ON %(table)s%(using)s '
|
||||
'(%(columns)s)%(include)s%(extra)s%(condition)s'
|
||||
)
|
||||
sql_create_index_concurrently = (
|
||||
'CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s '
|
||||
'(%(columns)s)%(include)s%(extra)s%(condition)s'
|
||||
)
|
||||
sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
|
||||
sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
|
||||
|
||||
# Setting the constraint to IMMEDIATE to allow changing data in the same
|
||||
# transaction.
|
||||
sql_create_column_inline_fk = (
|
||||
'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
|
||||
'; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE'
|
||||
)
|
||||
# Setting the constraint to IMMEDIATE runs any deferred checks to allow
|
||||
# dropping it in the same transaction.
|
||||
sql_delete_fk = "SET CONSTRAINTS %(name)s IMMEDIATE; ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
||||
|
||||
sql_delete_procedure = 'DROP FUNCTION %(procedure)s(%(param_types)s)'
|
||||
|
||||
def quote_value(self, value):
|
||||
if isinstance(value, str):
|
||||
value = value.replace('%', '%%')
|
||||
adapted = psycopg2.extensions.adapt(value)
|
||||
if hasattr(adapted, 'encoding'):
|
||||
adapted.encoding = 'utf8'
|
||||
# getquoted() returns a quoted bytestring of the adapted value.
|
||||
return adapted.getquoted().decode()
|
||||
|
||||
def _field_indexes_sql(self, model, field):
|
||||
output = super()._field_indexes_sql(model, field)
|
||||
like_index_statement = self._create_like_index_sql(model, field)
|
||||
if like_index_statement is not None:
|
||||
output.append(like_index_statement)
|
||||
return output
|
||||
|
||||
def _field_data_type(self, field):
|
||||
if field.is_relation:
|
||||
return field.rel_db_type(self.connection)
|
||||
return self.connection.data_types.get(
|
||||
field.get_internal_type(),
|
||||
field.db_type(self.connection),
|
||||
)
|
||||
|
||||
def _field_base_data_types(self, field):
|
||||
# Yield base data types for array fields.
|
||||
if field.base_field.get_internal_type() == 'ArrayField':
|
||||
yield from self._field_base_data_types(field.base_field)
|
||||
else:
|
||||
yield self._field_data_type(field.base_field)
|
||||
|
||||
def _create_like_index_sql(self, model, field):
|
||||
"""
|
||||
Return the statement to create an index with varchar operator pattern
|
||||
when the column type is 'varchar' or 'text', otherwise return None.
|
||||
"""
|
||||
db_type = field.db_type(connection=self.connection)
|
||||
if db_type is not None and (field.db_index or field.unique):
|
||||
# Fields with database column types of `varchar` and `text` need
|
||||
# a second index that specifies their operator class, which is
|
||||
# needed when performing correct LIKE queries outside the
|
||||
# C locale. See #12234.
|
||||
#
|
||||
# The same doesn't apply to array fields such as varchar[size]
|
||||
# and text[size], so skip them.
|
||||
if '[' in db_type:
|
||||
return None
|
||||
if db_type.startswith('varchar'):
|
||||
return self._create_index_sql(
|
||||
model,
|
||||
fields=[field],
|
||||
suffix='_like',
|
||||
opclasses=['varchar_pattern_ops'],
|
||||
)
|
||||
elif db_type.startswith('text'):
|
||||
return self._create_index_sql(
|
||||
model,
|
||||
fields=[field],
|
||||
suffix='_like',
|
||||
opclasses=['text_pattern_ops'],
|
||||
)
|
||||
return None
|
||||
|
||||
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
|
||||
self.sql_alter_column_type = 'ALTER COLUMN %(column)s TYPE %(type)s'
|
||||
# Cast when data type changed.
|
||||
using_sql = ' USING %(column)s::%(type)s'
|
||||
new_internal_type = new_field.get_internal_type()
|
||||
old_internal_type = old_field.get_internal_type()
|
||||
if new_internal_type == 'ArrayField' and new_internal_type == old_internal_type:
|
||||
# Compare base data types for array fields.
|
||||
if list(self._field_base_data_types(old_field)) != list(self._field_base_data_types(new_field)):
|
||||
self.sql_alter_column_type += using_sql
|
||||
elif self._field_data_type(old_field) != self._field_data_type(new_field):
|
||||
self.sql_alter_column_type += using_sql
|
||||
# Make ALTER TYPE with SERIAL make sense.
|
||||
table = strip_quotes(model._meta.db_table)
|
||||
serial_fields_map = {'bigserial': 'bigint', 'serial': 'integer', 'smallserial': 'smallint'}
|
||||
if new_type.lower() in serial_fields_map:
|
||||
column = strip_quotes(new_field.column)
|
||||
sequence_name = "%s_%s_seq" % (table, column)
|
||||
return (
|
||||
(
|
||||
self.sql_alter_column_type % {
|
||||
"column": self.quote_name(column),
|
||||
"type": serial_fields_map[new_type.lower()],
|
||||
},
|
||||
[],
|
||||
),
|
||||
[
|
||||
(
|
||||
self.sql_delete_sequence % {
|
||||
"sequence": self.quote_name(sequence_name),
|
||||
},
|
||||
[],
|
||||
),
|
||||
(
|
||||
self.sql_create_sequence % {
|
||||
"sequence": self.quote_name(sequence_name),
|
||||
},
|
||||
[],
|
||||
),
|
||||
(
|
||||
self.sql_alter_column % {
|
||||
"table": self.quote_name(table),
|
||||
"changes": self.sql_alter_column_default % {
|
||||
"column": self.quote_name(column),
|
||||
"default": "nextval('%s')" % self.quote_name(sequence_name),
|
||||
}
|
||||
},
|
||||
[],
|
||||
),
|
||||
(
|
||||
self.sql_set_sequence_max % {
|
||||
"table": self.quote_name(table),
|
||||
"column": self.quote_name(column),
|
||||
"sequence": self.quote_name(sequence_name),
|
||||
},
|
||||
[],
|
||||
),
|
||||
(
|
||||
self.sql_set_sequence_owner % {
|
||||
'table': self.quote_name(table),
|
||||
'column': self.quote_name(column),
|
||||
'sequence': self.quote_name(sequence_name),
|
||||
},
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
elif old_field.db_parameters(connection=self.connection)['type'] in serial_fields_map:
|
||||
# Drop the sequence if migrating away from AutoField.
|
||||
column = strip_quotes(new_field.column)
|
||||
sequence_name = '%s_%s_seq' % (table, column)
|
||||
fragment, _ = super()._alter_column_type_sql(model, old_field, new_field, new_type)
|
||||
return fragment, [
|
||||
(
|
||||
self.sql_delete_sequence % {
|
||||
'sequence': self.quote_name(sequence_name),
|
||||
},
|
||||
[],
|
||||
),
|
||||
]
|
||||
else:
|
||||
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
|
||||
|
||||
def _alter_field(self, model, old_field, new_field, old_type, new_type,
|
||||
old_db_params, new_db_params, strict=False):
|
||||
# Drop indexes on varchar/text/citext columns that are changing to a
|
||||
# different type.
|
||||
if (old_field.db_index or old_field.unique) and (
|
||||
(old_type.startswith('varchar') and not new_type.startswith('varchar')) or
|
||||
(old_type.startswith('text') and not new_type.startswith('text')) or
|
||||
(old_type.startswith('citext') and not new_type.startswith('citext'))
|
||||
):
|
||||
index_name = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
|
||||
self.execute(self._delete_index_sql(model, index_name))
|
||||
|
||||
super()._alter_field(
|
||||
model, old_field, new_field, old_type, new_type, old_db_params,
|
||||
new_db_params, strict,
|
||||
)
|
||||
# Added an index? Create any PostgreSQL-specific indexes.
|
||||
if ((not (old_field.db_index or old_field.unique) and new_field.db_index) or
|
||||
(not old_field.unique and new_field.unique)):
|
||||
like_index_statement = self._create_like_index_sql(model, new_field)
|
||||
if like_index_statement is not None:
|
||||
self.execute(like_index_statement)
|
||||
|
||||
# Removed an index? Drop any PostgreSQL-specific indexes.
|
||||
if old_field.unique and not (new_field.db_index or new_field.unique):
|
||||
index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
|
||||
self.execute(self._delete_index_sql(model, index_to_remove))
|
||||
|
||||
def _index_columns(self, table, columns, col_suffixes, opclasses):
|
||||
if opclasses:
|
||||
return IndexColumns(table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses)
|
||||
return super()._index_columns(table, columns, col_suffixes, opclasses)
|
||||
|
||||
def add_index(self, model, index, concurrently=False):
|
||||
self.execute(index.create_sql(model, self, concurrently=concurrently), params=None)
|
||||
|
||||
def remove_index(self, model, index, concurrently=False):
|
||||
self.execute(index.remove_sql(model, self, concurrently=concurrently))
|
||||
|
||||
def _delete_index_sql(self, model, name, sql=None, concurrently=False):
|
||||
sql = self.sql_delete_index_concurrently if concurrently else self.sql_delete_index
|
||||
return super()._delete_index_sql(model, name, sql)
|
||||
|
||||
def _create_index_sql(
|
||||
self, model, *, fields=None, name=None, suffix='', using='',
|
||||
db_tablespace=None, col_suffixes=(), sql=None, opclasses=(),
|
||||
condition=None, concurrently=False, include=None, expressions=None,
|
||||
):
|
||||
sql = self.sql_create_index if not concurrently else self.sql_create_index_concurrently
|
||||
return super()._create_index_sql(
|
||||
model, fields=fields, name=name, suffix=suffix, using=using,
|
||||
db_tablespace=db_tablespace, col_suffixes=col_suffixes, sql=sql,
|
||||
opclasses=opclasses, condition=condition, include=include,
|
||||
expressions=expressions,
|
||||
)
|
||||
3
venv/Lib/site-packages/django/db/backends/signals.py
Normal file
3
venv/Lib/site-packages/django/db/backends/signals.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from django.dispatch import Signal
|
||||
|
||||
connection_created = Signal()
|
||||
621
venv/Lib/site-packages/django/db/backends/sqlite3/base.py
Normal file
621
venv/Lib/site-packages/django/db/backends/sqlite3/base.py
Normal file
@@ -0,0 +1,621 @@
|
||||
"""
|
||||
SQLite backend for the sqlite3 module in the standard library.
|
||||
"""
|
||||
import datetime
|
||||
import decimal
|
||||
import functools
|
||||
import hashlib
|
||||
import math
|
||||
import operator
|
||||
import random
|
||||
import re
|
||||
import statistics
|
||||
import warnings
|
||||
from itertools import chain
|
||||
from sqlite3 import dbapi2 as Database
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import IntegrityError
|
||||
from django.db.backends import utils as backend_utils
|
||||
from django.db.backends.base.base import (
|
||||
BaseDatabaseWrapper, timezone_constructor,
|
||||
)
|
||||
from django.utils import timezone
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.dateparse import parse_datetime, parse_time
|
||||
from django.utils.duration import duration_microseconds
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
from .client import DatabaseClient
|
||||
from .creation import DatabaseCreation
|
||||
from .features import DatabaseFeatures
|
||||
from .introspection import DatabaseIntrospection
|
||||
from .operations import DatabaseOperations
|
||||
from .schema import DatabaseSchemaEditor
|
||||
|
||||
|
||||
def decoder(conv_func):
|
||||
"""
|
||||
Convert bytestrings from Python's sqlite3 interface to a regular string.
|
||||
"""
|
||||
return lambda s: conv_func(s.decode())
|
||||
|
||||
|
||||
def none_guard(func):
|
||||
"""
|
||||
Decorator that returns None if any of the arguments to the decorated
|
||||
function are None. Many SQL functions return NULL if any of their arguments
|
||||
are NULL. This decorator simplifies the implementation of this for the
|
||||
custom functions registered below.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return None if None in args else func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def list_aggregate(function):
|
||||
"""
|
||||
Return an aggregate class that accumulates values in a list and applies
|
||||
the provided function to the data.
|
||||
"""
|
||||
return type('ListAggregate', (list,), {'finalize': function, 'step': list.append})
|
||||
|
||||
|
||||
def check_sqlite_version():
|
||||
if Database.sqlite_version_info < (3, 9, 0):
|
||||
raise ImproperlyConfigured(
|
||||
'SQLite 3.9.0 or later is required (found %s).' % Database.sqlite_version
|
||||
)
|
||||
|
||||
|
||||
check_sqlite_version()
|
||||
|
||||
Database.register_converter("bool", b'1'.__eq__)
|
||||
Database.register_converter("time", decoder(parse_time))
|
||||
Database.register_converter("datetime", decoder(parse_datetime))
|
||||
Database.register_converter("timestamp", decoder(parse_datetime))
|
||||
|
||||
Database.register_adapter(decimal.Decimal, str)
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = 'sqlite'
|
||||
display_name = 'SQLite'
|
||||
# SQLite doesn't actually support most of these types, but it "does the right
|
||||
# thing" given more verbose field definitions, so leave them as is so that
|
||||
# schema inspection is more useful.
|
||||
data_types = {
|
||||
'AutoField': 'integer',
|
||||
'BigAutoField': 'integer',
|
||||
'BinaryField': 'BLOB',
|
||||
'BooleanField': 'bool',
|
||||
'CharField': 'varchar(%(max_length)s)',
|
||||
'DateField': 'date',
|
||||
'DateTimeField': 'datetime',
|
||||
'DecimalField': 'decimal',
|
||||
'DurationField': 'bigint',
|
||||
'FileField': 'varchar(%(max_length)s)',
|
||||
'FilePathField': 'varchar(%(max_length)s)',
|
||||
'FloatField': 'real',
|
||||
'IntegerField': 'integer',
|
||||
'BigIntegerField': 'bigint',
|
||||
'IPAddressField': 'char(15)',
|
||||
'GenericIPAddressField': 'char(39)',
|
||||
'JSONField': 'text',
|
||||
'OneToOneField': 'integer',
|
||||
'PositiveBigIntegerField': 'bigint unsigned',
|
||||
'PositiveIntegerField': 'integer unsigned',
|
||||
'PositiveSmallIntegerField': 'smallint unsigned',
|
||||
'SlugField': 'varchar(%(max_length)s)',
|
||||
'SmallAutoField': 'integer',
|
||||
'SmallIntegerField': 'smallint',
|
||||
'TextField': 'text',
|
||||
'TimeField': 'time',
|
||||
'UUIDField': 'char(32)',
|
||||
}
|
||||
data_type_check_constraints = {
|
||||
'PositiveBigIntegerField': '"%(column)s" >= 0',
|
||||
'JSONField': '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
|
||||
'PositiveIntegerField': '"%(column)s" >= 0',
|
||||
'PositiveSmallIntegerField': '"%(column)s" >= 0',
|
||||
}
|
||||
data_types_suffix = {
|
||||
'AutoField': 'AUTOINCREMENT',
|
||||
'BigAutoField': 'AUTOINCREMENT',
|
||||
'SmallAutoField': 'AUTOINCREMENT',
|
||||
}
|
||||
# SQLite requires LIKE statements to include an ESCAPE clause if the value
|
||||
# being escaped has a percent or underscore in it.
|
||||
# See https://www.sqlite.org/lang_expr.html for an explanation.
|
||||
operators = {
|
||||
'exact': '= %s',
|
||||
'iexact': "LIKE %s ESCAPE '\\'",
|
||||
'contains': "LIKE %s ESCAPE '\\'",
|
||||
'icontains': "LIKE %s ESCAPE '\\'",
|
||||
'regex': 'REGEXP %s',
|
||||
'iregex': "REGEXP '(?i)' || %s",
|
||||
'gt': '> %s',
|
||||
'gte': '>= %s',
|
||||
'lt': '< %s',
|
||||
'lte': '<= %s',
|
||||
'startswith': "LIKE %s ESCAPE '\\'",
|
||||
'endswith': "LIKE %s ESCAPE '\\'",
|
||||
'istartswith': "LIKE %s ESCAPE '\\'",
|
||||
'iendswith': "LIKE %s ESCAPE '\\'",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
|
||||
# escaped on database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
|
||||
pattern_ops = {
|
||||
'contains': r"LIKE '%%' || {} || '%%' ESCAPE '\'",
|
||||
'icontains': r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
|
||||
'startswith': r"LIKE {} || '%%' ESCAPE '\'",
|
||||
'istartswith': r"LIKE UPPER({}) || '%%' ESCAPE '\'",
|
||||
'endswith': r"LIKE '%%' || {} ESCAPE '\'",
|
||||
'iendswith': r"LIKE '%%' || UPPER({}) ESCAPE '\'",
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
|
||||
def get_connection_params(self):
|
||||
settings_dict = self.settings_dict
|
||||
if not settings_dict['NAME']:
|
||||
raise ImproperlyConfigured(
|
||||
"settings.DATABASES is improperly configured. "
|
||||
"Please supply the NAME value.")
|
||||
kwargs = {
|
||||
'database': settings_dict['NAME'],
|
||||
'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
|
||||
**settings_dict['OPTIONS'],
|
||||
}
|
||||
# Always allow the underlying SQLite connection to be shareable
|
||||
# between multiple threads. The safe-guarding will be handled at a
|
||||
# higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
|
||||
# property. This is necessary as the shareability is disabled by
|
||||
# default in pysqlite and it cannot be changed once a connection is
|
||||
# opened.
|
||||
if 'check_same_thread' in kwargs and kwargs['check_same_thread']:
|
||||
warnings.warn(
|
||||
'The `check_same_thread` option was provided and set to '
|
||||
'True. It will be overridden with False. Use the '
|
||||
'`DatabaseWrapper.allow_thread_sharing` property instead '
|
||||
'for controlling thread shareability.',
|
||||
RuntimeWarning
|
||||
)
|
||||
kwargs.update({'check_same_thread': False, 'uri': True})
|
||||
return kwargs
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
conn = Database.connect(**conn_params)
|
||||
create_deterministic_function = functools.partial(
|
||||
conn.create_function,
|
||||
deterministic=True,
|
||||
)
|
||||
create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract)
|
||||
create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc)
|
||||
create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date)
|
||||
create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time)
|
||||
create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract)
|
||||
create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc)
|
||||
create_deterministic_function('django_time_extract', 2, _sqlite_time_extract)
|
||||
create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc)
|
||||
create_deterministic_function('django_time_diff', 2, _sqlite_time_diff)
|
||||
create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff)
|
||||
create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta)
|
||||
create_deterministic_function('regexp', 2, _sqlite_regexp)
|
||||
create_deterministic_function('ACOS', 1, none_guard(math.acos))
|
||||
create_deterministic_function('ASIN', 1, none_guard(math.asin))
|
||||
create_deterministic_function('ATAN', 1, none_guard(math.atan))
|
||||
create_deterministic_function('ATAN2', 2, none_guard(math.atan2))
|
||||
create_deterministic_function('BITXOR', 2, none_guard(operator.xor))
|
||||
create_deterministic_function('CEILING', 1, none_guard(math.ceil))
|
||||
create_deterministic_function('COS', 1, none_guard(math.cos))
|
||||
create_deterministic_function('COT', 1, none_guard(lambda x: 1 / math.tan(x)))
|
||||
create_deterministic_function('DEGREES', 1, none_guard(math.degrees))
|
||||
create_deterministic_function('EXP', 1, none_guard(math.exp))
|
||||
create_deterministic_function('FLOOR', 1, none_guard(math.floor))
|
||||
create_deterministic_function('LN', 1, none_guard(math.log))
|
||||
create_deterministic_function('LOG', 2, none_guard(lambda x, y: math.log(y, x)))
|
||||
create_deterministic_function('LPAD', 3, _sqlite_lpad)
|
||||
create_deterministic_function('MD5', 1, none_guard(lambda x: hashlib.md5(x.encode()).hexdigest()))
|
||||
create_deterministic_function('MOD', 2, none_guard(math.fmod))
|
||||
create_deterministic_function('PI', 0, lambda: math.pi)
|
||||
create_deterministic_function('POWER', 2, none_guard(operator.pow))
|
||||
create_deterministic_function('RADIANS', 1, none_guard(math.radians))
|
||||
create_deterministic_function('REPEAT', 2, none_guard(operator.mul))
|
||||
create_deterministic_function('REVERSE', 1, none_guard(lambda x: x[::-1]))
|
||||
create_deterministic_function('RPAD', 3, _sqlite_rpad)
|
||||
create_deterministic_function('SHA1', 1, none_guard(lambda x: hashlib.sha1(x.encode()).hexdigest()))
|
||||
create_deterministic_function('SHA224', 1, none_guard(lambda x: hashlib.sha224(x.encode()).hexdigest()))
|
||||
create_deterministic_function('SHA256', 1, none_guard(lambda x: hashlib.sha256(x.encode()).hexdigest()))
|
||||
create_deterministic_function('SHA384', 1, none_guard(lambda x: hashlib.sha384(x.encode()).hexdigest()))
|
||||
create_deterministic_function('SHA512', 1, none_guard(lambda x: hashlib.sha512(x.encode()).hexdigest()))
|
||||
create_deterministic_function('SIGN', 1, none_guard(lambda x: (x > 0) - (x < 0)))
|
||||
create_deterministic_function('SIN', 1, none_guard(math.sin))
|
||||
create_deterministic_function('SQRT', 1, none_guard(math.sqrt))
|
||||
create_deterministic_function('TAN', 1, none_guard(math.tan))
|
||||
# Don't use the built-in RANDOM() function because it returns a value
|
||||
# in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
|
||||
conn.create_function('RAND', 0, random.random)
|
||||
conn.create_aggregate('STDDEV_POP', 1, list_aggregate(statistics.pstdev))
|
||||
conn.create_aggregate('STDDEV_SAMP', 1, list_aggregate(statistics.stdev))
|
||||
conn.create_aggregate('VAR_POP', 1, list_aggregate(statistics.pvariance))
|
||||
conn.create_aggregate('VAR_SAMP', 1, list_aggregate(statistics.variance))
|
||||
conn.execute('PRAGMA foreign_keys = ON')
|
||||
return conn
|
||||
|
||||
def init_connection_state(self):
|
||||
pass
|
||||
|
||||
def create_cursor(self, name=None):
|
||||
return self.connection.cursor(factory=SQLiteCursorWrapper)
|
||||
|
||||
@async_unsafe
|
||||
def close(self):
|
||||
self.validate_thread_sharing()
|
||||
# If database is in memory, closing the connection destroys the
|
||||
# database. To prevent accidental data loss, ignore close requests on
|
||||
# an in-memory db.
|
||||
if not self.is_in_memory_db():
|
||||
BaseDatabaseWrapper.close(self)
|
||||
|
||||
def _savepoint_allowed(self):
|
||||
# When 'isolation_level' is not None, sqlite3 commits before each
|
||||
# savepoint; it's a bug. When it is None, savepoints don't make sense
|
||||
# because autocommit is enabled. The only exception is inside 'atomic'
|
||||
# blocks. To work around that bug, on SQLite, 'atomic' starts a
|
||||
# transaction explicitly rather than simply disable autocommit.
|
||||
return self.in_atomic_block
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
if autocommit:
|
||||
level = None
|
||||
else:
|
||||
# sqlite3's internal default is ''. It's different from None.
|
||||
# See Modules/_sqlite/connection.c.
|
||||
level = ''
|
||||
# 'isolation_level' is a misleading API.
|
||||
# SQLite always runs at the SERIALIZABLE isolation level.
|
||||
with self.wrap_database_errors:
|
||||
self.connection.isolation_level = level
|
||||
|
||||
def disable_constraint_checking(self):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute('PRAGMA foreign_keys = OFF')
|
||||
# Foreign key constraints cannot be turned off while in a multi-
|
||||
# statement transaction. Fetch the current state of the pragma
|
||||
# to determine if constraints are effectively disabled.
|
||||
enabled = cursor.execute('PRAGMA foreign_keys').fetchone()[0]
|
||||
return not bool(enabled)
|
||||
|
||||
def enable_constraint_checking(self):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute('PRAGMA foreign_keys = ON')
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check each table name in `table_names` for rows with invalid foreign
|
||||
key references. This method is intended to be used in conjunction with
|
||||
`disable_constraint_checking()` and `enable_constraint_checking()`, to
|
||||
determine if rows with invalid references were entered while constraint
|
||||
checks were off.
|
||||
"""
|
||||
if self.features.supports_pragma_foreign_key_check:
|
||||
with self.cursor() as cursor:
|
||||
if table_names is None:
|
||||
violations = cursor.execute('PRAGMA foreign_key_check').fetchall()
|
||||
else:
|
||||
violations = chain.from_iterable(
|
||||
cursor.execute(
|
||||
'PRAGMA foreign_key_check(%s)'
|
||||
% self.ops.quote_name(table_name)
|
||||
).fetchall()
|
||||
for table_name in table_names
|
||||
)
|
||||
# See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
|
||||
for table_name, rowid, referenced_table_name, foreign_key_index in violations:
|
||||
foreign_key = cursor.execute(
|
||||
'PRAGMA foreign_key_list(%s)' % self.ops.quote_name(table_name)
|
||||
).fetchall()[foreign_key_index]
|
||||
column_name, referenced_column_name = foreign_key[3:5]
|
||||
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
|
||||
primary_key_value, bad_value = cursor.execute(
|
||||
'SELECT %s, %s FROM %s WHERE rowid = %%s' % (
|
||||
self.ops.quote_name(primary_key_column_name),
|
||||
self.ops.quote_name(column_name),
|
||||
self.ops.quote_name(table_name),
|
||||
),
|
||||
(rowid,),
|
||||
).fetchone()
|
||||
raise IntegrityError(
|
||||
"The row in table '%s' with primary key '%s' has an "
|
||||
"invalid foreign key: %s.%s contains a value '%s' that "
|
||||
"does not have a corresponding value in %s.%s." % (
|
||||
table_name, primary_key_value, table_name, column_name,
|
||||
bad_value, referenced_table_name, referenced_column_name
|
||||
)
|
||||
)
|
||||
else:
|
||||
with self.cursor() as cursor:
|
||||
if table_names is None:
|
||||
table_names = self.introspection.table_names(cursor)
|
||||
for table_name in table_names:
|
||||
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
|
||||
if not primary_key_column_name:
|
||||
continue
|
||||
key_columns = self.introspection.get_key_columns(cursor, table_name)
|
||||
for column_name, referenced_table_name, referenced_column_name in key_columns:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
|
||||
LEFT JOIN `%s` as REFERRED
|
||||
ON (REFERRING.`%s` = REFERRED.`%s`)
|
||||
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
|
||||
"""
|
||||
% (
|
||||
primary_key_column_name, column_name, table_name,
|
||||
referenced_table_name, column_name, referenced_column_name,
|
||||
column_name, referenced_column_name,
|
||||
)
|
||||
)
|
||||
for bad_row in cursor.fetchall():
|
||||
raise IntegrityError(
|
||||
"The row in table '%s' with primary key '%s' has an "
|
||||
"invalid foreign key: %s.%s contains a value '%s' that "
|
||||
"does not have a corresponding value in %s.%s." % (
|
||||
table_name, bad_row[0], table_name, column_name,
|
||||
bad_row[1], referenced_table_name, referenced_column_name,
|
||||
)
|
||||
)
|
||||
|
||||
def is_usable(self):
|
||||
return True
|
||||
|
||||
def _start_transaction_under_autocommit(self):
|
||||
"""
|
||||
Start a transaction explicitly in autocommit mode.
|
||||
|
||||
Staying in autocommit mode works around a bug of sqlite3 that breaks
|
||||
savepoints when autocommit is disabled.
|
||||
"""
|
||||
self.cursor().execute("BEGIN")
|
||||
|
||||
def is_in_memory_db(self):
|
||||
return self.creation.is_in_memory_db(self.settings_dict['NAME'])
|
||||
|
||||
|
||||
FORMAT_QMARK_REGEX = _lazy_re_compile(r'(?<!%)%s')
|
||||
|
||||
|
||||
class SQLiteCursorWrapper(Database.Cursor):
|
||||
"""
|
||||
Django uses "format" style placeholders, but pysqlite2 uses "qmark" style.
|
||||
This fixes it -- but note that if you want to use a literal "%s" in a query,
|
||||
you'll need to use "%%s".
|
||||
"""
|
||||
def execute(self, query, params=None):
|
||||
if params is None:
|
||||
return Database.Cursor.execute(self, query)
|
||||
query = self.convert_query(query)
|
||||
return Database.Cursor.execute(self, query, params)
|
||||
|
||||
def executemany(self, query, param_list):
|
||||
query = self.convert_query(query)
|
||||
return Database.Cursor.executemany(self, query, param_list)
|
||||
|
||||
def convert_query(self, query):
|
||||
return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%')
|
||||
|
||||
|
||||
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
dt = backend_utils.typecast_timestamp(dt)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if conn_tzname:
|
||||
dt = dt.replace(tzinfo=timezone_constructor(conn_tzname))
|
||||
if tzname is not None and tzname != conn_tzname:
|
||||
tzname, sign, offset = backend_utils.split_tzname_delta(tzname)
|
||||
if offset:
|
||||
hours, minutes = offset.split(':')
|
||||
offset_delta = datetime.timedelta(hours=int(hours), minutes=int(minutes))
|
||||
dt += offset_delta if sign == '+' else -offset_delta
|
||||
dt = timezone.localtime(dt, timezone_constructor(tzname))
|
||||
return dt
|
||||
|
||||
|
||||
def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == 'year':
|
||||
return "%i-01-01" % dt.year
|
||||
elif lookup_type == 'quarter':
|
||||
month_in_quarter = dt.month - (dt.month - 1) % 3
|
||||
return '%i-%02i-01' % (dt.year, month_in_quarter)
|
||||
elif lookup_type == 'month':
|
||||
return "%i-%02i-01" % (dt.year, dt.month)
|
||||
elif lookup_type == 'week':
|
||||
dt = dt - datetime.timedelta(days=dt.weekday())
|
||||
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
|
||||
elif lookup_type == 'day':
|
||||
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
|
||||
|
||||
|
||||
def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
if dt is None:
|
||||
return None
|
||||
dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt_parsed is None:
|
||||
try:
|
||||
dt = backend_utils.typecast_time(dt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
else:
|
||||
dt = dt_parsed
|
||||
if lookup_type == 'hour':
|
||||
return "%02i:00:00" % dt.hour
|
||||
elif lookup_type == 'minute':
|
||||
return "%02i:%02i:00" % (dt.hour, dt.minute)
|
||||
elif lookup_type == 'second':
|
||||
return "%02i:%02i:%02i" % (dt.hour, dt.minute, dt.second)
|
||||
|
||||
|
||||
def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.date().isoformat()
|
||||
|
||||
|
||||
def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.time().isoformat()
|
||||
|
||||
|
||||
def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == 'week_day':
|
||||
return (dt.isoweekday() % 7) + 1
|
||||
elif lookup_type == 'iso_week_day':
|
||||
return dt.isoweekday()
|
||||
elif lookup_type == 'week':
|
||||
return dt.isocalendar()[1]
|
||||
elif lookup_type == 'quarter':
|
||||
return math.ceil(dt.month / 3)
|
||||
elif lookup_type == 'iso_year':
|
||||
return dt.isocalendar()[0]
|
||||
else:
|
||||
return getattr(dt, lookup_type)
|
||||
|
||||
|
||||
def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == 'year':
|
||||
return "%i-01-01 00:00:00" % dt.year
|
||||
elif lookup_type == 'quarter':
|
||||
month_in_quarter = dt.month - (dt.month - 1) % 3
|
||||
return '%i-%02i-01 00:00:00' % (dt.year, month_in_quarter)
|
||||
elif lookup_type == 'month':
|
||||
return "%i-%02i-01 00:00:00" % (dt.year, dt.month)
|
||||
elif lookup_type == 'week':
|
||||
dt = dt - datetime.timedelta(days=dt.weekday())
|
||||
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
|
||||
elif lookup_type == 'day':
|
||||
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
|
||||
elif lookup_type == 'hour':
|
||||
return "%i-%02i-%02i %02i:00:00" % (dt.year, dt.month, dt.day, dt.hour)
|
||||
elif lookup_type == 'minute':
|
||||
return "%i-%02i-%02i %02i:%02i:00" % (dt.year, dt.month, dt.day, dt.hour, dt.minute)
|
||||
elif lookup_type == 'second':
|
||||
return "%i-%02i-%02i %02i:%02i:%02i" % (dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second)
|
||||
|
||||
|
||||
def _sqlite_time_extract(lookup_type, dt):
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
dt = backend_utils.typecast_time(dt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
return getattr(dt, lookup_type)
|
||||
|
||||
|
||||
def _sqlite_prepare_dtdelta_param(conn, param):
|
||||
if conn in ['+', '-']:
|
||||
if isinstance(param, int):
|
||||
return datetime.timedelta(0, 0, param)
|
||||
else:
|
||||
return backend_utils.typecast_timestamp(param)
|
||||
return param
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_format_dtdelta(conn, lhs, rhs):
|
||||
"""
|
||||
LHS and RHS can be either:
|
||||
- An integer number of microseconds
|
||||
- A string representing a datetime
|
||||
- A scalar value, e.g. float
|
||||
"""
|
||||
conn = conn.strip()
|
||||
try:
|
||||
real_lhs = _sqlite_prepare_dtdelta_param(conn, lhs)
|
||||
real_rhs = _sqlite_prepare_dtdelta_param(conn, rhs)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
if conn == '+':
|
||||
# typecast_timestamp returns a date or a datetime without timezone.
|
||||
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
|
||||
out = str(real_lhs + real_rhs)
|
||||
elif conn == '-':
|
||||
out = str(real_lhs - real_rhs)
|
||||
elif conn == '*':
|
||||
out = real_lhs * real_rhs
|
||||
else:
|
||||
out = real_lhs / real_rhs
|
||||
return out
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_time_diff(lhs, rhs):
|
||||
left = backend_utils.typecast_time(lhs)
|
||||
right = backend_utils.typecast_time(rhs)
|
||||
return (
|
||||
(left.hour * 60 * 60 * 1000000) +
|
||||
(left.minute * 60 * 1000000) +
|
||||
(left.second * 1000000) +
|
||||
(left.microsecond) -
|
||||
(right.hour * 60 * 60 * 1000000) -
|
||||
(right.minute * 60 * 1000000) -
|
||||
(right.second * 1000000) -
|
||||
(right.microsecond)
|
||||
)
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_timestamp_diff(lhs, rhs):
|
||||
left = backend_utils.typecast_timestamp(lhs)
|
||||
right = backend_utils.typecast_timestamp(rhs)
|
||||
return duration_microseconds(left - right)
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_regexp(re_pattern, re_string):
|
||||
return bool(re.search(re_pattern, str(re_string)))
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_lpad(text, length, fill_text):
|
||||
if len(text) >= length:
|
||||
return text[:length]
|
||||
return (fill_text * length)[:length - len(text)] + text
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_rpad(text, length, fill_text):
|
||||
return (text + fill_text * length)[:length]
|
||||
10
venv/Lib/site-packages/django/db/backends/sqlite3/client.py
Normal file
10
venv/Lib/site-packages/django/db/backends/sqlite3/client.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = 'sqlite3'
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name, settings_dict['NAME'], *parameters]
|
||||
return args, None
|
||||
103
venv/Lib/site-packages/django/db/backends/sqlite3/creation.py
Normal file
103
venv/Lib/site-packages/django/db/backends/sqlite3/creation.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
|
||||
@staticmethod
|
||||
def is_in_memory_db(database_name):
|
||||
return not isinstance(database_name, Path) and (
|
||||
database_name == ':memory:' or 'mode=memory' in database_name
|
||||
)
|
||||
|
||||
def _get_test_db_name(self):
|
||||
test_database_name = self.connection.settings_dict['TEST']['NAME'] or ':memory:'
|
||||
if test_database_name == ':memory:':
|
||||
return 'file:memorydb_%s?mode=memory&cache=shared' % self.connection.alias
|
||||
return test_database_name
|
||||
|
||||
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
|
||||
test_database_name = self._get_test_db_name()
|
||||
|
||||
if keepdb:
|
||||
return test_database_name
|
||||
if not self.is_in_memory_db(test_database_name):
|
||||
# Erase the old test database
|
||||
if verbosity >= 1:
|
||||
self.log('Destroying old test database for alias %s...' % (
|
||||
self._get_database_display_str(verbosity, test_database_name),
|
||||
))
|
||||
if os.access(test_database_name, os.F_OK):
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"Type 'yes' if you would like to try deleting the test "
|
||||
"database '%s', or 'no' to cancel: " % test_database_name
|
||||
)
|
||||
if autoclobber or confirm == 'yes':
|
||||
try:
|
||||
os.remove(test_database_name)
|
||||
except Exception as e:
|
||||
self.log('Got an error deleting the old test database: %s' % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log('Tests cancelled.')
|
||||
sys.exit(1)
|
||||
return test_database_name
|
||||
|
||||
def get_test_db_clone_settings(self, suffix):
|
||||
orig_settings_dict = self.connection.settings_dict
|
||||
source_database_name = orig_settings_dict['NAME']
|
||||
if self.is_in_memory_db(source_database_name):
|
||||
return orig_settings_dict
|
||||
else:
|
||||
root, ext = os.path.splitext(orig_settings_dict['NAME'])
|
||||
return {**orig_settings_dict, 'NAME': '{}_{}{}'.format(root, suffix, ext)}
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
source_database_name = self.connection.settings_dict['NAME']
|
||||
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
|
||||
# Forking automatically makes a copy of an in-memory database.
|
||||
if not self.is_in_memory_db(source_database_name):
|
||||
# Erase the old test database
|
||||
if os.access(target_database_name, os.F_OK):
|
||||
if keepdb:
|
||||
return
|
||||
if verbosity >= 1:
|
||||
self.log('Destroying old test database for alias %s...' % (
|
||||
self._get_database_display_str(verbosity, target_database_name),
|
||||
))
|
||||
try:
|
||||
os.remove(target_database_name)
|
||||
except Exception as e:
|
||||
self.log('Got an error deleting the old test database: %s' % e)
|
||||
sys.exit(2)
|
||||
try:
|
||||
shutil.copy(source_database_name, target_database_name)
|
||||
except Exception as e:
|
||||
self.log('Got an error cloning the test database: %s' % e)
|
||||
sys.exit(2)
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity):
|
||||
if test_database_name and not self.is_in_memory_db(test_database_name):
|
||||
# Remove the SQLite database file
|
||||
os.remove(test_database_name)
|
||||
|
||||
def test_db_signature(self):
|
||||
"""
|
||||
Return a tuple that uniquely identifies a test database.
|
||||
|
||||
This takes into account the special cases of ":memory:" and "" for
|
||||
SQLite since the databases will be distinct despite having the same
|
||||
TEST NAME. See https://www.sqlite.org/inmemorydb.html
|
||||
"""
|
||||
test_database_name = self._get_test_db_name()
|
||||
sig = [self.connection.settings_dict['NAME']]
|
||||
if self.is_in_memory_db(test_database_name):
|
||||
sig.append(self.connection.alias)
|
||||
else:
|
||||
sig.append(test_database_name)
|
||||
return tuple(sig)
|
||||
126
venv/Lib/site-packages/django/db/backends/sqlite3/features.py
Normal file
126
venv/Lib/site-packages/django/db/backends/sqlite3/features.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import operator
|
||||
import platform
|
||||
|
||||
from django.db import transaction
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.db.utils import OperationalError
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from .base import Database
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
# SQLite can read from a cursor since SQLite 3.6.5, subject to the caveat
|
||||
# that statements within a connection aren't isolated from each other. See
|
||||
# https://sqlite.org/isolation.html.
|
||||
can_use_chunked_reads = True
|
||||
test_db_allows_multiple_connections = False
|
||||
supports_unspecified_pk = True
|
||||
supports_timezones = False
|
||||
max_query_params = 999
|
||||
supports_mixed_date_datetime_comparisons = False
|
||||
supports_transactions = True
|
||||
atomic_transactions = False
|
||||
can_rollback_ddl = True
|
||||
can_create_inline_fk = False
|
||||
supports_paramstyle_pyformat = False
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
ignores_table_name_case = True
|
||||
supports_cast_with_precision = False
|
||||
time_cast_precision = 3
|
||||
can_release_savepoints = True
|
||||
# Is "ALTER TABLE ... RENAME COLUMN" supported?
|
||||
can_alter_table_rename_column = Database.sqlite_version_info >= (3, 25, 0)
|
||||
supports_parentheses_in_compound = False
|
||||
# Deferred constraint checks can be emulated on SQLite < 3.20 but not in a
|
||||
# reasonably performant way.
|
||||
supports_pragma_foreign_key_check = Database.sqlite_version_info >= (3, 20, 0)
|
||||
can_defer_constraint_checks = supports_pragma_foreign_key_check
|
||||
supports_functions_in_partial_indexes = Database.sqlite_version_info >= (3, 15, 0)
|
||||
supports_over_clause = Database.sqlite_version_info >= (3, 25, 0)
|
||||
supports_frame_range_fixed_distance = Database.sqlite_version_info >= (3, 28, 0)
|
||||
supports_aggregate_filter_clause = Database.sqlite_version_info >= (3, 30, 1)
|
||||
supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
|
||||
order_by_nulls_first = True
|
||||
supports_json_field_contains = False
|
||||
test_collations = {
|
||||
'ci': 'nocase',
|
||||
'cs': 'binary',
|
||||
'non_default': 'nocase',
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def django_test_skips(self):
|
||||
skips = {
|
||||
'SQLite stores values rounded to 15 significant digits.': {
|
||||
'model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding',
|
||||
},
|
||||
'SQLite naively remakes the table on field alteration.': {
|
||||
'schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops',
|
||||
'schema.tests.SchemaTests.test_unique_and_reverse_m2m',
|
||||
'schema.tests.SchemaTests.test_alter_field_default_doesnt_perform_queries',
|
||||
'schema.tests.SchemaTests.test_rename_column_renames_deferred_sql_references',
|
||||
},
|
||||
"SQLite doesn't have a constraint.": {
|
||||
'model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values',
|
||||
},
|
||||
"SQLite doesn't support negative precision for ROUND().": {
|
||||
'db_functions.math.test_round.RoundTests.test_null_with_negative_precision',
|
||||
'db_functions.math.test_round.RoundTests.test_decimal_with_negative_precision',
|
||||
'db_functions.math.test_round.RoundTests.test_float_with_negative_precision',
|
||||
'db_functions.math.test_round.RoundTests.test_integer_with_negative_precision',
|
||||
},
|
||||
}
|
||||
if Database.sqlite_version_info < (3, 27):
|
||||
skips.update({
|
||||
'Nondeterministic failure on SQLite < 3.27.': {
|
||||
'expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank',
|
||||
},
|
||||
})
|
||||
if self.connection.is_in_memory_db():
|
||||
skips.update({
|
||||
"the sqlite backend's close() method is a no-op when using an "
|
||||
"in-memory database": {
|
||||
'servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections',
|
||||
'servers.tests.LiveServerTestCloseConnectionTest.test_closes_connections',
|
||||
},
|
||||
})
|
||||
return skips
|
||||
|
||||
@cached_property
|
||||
def supports_atomic_references_rename(self):
|
||||
# SQLite 3.28.0 bundled with MacOS 10.15 does not support renaming
|
||||
# references atomically.
|
||||
if platform.mac_ver()[0].startswith('10.15.') and Database.sqlite_version_info == (3, 28, 0):
|
||||
return False
|
||||
return Database.sqlite_version_info >= (3, 26, 0)
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return{
|
||||
**super().introspected_field_types,
|
||||
'BigAutoField': 'AutoField',
|
||||
'DurationField': 'BigIntegerField',
|
||||
'GenericIPAddressField': 'CharField',
|
||||
'SmallAutoField': 'AutoField',
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def supports_json_field(self):
|
||||
with self.connection.cursor() as cursor:
|
||||
try:
|
||||
with transaction.atomic(self.connection.alias):
|
||||
cursor.execute('SELECT JSON(\'{"a": "b"}\')')
|
||||
except OperationalError:
|
||||
return False
|
||||
return True
|
||||
|
||||
can_introspect_json_field = property(operator.attrgetter('supports_json_field'))
|
||||
has_json_object_function = property(operator.attrgetter('supports_json_field'))
|
||||
|
||||
@cached_property
|
||||
def can_return_columns_from_insert(self):
|
||||
return Database.sqlite_version_info >= (3, 35)
|
||||
|
||||
can_return_rows_from_bulk_insert = property(operator.attrgetter('can_return_columns_from_insert'))
|
||||
@@ -0,0 +1,470 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
import sqlparse
|
||||
|
||||
from django.db.backends.base.introspection import (
|
||||
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
|
||||
)
|
||||
from django.db.models import Index
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk', 'has_json_constraint'))
|
||||
|
||||
field_size_re = _lazy_re_compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$')
|
||||
|
||||
|
||||
def get_field_size(name):
|
||||
""" Extract the size number from a "varchar(11)" type name """
|
||||
m = field_size_re.search(name)
|
||||
return int(m[1]) if m else None
|
||||
|
||||
|
||||
# This light wrapper "fakes" a dictionary interface, because some SQLite data
|
||||
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
|
||||
# as a simple dictionary lookup.
|
||||
class FlexibleFieldLookupDict:
|
||||
# Maps SQL types to Django Field types. Some of the SQL types have multiple
|
||||
# entries here because SQLite allows for anything and doesn't normalize the
|
||||
# field type; it uses whatever was given.
|
||||
base_data_types_reverse = {
|
||||
'bool': 'BooleanField',
|
||||
'boolean': 'BooleanField',
|
||||
'smallint': 'SmallIntegerField',
|
||||
'smallint unsigned': 'PositiveSmallIntegerField',
|
||||
'smallinteger': 'SmallIntegerField',
|
||||
'int': 'IntegerField',
|
||||
'integer': 'IntegerField',
|
||||
'bigint': 'BigIntegerField',
|
||||
'integer unsigned': 'PositiveIntegerField',
|
||||
'bigint unsigned': 'PositiveBigIntegerField',
|
||||
'decimal': 'DecimalField',
|
||||
'real': 'FloatField',
|
||||
'text': 'TextField',
|
||||
'char': 'CharField',
|
||||
'varchar': 'CharField',
|
||||
'blob': 'BinaryField',
|
||||
'date': 'DateField',
|
||||
'datetime': 'DateTimeField',
|
||||
'time': 'TimeField',
|
||||
}
|
||||
|
||||
def __getitem__(self, key):
|
||||
key = key.lower().split('(', 1)[0].strip()
|
||||
return self.base_data_types_reverse[key]
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
data_types_reverse = FlexibleFieldLookupDict()
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
field_type = super().get_field_type(data_type, description)
|
||||
if description.pk and field_type in {'BigIntegerField', 'IntegerField', 'SmallIntegerField'}:
|
||||
# No support for BigAutoField or SmallAutoField as SQLite treats
|
||||
# all integer primary keys as signed 64-bit integers.
|
||||
return 'AutoField'
|
||||
if description.has_json_constraint:
|
||||
return 'JSONField'
|
||||
return field_type
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
# Skip the sqlite_sequence system table used for autoincrement key
|
||||
# generation.
|
||||
cursor.execute("""
|
||||
SELECT name, type FROM sqlite_master
|
||||
WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
|
||||
ORDER BY name""")
|
||||
return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name))
|
||||
table_info = cursor.fetchall()
|
||||
collations = self._get_column_collations(cursor, table_name)
|
||||
json_columns = set()
|
||||
if self.connection.features.can_introspect_json_field:
|
||||
for line in table_info:
|
||||
column = line[1]
|
||||
json_constraint_sql = '%%json_valid("%s")%%' % column
|
||||
has_json_constraint = cursor.execute("""
|
||||
SELECT sql
|
||||
FROM sqlite_master
|
||||
WHERE
|
||||
type = 'table' AND
|
||||
name = %s AND
|
||||
sql LIKE %s
|
||||
""", [table_name, json_constraint_sql]).fetchone()
|
||||
if has_json_constraint:
|
||||
json_columns.add(column)
|
||||
return [
|
||||
FieldInfo(
|
||||
name, data_type, None, get_field_size(data_type), None, None,
|
||||
not notnull, default, collations.get(name), pk == 1, name in json_columns
|
||||
)
|
||||
for cid, name, data_type, notnull, default, pk in table_info
|
||||
]
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
pk_col = self.get_primary_key_column(cursor, table_name)
|
||||
return [{'table': table_name, 'column': pk_col}]
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all relationships to the given table.
|
||||
"""
|
||||
# Dictionary of relations to return
|
||||
relations = {}
|
||||
|
||||
# Schema for this table
|
||||
cursor.execute(
|
||||
"SELECT sql, type FROM sqlite_master "
|
||||
"WHERE tbl_name = %s AND type IN ('table', 'view')",
|
||||
[table_name]
|
||||
)
|
||||
create_sql, table_type = cursor.fetchone()
|
||||
if table_type == 'view':
|
||||
# It might be a view, then no results will be returned
|
||||
return relations
|
||||
results = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
|
||||
|
||||
# Walk through and look for references to other tables. SQLite doesn't
|
||||
# really have enforced references, but since it echoes out the SQL used
|
||||
# to create the table we can look for REFERENCES statements used there.
|
||||
for field_desc in results.split(','):
|
||||
field_desc = field_desc.strip()
|
||||
if field_desc.startswith("UNIQUE"):
|
||||
continue
|
||||
|
||||
m = re.search(r'references (\S*) ?\(["|]?(.*)["|]?\)', field_desc, re.I)
|
||||
if not m:
|
||||
continue
|
||||
table, column = [s.strip('"') for s in m.groups()]
|
||||
|
||||
if field_desc.startswith("FOREIGN KEY"):
|
||||
# Find name of the target FK field
|
||||
m = re.match(r'FOREIGN KEY\s*\(([^\)]*)\).*', field_desc, re.I)
|
||||
field_name = m[1].strip('"')
|
||||
else:
|
||||
field_name = field_desc.split()[0].strip('"')
|
||||
|
||||
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s", [table])
|
||||
result = cursor.fetchall()[0]
|
||||
other_table_results = result[0].strip()
|
||||
li, ri = other_table_results.index('('), other_table_results.rindex(')')
|
||||
other_table_results = other_table_results[li + 1:ri]
|
||||
|
||||
for other_desc in other_table_results.split(','):
|
||||
other_desc = other_desc.strip()
|
||||
if other_desc.startswith('UNIQUE'):
|
||||
continue
|
||||
|
||||
other_name = other_desc.split(' ', 1)[0].strip('"')
|
||||
if other_name == column:
|
||||
relations[field_name] = (other_name, table)
|
||||
break
|
||||
|
||||
return relations
|
||||
|
||||
def get_key_columns(self, cursor, table_name):
|
||||
"""
|
||||
Return a list of (column_name, referenced_table_name, referenced_column_name)
|
||||
for all key columns in given table.
|
||||
"""
|
||||
key_columns = []
|
||||
|
||||
# Schema for this table
|
||||
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"])
|
||||
results = cursor.fetchone()[0].strip()
|
||||
results = results[results.index('(') + 1:results.rindex(')')]
|
||||
|
||||
# Walk through and look for references to other tables. SQLite doesn't
|
||||
# really have enforced references, but since it echoes out the SQL used
|
||||
# to create the table we can look for REFERENCES statements used there.
|
||||
for field_index, field_desc in enumerate(results.split(',')):
|
||||
field_desc = field_desc.strip()
|
||||
if field_desc.startswith("UNIQUE"):
|
||||
continue
|
||||
|
||||
m = re.search(r'"(.*)".*references (.*) \(["|](.*)["|]\)', field_desc, re.I)
|
||||
if not m:
|
||||
continue
|
||||
|
||||
# This will append (column_name, referenced_table_name, referenced_column_name) to key_columns
|
||||
key_columns.append(tuple(s.strip('"') for s in m.groups()))
|
||||
|
||||
return key_columns
|
||||
|
||||
def get_primary_key_column(self, cursor, table_name):
|
||||
"""Return the column name of the primary key for the given table."""
|
||||
# Don't use PRAGMA because that causes issues with some transactions
|
||||
cursor.execute(
|
||||
"SELECT sql, type FROM sqlite_master "
|
||||
"WHERE tbl_name = %s AND type IN ('table', 'view')",
|
||||
[table_name]
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise ValueError("Table %s does not exist" % table_name)
|
||||
create_sql, table_type = row
|
||||
if table_type == 'view':
|
||||
# Views don't have a primary key.
|
||||
return None
|
||||
fields_sql = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
|
||||
for field_desc in fields_sql.split(','):
|
||||
field_desc = field_desc.strip()
|
||||
m = re.match(r'(?:(?:["`\[])(.*)(?:["`\]])|(\w+)).*PRIMARY KEY.*', field_desc)
|
||||
if m:
|
||||
return m[1] if m[1] else m[2]
|
||||
return None
|
||||
|
||||
def _get_foreign_key_constraints(self, cursor, table_name):
|
||||
constraints = {}
|
||||
cursor.execute('PRAGMA foreign_key_list(%s)' % self.connection.ops.quote_name(table_name))
|
||||
for row in cursor.fetchall():
|
||||
# Remaining on_update/on_delete/match values are of no interest.
|
||||
id_, _, table, from_, to = row[:5]
|
||||
constraints['fk_%d' % id_] = {
|
||||
'columns': [from_],
|
||||
'primary_key': False,
|
||||
'unique': False,
|
||||
'foreign_key': (table, to),
|
||||
'check': False,
|
||||
'index': False,
|
||||
}
|
||||
return constraints
|
||||
|
||||
def _parse_column_or_constraint_definition(self, tokens, columns):
|
||||
token = None
|
||||
is_constraint_definition = None
|
||||
field_name = None
|
||||
constraint_name = None
|
||||
unique = False
|
||||
unique_columns = []
|
||||
check = False
|
||||
check_columns = []
|
||||
braces_deep = 0
|
||||
for token in tokens:
|
||||
if token.match(sqlparse.tokens.Punctuation, '('):
|
||||
braces_deep += 1
|
||||
elif token.match(sqlparse.tokens.Punctuation, ')'):
|
||||
braces_deep -= 1
|
||||
if braces_deep < 0:
|
||||
# End of columns and constraints for table definition.
|
||||
break
|
||||
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','):
|
||||
# End of current column or constraint definition.
|
||||
break
|
||||
# Detect column or constraint definition by first token.
|
||||
if is_constraint_definition is None:
|
||||
is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT')
|
||||
if is_constraint_definition:
|
||||
continue
|
||||
if is_constraint_definition:
|
||||
# Detect constraint name by second token.
|
||||
if constraint_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
constraint_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
constraint_name = token.value[1:-1]
|
||||
# Start constraint columns parsing after UNIQUE keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
|
||||
unique = True
|
||||
unique_braces_deep = braces_deep
|
||||
elif unique:
|
||||
if unique_braces_deep == braces_deep:
|
||||
if unique_columns:
|
||||
# Stop constraint parsing.
|
||||
unique = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
unique_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
unique_columns.append(token.value[1:-1])
|
||||
else:
|
||||
# Detect field name by first token.
|
||||
if field_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
field_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
field_name = token.value[1:-1]
|
||||
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
|
||||
unique_columns = [field_name]
|
||||
# Start constraint columns parsing after CHECK keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, 'CHECK'):
|
||||
check = True
|
||||
check_braces_deep = braces_deep
|
||||
elif check:
|
||||
if check_braces_deep == braces_deep:
|
||||
if check_columns:
|
||||
# Stop constraint parsing.
|
||||
check = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
if token.value in columns:
|
||||
check_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
if token.value[1:-1] in columns:
|
||||
check_columns.append(token.value[1:-1])
|
||||
unique_constraint = {
|
||||
'unique': True,
|
||||
'columns': unique_columns,
|
||||
'primary_key': False,
|
||||
'foreign_key': None,
|
||||
'check': False,
|
||||
'index': False,
|
||||
} if unique_columns else None
|
||||
check_constraint = {
|
||||
'check': True,
|
||||
'columns': check_columns,
|
||||
'primary_key': False,
|
||||
'unique': False,
|
||||
'foreign_key': None,
|
||||
'index': False,
|
||||
} if check_columns else None
|
||||
return constraint_name, unique_constraint, check_constraint, token
|
||||
|
||||
def _parse_table_constraints(self, sql, columns):
|
||||
# Check constraint parsing is based of SQLite syntax diagram.
|
||||
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
constraints = {}
|
||||
unnamed_constrains_index = 0
|
||||
tokens = (token for token in statement.flatten() if not token.is_whitespace)
|
||||
# Go to columns and constraint definition
|
||||
for token in tokens:
|
||||
if token.match(sqlparse.tokens.Punctuation, '('):
|
||||
break
|
||||
# Parse columns and constraint definition
|
||||
while True:
|
||||
constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns)
|
||||
if unique:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = unique
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique
|
||||
if check:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = check
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check
|
||||
if end_token.match(sqlparse.tokens.Punctuation, ')'):
|
||||
break
|
||||
return constraints
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns.
|
||||
"""
|
||||
constraints = {}
|
||||
# Find inline check constraints.
|
||||
try:
|
||||
table_schema = cursor.execute(
|
||||
"SELECT sql FROM sqlite_master WHERE type='table' and name=%s" % (
|
||||
self.connection.ops.quote_name(table_name),
|
||||
)
|
||||
).fetchone()[0]
|
||||
except TypeError:
|
||||
# table_name is a view.
|
||||
pass
|
||||
else:
|
||||
columns = {info.name for info in self.get_table_description(cursor, table_name)}
|
||||
constraints.update(self._parse_table_constraints(table_schema, columns))
|
||||
|
||||
# Get the index info
|
||||
cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name))
|
||||
for row in cursor.fetchall():
|
||||
# SQLite 3.8.9+ has 5 columns, however older versions only give 3
|
||||
# columns. Discard last 2 columns if there.
|
||||
number, index, unique = row[:3]
|
||||
cursor.execute(
|
||||
"SELECT sql FROM sqlite_master "
|
||||
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
|
||||
)
|
||||
# There's at most one row.
|
||||
sql, = cursor.fetchone() or (None,)
|
||||
# Inline constraints are already detected in
|
||||
# _parse_table_constraints(). The reasons to avoid fetching inline
|
||||
# constraints from `PRAGMA index_list` are:
|
||||
# - Inline constraints can have a different name and information
|
||||
# than what `PRAGMA index_list` gives.
|
||||
# - Not all inline constraints may appear in `PRAGMA index_list`.
|
||||
if not sql:
|
||||
# An inline constraint
|
||||
continue
|
||||
# Get the index info for that index
|
||||
cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
|
||||
for index_rank, column_rank, column in cursor.fetchall():
|
||||
if index not in constraints:
|
||||
constraints[index] = {
|
||||
"columns": [],
|
||||
"primary_key": False,
|
||||
"unique": bool(unique),
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": True,
|
||||
}
|
||||
constraints[index]['columns'].append(column)
|
||||
# Add type and column orders for indexes
|
||||
if constraints[index]['index']:
|
||||
# SQLite doesn't support any index type other than b-tree
|
||||
constraints[index]['type'] = Index.suffix
|
||||
orders = self._get_index_columns_orders(sql)
|
||||
if orders is not None:
|
||||
constraints[index]['orders'] = orders
|
||||
# Get the PK
|
||||
pk_column = self.get_primary_key_column(cursor, table_name)
|
||||
if pk_column:
|
||||
# SQLite doesn't actually give a name to the PK constraint,
|
||||
# so we invent one. This is fine, as the SQLite backend never
|
||||
# deletes PK constraints by name, as you can't delete constraints
|
||||
# in SQLite; we remake the table with a new PK instead.
|
||||
constraints["__primary__"] = {
|
||||
"columns": [pk_column],
|
||||
"primary_key": True,
|
||||
"unique": False, # It's not actually a unique constraint.
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": False,
|
||||
}
|
||||
constraints.update(self._get_foreign_key_constraints(cursor, table_name))
|
||||
return constraints
|
||||
|
||||
def _get_index_columns_orders(self, sql):
|
||||
tokens = sqlparse.parse(sql)[0]
|
||||
for token in tokens:
|
||||
if isinstance(token, sqlparse.sql.Parenthesis):
|
||||
columns = str(token).strip('()').split(', ')
|
||||
return ['DESC' if info.endswith('DESC') else 'ASC' for info in columns]
|
||||
return None
|
||||
|
||||
def _get_column_collations(self, cursor, table_name):
|
||||
row = cursor.execute("""
|
||||
SELECT sql
|
||||
FROM sqlite_master
|
||||
WHERE type = 'table' AND name = %s
|
||||
""", [table_name]).fetchone()
|
||||
if not row:
|
||||
return {}
|
||||
|
||||
sql = row[0]
|
||||
columns = str(sqlparse.parse(sql)[0][-1]).strip('()').split(', ')
|
||||
collations = {}
|
||||
for column in columns:
|
||||
tokens = column[1:].split()
|
||||
column_name = tokens[0].strip('"')
|
||||
for index, token in enumerate(tokens):
|
||||
if token == 'COLLATE':
|
||||
collation = tokens[index + 1]
|
||||
break
|
||||
else:
|
||||
collation = None
|
||||
collations[column_name] = collation
|
||||
return collations
|
||||
386
venv/Lib/site-packages/django/db/backends/sqlite3/operations.py
Normal file
386
venv/Lib/site-packages/django/db/backends/sqlite3/operations.py
Normal file
@@ -0,0 +1,386 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import DatabaseError, NotSupportedError, models
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.models.expressions import Col
|
||||
from django.utils import timezone
|
||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
cast_char_field_without_max_length = 'text'
|
||||
cast_data_types = {
|
||||
'DateField': 'TEXT',
|
||||
'DateTimeField': 'TEXT',
|
||||
}
|
||||
explain_prefix = 'EXPLAIN QUERY PLAN'
|
||||
# List of datatypes to that cannot be extracted with JSON_EXTRACT() on
|
||||
# SQLite. Use JSON_TYPE() instead.
|
||||
jsonfield_datatype_values = frozenset(['null', 'false', 'true'])
|
||||
|
||||
def bulk_batch_size(self, fields, objs):
|
||||
"""
|
||||
SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
|
||||
999 variables per query.
|
||||
|
||||
If there's only a single field to insert, the limit is 500
|
||||
(SQLITE_MAX_COMPOUND_SELECT).
|
||||
"""
|
||||
if len(fields) == 1:
|
||||
return 500
|
||||
elif len(fields) > 1:
|
||||
return self.connection.features.max_query_params // len(fields)
|
||||
else:
|
||||
return len(objs)
|
||||
|
||||
def check_expression_support(self, expression):
|
||||
bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
|
||||
bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
|
||||
if isinstance(expression, bad_aggregates):
|
||||
for expr in expression.get_source_expressions():
|
||||
try:
|
||||
output_field = expr.output_field
|
||||
except (AttributeError, FieldError):
|
||||
# Not every subexpression has an output_field which is fine
|
||||
# to ignore.
|
||||
pass
|
||||
else:
|
||||
if isinstance(output_field, bad_fields):
|
||||
raise NotSupportedError(
|
||||
'You cannot use Sum, Avg, StdDev, and Variance '
|
||||
'aggregations on date/time fields in sqlite3 '
|
||||
'since date/time is saved as text.'
|
||||
)
|
||||
if (
|
||||
isinstance(expression, models.Aggregate) and
|
||||
expression.distinct and
|
||||
len(expression.source_expressions) > 1
|
||||
):
|
||||
raise NotSupportedError(
|
||||
"SQLite doesn't support DISTINCT on aggregate functions "
|
||||
"accepting multiple arguments."
|
||||
)
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
"""
|
||||
Support EXTRACT with a user-defined function django_date_extract()
|
||||
that's registered in connect(). Use single quotes because this is a
|
||||
string and could otherwise cause a collision with a field name.
|
||||
"""
|
||||
return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the list of returned data.
|
||||
"""
|
||||
return cursor.fetchall()
|
||||
|
||||
def format_for_duration_arithmetic(self, sql):
|
||||
"""Do nothing since formatting is handled in the custom function."""
|
||||
return sql
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
return "django_date_trunc('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
return "django_time_trunc('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def _convert_tznames_to_sql(self, tzname):
|
||||
if tzname and settings.USE_TZ:
|
||||
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
|
||||
return 'NULL', 'NULL'
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
return 'django_datetime_cast_date(%s, %s, %s)' % (
|
||||
field_name, *self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
return 'django_datetime_cast_time(%s, %s, %s)' % (
|
||||
field_name, *self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
return "django_datetime_extract('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
return "django_datetime_trunc('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_extract_sql(self, lookup_type, field_name):
|
||||
return "django_time_extract('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
|
||||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
||||
def _quote_params_for_last_executed_query(self, params):
|
||||
"""
|
||||
Only for last_executed_query! Don't use this to execute SQL queries!
|
||||
"""
|
||||
# This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
|
||||
# number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
|
||||
# number of return values, default = 2000). Since Python's sqlite3
|
||||
# module doesn't expose the get_limit() C API, assume the default
|
||||
# limits are in effect and split the work in batches if needed.
|
||||
BATCH_SIZE = 999
|
||||
if len(params) > BATCH_SIZE:
|
||||
results = ()
|
||||
for index in range(0, len(params), BATCH_SIZE):
|
||||
chunk = params[index:index + BATCH_SIZE]
|
||||
results += self._quote_params_for_last_executed_query(chunk)
|
||||
return results
|
||||
|
||||
sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params))
|
||||
# Bypass Django's wrappers and use the underlying sqlite3 connection
|
||||
# to avoid logging this query - it would trigger infinite recursion.
|
||||
cursor = self.connection.connection.cursor()
|
||||
# Native sqlite3 cursors cannot be used as context managers.
|
||||
try:
|
||||
return cursor.execute(sql, params).fetchone()
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# Python substitutes parameters in Modules/_sqlite/cursor.c with:
|
||||
# pysqlite_statement_bind_parameters(self->statement, parameters, allow_8bit_chars);
|
||||
# Unfortunately there is no way to reach self->statement from Python,
|
||||
# so we quote and substitute parameters manually.
|
||||
if params:
|
||||
if isinstance(params, (list, tuple)):
|
||||
params = self._quote_params_for_last_executed_query(params)
|
||||
else:
|
||||
values = tuple(params.values())
|
||||
values = self._quote_params_for_last_executed_query(values)
|
||||
params = dict(zip(params, values))
|
||||
return sql % params
|
||||
# For consistency with SQLiteCursorWrapper.execute(), just return sql
|
||||
# when there are no parameters. See #13648 and #17158.
|
||||
else:
|
||||
return sql
|
||||
|
||||
def quote_name(self, name):
|
||||
if name.startswith('"') and name.endswith('"'):
|
||||
return name # Quoting once is enough.
|
||||
return '"%s"' % name
|
||||
|
||||
def no_limit_value(self):
|
||||
return -1
|
||||
|
||||
def __references_graph(self, table_name):
|
||||
query = """
|
||||
WITH tables AS (
|
||||
SELECT %s name
|
||||
UNION
|
||||
SELECT sqlite_master.name
|
||||
FROM sqlite_master
|
||||
JOIN tables ON (sql REGEXP %s || tables.name || %s)
|
||||
) SELECT name FROM tables;
|
||||
"""
|
||||
params = (
|
||||
table_name,
|
||||
r'(?i)\s+references\s+("|\')?',
|
||||
r'("|\')?\s*\(',
|
||||
)
|
||||
with self.connection.cursor() as cursor:
|
||||
results = cursor.execute(query, params)
|
||||
return [row[0] for row in results.fetchall()]
|
||||
|
||||
@cached_property
|
||||
def _references_graph(self):
|
||||
# 512 is large enough to fit the ~330 tables (as of this writing) in
|
||||
# Django's test suite.
|
||||
return lru_cache(maxsize=512)(self.__references_graph)
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if tables and allow_cascade:
|
||||
# Simulate TRUNCATE CASCADE by recursively collecting the tables
|
||||
# referencing the tables to be flushed.
|
||||
tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
|
||||
sql = ['%s %s %s;' % (
|
||||
style.SQL_KEYWORD('DELETE'),
|
||||
style.SQL_KEYWORD('FROM'),
|
||||
style.SQL_FIELD(self.quote_name(table))
|
||||
) for table in tables]
|
||||
if reset_sequences:
|
||||
sequences = [{'table': table} for table in tables]
|
||||
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
|
||||
return sql
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
if not sequences:
|
||||
return []
|
||||
return [
|
||||
'%s %s %s %s = 0 %s %s %s (%s);' % (
|
||||
style.SQL_KEYWORD('UPDATE'),
|
||||
style.SQL_TABLE(self.quote_name('sqlite_sequence')),
|
||||
style.SQL_KEYWORD('SET'),
|
||||
style.SQL_FIELD(self.quote_name('seq')),
|
||||
style.SQL_KEYWORD('WHERE'),
|
||||
style.SQL_FIELD(self.quote_name('name')),
|
||||
style.SQL_KEYWORD('IN'),
|
||||
', '.join([
|
||||
"'%s'" % sequence_info['table'] for sequence_info in sequences
|
||||
]),
|
||||
),
|
||||
]
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
value = timezone.make_naive(value, self.connection.timezone)
|
||||
else:
|
||||
raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.")
|
||||
|
||||
return str(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("SQLite backend does not support timezone-aware times.")
|
||||
|
||||
return str(value)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type == 'DateTimeField':
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == 'DateField':
|
||||
converters.append(self.convert_datefield_value)
|
||||
elif internal_type == 'TimeField':
|
||||
converters.append(self.convert_timefield_value)
|
||||
elif internal_type == 'DecimalField':
|
||||
converters.append(self.get_decimalfield_converter(expression))
|
||||
elif internal_type == 'UUIDField':
|
||||
converters.append(self.convert_uuidfield_value)
|
||||
elif internal_type == 'BooleanField':
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
return converters
|
||||
|
||||
def convert_datetimefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.datetime):
|
||||
value = parse_datetime(value)
|
||||
if settings.USE_TZ and not timezone.is_aware(value):
|
||||
value = timezone.make_aware(value, self.connection.timezone)
|
||||
return value
|
||||
|
||||
def convert_datefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.date):
|
||||
value = parse_date(value)
|
||||
return value
|
||||
|
||||
def convert_timefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.time):
|
||||
value = parse_time(value)
|
||||
return value
|
||||
|
||||
def get_decimalfield_converter(self, expression):
|
||||
# SQLite stores only 15 significant digits. Digits coming from
|
||||
# float inaccuracy must be removed.
|
||||
create_decimal = decimal.Context(prec=15).create_decimal_from_float
|
||||
if isinstance(expression, Col):
|
||||
quantize_value = decimal.Decimal(1).scaleb(-expression.output_field.decimal_places)
|
||||
|
||||
def converter(value, expression, connection):
|
||||
if value is not None:
|
||||
return create_decimal(value).quantize(quantize_value, context=expression.output_field.context)
|
||||
else:
|
||||
def converter(value, expression, connection):
|
||||
if value is not None:
|
||||
return create_decimal(value)
|
||||
return converter
|
||||
|
||||
def convert_uuidfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
return bool(value) if value in (1, 0) else value
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
return " UNION ALL ".join(
|
||||
"SELECT %s" % ", ".join(row)
|
||||
for row in placeholder_rows
|
||||
)
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
# SQLite doesn't have a ^ operator, so use the user-defined POWER
|
||||
# function that's registered in connect().
|
||||
if connector == '^':
|
||||
return 'POWER(%s)' % ','.join(sub_expressions)
|
||||
elif connector == '#':
|
||||
return 'BITXOR(%s)' % ','.join(sub_expressions)
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
|
||||
def combine_duration_expression(self, connector, sub_expressions):
|
||||
if connector not in ['+', '-', '*', '/']:
|
||||
raise DatabaseError('Invalid connector for timedelta: %s.' % connector)
|
||||
fn_params = ["'%s'" % connector] + sub_expressions
|
||||
if len(fn_params) > 3:
|
||||
raise ValueError('Too many params for timedelta operations.')
|
||||
return "django_format_dtdelta(%s)" % ', '.join(fn_params)
|
||||
|
||||
def integer_field_range(self, internal_type):
|
||||
# SQLite doesn't enforce any integer constraints
|
||||
return (None, None)
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
params = (*lhs_params, *rhs_params)
|
||||
if internal_type == 'TimeField':
|
||||
return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
|
||||
return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params
|
||||
|
||||
def insert_statement(self, ignore_conflicts=False):
|
||||
return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
|
||||
if not fields:
|
||||
return '', ()
|
||||
columns = [
|
||||
'%s.%s' % (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
) for field in fields
|
||||
]
|
||||
return 'RETURNING %s' % ', '.join(columns), ()
|
||||
444
venv/Lib/site-packages/django/db/backends/sqlite3/schema.py
Normal file
444
venv/Lib/site-packages/django/db/backends/sqlite3/schema.py
Normal file
@@ -0,0 +1,444 @@
|
||||
import copy
|
||||
from decimal import Decimal
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db import NotSupportedError
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.backends.ddl_references import Statement
|
||||
from django.db.backends.utils import strip_quotes
|
||||
from django.db.models import UniqueConstraint
|
||||
from django.db.transaction import atomic
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
|
||||
sql_delete_table = "DROP TABLE %(table)s"
|
||||
sql_create_fk = None
|
||||
sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
|
||||
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
|
||||
sql_delete_unique = "DROP INDEX %(name)s"
|
||||
|
||||
def __enter__(self):
|
||||
# Some SQLite schema alterations need foreign key constraints to be
|
||||
# disabled. Enforce it here for the duration of the schema edition.
|
||||
if not self.connection.disable_constraint_checking():
|
||||
raise NotSupportedError(
|
||||
'SQLite schema editor cannot be used while foreign key '
|
||||
'constraint checks are enabled. Make sure to disable them '
|
||||
'before entering a transaction.atomic() context because '
|
||||
'SQLite does not support disabling them in the middle of '
|
||||
'a multi-statement transaction.'
|
||||
)
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.connection.check_constraints()
|
||||
super().__exit__(exc_type, exc_value, traceback)
|
||||
self.connection.enable_constraint_checking()
|
||||
|
||||
def quote_value(self, value):
|
||||
# The backend "mostly works" without this function and there are use
|
||||
# cases for compiling Python without the sqlite3 libraries (e.g.
|
||||
# security hardening).
|
||||
try:
|
||||
import sqlite3
|
||||
value = sqlite3.adapt(value)
|
||||
except ImportError:
|
||||
pass
|
||||
except sqlite3.ProgrammingError:
|
||||
pass
|
||||
# Manual emulation of SQLite parameter quoting
|
||||
if isinstance(value, bool):
|
||||
return str(int(value))
|
||||
elif isinstance(value, (Decimal, float, int)):
|
||||
return str(value)
|
||||
elif isinstance(value, str):
|
||||
return "'%s'" % value.replace("\'", "\'\'")
|
||||
elif value is None:
|
||||
return "NULL"
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
# Bytes are only allowed for BLOB fields, encoded as string
|
||||
# literals containing hexadecimal data and preceded by a single "X"
|
||||
# character.
|
||||
return "X'%s'" % value.hex()
|
||||
else:
|
||||
raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value)))
|
||||
|
||||
def _is_referenced_by_fk_constraint(self, table_name, column_name=None, ignore_self=False):
|
||||
"""
|
||||
Return whether or not the provided table name is referenced by another
|
||||
one. If `column_name` is specified, only references pointing to that
|
||||
column are considered. If `ignore_self` is True, self-referential
|
||||
constraints are ignored.
|
||||
"""
|
||||
with self.connection.cursor() as cursor:
|
||||
for other_table in self.connection.introspection.get_table_list(cursor):
|
||||
if ignore_self and other_table.name == table_name:
|
||||
continue
|
||||
constraints = self.connection.introspection._get_foreign_key_constraints(cursor, other_table.name)
|
||||
for constraint in constraints.values():
|
||||
constraint_table, constraint_column = constraint['foreign_key']
|
||||
if (constraint_table == table_name and
|
||||
(column_name is None or constraint_column == column_name)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True):
|
||||
if (not self.connection.features.supports_atomic_references_rename and
|
||||
disable_constraints and self._is_referenced_by_fk_constraint(old_db_table)):
|
||||
if self.connection.in_atomic_block:
|
||||
raise NotSupportedError((
|
||||
'Renaming the %r table while in a transaction is not '
|
||||
'supported on SQLite < 3.26 because it would break referential '
|
||||
'integrity. Try adding `atomic = False` to the Migration class.'
|
||||
) % old_db_table)
|
||||
self.connection.enable_constraint_checking()
|
||||
super().alter_db_table(model, old_db_table, new_db_table)
|
||||
self.connection.disable_constraint_checking()
|
||||
else:
|
||||
super().alter_db_table(model, old_db_table, new_db_table)
|
||||
|
||||
def alter_field(self, model, old_field, new_field, strict=False):
|
||||
if not self._field_should_be_altered(old_field, new_field):
|
||||
return
|
||||
old_field_name = old_field.name
|
||||
table_name = model._meta.db_table
|
||||
_, old_column_name = old_field.get_attname_column()
|
||||
if (new_field.name != old_field_name and
|
||||
not self.connection.features.supports_atomic_references_rename and
|
||||
self._is_referenced_by_fk_constraint(table_name, old_column_name, ignore_self=True)):
|
||||
if self.connection.in_atomic_block:
|
||||
raise NotSupportedError((
|
||||
'Renaming the %r.%r column while in a transaction is not '
|
||||
'supported on SQLite < 3.26 because it would break referential '
|
||||
'integrity. Try adding `atomic = False` to the Migration class.'
|
||||
) % (model._meta.db_table, old_field_name))
|
||||
with atomic(self.connection.alias):
|
||||
super().alter_field(model, old_field, new_field, strict=strict)
|
||||
# Follow SQLite's documented procedure for performing changes
|
||||
# that don't affect the on-disk content.
|
||||
# https://sqlite.org/lang_altertable.html#otheralter
|
||||
with self.connection.cursor() as cursor:
|
||||
schema_version = cursor.execute('PRAGMA schema_version').fetchone()[0]
|
||||
cursor.execute('PRAGMA writable_schema = 1')
|
||||
references_template = ' REFERENCES "%s" ("%%s") ' % table_name
|
||||
new_column_name = new_field.get_attname_column()[1]
|
||||
search = references_template % old_column_name
|
||||
replacement = references_template % new_column_name
|
||||
cursor.execute('UPDATE sqlite_master SET sql = replace(sql, %s, %s)', (search, replacement))
|
||||
cursor.execute('PRAGMA schema_version = %d' % (schema_version + 1))
|
||||
cursor.execute('PRAGMA writable_schema = 0')
|
||||
# The integrity check will raise an exception and rollback
|
||||
# the transaction if the sqlite_master updates corrupt the
|
||||
# database.
|
||||
cursor.execute('PRAGMA integrity_check')
|
||||
# Perform a VACUUM to refresh the database representation from
|
||||
# the sqlite_master table.
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute('VACUUM')
|
||||
else:
|
||||
super().alter_field(model, old_field, new_field, strict=strict)
|
||||
|
||||
def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None):
|
||||
"""
|
||||
Shortcut to transform a model from old_model into new_model
|
||||
|
||||
This follows the correct procedure to perform non-rename or column
|
||||
addition operations based on SQLite's documentation
|
||||
|
||||
https://www.sqlite.org/lang_altertable.html#caution
|
||||
|
||||
The essential steps are:
|
||||
1. Create a table with the updated definition called "new__app_model"
|
||||
2. Copy the data from the existing "app_model" table to the new table
|
||||
3. Drop the "app_model" table
|
||||
4. Rename the "new__app_model" table to "app_model"
|
||||
5. Restore any index of the previous "app_model" table.
|
||||
"""
|
||||
# Self-referential fields must be recreated rather than copied from
|
||||
# the old model to ensure their remote_field.field_name doesn't refer
|
||||
# to an altered field.
|
||||
def is_self_referential(f):
|
||||
return f.is_relation and f.remote_field.model is model
|
||||
# Work out the new fields dict / mapping
|
||||
body = {
|
||||
f.name: f.clone() if is_self_referential(f) else f
|
||||
for f in model._meta.local_concrete_fields
|
||||
}
|
||||
# Since mapping might mix column names and default values,
|
||||
# its values must be already quoted.
|
||||
mapping = {f.column: self.quote_name(f.column) for f in model._meta.local_concrete_fields}
|
||||
# This maps field names (not columns) for things like unique_together
|
||||
rename_mapping = {}
|
||||
# If any of the new or altered fields is introducing a new PK,
|
||||
# remove the old one
|
||||
restore_pk_field = None
|
||||
if getattr(create_field, 'primary_key', False) or (
|
||||
alter_field and getattr(alter_field[1], 'primary_key', False)):
|
||||
for name, field in list(body.items()):
|
||||
if field.primary_key:
|
||||
field.primary_key = False
|
||||
restore_pk_field = field
|
||||
if field.auto_created:
|
||||
del body[name]
|
||||
del mapping[field.column]
|
||||
# Add in any created fields
|
||||
if create_field:
|
||||
body[create_field.name] = create_field
|
||||
# Choose a default and insert it into the copy map
|
||||
if not create_field.many_to_many and create_field.concrete:
|
||||
mapping[create_field.column] = self.quote_value(
|
||||
self.effective_default(create_field)
|
||||
)
|
||||
# Add in any altered fields
|
||||
if alter_field:
|
||||
old_field, new_field = alter_field
|
||||
body.pop(old_field.name, None)
|
||||
mapping.pop(old_field.column, None)
|
||||
body[new_field.name] = new_field
|
||||
if old_field.null and not new_field.null:
|
||||
case_sql = "coalesce(%(col)s, %(default)s)" % {
|
||||
'col': self.quote_name(old_field.column),
|
||||
'default': self.quote_value(self.effective_default(new_field))
|
||||
}
|
||||
mapping[new_field.column] = case_sql
|
||||
else:
|
||||
mapping[new_field.column] = self.quote_name(old_field.column)
|
||||
rename_mapping[old_field.name] = new_field.name
|
||||
# Remove any deleted fields
|
||||
if delete_field:
|
||||
del body[delete_field.name]
|
||||
del mapping[delete_field.column]
|
||||
# Remove any implicit M2M tables
|
||||
if delete_field.many_to_many and delete_field.remote_field.through._meta.auto_created:
|
||||
return self.delete_model(delete_field.remote_field.through)
|
||||
# Work inside a new app registry
|
||||
apps = Apps()
|
||||
|
||||
# Work out the new value of unique_together, taking renames into
|
||||
# account
|
||||
unique_together = [
|
||||
[rename_mapping.get(n, n) for n in unique]
|
||||
for unique in model._meta.unique_together
|
||||
]
|
||||
|
||||
# Work out the new value for index_together, taking renames into
|
||||
# account
|
||||
index_together = [
|
||||
[rename_mapping.get(n, n) for n in index]
|
||||
for index in model._meta.index_together
|
||||
]
|
||||
|
||||
indexes = model._meta.indexes
|
||||
if delete_field:
|
||||
indexes = [
|
||||
index for index in indexes
|
||||
if delete_field.name not in index.fields
|
||||
]
|
||||
|
||||
constraints = list(model._meta.constraints)
|
||||
|
||||
# Provide isolated instances of the fields to the new model body so
|
||||
# that the existing model's internals aren't interfered with when
|
||||
# the dummy model is constructed.
|
||||
body_copy = copy.deepcopy(body)
|
||||
|
||||
# Construct a new model with the new fields to allow self referential
|
||||
# primary key to resolve to. This model won't ever be materialized as a
|
||||
# table and solely exists for foreign key reference resolution purposes.
|
||||
# This wouldn't be required if the schema editor was operating on model
|
||||
# states instead of rendered models.
|
||||
meta_contents = {
|
||||
'app_label': model._meta.app_label,
|
||||
'db_table': model._meta.db_table,
|
||||
'unique_together': unique_together,
|
||||
'index_together': index_together,
|
||||
'indexes': indexes,
|
||||
'constraints': constraints,
|
||||
'apps': apps,
|
||||
}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
body_copy['Meta'] = meta
|
||||
body_copy['__module__'] = model.__module__
|
||||
type(model._meta.object_name, model.__bases__, body_copy)
|
||||
|
||||
# Construct a model with a renamed table name.
|
||||
body_copy = copy.deepcopy(body)
|
||||
meta_contents = {
|
||||
'app_label': model._meta.app_label,
|
||||
'db_table': 'new__%s' % strip_quotes(model._meta.db_table),
|
||||
'unique_together': unique_together,
|
||||
'index_together': index_together,
|
||||
'indexes': indexes,
|
||||
'constraints': constraints,
|
||||
'apps': apps,
|
||||
}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
body_copy['Meta'] = meta
|
||||
body_copy['__module__'] = model.__module__
|
||||
new_model = type('New%s' % model._meta.object_name, model.__bases__, body_copy)
|
||||
|
||||
# Create a new table with the updated schema.
|
||||
self.create_model(new_model)
|
||||
|
||||
# Copy data from the old table into the new table
|
||||
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
|
||||
self.quote_name(new_model._meta.db_table),
|
||||
', '.join(self.quote_name(x) for x in mapping),
|
||||
', '.join(mapping.values()),
|
||||
self.quote_name(model._meta.db_table),
|
||||
))
|
||||
|
||||
# Delete the old table to make way for the new
|
||||
self.delete_model(model, handle_autom2m=False)
|
||||
|
||||
# Rename the new table to take way for the old
|
||||
self.alter_db_table(
|
||||
new_model, new_model._meta.db_table, model._meta.db_table,
|
||||
disable_constraints=False,
|
||||
)
|
||||
|
||||
# Run deferred SQL on correct table
|
||||
for sql in self.deferred_sql:
|
||||
self.execute(sql)
|
||||
self.deferred_sql = []
|
||||
# Fix any PK-removed field
|
||||
if restore_pk_field:
|
||||
restore_pk_field.primary_key = True
|
||||
|
||||
def delete_model(self, model, handle_autom2m=True):
|
||||
if handle_autom2m:
|
||||
super().delete_model(model)
|
||||
else:
|
||||
# Delete the table (and only that)
|
||||
self.execute(self.sql_delete_table % {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
})
|
||||
# Remove all deferred statements referencing the deleted table.
|
||||
for sql in list(self.deferred_sql):
|
||||
if isinstance(sql, Statement) and sql.references_table(model._meta.db_table):
|
||||
self.deferred_sql.remove(sql)
|
||||
|
||||
def add_field(self, model, field):
|
||||
"""
|
||||
Create a field on a model. Usually involves adding a column, but may
|
||||
involve adding a table instead (for M2M fields).
|
||||
"""
|
||||
# Special-case implicit M2M tables
|
||||
if field.many_to_many and field.remote_field.through._meta.auto_created:
|
||||
return self.create_model(field.remote_field.through)
|
||||
self._remake_table(model, create_field=field)
|
||||
|
||||
def remove_field(self, model, field):
|
||||
"""
|
||||
Remove a field from a model. Usually involves deleting a column,
|
||||
but for M2Ms may involve deleting a table.
|
||||
"""
|
||||
# M2M fields are a special case
|
||||
if field.many_to_many:
|
||||
# For implicit M2M tables, delete the auto-created table
|
||||
if field.remote_field.through._meta.auto_created:
|
||||
self.delete_model(field.remote_field.through)
|
||||
# For explicit "through" M2M fields, do nothing
|
||||
# For everything else, remake.
|
||||
else:
|
||||
# It might not actually have a column behind it
|
||||
if field.db_parameters(connection=self.connection)['type'] is None:
|
||||
return
|
||||
self._remake_table(model, delete_field=field)
|
||||
|
||||
def _alter_field(self, model, old_field, new_field, old_type, new_type,
|
||||
old_db_params, new_db_params, strict=False):
|
||||
"""Perform a "physical" (non-ManyToMany) field update."""
|
||||
# Use "ALTER TABLE ... RENAME COLUMN" if only the column name
|
||||
# changed and there aren't any constraints.
|
||||
if (self.connection.features.can_alter_table_rename_column and
|
||||
old_field.column != new_field.column and
|
||||
self.column_sql(model, old_field) == self.column_sql(model, new_field) and
|
||||
not (old_field.remote_field and old_field.db_constraint or
|
||||
new_field.remote_field and new_field.db_constraint)):
|
||||
return self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
|
||||
# Alter by remaking table
|
||||
self._remake_table(model, alter_field=(old_field, new_field))
|
||||
# Rebuild tables with FKs pointing to this field.
|
||||
if new_field.unique and old_type != new_type:
|
||||
related_models = set()
|
||||
opts = new_field.model._meta
|
||||
for remote_field in opts.related_objects:
|
||||
# Ignore self-relationship since the table was already rebuilt.
|
||||
if remote_field.related_model == model:
|
||||
continue
|
||||
if not remote_field.many_to_many:
|
||||
if remote_field.field_name == new_field.name:
|
||||
related_models.add(remote_field.related_model)
|
||||
elif new_field.primary_key and remote_field.through._meta.auto_created:
|
||||
related_models.add(remote_field.through)
|
||||
if new_field.primary_key:
|
||||
for many_to_many in opts.many_to_many:
|
||||
# Ignore self-relationship since the table was already rebuilt.
|
||||
if many_to_many.related_model == model:
|
||||
continue
|
||||
if many_to_many.remote_field.through._meta.auto_created:
|
||||
related_models.add(many_to_many.remote_field.through)
|
||||
for related_model in related_models:
|
||||
self._remake_table(related_model)
|
||||
|
||||
def _alter_many_to_many(self, model, old_field, new_field, strict):
|
||||
"""Alter M2Ms to repoint their to= endpoints."""
|
||||
if old_field.remote_field.through._meta.db_table == new_field.remote_field.through._meta.db_table:
|
||||
# The field name didn't change, but some options did; we have to propagate this altering.
|
||||
self._remake_table(
|
||||
old_field.remote_field.through,
|
||||
alter_field=(
|
||||
# We need the field that points to the target model, so we can tell alter_field to change it -
|
||||
# this is m2m_reverse_field_name() (as opposed to m2m_field_name, which points to our model)
|
||||
old_field.remote_field.through._meta.get_field(old_field.m2m_reverse_field_name()),
|
||||
new_field.remote_field.through._meta.get_field(new_field.m2m_reverse_field_name()),
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
# Make a new through table
|
||||
self.create_model(new_field.remote_field.through)
|
||||
# Copy the data across
|
||||
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
|
||||
self.quote_name(new_field.remote_field.through._meta.db_table),
|
||||
', '.join([
|
||||
"id",
|
||||
new_field.m2m_column_name(),
|
||||
new_field.m2m_reverse_name(),
|
||||
]),
|
||||
', '.join([
|
||||
"id",
|
||||
old_field.m2m_column_name(),
|
||||
old_field.m2m_reverse_name(),
|
||||
]),
|
||||
self.quote_name(old_field.remote_field.through._meta.db_table),
|
||||
))
|
||||
# Delete the old through table
|
||||
self.delete_model(old_field.remote_field.through)
|
||||
|
||||
def add_constraint(self, model, constraint):
|
||||
if isinstance(constraint, UniqueConstraint) and (
|
||||
constraint.condition or
|
||||
constraint.contains_expressions or
|
||||
constraint.include or
|
||||
constraint.deferrable
|
||||
):
|
||||
super().add_constraint(model, constraint)
|
||||
else:
|
||||
self._remake_table(model)
|
||||
|
||||
def remove_constraint(self, model, constraint):
|
||||
if isinstance(constraint, UniqueConstraint) and (
|
||||
constraint.condition or
|
||||
constraint.contains_expressions or
|
||||
constraint.include or
|
||||
constraint.deferrable
|
||||
):
|
||||
super().remove_constraint(model, constraint)
|
||||
else:
|
||||
self._remake_table(model)
|
||||
|
||||
def _collate_sql(self, collation):
|
||||
return 'COLLATE ' + collation
|
||||
263
venv/Lib/site-packages/django/db/backends/utils.py
Normal file
263
venv/Lib/site-packages/django/db/backends/utils.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.db import NotSupportedError
|
||||
from django.utils.dateparse import parse_time
|
||||
|
||||
logger = logging.getLogger('django.db.backends')
|
||||
|
||||
|
||||
class CursorWrapper:
|
||||
def __init__(self, cursor, db):
|
||||
self.cursor = cursor
|
||||
self.db = db
|
||||
|
||||
WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
|
||||
|
||||
def __getattr__(self, attr):
|
||||
cursor_attr = getattr(self.cursor, attr)
|
||||
if attr in CursorWrapper.WRAP_ERROR_ATTRS:
|
||||
return self.db.wrap_database_errors(cursor_attr)
|
||||
else:
|
||||
return cursor_attr
|
||||
|
||||
def __iter__(self):
|
||||
with self.db.wrap_database_errors:
|
||||
yield from self.cursor
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
# Close instead of passing through to avoid backend-specific behavior
|
||||
# (#17671). Catch errors liberally because errors in cleanup code
|
||||
# aren't useful.
|
||||
try:
|
||||
self.close()
|
||||
except self.db.Database.Error:
|
||||
pass
|
||||
|
||||
# The following methods cannot be implemented in __getattr__, because the
|
||||
# code must run when the method is invoked, not just when it is accessed.
|
||||
|
||||
def callproc(self, procname, params=None, kparams=None):
|
||||
# Keyword parameters for callproc aren't supported in PEP 249, but the
|
||||
# database driver may support them (e.g. cx_Oracle).
|
||||
if kparams is not None and not self.db.features.supports_callproc_kwargs:
|
||||
raise NotSupportedError(
|
||||
'Keyword parameters for callproc are not supported on this '
|
||||
'database backend.'
|
||||
)
|
||||
self.db.validate_no_broken_transaction()
|
||||
with self.db.wrap_database_errors:
|
||||
if params is None and kparams is None:
|
||||
return self.cursor.callproc(procname)
|
||||
elif kparams is None:
|
||||
return self.cursor.callproc(procname, params)
|
||||
else:
|
||||
params = params or ()
|
||||
return self.cursor.callproc(procname, params, kparams)
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
return self._execute_with_wrappers(sql, params, many=False, executor=self._execute)
|
||||
|
||||
def executemany(self, sql, param_list):
|
||||
return self._execute_with_wrappers(sql, param_list, many=True, executor=self._executemany)
|
||||
|
||||
def _execute_with_wrappers(self, sql, params, many, executor):
|
||||
context = {'connection': self.db, 'cursor': self}
|
||||
for wrapper in reversed(self.db.execute_wrappers):
|
||||
executor = functools.partial(wrapper, executor)
|
||||
return executor(sql, params, many, context)
|
||||
|
||||
def _execute(self, sql, params, *ignored_wrapper_args):
|
||||
self.db.validate_no_broken_transaction()
|
||||
with self.db.wrap_database_errors:
|
||||
if params is None:
|
||||
# params default might be backend specific.
|
||||
return self.cursor.execute(sql)
|
||||
else:
|
||||
return self.cursor.execute(sql, params)
|
||||
|
||||
def _executemany(self, sql, param_list, *ignored_wrapper_args):
|
||||
self.db.validate_no_broken_transaction()
|
||||
with self.db.wrap_database_errors:
|
||||
return self.cursor.executemany(sql, param_list)
|
||||
|
||||
|
||||
class CursorDebugWrapper(CursorWrapper):
|
||||
|
||||
# XXX callproc isn't instrumented at this time.
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
with self.debug_sql(sql, params, use_last_executed_query=True):
|
||||
return super().execute(sql, params)
|
||||
|
||||
def executemany(self, sql, param_list):
|
||||
with self.debug_sql(sql, param_list, many=True):
|
||||
return super().executemany(sql, param_list)
|
||||
|
||||
@contextmanager
|
||||
def debug_sql(self, sql=None, params=None, use_last_executed_query=False, many=False):
|
||||
start = time.monotonic()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
stop = time.monotonic()
|
||||
duration = stop - start
|
||||
if use_last_executed_query:
|
||||
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
|
||||
try:
|
||||
times = len(params) if many else ''
|
||||
except TypeError:
|
||||
# params could be an iterator.
|
||||
times = '?'
|
||||
self.db.queries_log.append({
|
||||
'sql': '%s times: %s' % (times, sql) if many else sql,
|
||||
'time': '%.3f' % duration,
|
||||
})
|
||||
logger.debug(
|
||||
'(%.3f) %s; args=%s; alias=%s',
|
||||
duration,
|
||||
sql,
|
||||
params,
|
||||
self.db.alias,
|
||||
extra={'duration': duration, 'sql': sql, 'params': params, 'alias': self.db.alias},
|
||||
)
|
||||
|
||||
|
||||
def split_tzname_delta(tzname):
|
||||
"""
|
||||
Split a time zone name into a 3-tuple of (name, sign, offset).
|
||||
"""
|
||||
for sign in ['+', '-']:
|
||||
if sign in tzname:
|
||||
name, offset = tzname.rsplit(sign, 1)
|
||||
if offset and parse_time(offset):
|
||||
return name, sign, offset
|
||||
return tzname, None, None
|
||||
|
||||
|
||||
###############################################
|
||||
# Converters from database (string) to Python #
|
||||
###############################################
|
||||
|
||||
def typecast_date(s):
|
||||
return datetime.date(*map(int, s.split('-'))) if s else None # return None if s is null
|
||||
|
||||
|
||||
def typecast_time(s): # does NOT store time zone information
|
||||
if not s:
|
||||
return None
|
||||
hour, minutes, seconds = s.split(':')
|
||||
if '.' in seconds: # check whether seconds have a fractional part
|
||||
seconds, microseconds = seconds.split('.')
|
||||
else:
|
||||
microseconds = '0'
|
||||
return datetime.time(int(hour), int(minutes), int(seconds), int((microseconds + '000000')[:6]))
|
||||
|
||||
|
||||
def typecast_timestamp(s): # does NOT store time zone information
|
||||
# "2005-07-29 15:48:00.590358-05"
|
||||
# "2005-07-29 09:56:00-05"
|
||||
if not s:
|
||||
return None
|
||||
if ' ' not in s:
|
||||
return typecast_date(s)
|
||||
d, t = s.split()
|
||||
# Remove timezone information.
|
||||
if '-' in t:
|
||||
t, _ = t.split('-', 1)
|
||||
elif '+' in t:
|
||||
t, _ = t.split('+', 1)
|
||||
dates = d.split('-')
|
||||
times = t.split(':')
|
||||
seconds = times[2]
|
||||
if '.' in seconds: # check whether seconds have a fractional part
|
||||
seconds, microseconds = seconds.split('.')
|
||||
else:
|
||||
microseconds = '0'
|
||||
return datetime.datetime(
|
||||
int(dates[0]), int(dates[1]), int(dates[2]),
|
||||
int(times[0]), int(times[1]), int(seconds),
|
||||
int((microseconds + '000000')[:6])
|
||||
)
|
||||
|
||||
|
||||
###############################################
|
||||
# Converters from Python to database (string) #
|
||||
###############################################
|
||||
|
||||
def split_identifier(identifier):
|
||||
"""
|
||||
Split an SQL identifier into a two element tuple of (namespace, name).
|
||||
|
||||
The identifier could be a table, column, or sequence name might be prefixed
|
||||
by a namespace.
|
||||
"""
|
||||
try:
|
||||
namespace, name = identifier.split('"."')
|
||||
except ValueError:
|
||||
namespace, name = '', identifier
|
||||
return namespace.strip('"'), name.strip('"')
|
||||
|
||||
|
||||
def truncate_name(identifier, length=None, hash_len=4):
|
||||
"""
|
||||
Shorten an SQL identifier to a repeatable mangled version with the given
|
||||
length.
|
||||
|
||||
If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
|
||||
truncate the table portion only.
|
||||
"""
|
||||
namespace, name = split_identifier(identifier)
|
||||
|
||||
if length is None or len(name) <= length:
|
||||
return identifier
|
||||
|
||||
digest = names_digest(name, length=hash_len)
|
||||
return '%s%s%s' % ('%s"."' % namespace if namespace else '', name[:length - hash_len], digest)
|
||||
|
||||
|
||||
def names_digest(*args, length):
|
||||
"""
|
||||
Generate a 32-bit digest of a set of arguments that can be used to shorten
|
||||
identifying names.
|
||||
"""
|
||||
h = hashlib.md5()
|
||||
for arg in args:
|
||||
h.update(arg.encode())
|
||||
return h.hexdigest()[:length]
|
||||
|
||||
|
||||
def format_number(value, max_digits, decimal_places):
|
||||
"""
|
||||
Format a number into a string with the requisite number of digits and
|
||||
decimal places.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
context = decimal.getcontext().copy()
|
||||
if max_digits is not None:
|
||||
context.prec = max_digits
|
||||
if decimal_places is not None:
|
||||
value = value.quantize(decimal.Decimal(1).scaleb(-decimal_places), context=context)
|
||||
else:
|
||||
context.traps[decimal.Rounded] = 1
|
||||
value = context.create_decimal(value)
|
||||
return "{:f}".format(value)
|
||||
|
||||
|
||||
def strip_quotes(table_name):
|
||||
"""
|
||||
Strip quotes off of quoted table names to make them safe for use in index
|
||||
names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
|
||||
scheme) becomes 'USER"."TABLE'.
|
||||
"""
|
||||
has_quotes = table_name.startswith('"') and table_name.endswith('"')
|
||||
return table_name[1:-1] if has_quotes else table_name
|
||||
2
venv/Lib/site-packages/django/db/migrations/__init__.py
Normal file
2
venv/Lib/site-packages/django/db/migrations/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .migration import Migration, swappable_dependency # NOQA
|
||||
from .operations import * # NOQA
|
||||
1370
venv/Lib/site-packages/django/db/migrations/autodetector.py
Normal file
1370
venv/Lib/site-packages/django/db/migrations/autodetector.py
Normal file
File diff suppressed because it is too large
Load Diff
54
venv/Lib/site-packages/django/db/migrations/exceptions.py
Normal file
54
venv/Lib/site-packages/django/db/migrations/exceptions.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from django.db import DatabaseError
|
||||
|
||||
|
||||
class AmbiguityError(Exception):
|
||||
"""More than one migration matches a name prefix."""
|
||||
pass
|
||||
|
||||
|
||||
class BadMigrationError(Exception):
|
||||
"""There's a bad migration (unreadable/bad format/etc.)."""
|
||||
pass
|
||||
|
||||
|
||||
class CircularDependencyError(Exception):
|
||||
"""There's an impossible-to-resolve circular dependency."""
|
||||
pass
|
||||
|
||||
|
||||
class InconsistentMigrationHistory(Exception):
|
||||
"""An applied migration has some of its dependencies not applied."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidBasesError(ValueError):
|
||||
"""A model's base classes can't be resolved."""
|
||||
pass
|
||||
|
||||
|
||||
class IrreversibleError(RuntimeError):
|
||||
"""An irreversible migration is about to be reversed."""
|
||||
pass
|
||||
|
||||
|
||||
class NodeNotFoundError(LookupError):
|
||||
"""An attempt on a node is made that is not available in the graph."""
|
||||
|
||||
def __init__(self, message, node, origin=None):
|
||||
self.message = message
|
||||
self.origin = origin
|
||||
self.node = node
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
def __repr__(self):
|
||||
return "NodeNotFoundError(%r)" % (self.node,)
|
||||
|
||||
|
||||
class MigrationSchemaMissing(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidMigrationPlan(ValueError):
|
||||
pass
|
||||
381
venv/Lib/site-packages/django/db/migrations/executor.py
Normal file
381
venv/Lib/site-packages/django/db/migrations/executor.py
Normal file
@@ -0,0 +1,381 @@
|
||||
from django.apps.registry import apps as global_apps
|
||||
from django.db import migrations, router
|
||||
|
||||
from .exceptions import InvalidMigrationPlan
|
||||
from .loader import MigrationLoader
|
||||
from .recorder import MigrationRecorder
|
||||
from .state import ProjectState
|
||||
|
||||
|
||||
class MigrationExecutor:
|
||||
"""
|
||||
End-to-end migration execution - load migrations and run them up or down
|
||||
to a specified set of targets.
|
||||
"""
|
||||
|
||||
def __init__(self, connection, progress_callback=None):
|
||||
self.connection = connection
|
||||
self.loader = MigrationLoader(self.connection)
|
||||
self.recorder = MigrationRecorder(self.connection)
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
def migration_plan(self, targets, clean_start=False):
|
||||
"""
|
||||
Given a set of targets, return a list of (Migration instance, backwards?).
|
||||
"""
|
||||
plan = []
|
||||
if clean_start:
|
||||
applied = {}
|
||||
else:
|
||||
applied = dict(self.loader.applied_migrations)
|
||||
for target in targets:
|
||||
# If the target is (app_label, None), that means unmigrate everything
|
||||
if target[1] is None:
|
||||
for root in self.loader.graph.root_nodes():
|
||||
if root[0] == target[0]:
|
||||
for migration in self.loader.graph.backwards_plan(root):
|
||||
if migration in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], True))
|
||||
applied.pop(migration)
|
||||
# If the migration is already applied, do backwards mode,
|
||||
# otherwise do forwards mode.
|
||||
elif target in applied:
|
||||
# If the target is missing, it's likely a replaced migration.
|
||||
# Reload the graph without replacements.
|
||||
if (
|
||||
self.loader.replace_migrations and
|
||||
target not in self.loader.graph.node_map
|
||||
):
|
||||
self.loader.replace_migrations = False
|
||||
self.loader.build_graph()
|
||||
return self.migration_plan(targets, clean_start=clean_start)
|
||||
# Don't migrate backwards all the way to the target node (that
|
||||
# may roll back dependencies in other apps that don't need to
|
||||
# be rolled back); instead roll back through target's immediate
|
||||
# child(ren) in the same app, and no further.
|
||||
next_in_app = sorted(
|
||||
n for n in
|
||||
self.loader.graph.node_map[target].children
|
||||
if n[0] == target[0]
|
||||
)
|
||||
for node in next_in_app:
|
||||
for migration in self.loader.graph.backwards_plan(node):
|
||||
if migration in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], True))
|
||||
applied.pop(migration)
|
||||
else:
|
||||
for migration in self.loader.graph.forwards_plan(target):
|
||||
if migration not in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], False))
|
||||
applied[migration] = self.loader.graph.nodes[migration]
|
||||
return plan
|
||||
|
||||
def _create_project_state(self, with_applied_migrations=False):
|
||||
"""
|
||||
Create a project state including all the applications without
|
||||
migrations and applied migrations if with_applied_migrations=True.
|
||||
"""
|
||||
state = ProjectState(real_apps=self.loader.unmigrated_apps)
|
||||
if with_applied_migrations:
|
||||
# Create the forwards plan Django would follow on an empty database
|
||||
full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True)
|
||||
applied_migrations = {
|
||||
self.loader.graph.nodes[key] for key in self.loader.applied_migrations
|
||||
if key in self.loader.graph.nodes
|
||||
}
|
||||
for migration, _ in full_plan:
|
||||
if migration in applied_migrations:
|
||||
migration.mutate_state(state, preserve=False)
|
||||
return state
|
||||
|
||||
def migrate(self, targets, plan=None, state=None, fake=False, fake_initial=False):
|
||||
"""
|
||||
Migrate the database up to the given targets.
|
||||
|
||||
Django first needs to create all project states before a migration is
|
||||
(un)applied and in a second step run all the database operations.
|
||||
"""
|
||||
# The django_migrations table must be present to record applied
|
||||
# migrations.
|
||||
self.recorder.ensure_schema()
|
||||
|
||||
if plan is None:
|
||||
plan = self.migration_plan(targets)
|
||||
# Create the forwards plan Django would follow on an empty database
|
||||
full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True)
|
||||
|
||||
all_forwards = all(not backwards for mig, backwards in plan)
|
||||
all_backwards = all(backwards for mig, backwards in plan)
|
||||
|
||||
if not plan:
|
||||
if state is None:
|
||||
# The resulting state should include applied migrations.
|
||||
state = self._create_project_state(with_applied_migrations=True)
|
||||
elif all_forwards == all_backwards:
|
||||
# This should only happen if there's a mixed plan
|
||||
raise InvalidMigrationPlan(
|
||||
"Migration plans with both forwards and backwards migrations "
|
||||
"are not supported. Please split your migration process into "
|
||||
"separate plans of only forwards OR backwards migrations.",
|
||||
plan
|
||||
)
|
||||
elif all_forwards:
|
||||
if state is None:
|
||||
# The resulting state should still include applied migrations.
|
||||
state = self._create_project_state(with_applied_migrations=True)
|
||||
state = self._migrate_all_forwards(state, plan, full_plan, fake=fake, fake_initial=fake_initial)
|
||||
else:
|
||||
# No need to check for `elif all_backwards` here, as that condition
|
||||
# would always evaluate to true.
|
||||
state = self._migrate_all_backwards(plan, full_plan, fake=fake)
|
||||
|
||||
self.check_replacements()
|
||||
|
||||
return state
|
||||
|
||||
def _migrate_all_forwards(self, state, plan, full_plan, fake, fake_initial):
|
||||
"""
|
||||
Take a list of 2-tuples of the form (migration instance, False) and
|
||||
apply them in the order they occur in the full_plan.
|
||||
"""
|
||||
migrations_to_run = {m[0] for m in plan}
|
||||
for migration, _ in full_plan:
|
||||
if not migrations_to_run:
|
||||
# We remove every migration that we applied from these sets so
|
||||
# that we can bail out once the last migration has been applied
|
||||
# and don't always run until the very end of the migration
|
||||
# process.
|
||||
break
|
||||
if migration in migrations_to_run:
|
||||
if 'apps' not in state.__dict__:
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_start")
|
||||
state.apps # Render all -- performance critical
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_success")
|
||||
state = self.apply_migration(state, migration, fake=fake, fake_initial=fake_initial)
|
||||
migrations_to_run.remove(migration)
|
||||
|
||||
return state
|
||||
|
||||
def _migrate_all_backwards(self, plan, full_plan, fake):
|
||||
"""
|
||||
Take a list of 2-tuples of the form (migration instance, True) and
|
||||
unapply them in reverse order they occur in the full_plan.
|
||||
|
||||
Since unapplying a migration requires the project state prior to that
|
||||
migration, Django will compute the migration states before each of them
|
||||
in a first run over the plan and then unapply them in a second run over
|
||||
the plan.
|
||||
"""
|
||||
migrations_to_run = {m[0] for m in plan}
|
||||
# Holds all migration states prior to the migrations being unapplied
|
||||
states = {}
|
||||
state = self._create_project_state()
|
||||
applied_migrations = {
|
||||
self.loader.graph.nodes[key] for key in self.loader.applied_migrations
|
||||
if key in self.loader.graph.nodes
|
||||
}
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_start")
|
||||
for migration, _ in full_plan:
|
||||
if not migrations_to_run:
|
||||
# We remove every migration that we applied from this set so
|
||||
# that we can bail out once the last migration has been applied
|
||||
# and don't always run until the very end of the migration
|
||||
# process.
|
||||
break
|
||||
if migration in migrations_to_run:
|
||||
if 'apps' not in state.__dict__:
|
||||
state.apps # Render all -- performance critical
|
||||
# The state before this migration
|
||||
states[migration] = state
|
||||
# The old state keeps as-is, we continue with the new state
|
||||
state = migration.mutate_state(state, preserve=True)
|
||||
migrations_to_run.remove(migration)
|
||||
elif migration in applied_migrations:
|
||||
# Only mutate the state if the migration is actually applied
|
||||
# to make sure the resulting state doesn't include changes
|
||||
# from unrelated migrations.
|
||||
migration.mutate_state(state, preserve=False)
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_success")
|
||||
|
||||
for migration, _ in plan:
|
||||
self.unapply_migration(states[migration], migration, fake=fake)
|
||||
applied_migrations.remove(migration)
|
||||
|
||||
# Generate the post migration state by starting from the state before
|
||||
# the last migration is unapplied and mutating it to include all the
|
||||
# remaining applied migrations.
|
||||
last_unapplied_migration = plan[-1][0]
|
||||
state = states[last_unapplied_migration]
|
||||
for index, (migration, _) in enumerate(full_plan):
|
||||
if migration == last_unapplied_migration:
|
||||
for migration, _ in full_plan[index:]:
|
||||
if migration in applied_migrations:
|
||||
migration.mutate_state(state, preserve=False)
|
||||
break
|
||||
|
||||
return state
|
||||
|
||||
def apply_migration(self, state, migration, fake=False, fake_initial=False):
|
||||
"""Run a migration forwards."""
|
||||
migration_recorded = False
|
||||
if self.progress_callback:
|
||||
self.progress_callback("apply_start", migration, fake)
|
||||
if not fake:
|
||||
if fake_initial:
|
||||
# Test to see if this is an already-applied initial migration
|
||||
applied, state = self.detect_soft_applied(state, migration)
|
||||
if applied:
|
||||
fake = True
|
||||
if not fake:
|
||||
# Alright, do it normally
|
||||
with self.connection.schema_editor(atomic=migration.atomic) as schema_editor:
|
||||
state = migration.apply(state, schema_editor)
|
||||
if not schema_editor.deferred_sql:
|
||||
self.record_migration(migration)
|
||||
migration_recorded = True
|
||||
if not migration_recorded:
|
||||
self.record_migration(migration)
|
||||
# Report progress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("apply_success", migration, fake)
|
||||
return state
|
||||
|
||||
def record_migration(self, migration):
|
||||
# For replacement migrations, record individual statuses
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_applied(app_label, name)
|
||||
else:
|
||||
self.recorder.record_applied(migration.app_label, migration.name)
|
||||
|
||||
def unapply_migration(self, state, migration, fake=False):
|
||||
"""Run a migration backwards."""
|
||||
if self.progress_callback:
|
||||
self.progress_callback("unapply_start", migration, fake)
|
||||
if not fake:
|
||||
with self.connection.schema_editor(atomic=migration.atomic) as schema_editor:
|
||||
state = migration.unapply(state, schema_editor)
|
||||
# For replacement migrations, also record individual statuses.
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_unapplied(app_label, name)
|
||||
self.recorder.record_unapplied(migration.app_label, migration.name)
|
||||
# Report progress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("unapply_success", migration, fake)
|
||||
return state
|
||||
|
||||
def check_replacements(self):
|
||||
"""
|
||||
Mark replacement migrations applied if their replaced set all are.
|
||||
|
||||
Do this unconditionally on every migrate, rather than just when
|
||||
migrations are applied or unapplied, to correctly handle the case
|
||||
when a new squash migration is pushed to a deployment that already had
|
||||
all its replaced migrations applied. In this case no new migration will
|
||||
be applied, but the applied state of the squashed migration must be
|
||||
maintained.
|
||||
"""
|
||||
applied = self.recorder.applied_migrations()
|
||||
for key, migration in self.loader.replacements.items():
|
||||
all_applied = all(m in applied for m in migration.replaces)
|
||||
if all_applied and key not in applied:
|
||||
self.recorder.record_applied(*key)
|
||||
|
||||
def detect_soft_applied(self, project_state, migration):
|
||||
"""
|
||||
Test whether a migration has been implicitly applied - that the
|
||||
tables or columns it would create exist. This is intended only for use
|
||||
on initial migrations (as it only looks for CreateModel and AddField).
|
||||
"""
|
||||
def should_skip_detecting_model(migration, model):
|
||||
"""
|
||||
No need to detect tables for proxy models, unmanaged models, or
|
||||
models that can't be migrated on the current database.
|
||||
"""
|
||||
return (
|
||||
model._meta.proxy or not model._meta.managed or not
|
||||
router.allow_migrate(
|
||||
self.connection.alias, migration.app_label,
|
||||
model_name=model._meta.model_name,
|
||||
)
|
||||
)
|
||||
|
||||
if migration.initial is None:
|
||||
# Bail if the migration isn't the first one in its app
|
||||
if any(app == migration.app_label for app, name in migration.dependencies):
|
||||
return False, project_state
|
||||
elif migration.initial is False:
|
||||
# Bail if it's NOT an initial migration
|
||||
return False, project_state
|
||||
|
||||
if project_state is None:
|
||||
after_state = self.loader.project_state((migration.app_label, migration.name), at_end=True)
|
||||
else:
|
||||
after_state = migration.mutate_state(project_state)
|
||||
apps = after_state.apps
|
||||
found_create_model_migration = False
|
||||
found_add_field_migration = False
|
||||
fold_identifier_case = self.connection.features.ignores_table_name_case
|
||||
with self.connection.cursor() as cursor:
|
||||
existing_table_names = set(self.connection.introspection.table_names(cursor))
|
||||
if fold_identifier_case:
|
||||
existing_table_names = {name.casefold() for name in existing_table_names}
|
||||
# Make sure all create model and add field operations are done
|
||||
for operation in migration.operations:
|
||||
if isinstance(operation, migrations.CreateModel):
|
||||
model = apps.get_model(migration.app_label, operation.name)
|
||||
if model._meta.swapped:
|
||||
# We have to fetch the model to test with from the
|
||||
# main app cache, as it's not a direct dependency.
|
||||
model = global_apps.get_model(model._meta.swapped)
|
||||
if should_skip_detecting_model(migration, model):
|
||||
continue
|
||||
db_table = model._meta.db_table
|
||||
if fold_identifier_case:
|
||||
db_table = db_table.casefold()
|
||||
if db_table not in existing_table_names:
|
||||
return False, project_state
|
||||
found_create_model_migration = True
|
||||
elif isinstance(operation, migrations.AddField):
|
||||
model = apps.get_model(migration.app_label, operation.model_name)
|
||||
if model._meta.swapped:
|
||||
# We have to fetch the model to test with from the
|
||||
# main app cache, as it's not a direct dependency.
|
||||
model = global_apps.get_model(model._meta.swapped)
|
||||
if should_skip_detecting_model(migration, model):
|
||||
continue
|
||||
|
||||
table = model._meta.db_table
|
||||
field = model._meta.get_field(operation.name)
|
||||
|
||||
# Handle implicit many-to-many tables created by AddField.
|
||||
if field.many_to_many:
|
||||
through_db_table = field.remote_field.through._meta.db_table
|
||||
if fold_identifier_case:
|
||||
through_db_table = through_db_table.casefold()
|
||||
if through_db_table not in existing_table_names:
|
||||
return False, project_state
|
||||
else:
|
||||
found_add_field_migration = True
|
||||
continue
|
||||
with self.connection.cursor() as cursor:
|
||||
columns = self.connection.introspection.get_table_description(cursor, table)
|
||||
for column in columns:
|
||||
field_column = field.column
|
||||
column_name = column.name
|
||||
if fold_identifier_case:
|
||||
column_name = column_name.casefold()
|
||||
field_column = field_column.casefold()
|
||||
if column_name == field_column:
|
||||
found_add_field_migration = True
|
||||
break
|
||||
else:
|
||||
return False, project_state
|
||||
# If we get this far and we found at least one CreateModel or AddField migration,
|
||||
# the migration is considered implicitly applied.
|
||||
return (found_create_model_migration or found_add_field_migration), after_state
|
||||
319
venv/Lib/site-packages/django/db/migrations/graph.py
Normal file
319
venv/Lib/site-packages/django/db/migrations/graph.py
Normal file
@@ -0,0 +1,319 @@
|
||||
from functools import total_ordering
|
||||
|
||||
from django.db.migrations.state import ProjectState
|
||||
|
||||
from .exceptions import CircularDependencyError, NodeNotFoundError
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Node:
|
||||
"""
|
||||
A single node in the migration graph. Contains direct links to adjacent
|
||||
nodes in either direction.
|
||||
"""
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.children = set()
|
||||
self.parents = set()
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.key == other
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.key < other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.key)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.key[item]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.key)
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s: (%r, %r)>' % (self.__class__.__name__, self.key[0], self.key[1])
|
||||
|
||||
def add_child(self, child):
|
||||
self.children.add(child)
|
||||
|
||||
def add_parent(self, parent):
|
||||
self.parents.add(parent)
|
||||
|
||||
|
||||
class DummyNode(Node):
|
||||
"""
|
||||
A node that doesn't correspond to a migration file on disk.
|
||||
(A squashed migration that was removed, for example.)
|
||||
|
||||
After the migration graph is processed, all dummy nodes should be removed.
|
||||
If there are any left, a nonexistent dependency error is raised.
|
||||
"""
|
||||
def __init__(self, key, origin, error_message):
|
||||
super().__init__(key)
|
||||
self.origin = origin
|
||||
self.error_message = error_message
|
||||
|
||||
def raise_error(self):
|
||||
raise NodeNotFoundError(self.error_message, self.key, origin=self.origin)
|
||||
|
||||
|
||||
class MigrationGraph:
|
||||
"""
|
||||
Represent the digraph of all migrations in a project.
|
||||
|
||||
Each migration is a node, and each dependency is an edge. There are
|
||||
no implicit dependencies between numbered migrations - the numbering is
|
||||
merely a convention to aid file listing. Every new numbered migration
|
||||
has a declared dependency to the previous number, meaning that VCS
|
||||
branch merges can be detected and resolved.
|
||||
|
||||
Migrations files can be marked as replacing another set of migrations -
|
||||
this is to support the "squash" feature. The graph handler isn't responsible
|
||||
for these; instead, the code to load them in here should examine the
|
||||
migration files and if the replaced migrations are all either unapplied
|
||||
or not present, it should ignore the replaced ones, load in just the
|
||||
replacing migration, and repoint any dependencies that pointed to the
|
||||
replaced migrations to point to the replacing one.
|
||||
|
||||
A node should be a tuple: (app_path, migration_name). The tree special-cases
|
||||
things within an app - namely, root nodes and leaf nodes ignore dependencies
|
||||
to other apps.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.node_map = {}
|
||||
self.nodes = {}
|
||||
|
||||
def add_node(self, key, migration):
|
||||
assert key not in self.node_map
|
||||
node = Node(key)
|
||||
self.node_map[key] = node
|
||||
self.nodes[key] = migration
|
||||
|
||||
def add_dummy_node(self, key, origin, error_message):
|
||||
node = DummyNode(key, origin, error_message)
|
||||
self.node_map[key] = node
|
||||
self.nodes[key] = None
|
||||
|
||||
def add_dependency(self, migration, child, parent, skip_validation=False):
|
||||
"""
|
||||
This may create dummy nodes if they don't yet exist. If
|
||||
`skip_validation=True`, validate_consistency() should be called
|
||||
afterward.
|
||||
"""
|
||||
if child not in self.nodes:
|
||||
error_message = (
|
||||
"Migration %s dependencies reference nonexistent"
|
||||
" child node %r" % (migration, child)
|
||||
)
|
||||
self.add_dummy_node(child, migration, error_message)
|
||||
if parent not in self.nodes:
|
||||
error_message = (
|
||||
"Migration %s dependencies reference nonexistent"
|
||||
" parent node %r" % (migration, parent)
|
||||
)
|
||||
self.add_dummy_node(parent, migration, error_message)
|
||||
self.node_map[child].add_parent(self.node_map[parent])
|
||||
self.node_map[parent].add_child(self.node_map[child])
|
||||
if not skip_validation:
|
||||
self.validate_consistency()
|
||||
|
||||
def remove_replaced_nodes(self, replacement, replaced):
|
||||
"""
|
||||
Remove each of the `replaced` nodes (when they exist). Any
|
||||
dependencies that were referencing them are changed to reference the
|
||||
`replacement` node instead.
|
||||
"""
|
||||
# Cast list of replaced keys to set to speed up lookup later.
|
||||
replaced = set(replaced)
|
||||
try:
|
||||
replacement_node = self.node_map[replacement]
|
||||
except KeyError as err:
|
||||
raise NodeNotFoundError(
|
||||
"Unable to find replacement node %r. It was either never added"
|
||||
" to the migration graph, or has been removed." % (replacement,),
|
||||
replacement
|
||||
) from err
|
||||
for replaced_key in replaced:
|
||||
self.nodes.pop(replaced_key, None)
|
||||
replaced_node = self.node_map.pop(replaced_key, None)
|
||||
if replaced_node:
|
||||
for child in replaced_node.children:
|
||||
child.parents.remove(replaced_node)
|
||||
# We don't want to create dependencies between the replaced
|
||||
# node and the replacement node as this would lead to
|
||||
# self-referencing on the replacement node at a later iteration.
|
||||
if child.key not in replaced:
|
||||
replacement_node.add_child(child)
|
||||
child.add_parent(replacement_node)
|
||||
for parent in replaced_node.parents:
|
||||
parent.children.remove(replaced_node)
|
||||
# Again, to avoid self-referencing.
|
||||
if parent.key not in replaced:
|
||||
replacement_node.add_parent(parent)
|
||||
parent.add_child(replacement_node)
|
||||
|
||||
def remove_replacement_node(self, replacement, replaced):
|
||||
"""
|
||||
The inverse operation to `remove_replaced_nodes`. Almost. Remove the
|
||||
replacement node `replacement` and remap its child nodes to `replaced`
|
||||
- the list of nodes it would have replaced. Don't remap its parent
|
||||
nodes as they are expected to be correct already.
|
||||
"""
|
||||
self.nodes.pop(replacement, None)
|
||||
try:
|
||||
replacement_node = self.node_map.pop(replacement)
|
||||
except KeyError as err:
|
||||
raise NodeNotFoundError(
|
||||
"Unable to remove replacement node %r. It was either never added"
|
||||
" to the migration graph, or has been removed already." % (replacement,),
|
||||
replacement
|
||||
) from err
|
||||
replaced_nodes = set()
|
||||
replaced_nodes_parents = set()
|
||||
for key in replaced:
|
||||
replaced_node = self.node_map.get(key)
|
||||
if replaced_node:
|
||||
replaced_nodes.add(replaced_node)
|
||||
replaced_nodes_parents |= replaced_node.parents
|
||||
# We're only interested in the latest replaced node, so filter out
|
||||
# replaced nodes that are parents of other replaced nodes.
|
||||
replaced_nodes -= replaced_nodes_parents
|
||||
for child in replacement_node.children:
|
||||
child.parents.remove(replacement_node)
|
||||
for replaced_node in replaced_nodes:
|
||||
replaced_node.add_child(child)
|
||||
child.add_parent(replaced_node)
|
||||
for parent in replacement_node.parents:
|
||||
parent.children.remove(replacement_node)
|
||||
# NOTE: There is no need to remap parent dependencies as we can
|
||||
# assume the replaced nodes already have the correct ancestry.
|
||||
|
||||
def validate_consistency(self):
|
||||
"""Ensure there are no dummy nodes remaining in the graph."""
|
||||
[n.raise_error() for n in self.node_map.values() if isinstance(n, DummyNode)]
|
||||
|
||||
def forwards_plan(self, target):
|
||||
"""
|
||||
Given a node, return a list of which previous nodes (dependencies) must
|
||||
be applied, ending with the node itself. This is the list you would
|
||||
follow if applying the migrations to a database.
|
||||
"""
|
||||
if target not in self.nodes:
|
||||
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
|
||||
return self.iterative_dfs(self.node_map[target])
|
||||
|
||||
def backwards_plan(self, target):
|
||||
"""
|
||||
Given a node, return a list of which dependent nodes (dependencies)
|
||||
must be unapplied, ending with the node itself. This is the list you
|
||||
would follow if removing the migrations from a database.
|
||||
"""
|
||||
if target not in self.nodes:
|
||||
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
|
||||
return self.iterative_dfs(self.node_map[target], forwards=False)
|
||||
|
||||
def iterative_dfs(self, start, forwards=True):
|
||||
"""Iterative depth-first search for finding dependencies."""
|
||||
visited = []
|
||||
visited_set = set()
|
||||
stack = [(start, False)]
|
||||
while stack:
|
||||
node, processed = stack.pop()
|
||||
if node in visited_set:
|
||||
pass
|
||||
elif processed:
|
||||
visited_set.add(node)
|
||||
visited.append(node.key)
|
||||
else:
|
||||
stack.append((node, True))
|
||||
stack += [(n, False) for n in sorted(node.parents if forwards else node.children)]
|
||||
return visited
|
||||
|
||||
def root_nodes(self, app=None):
|
||||
"""
|
||||
Return all root nodes - that is, nodes with no dependencies inside
|
||||
their app. These are the starting point for an app.
|
||||
"""
|
||||
roots = set()
|
||||
for node in self.nodes:
|
||||
if all(key[0] != node[0] for key in self.node_map[node].parents) and (not app or app == node[0]):
|
||||
roots.add(node)
|
||||
return sorted(roots)
|
||||
|
||||
def leaf_nodes(self, app=None):
|
||||
"""
|
||||
Return all leaf nodes - that is, nodes with no dependents in their app.
|
||||
These are the "most current" version of an app's schema.
|
||||
Having more than one per app is technically an error, but one that
|
||||
gets handled further up, in the interactive command - it's usually the
|
||||
result of a VCS merge and needs some user input.
|
||||
"""
|
||||
leaves = set()
|
||||
for node in self.nodes:
|
||||
if all(key[0] != node[0] for key in self.node_map[node].children) and (not app or app == node[0]):
|
||||
leaves.add(node)
|
||||
return sorted(leaves)
|
||||
|
||||
def ensure_not_cyclic(self):
|
||||
# Algo from GvR:
|
||||
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
|
||||
todo = set(self.nodes)
|
||||
while todo:
|
||||
node = todo.pop()
|
||||
stack = [node]
|
||||
while stack:
|
||||
top = stack[-1]
|
||||
for child in self.node_map[top].children:
|
||||
# Use child.key instead of child to speed up the frequent
|
||||
# hashing.
|
||||
node = child.key
|
||||
if node in stack:
|
||||
cycle = stack[stack.index(node):]
|
||||
raise CircularDependencyError(", ".join("%s.%s" % n for n in cycle))
|
||||
if node in todo:
|
||||
stack.append(node)
|
||||
todo.remove(node)
|
||||
break
|
||||
else:
|
||||
node = stack.pop()
|
||||
|
||||
def __str__(self):
|
||||
return 'Graph: %s nodes, %s edges' % self._nodes_and_edges()
|
||||
|
||||
def __repr__(self):
|
||||
nodes, edges = self._nodes_and_edges()
|
||||
return '<%s: nodes=%s, edges=%s>' % (self.__class__.__name__, nodes, edges)
|
||||
|
||||
def _nodes_and_edges(self):
|
||||
return len(self.nodes), sum(len(node.parents) for node in self.node_map.values())
|
||||
|
||||
def _generate_plan(self, nodes, at_end):
|
||||
plan = []
|
||||
for node in nodes:
|
||||
for migration in self.forwards_plan(node):
|
||||
if migration not in plan and (at_end or migration not in nodes):
|
||||
plan.append(migration)
|
||||
return plan
|
||||
|
||||
def make_state(self, nodes=None, at_end=True, real_apps=None):
|
||||
"""
|
||||
Given a migration node or nodes, return a complete ProjectState for it.
|
||||
If at_end is False, return the state before the migration has run.
|
||||
If nodes is not provided, return the overall most current project state.
|
||||
"""
|
||||
if nodes is None:
|
||||
nodes = list(self.leaf_nodes())
|
||||
if not nodes:
|
||||
return ProjectState()
|
||||
if not isinstance(nodes[0], tuple):
|
||||
nodes = [nodes]
|
||||
plan = self._generate_plan(nodes, at_end)
|
||||
project_state = ProjectState(real_apps=real_apps)
|
||||
for node in plan:
|
||||
project_state = self.nodes[node].mutate_state(project_state, preserve=False)
|
||||
return project_state
|
||||
|
||||
def __contains__(self, node):
|
||||
return node in self.nodes
|
||||
356
venv/Lib/site-packages/django/db/migrations/loader.py
Normal file
356
venv/Lib/site-packages/django/db/migrations/loader.py
Normal file
@@ -0,0 +1,356 @@
|
||||
import pkgutil
|
||||
import sys
|
||||
from importlib import import_module, reload
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.db.migrations.graph import MigrationGraph
|
||||
from django.db.migrations.recorder import MigrationRecorder
|
||||
|
||||
from .exceptions import (
|
||||
AmbiguityError, BadMigrationError, InconsistentMigrationHistory,
|
||||
NodeNotFoundError,
|
||||
)
|
||||
|
||||
MIGRATIONS_MODULE_NAME = 'migrations'
|
||||
|
||||
|
||||
class MigrationLoader:
|
||||
"""
|
||||
Load migration files from disk and their status from the database.
|
||||
|
||||
Migration files are expected to live in the "migrations" directory of
|
||||
an app. Their names are entirely unimportant from a code perspective,
|
||||
but will probably follow the 1234_name.py convention.
|
||||
|
||||
On initialization, this class will scan those directories, and open and
|
||||
read the Python files, looking for a class called Migration, which should
|
||||
inherit from django.db.migrations.Migration. See
|
||||
django.db.migrations.migration for what that looks like.
|
||||
|
||||
Some migrations will be marked as "replacing" another set of migrations.
|
||||
These are loaded into a separate set of migrations away from the main ones.
|
||||
If all the migrations they replace are either unapplied or missing from
|
||||
disk, then they are injected into the main set, replacing the named migrations.
|
||||
Any dependency pointers to the replaced migrations are re-pointed to the
|
||||
new migration.
|
||||
|
||||
This does mean that this class MUST also talk to the database as well as
|
||||
to disk, but this is probably fine. We're already not just operating
|
||||
in memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, connection, load=True, ignore_no_migrations=False,
|
||||
replace_migrations=True,
|
||||
):
|
||||
self.connection = connection
|
||||
self.disk_migrations = None
|
||||
self.applied_migrations = None
|
||||
self.ignore_no_migrations = ignore_no_migrations
|
||||
self.replace_migrations = replace_migrations
|
||||
if load:
|
||||
self.build_graph()
|
||||
|
||||
@classmethod
|
||||
def migrations_module(cls, app_label):
|
||||
"""
|
||||
Return the path to the migrations module for the specified app_label
|
||||
and a boolean indicating if the module is specified in
|
||||
settings.MIGRATION_MODULE.
|
||||
"""
|
||||
if app_label in settings.MIGRATION_MODULES:
|
||||
return settings.MIGRATION_MODULES[app_label], True
|
||||
else:
|
||||
app_package_name = apps.get_app_config(app_label).name
|
||||
return '%s.%s' % (app_package_name, MIGRATIONS_MODULE_NAME), False
|
||||
|
||||
def load_disk(self):
|
||||
"""Load the migrations from all INSTALLED_APPS from disk."""
|
||||
self.disk_migrations = {}
|
||||
self.unmigrated_apps = set()
|
||||
self.migrated_apps = set()
|
||||
for app_config in apps.get_app_configs():
|
||||
# Get the migrations module directory
|
||||
module_name, explicit = self.migrations_module(app_config.label)
|
||||
if module_name is None:
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
was_loaded = module_name in sys.modules
|
||||
try:
|
||||
module = import_module(module_name)
|
||||
except ModuleNotFoundError as e:
|
||||
if (
|
||||
(explicit and self.ignore_no_migrations) or
|
||||
(not explicit and MIGRATIONS_MODULE_NAME in e.name.split('.'))
|
||||
):
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
raise
|
||||
else:
|
||||
# Module is not a package (e.g. migrations.py).
|
||||
if not hasattr(module, '__path__'):
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
# Empty directories are namespaces. Namespace packages have no
|
||||
# __file__ and don't use a list for __path__. See
|
||||
# https://docs.python.org/3/reference/import.html#namespace-packages
|
||||
if (
|
||||
getattr(module, '__file__', None) is None and
|
||||
not isinstance(module.__path__, list)
|
||||
):
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
# Force a reload if it's already loaded (tests need this)
|
||||
if was_loaded:
|
||||
reload(module)
|
||||
self.migrated_apps.add(app_config.label)
|
||||
migration_names = {
|
||||
name for _, name, is_pkg in pkgutil.iter_modules(module.__path__)
|
||||
if not is_pkg and name[0] not in '_~'
|
||||
}
|
||||
# Load migrations
|
||||
for migration_name in migration_names:
|
||||
migration_path = '%s.%s' % (module_name, migration_name)
|
||||
try:
|
||||
migration_module = import_module(migration_path)
|
||||
except ImportError as e:
|
||||
if 'bad magic number' in str(e):
|
||||
raise ImportError(
|
||||
"Couldn't import %r as it appears to be a stale "
|
||||
".pyc file." % migration_path
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
if not hasattr(migration_module, "Migration"):
|
||||
raise BadMigrationError(
|
||||
"Migration %s in app %s has no Migration class" % (migration_name, app_config.label)
|
||||
)
|
||||
self.disk_migrations[app_config.label, migration_name] = migration_module.Migration(
|
||||
migration_name,
|
||||
app_config.label,
|
||||
)
|
||||
|
||||
def get_migration(self, app_label, name_prefix):
|
||||
"""Return the named migration or raise NodeNotFoundError."""
|
||||
return self.graph.nodes[app_label, name_prefix]
|
||||
|
||||
def get_migration_by_prefix(self, app_label, name_prefix):
|
||||
"""
|
||||
Return the migration(s) which match the given app label and name_prefix.
|
||||
"""
|
||||
# Do the search
|
||||
results = []
|
||||
for migration_app_label, migration_name in self.disk_migrations:
|
||||
if migration_app_label == app_label and migration_name.startswith(name_prefix):
|
||||
results.append((migration_app_label, migration_name))
|
||||
if len(results) > 1:
|
||||
raise AmbiguityError(
|
||||
"There is more than one migration for '%s' with the prefix '%s'" % (app_label, name_prefix)
|
||||
)
|
||||
elif not results:
|
||||
raise KeyError(
|
||||
f"There is no migration for '{app_label}' with the prefix "
|
||||
f"'{name_prefix}'"
|
||||
)
|
||||
else:
|
||||
return self.disk_migrations[results[0]]
|
||||
|
||||
def check_key(self, key, current_app):
|
||||
if (key[1] != "__first__" and key[1] != "__latest__") or key in self.graph:
|
||||
return key
|
||||
# Special-case __first__, which means "the first migration" for
|
||||
# migrated apps, and is ignored for unmigrated apps. It allows
|
||||
# makemigrations to declare dependencies on apps before they even have
|
||||
# migrations.
|
||||
if key[0] == current_app:
|
||||
# Ignore __first__ references to the same app (#22325)
|
||||
return
|
||||
if key[0] in self.unmigrated_apps:
|
||||
# This app isn't migrated, but something depends on it.
|
||||
# The models will get auto-added into the state, though
|
||||
# so we're fine.
|
||||
return
|
||||
if key[0] in self.migrated_apps:
|
||||
try:
|
||||
if key[1] == "__first__":
|
||||
return self.graph.root_nodes(key[0])[0]
|
||||
else: # "__latest__"
|
||||
return self.graph.leaf_nodes(key[0])[0]
|
||||
except IndexError:
|
||||
if self.ignore_no_migrations:
|
||||
return None
|
||||
else:
|
||||
raise ValueError("Dependency on app with no migrations: %s" % key[0])
|
||||
raise ValueError("Dependency on unknown app: %s" % key[0])
|
||||
|
||||
def add_internal_dependencies(self, key, migration):
|
||||
"""
|
||||
Internal dependencies need to be added first to ensure `__first__`
|
||||
dependencies find the correct root node.
|
||||
"""
|
||||
for parent in migration.dependencies:
|
||||
# Ignore __first__ references to the same app.
|
||||
if parent[0] == key[0] and parent[1] != '__first__':
|
||||
self.graph.add_dependency(migration, key, parent, skip_validation=True)
|
||||
|
||||
def add_external_dependencies(self, key, migration):
|
||||
for parent in migration.dependencies:
|
||||
# Skip internal dependencies
|
||||
if key[0] == parent[0]:
|
||||
continue
|
||||
parent = self.check_key(parent, key[0])
|
||||
if parent is not None:
|
||||
self.graph.add_dependency(migration, key, parent, skip_validation=True)
|
||||
for child in migration.run_before:
|
||||
child = self.check_key(child, key[0])
|
||||
if child is not None:
|
||||
self.graph.add_dependency(migration, child, key, skip_validation=True)
|
||||
|
||||
def build_graph(self):
|
||||
"""
|
||||
Build a migration dependency graph using both the disk and database.
|
||||
You'll need to rebuild the graph if you apply migrations. This isn't
|
||||
usually a problem as generally migration stuff runs in a one-shot process.
|
||||
"""
|
||||
# Load disk data
|
||||
self.load_disk()
|
||||
# Load database data
|
||||
if self.connection is None:
|
||||
self.applied_migrations = {}
|
||||
else:
|
||||
recorder = MigrationRecorder(self.connection)
|
||||
self.applied_migrations = recorder.applied_migrations()
|
||||
# To start, populate the migration graph with nodes for ALL migrations
|
||||
# and their dependencies. Also make note of replacing migrations at this step.
|
||||
self.graph = MigrationGraph()
|
||||
self.replacements = {}
|
||||
for key, migration in self.disk_migrations.items():
|
||||
self.graph.add_node(key, migration)
|
||||
# Replacing migrations.
|
||||
if migration.replaces:
|
||||
self.replacements[key] = migration
|
||||
for key, migration in self.disk_migrations.items():
|
||||
# Internal (same app) dependencies.
|
||||
self.add_internal_dependencies(key, migration)
|
||||
# Add external dependencies now that the internal ones have been resolved.
|
||||
for key, migration in self.disk_migrations.items():
|
||||
self.add_external_dependencies(key, migration)
|
||||
# Carry out replacements where possible and if enabled.
|
||||
if self.replace_migrations:
|
||||
for key, migration in self.replacements.items():
|
||||
# Get applied status of each of this migration's replacement
|
||||
# targets.
|
||||
applied_statuses = [(target in self.applied_migrations) for target in migration.replaces]
|
||||
# The replacing migration is only marked as applied if all of
|
||||
# its replacement targets are.
|
||||
if all(applied_statuses):
|
||||
self.applied_migrations[key] = migration
|
||||
else:
|
||||
self.applied_migrations.pop(key, None)
|
||||
# A replacing migration can be used if either all or none of
|
||||
# its replacement targets have been applied.
|
||||
if all(applied_statuses) or (not any(applied_statuses)):
|
||||
self.graph.remove_replaced_nodes(key, migration.replaces)
|
||||
else:
|
||||
# This replacing migration cannot be used because it is
|
||||
# partially applied. Remove it from the graph and remap
|
||||
# dependencies to it (#25945).
|
||||
self.graph.remove_replacement_node(key, migration.replaces)
|
||||
# Ensure the graph is consistent.
|
||||
try:
|
||||
self.graph.validate_consistency()
|
||||
except NodeNotFoundError as exc:
|
||||
# Check if the missing node could have been replaced by any squash
|
||||
# migration but wasn't because the squash migration was partially
|
||||
# applied before. In that case raise a more understandable exception
|
||||
# (#23556).
|
||||
# Get reverse replacements.
|
||||
reverse_replacements = {}
|
||||
for key, migration in self.replacements.items():
|
||||
for replaced in migration.replaces:
|
||||
reverse_replacements.setdefault(replaced, set()).add(key)
|
||||
# Try to reraise exception with more detail.
|
||||
if exc.node in reverse_replacements:
|
||||
candidates = reverse_replacements.get(exc.node, set())
|
||||
is_replaced = any(candidate in self.graph.nodes for candidate in candidates)
|
||||
if not is_replaced:
|
||||
tries = ', '.join('%s.%s' % c for c in candidates)
|
||||
raise NodeNotFoundError(
|
||||
"Migration {0} depends on nonexistent node ('{1}', '{2}'). "
|
||||
"Django tried to replace migration {1}.{2} with any of [{3}] "
|
||||
"but wasn't able to because some of the replaced migrations "
|
||||
"are already applied.".format(
|
||||
exc.origin, exc.node[0], exc.node[1], tries
|
||||
),
|
||||
exc.node
|
||||
) from exc
|
||||
raise
|
||||
self.graph.ensure_not_cyclic()
|
||||
|
||||
def check_consistent_history(self, connection):
|
||||
"""
|
||||
Raise InconsistentMigrationHistory if any applied migrations have
|
||||
unapplied dependencies.
|
||||
"""
|
||||
recorder = MigrationRecorder(connection)
|
||||
applied = recorder.applied_migrations()
|
||||
for migration in applied:
|
||||
# If the migration is unknown, skip it.
|
||||
if migration not in self.graph.nodes:
|
||||
continue
|
||||
for parent in self.graph.node_map[migration].parents:
|
||||
if parent not in applied:
|
||||
# Skip unapplied squashed migrations that have all of their
|
||||
# `replaces` applied.
|
||||
if parent in self.replacements:
|
||||
if all(m in applied for m in self.replacements[parent].replaces):
|
||||
continue
|
||||
raise InconsistentMigrationHistory(
|
||||
"Migration {}.{} is applied before its dependency "
|
||||
"{}.{} on database '{}'.".format(
|
||||
migration[0], migration[1], parent[0], parent[1],
|
||||
connection.alias,
|
||||
)
|
||||
)
|
||||
|
||||
def detect_conflicts(self):
|
||||
"""
|
||||
Look through the loaded graph and detect any conflicts - apps
|
||||
with more than one leaf migration. Return a dict of the app labels
|
||||
that conflict with the migration names that conflict.
|
||||
"""
|
||||
seen_apps = {}
|
||||
conflicting_apps = set()
|
||||
for app_label, migration_name in self.graph.leaf_nodes():
|
||||
if app_label in seen_apps:
|
||||
conflicting_apps.add(app_label)
|
||||
seen_apps.setdefault(app_label, set()).add(migration_name)
|
||||
return {app_label: sorted(seen_apps[app_label]) for app_label in conflicting_apps}
|
||||
|
||||
def project_state(self, nodes=None, at_end=True):
|
||||
"""
|
||||
Return a ProjectState object representing the most recent state
|
||||
that the loaded migrations represent.
|
||||
|
||||
See graph.make_state() for the meaning of "nodes" and "at_end".
|
||||
"""
|
||||
return self.graph.make_state(nodes=nodes, at_end=at_end, real_apps=self.unmigrated_apps)
|
||||
|
||||
def collect_sql(self, plan):
|
||||
"""
|
||||
Take a migration plan and return a list of collected SQL statements
|
||||
that represent the best-efforts version of that plan.
|
||||
"""
|
||||
statements = []
|
||||
state = None
|
||||
for migration, backwards in plan:
|
||||
with self.connection.schema_editor(collect_sql=True, atomic=migration.atomic) as schema_editor:
|
||||
if state is None:
|
||||
state = self.project_state((migration.app_label, migration.name), at_end=False)
|
||||
if not backwards:
|
||||
state = migration.apply(state, schema_editor, collect_sql=True)
|
||||
else:
|
||||
state = migration.unapply(state, schema_editor, collect_sql=True)
|
||||
statements.extend(schema_editor.collected_sql)
|
||||
return statements
|
||||
218
venv/Lib/site-packages/django/db/migrations/migration.py
Normal file
218
venv/Lib/site-packages/django/db/migrations/migration.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from django.db.migrations.utils import get_migration_name_timestamp
|
||||
from django.db.transaction import atomic
|
||||
|
||||
from .exceptions import IrreversibleError
|
||||
|
||||
|
||||
class Migration:
|
||||
"""
|
||||
The base class for all migrations.
|
||||
|
||||
Migration files will import this from django.db.migrations.Migration
|
||||
and subclass it as a class called Migration. It will have one or more
|
||||
of the following attributes:
|
||||
|
||||
- operations: A list of Operation instances, probably from django.db.migrations.operations
|
||||
- dependencies: A list of tuples of (app_path, migration_name)
|
||||
- run_before: A list of tuples of (app_path, migration_name)
|
||||
- replaces: A list of migration_names
|
||||
|
||||
Note that all migrations come out of migrations and into the Loader or
|
||||
Graph as instances, having been initialized with their app label and name.
|
||||
"""
|
||||
|
||||
# Operations to apply during this migration, in order.
|
||||
operations = []
|
||||
|
||||
# Other migrations that should be run before this migration.
|
||||
# Should be a list of (app, migration_name).
|
||||
dependencies = []
|
||||
|
||||
# Other migrations that should be run after this one (i.e. have
|
||||
# this migration added to their dependencies). Useful to make third-party
|
||||
# apps' migrations run after your AUTH_USER replacement, for example.
|
||||
run_before = []
|
||||
|
||||
# Migration names in this app that this migration replaces. If this is
|
||||
# non-empty, this migration will only be applied if all these migrations
|
||||
# are not applied.
|
||||
replaces = []
|
||||
|
||||
# Is this an initial migration? Initial migrations are skipped on
|
||||
# --fake-initial if the table or fields already exist. If None, check if
|
||||
# the migration has any dependencies to determine if there are dependencies
|
||||
# to tell if db introspection needs to be done. If True, always perform
|
||||
# introspection. If False, never perform introspection.
|
||||
initial = None
|
||||
|
||||
# Whether to wrap the whole migration in a transaction. Only has an effect
|
||||
# on database backends which support transactional DDL.
|
||||
atomic = True
|
||||
|
||||
def __init__(self, name, app_label):
|
||||
self.name = name
|
||||
self.app_label = app_label
|
||||
# Copy dependencies & other attrs as we might mutate them at runtime
|
||||
self.operations = list(self.__class__.operations)
|
||||
self.dependencies = list(self.__class__.dependencies)
|
||||
self.run_before = list(self.__class__.run_before)
|
||||
self.replaces = list(self.__class__.replaces)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, Migration) and
|
||||
self.name == other.name and
|
||||
self.app_label == other.app_label
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<Migration %s.%s>" % (self.app_label, self.name)
|
||||
|
||||
def __str__(self):
|
||||
return "%s.%s" % (self.app_label, self.name)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("%s.%s" % (self.app_label, self.name))
|
||||
|
||||
def mutate_state(self, project_state, preserve=True):
|
||||
"""
|
||||
Take a ProjectState and return a new one with the migration's
|
||||
operations applied to it. Preserve the original object state by
|
||||
default and return a mutated state from a copy.
|
||||
"""
|
||||
new_state = project_state
|
||||
if preserve:
|
||||
new_state = project_state.clone()
|
||||
|
||||
for operation in self.operations:
|
||||
operation.state_forwards(self.app_label, new_state)
|
||||
return new_state
|
||||
|
||||
def apply(self, project_state, schema_editor, collect_sql=False):
|
||||
"""
|
||||
Take a project_state representing all migrations prior to this one
|
||||
and a schema_editor for a live database and apply the migration
|
||||
in a forwards order.
|
||||
|
||||
Return the resulting project state for efficient reuse by following
|
||||
Migrations.
|
||||
"""
|
||||
for operation in self.operations:
|
||||
# If this operation cannot be represented as SQL, place a comment
|
||||
# there instead
|
||||
if collect_sql:
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
schema_editor.collected_sql.append(
|
||||
"-- MIGRATION NOW PERFORMS OPERATION THAT CANNOT BE WRITTEN AS SQL:"
|
||||
)
|
||||
schema_editor.collected_sql.append("-- %s" % operation.describe())
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
continue
|
||||
# Save the state before the operation has run
|
||||
old_state = project_state.clone()
|
||||
operation.state_forwards(self.app_label, project_state)
|
||||
# Run the operation
|
||||
atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False)
|
||||
if not schema_editor.atomic_migration and atomic_operation:
|
||||
# Force a transaction on a non-transactional-DDL backend or an
|
||||
# atomic operation inside a non-atomic migration.
|
||||
with atomic(schema_editor.connection.alias):
|
||||
operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
|
||||
else:
|
||||
# Normal behaviour
|
||||
operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
|
||||
return project_state
|
||||
|
||||
def unapply(self, project_state, schema_editor, collect_sql=False):
|
||||
"""
|
||||
Take a project_state representing all migrations prior to this one
|
||||
and a schema_editor for a live database and apply the migration
|
||||
in a reverse order.
|
||||
|
||||
The backwards migration process consists of two phases:
|
||||
|
||||
1. The intermediate states from right before the first until right
|
||||
after the last operation inside this migration are preserved.
|
||||
2. The operations are applied in reverse order using the states
|
||||
recorded in step 1.
|
||||
"""
|
||||
# Construct all the intermediate states we need for a reverse migration
|
||||
to_run = []
|
||||
new_state = project_state
|
||||
# Phase 1
|
||||
for operation in self.operations:
|
||||
# If it's irreversible, error out
|
||||
if not operation.reversible:
|
||||
raise IrreversibleError("Operation %s in %s is not reversible" % (operation, self))
|
||||
# Preserve new state from previous run to not tamper the same state
|
||||
# over all operations
|
||||
new_state = new_state.clone()
|
||||
old_state = new_state.clone()
|
||||
operation.state_forwards(self.app_label, new_state)
|
||||
to_run.insert(0, (operation, old_state, new_state))
|
||||
|
||||
# Phase 2
|
||||
for operation, to_state, from_state in to_run:
|
||||
if collect_sql:
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
schema_editor.collected_sql.append(
|
||||
"-- MIGRATION NOW PERFORMS OPERATION THAT CANNOT BE WRITTEN AS SQL:"
|
||||
)
|
||||
schema_editor.collected_sql.append("-- %s" % operation.describe())
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
continue
|
||||
atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False)
|
||||
if not schema_editor.atomic_migration and atomic_operation:
|
||||
# Force a transaction on a non-transactional-DDL backend or an
|
||||
# atomic operation inside a non-atomic migration.
|
||||
with atomic(schema_editor.connection.alias):
|
||||
operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
|
||||
else:
|
||||
# Normal behaviour
|
||||
operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
|
||||
return project_state
|
||||
|
||||
def suggest_name(self):
|
||||
"""
|
||||
Suggest a name for the operations this migration might represent. Names
|
||||
are not guaranteed to be unique, but put some effort into the fallback
|
||||
name to avoid VCS conflicts if possible.
|
||||
"""
|
||||
if self.initial:
|
||||
return 'initial'
|
||||
|
||||
raw_fragments = [op.migration_name_fragment for op in self.operations]
|
||||
fragments = [name for name in raw_fragments if name]
|
||||
|
||||
if not fragments or len(fragments) != len(self.operations):
|
||||
return 'auto_%s' % get_migration_name_timestamp()
|
||||
|
||||
name = fragments[0]
|
||||
for fragment in fragments[1:]:
|
||||
new_name = f'{name}_{fragment}'
|
||||
if len(new_name) > 52:
|
||||
name = f'{name}_and_more'
|
||||
break
|
||||
name = new_name
|
||||
return name
|
||||
|
||||
|
||||
class SwappableTuple(tuple):
|
||||
"""
|
||||
Subclass of tuple so Django can tell this was originally a swappable
|
||||
dependency when it reads the migration file.
|
||||
"""
|
||||
|
||||
def __new__(cls, value, setting):
|
||||
self = tuple.__new__(cls, value)
|
||||
self.setting = setting
|
||||
return self
|
||||
|
||||
|
||||
def swappable_dependency(value):
|
||||
"""Turn a setting value into a dependency."""
|
||||
return SwappableTuple((value.split(".", 1)[0], "__first__"), value)
|
||||
@@ -0,0 +1,17 @@
|
||||
from .fields import AddField, AlterField, RemoveField, RenameField
|
||||
from .models import (
|
||||
AddConstraint, AddIndex, AlterIndexTogether, AlterModelManagers,
|
||||
AlterModelOptions, AlterModelTable, AlterOrderWithRespectTo,
|
||||
AlterUniqueTogether, CreateModel, DeleteModel, RemoveConstraint,
|
||||
RemoveIndex, RenameModel,
|
||||
)
|
||||
from .special import RunPython, RunSQL, SeparateDatabaseAndState
|
||||
|
||||
__all__ = [
|
||||
'CreateModel', 'DeleteModel', 'AlterModelTable', 'AlterUniqueTogether',
|
||||
'RenameModel', 'AlterIndexTogether', 'AlterModelOptions', 'AddIndex',
|
||||
'RemoveIndex', 'AddField', 'RemoveField', 'AlterField', 'RenameField',
|
||||
'AddConstraint', 'RemoveConstraint',
|
||||
'SeparateDatabaseAndState', 'RunSQL', 'RunPython',
|
||||
'AlterOrderWithRespectTo', 'AlterModelManagers',
|
||||
]
|
||||
140
venv/Lib/site-packages/django/db/migrations/operations/base.py
Normal file
140
venv/Lib/site-packages/django/db/migrations/operations/base.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from django.db import router
|
||||
|
||||
|
||||
class Operation:
|
||||
"""
|
||||
Base class for migration operations.
|
||||
|
||||
It's responsible for both mutating the in-memory model state
|
||||
(see db/migrations/state.py) to represent what it performs, as well
|
||||
as actually performing it against a live database.
|
||||
|
||||
Note that some operations won't modify memory state at all (e.g. data
|
||||
copying operations), and some will need their modifications to be
|
||||
optionally specified by the user (e.g. custom Python code snippets)
|
||||
|
||||
Due to the way this class deals with deconstruction, it should be
|
||||
considered immutable.
|
||||
"""
|
||||
|
||||
# If this migration can be run in reverse.
|
||||
# Some operations are impossible to reverse, like deleting data.
|
||||
reversible = True
|
||||
|
||||
# Can this migration be represented as SQL? (things like RunPython cannot)
|
||||
reduces_to_sql = True
|
||||
|
||||
# Should this operation be forced as atomic even on backends with no
|
||||
# DDL transaction support (i.e., does it have no DDL, like RunPython)
|
||||
atomic = False
|
||||
|
||||
# Should this operation be considered safe to elide and optimize across?
|
||||
elidable = False
|
||||
|
||||
serialization_expand_args = []
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# We capture the arguments to make returning them trivial
|
||||
self = object.__new__(cls)
|
||||
self._constructor_args = (args, kwargs)
|
||||
return self
|
||||
|
||||
def deconstruct(self):
|
||||
"""
|
||||
Return a 3-tuple of class import path (or just name if it lives
|
||||
under django.db.migrations), positional arguments, and keyword
|
||||
arguments.
|
||||
"""
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
self._constructor_args[0],
|
||||
self._constructor_args[1],
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
"""
|
||||
Take the state from the previous migration, and mutate it
|
||||
so that it matches what this migration would perform.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of Operation must provide a state_forwards() method')
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
"""
|
||||
Perform the mutation on the database schema in the normal
|
||||
(forwards) direction.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of Operation must provide a database_forwards() method')
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
"""
|
||||
Perform the mutation on the database schema in the reverse
|
||||
direction - e.g. if this were CreateModel, it would in fact
|
||||
drop the model's table.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of Operation must provide a database_backwards() method')
|
||||
|
||||
def describe(self):
|
||||
"""
|
||||
Output a brief summary of what the action does.
|
||||
"""
|
||||
return "%s: %s" % (self.__class__.__name__, self._constructor_args)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
"""
|
||||
A filename part suitable for automatically naming a migration
|
||||
containing this operation, or None if not applicable.
|
||||
"""
|
||||
return None
|
||||
|
||||
def references_model(self, name, app_label):
|
||||
"""
|
||||
Return True if there is a chance this operation references the given
|
||||
model name (as a string), with an app label for accuracy.
|
||||
|
||||
Used for optimization. If in doubt, return True;
|
||||
returning a false positive will merely make the optimizer a little
|
||||
less efficient, while returning a false negative may result in an
|
||||
unusable optimized migration.
|
||||
"""
|
||||
return True
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
"""
|
||||
Return True if there is a chance this operation references the given
|
||||
field name, with an app label for accuracy.
|
||||
|
||||
Used for optimization. If in doubt, return True.
|
||||
"""
|
||||
return self.references_model(model_name, app_label)
|
||||
|
||||
def allow_migrate_model(self, connection_alias, model):
|
||||
"""
|
||||
Return whether or not a model may be migrated.
|
||||
|
||||
This is a thin wrapper around router.allow_migrate_model() that
|
||||
preemptively rejects any proxy, swapped out, or unmanaged model.
|
||||
"""
|
||||
if not model._meta.can_migrate(connection_alias):
|
||||
return False
|
||||
|
||||
return router.allow_migrate_model(connection_alias, model)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
"""
|
||||
Return either a list of operations the actual operation should be
|
||||
replaced with or a boolean that indicates whether or not the specified
|
||||
operation can be optimized across.
|
||||
"""
|
||||
if self.elidable:
|
||||
return [operation]
|
||||
elif operation.elidable:
|
||||
return [self]
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s %s%s>" % (
|
||||
self.__class__.__name__,
|
||||
", ".join(map(repr, self._constructor_args[0])),
|
||||
",".join(" %s=%r" % x for x in self._constructor_args[1].items()),
|
||||
)
|
||||
341
venv/Lib/site-packages/django/db/migrations/operations/fields.py
Normal file
341
venv/Lib/site-packages/django/db/migrations/operations/fields.py
Normal file
@@ -0,0 +1,341 @@
|
||||
from django.db.migrations.utils import field_references
|
||||
from django.db.models import NOT_PROVIDED
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from .base import Operation
|
||||
|
||||
|
||||
class FieldOperation(Operation):
|
||||
def __init__(self, model_name, name, field=None):
|
||||
self.model_name = model_name
|
||||
self.name = name
|
||||
self.field = field
|
||||
|
||||
@cached_property
|
||||
def model_name_lower(self):
|
||||
return self.model_name.lower()
|
||||
|
||||
@cached_property
|
||||
def name_lower(self):
|
||||
return self.name.lower()
|
||||
|
||||
def is_same_model_operation(self, operation):
|
||||
return self.model_name_lower == operation.model_name_lower
|
||||
|
||||
def is_same_field_operation(self, operation):
|
||||
return self.is_same_model_operation(operation) and self.name_lower == operation.name_lower
|
||||
|
||||
def references_model(self, name, app_label):
|
||||
name_lower = name.lower()
|
||||
if name_lower == self.model_name_lower:
|
||||
return True
|
||||
if self.field:
|
||||
return bool(field_references(
|
||||
(app_label, self.model_name_lower), self.field, (app_label, name_lower)
|
||||
))
|
||||
return False
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
model_name_lower = model_name.lower()
|
||||
# Check if this operation locally references the field.
|
||||
if model_name_lower == self.model_name_lower:
|
||||
if name == self.name:
|
||||
return True
|
||||
elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields:
|
||||
return True
|
||||
# Check if this operation remotely references the field.
|
||||
if self.field is None:
|
||||
return False
|
||||
return bool(field_references(
|
||||
(app_label, self.model_name_lower),
|
||||
self.field,
|
||||
(app_label, model_name_lower),
|
||||
name,
|
||||
))
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
return (
|
||||
super().reduce(operation, app_label) or
|
||||
not operation.references_field(self.model_name, self.name, app_label)
|
||||
)
|
||||
|
||||
|
||||
class AddField(FieldOperation):
|
||||
"""Add a field to a model."""
|
||||
|
||||
def __init__(self, model_name, name, field, preserve_default=True):
|
||||
self.preserve_default = preserve_default
|
||||
super().__init__(model_name, name, field)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'name': self.name,
|
||||
'field': self.field,
|
||||
}
|
||||
if self.preserve_default is not True:
|
||||
kwargs['preserve_default'] = self.preserve_default
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.add_field(
|
||||
app_label,
|
||||
self.model_name_lower,
|
||||
self.name,
|
||||
self.field,
|
||||
self.preserve_default,
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
field = to_model._meta.get_field(self.name)
|
||||
if not self.preserve_default:
|
||||
field.default = self.field.default
|
||||
schema_editor.add_field(
|
||||
from_model,
|
||||
field,
|
||||
)
|
||||
if not self.preserve_default:
|
||||
field.default = NOT_PROVIDED
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
|
||||
schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
|
||||
|
||||
def describe(self):
|
||||
return "Add field %s to %s" % (self.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return '%s_%s' % (self.model_name_lower, self.name_lower)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
if isinstance(operation, FieldOperation) and self.is_same_field_operation(operation):
|
||||
if isinstance(operation, AlterField):
|
||||
return [
|
||||
AddField(
|
||||
model_name=self.model_name,
|
||||
name=operation.name,
|
||||
field=operation.field,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, RemoveField):
|
||||
return []
|
||||
elif isinstance(operation, RenameField):
|
||||
return [
|
||||
AddField(
|
||||
model_name=self.model_name,
|
||||
name=operation.new_name,
|
||||
field=self.field,
|
||||
),
|
||||
]
|
||||
return super().reduce(operation, app_label)
|
||||
|
||||
|
||||
class RemoveField(FieldOperation):
|
||||
"""Remove a field from a model."""
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'name': self.name,
|
||||
}
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.remove_field(app_label, self.model_name_lower, self.name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
|
||||
schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
|
||||
|
||||
def describe(self):
|
||||
return "Remove field %s from %s" % (self.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'remove_%s_%s' % (self.model_name_lower, self.name_lower)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
from .models import DeleteModel
|
||||
if isinstance(operation, DeleteModel) and operation.name_lower == self.model_name_lower:
|
||||
return [operation]
|
||||
return super().reduce(operation, app_label)
|
||||
|
||||
|
||||
class AlterField(FieldOperation):
|
||||
"""
|
||||
Alter a field's database column (e.g. null, max_length) to the provided
|
||||
new field.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name, name, field, preserve_default=True):
|
||||
self.preserve_default = preserve_default
|
||||
super().__init__(model_name, name, field)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'name': self.name,
|
||||
'field': self.field,
|
||||
}
|
||||
if self.preserve_default is not True:
|
||||
kwargs['preserve_default'] = self.preserve_default
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.alter_field(
|
||||
app_label,
|
||||
self.model_name_lower,
|
||||
self.name,
|
||||
self.field,
|
||||
self.preserve_default,
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
from_field = from_model._meta.get_field(self.name)
|
||||
to_field = to_model._meta.get_field(self.name)
|
||||
if not self.preserve_default:
|
||||
to_field.default = self.field.default
|
||||
schema_editor.alter_field(from_model, from_field, to_field)
|
||||
if not self.preserve_default:
|
||||
to_field.default = NOT_PROVIDED
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def describe(self):
|
||||
return "Alter field %s on %s" % (self.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'alter_%s_%s' % (self.model_name_lower, self.name_lower)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
if isinstance(operation, RemoveField) and self.is_same_field_operation(operation):
|
||||
return [operation]
|
||||
elif isinstance(operation, RenameField) and self.is_same_field_operation(operation):
|
||||
return [
|
||||
operation,
|
||||
AlterField(
|
||||
model_name=self.model_name,
|
||||
name=operation.new_name,
|
||||
field=self.field,
|
||||
),
|
||||
]
|
||||
return super().reduce(operation, app_label)
|
||||
|
||||
|
||||
class RenameField(FieldOperation):
|
||||
"""Rename a field on the model. Might affect db_column too."""
|
||||
|
||||
def __init__(self, model_name, old_name, new_name):
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
super().__init__(model_name, old_name)
|
||||
|
||||
@cached_property
|
||||
def old_name_lower(self):
|
||||
return self.old_name.lower()
|
||||
|
||||
@cached_property
|
||||
def new_name_lower(self):
|
||||
return self.new_name.lower()
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'old_name': self.old_name,
|
||||
'new_name': self.new_name,
|
||||
}
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.rename_field(app_label, self.model_name_lower, self.old_name, self.new_name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.alter_field(
|
||||
from_model,
|
||||
from_model._meta.get_field(self.old_name),
|
||||
to_model._meta.get_field(self.new_name),
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.alter_field(
|
||||
from_model,
|
||||
from_model._meta.get_field(self.new_name),
|
||||
to_model._meta.get_field(self.old_name),
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'rename_%s_%s_%s' % (
|
||||
self.old_name_lower,
|
||||
self.model_name_lower,
|
||||
self.new_name_lower,
|
||||
)
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
return self.references_model(model_name, app_label) and (
|
||||
name.lower() == self.old_name_lower or
|
||||
name.lower() == self.new_name_lower
|
||||
)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
if (isinstance(operation, RenameField) and
|
||||
self.is_same_model_operation(operation) and
|
||||
self.new_name_lower == operation.old_name_lower):
|
||||
return [
|
||||
RenameField(
|
||||
self.model_name,
|
||||
self.old_name,
|
||||
operation.new_name,
|
||||
),
|
||||
]
|
||||
# Skip `FieldOperation.reduce` as we want to run `references_field`
|
||||
# against self.old_name and self.new_name.
|
||||
return (
|
||||
super(FieldOperation, self).reduce(operation, app_label) or
|
||||
not (
|
||||
operation.references_field(self.model_name, self.old_name, app_label) or
|
||||
operation.references_field(self.model_name, self.new_name, app_label)
|
||||
)
|
||||
)
|
||||
884
venv/Lib/site-packages/django/db/migrations/operations/models.py
Normal file
884
venv/Lib/site-packages/django/db/migrations/operations/models.py
Normal file
@@ -0,0 +1,884 @@
|
||||
from django.db import models
|
||||
from django.db.migrations.operations.base import Operation
|
||||
from django.db.migrations.state import ModelState
|
||||
from django.db.migrations.utils import field_references, resolve_relation
|
||||
from django.db.models.options import normalize_together
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from .fields import (
|
||||
AddField, AlterField, FieldOperation, RemoveField, RenameField,
|
||||
)
|
||||
|
||||
|
||||
def _check_for_duplicates(arg_name, objs):
|
||||
used_vals = set()
|
||||
for val in objs:
|
||||
if val in used_vals:
|
||||
raise ValueError(
|
||||
"Found duplicate value %s in CreateModel %s argument." % (val, arg_name)
|
||||
)
|
||||
used_vals.add(val)
|
||||
|
||||
|
||||
class ModelOperation(Operation):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
@cached_property
|
||||
def name_lower(self):
|
||||
return self.name.lower()
|
||||
|
||||
def references_model(self, name, app_label):
|
||||
return name.lower() == self.name_lower
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
return (
|
||||
super().reduce(operation, app_label) or
|
||||
not operation.references_model(self.name, app_label)
|
||||
)
|
||||
|
||||
|
||||
class CreateModel(ModelOperation):
|
||||
"""Create a model's table."""
|
||||
|
||||
serialization_expand_args = ['fields', 'options', 'managers']
|
||||
|
||||
def __init__(self, name, fields, options=None, bases=None, managers=None):
|
||||
self.fields = fields
|
||||
self.options = options or {}
|
||||
self.bases = bases or (models.Model,)
|
||||
self.managers = managers or []
|
||||
super().__init__(name)
|
||||
# Sanity-check that there are no duplicated field names, bases, or
|
||||
# manager names
|
||||
_check_for_duplicates('fields', (name for name, _ in self.fields))
|
||||
_check_for_duplicates('bases', (
|
||||
base._meta.label_lower if hasattr(base, '_meta') else
|
||||
base.lower() if isinstance(base, str) else base
|
||||
for base in self.bases
|
||||
))
|
||||
_check_for_duplicates('managers', (name for name, _ in self.managers))
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
'fields': self.fields,
|
||||
}
|
||||
if self.options:
|
||||
kwargs['options'] = self.options
|
||||
if self.bases and self.bases != (models.Model,):
|
||||
kwargs['bases'] = self.bases
|
||||
if self.managers and self.managers != [('objects', models.Manager())]:
|
||||
kwargs['managers'] = self.managers
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.add_model(ModelState(
|
||||
app_label,
|
||||
self.name,
|
||||
list(self.fields),
|
||||
dict(self.options),
|
||||
tuple(self.bases),
|
||||
list(self.managers),
|
||||
))
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.create_model(model)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = from_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.delete_model(model)
|
||||
|
||||
def describe(self):
|
||||
return "Create %smodel %s" % ("proxy " if self.options.get("proxy", False) else "", self.name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return self.name_lower
|
||||
|
||||
def references_model(self, name, app_label):
|
||||
name_lower = name.lower()
|
||||
if name_lower == self.name_lower:
|
||||
return True
|
||||
|
||||
# Check we didn't inherit from the model
|
||||
reference_model_tuple = (app_label, name_lower)
|
||||
for base in self.bases:
|
||||
if (base is not models.Model and isinstance(base, (models.base.ModelBase, str)) and
|
||||
resolve_relation(base, app_label) == reference_model_tuple):
|
||||
return True
|
||||
|
||||
# Check we have no FKs/M2Ms with it
|
||||
for _name, field in self.fields:
|
||||
if field_references((app_label, self.name_lower), field, reference_model_tuple):
|
||||
return True
|
||||
return False
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
if (isinstance(operation, DeleteModel) and
|
||||
self.name_lower == operation.name_lower and
|
||||
not self.options.get("proxy", False)):
|
||||
return []
|
||||
elif isinstance(operation, RenameModel) and self.name_lower == operation.old_name_lower:
|
||||
return [
|
||||
CreateModel(
|
||||
operation.new_name,
|
||||
fields=self.fields,
|
||||
options=self.options,
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, AlterModelOptions) and self.name_lower == operation.name_lower:
|
||||
options = {**self.options, **operation.options}
|
||||
for key in operation.ALTER_OPTION_KEYS:
|
||||
if key not in operation.options:
|
||||
options.pop(key, None)
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=self.fields,
|
||||
options=options,
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, AlterTogetherOptionOperation) and self.name_lower == operation.name_lower:
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=self.fields,
|
||||
options={**self.options, **{operation.option_name: operation.option_value}},
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, AlterOrderWithRespectTo) and self.name_lower == operation.name_lower:
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=self.fields,
|
||||
options={**self.options, 'order_with_respect_to': operation.order_with_respect_to},
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, FieldOperation) and self.name_lower == operation.model_name_lower:
|
||||
if isinstance(operation, AddField):
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=self.fields + [(operation.name, operation.field)],
|
||||
options=self.options,
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, AlterField):
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=[
|
||||
(n, operation.field if n == operation.name else v)
|
||||
for n, v in self.fields
|
||||
],
|
||||
options=self.options,
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, RemoveField):
|
||||
options = self.options.copy()
|
||||
for option_name in ('unique_together', 'index_together'):
|
||||
option = options.pop(option_name, None)
|
||||
if option:
|
||||
option = set(filter(bool, (
|
||||
tuple(f for f in fields if f != operation.name_lower) for fields in option
|
||||
)))
|
||||
if option:
|
||||
options[option_name] = option
|
||||
order_with_respect_to = options.get('order_with_respect_to')
|
||||
if order_with_respect_to == operation.name_lower:
|
||||
del options['order_with_respect_to']
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=[
|
||||
(n, v)
|
||||
for n, v in self.fields
|
||||
if n.lower() != operation.name_lower
|
||||
],
|
||||
options=options,
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, RenameField):
|
||||
options = self.options.copy()
|
||||
for option_name in ('unique_together', 'index_together'):
|
||||
option = options.get(option_name)
|
||||
if option:
|
||||
options[option_name] = {
|
||||
tuple(operation.new_name if f == operation.old_name else f for f in fields)
|
||||
for fields in option
|
||||
}
|
||||
order_with_respect_to = options.get('order_with_respect_to')
|
||||
if order_with_respect_to == operation.old_name:
|
||||
options['order_with_respect_to'] = operation.new_name
|
||||
return [
|
||||
CreateModel(
|
||||
self.name,
|
||||
fields=[
|
||||
(operation.new_name if n == operation.old_name else n, v)
|
||||
for n, v in self.fields
|
||||
],
|
||||
options=options,
|
||||
bases=self.bases,
|
||||
managers=self.managers,
|
||||
),
|
||||
]
|
||||
return super().reduce(operation, app_label)
|
||||
|
||||
|
||||
class DeleteModel(ModelOperation):
|
||||
"""Drop a model's table."""
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.remove_model(app_label, self.name_lower)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = from_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.delete_model(model)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.create_model(model)
|
||||
|
||||
def references_model(self, name, app_label):
|
||||
# The deleted model could be referencing the specified model through
|
||||
# related fields.
|
||||
return True
|
||||
|
||||
def describe(self):
|
||||
return "Delete model %s" % self.name
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'delete_%s' % self.name_lower
|
||||
|
||||
|
||||
class RenameModel(ModelOperation):
|
||||
"""Rename a model."""
|
||||
|
||||
def __init__(self, old_name, new_name):
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
super().__init__(old_name)
|
||||
|
||||
@cached_property
|
||||
def old_name_lower(self):
|
||||
return self.old_name.lower()
|
||||
|
||||
@cached_property
|
||||
def new_name_lower(self):
|
||||
return self.new_name.lower()
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'old_name': self.old_name,
|
||||
'new_name': self.new_name,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.rename_model(app_label, self.old_name, self.new_name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
new_model = to_state.apps.get_model(app_label, self.new_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, new_model):
|
||||
old_model = from_state.apps.get_model(app_label, self.old_name)
|
||||
# Move the main table
|
||||
schema_editor.alter_db_table(
|
||||
new_model,
|
||||
old_model._meta.db_table,
|
||||
new_model._meta.db_table,
|
||||
)
|
||||
# Alter the fields pointing to us
|
||||
for related_object in old_model._meta.related_objects:
|
||||
if related_object.related_model == old_model:
|
||||
model = new_model
|
||||
related_key = (app_label, self.new_name_lower)
|
||||
else:
|
||||
model = related_object.related_model
|
||||
related_key = (
|
||||
related_object.related_model._meta.app_label,
|
||||
related_object.related_model._meta.model_name,
|
||||
)
|
||||
to_field = to_state.apps.get_model(
|
||||
*related_key
|
||||
)._meta.get_field(related_object.field.name)
|
||||
schema_editor.alter_field(
|
||||
model,
|
||||
related_object.field,
|
||||
to_field,
|
||||
)
|
||||
# Rename M2M fields whose name is based on this model's name.
|
||||
fields = zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many)
|
||||
for (old_field, new_field) in fields:
|
||||
# Skip self-referential fields as these are renamed above.
|
||||
if new_field.model == new_field.related_model or not new_field.remote_field.through._meta.auto_created:
|
||||
continue
|
||||
# Rename the M2M table that's based on this model's name.
|
||||
old_m2m_model = old_field.remote_field.through
|
||||
new_m2m_model = new_field.remote_field.through
|
||||
schema_editor.alter_db_table(
|
||||
new_m2m_model,
|
||||
old_m2m_model._meta.db_table,
|
||||
new_m2m_model._meta.db_table,
|
||||
)
|
||||
# Rename the column in the M2M table that's based on this
|
||||
# model's name.
|
||||
schema_editor.alter_field(
|
||||
new_m2m_model,
|
||||
old_m2m_model._meta.get_field(old_model._meta.model_name),
|
||||
new_m2m_model._meta.get_field(new_model._meta.model_name),
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower
|
||||
self.new_name, self.old_name = self.old_name, self.new_name
|
||||
|
||||
self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower
|
||||
self.new_name, self.old_name = self.old_name, self.new_name
|
||||
|
||||
def references_model(self, name, app_label):
|
||||
return (
|
||||
name.lower() == self.old_name_lower or
|
||||
name.lower() == self.new_name_lower
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Rename model %s to %s" % (self.old_name, self.new_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'rename_%s_%s' % (self.old_name_lower, self.new_name_lower)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
if (isinstance(operation, RenameModel) and
|
||||
self.new_name_lower == operation.old_name_lower):
|
||||
return [
|
||||
RenameModel(
|
||||
self.old_name,
|
||||
operation.new_name,
|
||||
),
|
||||
]
|
||||
# Skip `ModelOperation.reduce` as we want to run `references_model`
|
||||
# against self.new_name.
|
||||
return (
|
||||
super(ModelOperation, self).reduce(operation, app_label) or
|
||||
not operation.references_model(self.new_name, app_label)
|
||||
)
|
||||
|
||||
|
||||
class ModelOptionOperation(ModelOperation):
|
||||
def reduce(self, operation, app_label):
|
||||
if isinstance(operation, (self.__class__, DeleteModel)) and self.name_lower == operation.name_lower:
|
||||
return [operation]
|
||||
return super().reduce(operation, app_label)
|
||||
|
||||
|
||||
class AlterModelTable(ModelOptionOperation):
|
||||
"""Rename a model's table."""
|
||||
|
||||
def __init__(self, name, table):
|
||||
self.table = table
|
||||
super().__init__(name)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
'table': self.table,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.alter_model_options(app_label, self.name_lower, {'db_table': self.table})
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
new_model = to_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, new_model):
|
||||
old_model = from_state.apps.get_model(app_label, self.name)
|
||||
schema_editor.alter_db_table(
|
||||
new_model,
|
||||
old_model._meta.db_table,
|
||||
new_model._meta.db_table,
|
||||
)
|
||||
# Rename M2M fields whose name is based on this model's db_table
|
||||
for (old_field, new_field) in zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many):
|
||||
if new_field.remote_field.through._meta.auto_created:
|
||||
schema_editor.alter_db_table(
|
||||
new_field.remote_field.through,
|
||||
old_field.remote_field.through._meta.db_table,
|
||||
new_field.remote_field.through._meta.db_table,
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
return self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def describe(self):
|
||||
return "Rename table for %s to %s" % (
|
||||
self.name,
|
||||
self.table if self.table is not None else "(default)"
|
||||
)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'alter_%s_table' % self.name_lower
|
||||
|
||||
|
||||
class AlterTogetherOptionOperation(ModelOptionOperation):
|
||||
option_name = None
|
||||
|
||||
def __init__(self, name, option_value):
|
||||
if option_value:
|
||||
option_value = set(normalize_together(option_value))
|
||||
setattr(self, self.option_name, option_value)
|
||||
super().__init__(name)
|
||||
|
||||
@cached_property
|
||||
def option_value(self):
|
||||
return getattr(self, self.option_name)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
self.option_name: self.option_value,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.alter_model_options(
|
||||
app_label,
|
||||
self.name_lower,
|
||||
{self.option_name: self.option_value},
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
new_model = to_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, new_model):
|
||||
old_model = from_state.apps.get_model(app_label, self.name)
|
||||
alter_together = getattr(schema_editor, 'alter_%s' % self.option_name)
|
||||
alter_together(
|
||||
new_model,
|
||||
getattr(old_model._meta, self.option_name, set()),
|
||||
getattr(new_model._meta, self.option_name, set()),
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
return self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
return (
|
||||
self.references_model(model_name, app_label) and
|
||||
(
|
||||
not self.option_value or
|
||||
any((name in fields) for fields in self.option_value)
|
||||
)
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Alter %s for %s (%s constraint(s))" % (self.option_name, self.name, len(self.option_value or ''))
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'alter_%s_%s' % (self.name_lower, self.option_name)
|
||||
|
||||
|
||||
class AlterUniqueTogether(AlterTogetherOptionOperation):
|
||||
"""
|
||||
Change the value of unique_together to the target one.
|
||||
Input value of unique_together must be a set of tuples.
|
||||
"""
|
||||
option_name = 'unique_together'
|
||||
|
||||
def __init__(self, name, unique_together):
|
||||
super().__init__(name, unique_together)
|
||||
|
||||
|
||||
class AlterIndexTogether(AlterTogetherOptionOperation):
|
||||
"""
|
||||
Change the value of index_together to the target one.
|
||||
Input value of index_together must be a set of tuples.
|
||||
"""
|
||||
option_name = "index_together"
|
||||
|
||||
def __init__(self, name, index_together):
|
||||
super().__init__(name, index_together)
|
||||
|
||||
|
||||
class AlterOrderWithRespectTo(ModelOptionOperation):
|
||||
"""Represent a change with the order_with_respect_to option."""
|
||||
|
||||
option_name = 'order_with_respect_to'
|
||||
|
||||
def __init__(self, name, order_with_respect_to):
|
||||
self.order_with_respect_to = order_with_respect_to
|
||||
super().__init__(name)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
'order_with_respect_to': self.order_with_respect_to,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.alter_model_options(
|
||||
app_label,
|
||||
self.name_lower,
|
||||
{self.option_name: self.order_with_respect_to},
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.name)
|
||||
# Remove a field if we need to
|
||||
if from_model._meta.order_with_respect_to and not to_model._meta.order_with_respect_to:
|
||||
schema_editor.remove_field(from_model, from_model._meta.get_field("_order"))
|
||||
# Add a field if we need to (altering the column is untouched as
|
||||
# it's likely a rename)
|
||||
elif to_model._meta.order_with_respect_to and not from_model._meta.order_with_respect_to:
|
||||
field = to_model._meta.get_field("_order")
|
||||
if not field.has_default():
|
||||
field.default = 0
|
||||
schema_editor.add_field(
|
||||
from_model,
|
||||
field,
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
return (
|
||||
self.references_model(model_name, app_label) and
|
||||
(
|
||||
self.order_with_respect_to is None or
|
||||
name == self.order_with_respect_to
|
||||
)
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Set order_with_respect_to on %s to %s" % (self.name, self.order_with_respect_to)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'alter_%s_order_with_respect_to' % self.name_lower
|
||||
|
||||
|
||||
class AlterModelOptions(ModelOptionOperation):
|
||||
"""
|
||||
Set new model options that don't directly affect the database schema
|
||||
(like verbose_name, permissions, ordering). Python code in migrations
|
||||
may still need them.
|
||||
"""
|
||||
|
||||
# Model options we want to compare and preserve in an AlterModelOptions op
|
||||
ALTER_OPTION_KEYS = [
|
||||
"base_manager_name",
|
||||
"default_manager_name",
|
||||
"default_related_name",
|
||||
"get_latest_by",
|
||||
"managed",
|
||||
"ordering",
|
||||
"permissions",
|
||||
"default_permissions",
|
||||
"select_on_save",
|
||||
"verbose_name",
|
||||
"verbose_name_plural",
|
||||
]
|
||||
|
||||
def __init__(self, name, options):
|
||||
self.options = options
|
||||
super().__init__(name)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'name': self.name,
|
||||
'options': self.options,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.alter_model_options(
|
||||
app_label,
|
||||
self.name_lower,
|
||||
self.options,
|
||||
self.ALTER_OPTION_KEYS,
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
pass
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
pass
|
||||
|
||||
def describe(self):
|
||||
return "Change Meta options on %s" % self.name
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'alter_%s_options' % self.name_lower
|
||||
|
||||
|
||||
class AlterModelManagers(ModelOptionOperation):
|
||||
"""Alter the model's managers."""
|
||||
|
||||
serialization_expand_args = ['managers']
|
||||
|
||||
def __init__(self, name, managers):
|
||||
self.managers = managers
|
||||
super().__init__(name)
|
||||
|
||||
def deconstruct(self):
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[self.name, self.managers],
|
||||
{}
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.alter_model_managers(app_label, self.name_lower, self.managers)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
pass
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
pass
|
||||
|
||||
def describe(self):
|
||||
return "Change managers on %s" % self.name
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'alter_%s_managers' % self.name_lower
|
||||
|
||||
|
||||
class IndexOperation(Operation):
|
||||
option_name = 'indexes'
|
||||
|
||||
@cached_property
|
||||
def model_name_lower(self):
|
||||
return self.model_name.lower()
|
||||
|
||||
|
||||
class AddIndex(IndexOperation):
|
||||
"""Add an index on a model."""
|
||||
|
||||
def __init__(self, model_name, index):
|
||||
self.model_name = model_name
|
||||
if not index.name:
|
||||
raise ValueError(
|
||||
"Indexes passed to AddIndex operations require a name "
|
||||
"argument. %r doesn't have one." % index
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.add_index(app_label, self.model_name_lower, self.index)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.add_index(model, self.index)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.remove_index(model, self.index)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'index': self.index,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
if self.index.expressions:
|
||||
return 'Create index %s on %s on model %s' % (
|
||||
self.index.name,
|
||||
', '.join([str(expression) for expression in self.index.expressions]),
|
||||
self.model_name,
|
||||
)
|
||||
return 'Create index %s on field(s) %s of model %s' % (
|
||||
self.index.name,
|
||||
', '.join(self.index.fields),
|
||||
self.model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return '%s_%s' % (self.model_name_lower, self.index.name.lower())
|
||||
|
||||
|
||||
class RemoveIndex(IndexOperation):
|
||||
"""Remove an index from a model."""
|
||||
|
||||
def __init__(self, model_name, name):
|
||||
self.model_name = model_name
|
||||
self.name = name
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.remove_index(app_label, self.model_name_lower, self.name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
from_model_state = from_state.models[app_label, self.model_name_lower]
|
||||
index = from_model_state.get_index_by_name(self.name)
|
||||
schema_editor.remove_index(model, index)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
to_model_state = to_state.models[app_label, self.model_name_lower]
|
||||
index = to_model_state.get_index_by_name(self.name)
|
||||
schema_editor.add_index(model, index)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'model_name': self.model_name,
|
||||
'name': self.name,
|
||||
}
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return 'Remove index %s from %s' % (self.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'remove_%s_%s' % (self.model_name_lower, self.name.lower())
|
||||
|
||||
|
||||
class AddConstraint(IndexOperation):
|
||||
option_name = 'constraints'
|
||||
|
||||
def __init__(self, model_name, constraint):
|
||||
self.model_name = model_name
|
||||
self.constraint = constraint
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.add_constraint(app_label, self.model_name_lower, self.constraint)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.add_constraint(model, self.constraint)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
schema_editor.remove_constraint(model, self.constraint)
|
||||
|
||||
def deconstruct(self):
|
||||
return self.__class__.__name__, [], {
|
||||
'model_name': self.model_name,
|
||||
'constraint': self.constraint,
|
||||
}
|
||||
|
||||
def describe(self):
|
||||
return 'Create constraint %s on model %s' % (self.constraint.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return '%s_%s' % (self.model_name_lower, self.constraint.name.lower())
|
||||
|
||||
|
||||
class RemoveConstraint(IndexOperation):
|
||||
option_name = 'constraints'
|
||||
|
||||
def __init__(self, model_name, name):
|
||||
self.model_name = model_name
|
||||
self.name = name
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.remove_constraint(app_label, self.model_name_lower, self.name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
from_model_state = from_state.models[app_label, self.model_name_lower]
|
||||
constraint = from_model_state.get_constraint_by_name(self.name)
|
||||
schema_editor.remove_constraint(model, constraint)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, model):
|
||||
to_model_state = to_state.models[app_label, self.model_name_lower]
|
||||
constraint = to_model_state.get_constraint_by_name(self.name)
|
||||
schema_editor.add_constraint(model, constraint)
|
||||
|
||||
def deconstruct(self):
|
||||
return self.__class__.__name__, [], {
|
||||
'model_name': self.model_name,
|
||||
'name': self.name,
|
||||
}
|
||||
|
||||
def describe(self):
|
||||
return 'Remove constraint %s from model %s' % (self.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return 'remove_%s_%s' % (self.model_name_lower, self.name.lower())
|
||||
@@ -0,0 +1,203 @@
|
||||
from django.db import router
|
||||
|
||||
from .base import Operation
|
||||
|
||||
|
||||
class SeparateDatabaseAndState(Operation):
|
||||
"""
|
||||
Take two lists of operations - ones that will be used for the database,
|
||||
and ones that will be used for the state change. This allows operations
|
||||
that don't support state change to have it applied, or have operations
|
||||
that affect the state or not the database, or so on.
|
||||
"""
|
||||
|
||||
serialization_expand_args = ['database_operations', 'state_operations']
|
||||
|
||||
def __init__(self, database_operations=None, state_operations=None):
|
||||
self.database_operations = database_operations or []
|
||||
self.state_operations = state_operations or []
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {}
|
||||
if self.database_operations:
|
||||
kwargs['database_operations'] = self.database_operations
|
||||
if self.state_operations:
|
||||
kwargs['state_operations'] = self.state_operations
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
for state_operation in self.state_operations:
|
||||
state_operation.state_forwards(app_label, state)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# We calculate state separately in here since our state functions aren't useful
|
||||
for database_operation in self.database_operations:
|
||||
to_state = from_state.clone()
|
||||
database_operation.state_forwards(app_label, to_state)
|
||||
database_operation.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
from_state = to_state
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# We calculate state separately in here since our state functions aren't useful
|
||||
to_states = {}
|
||||
for dbop in self.database_operations:
|
||||
to_states[dbop] = to_state
|
||||
to_state = to_state.clone()
|
||||
dbop.state_forwards(app_label, to_state)
|
||||
# to_state now has the states of all the database_operations applied
|
||||
# which is the from_state for the backwards migration of the last
|
||||
# operation.
|
||||
for database_operation in reversed(self.database_operations):
|
||||
from_state = to_state
|
||||
to_state = to_states[database_operation]
|
||||
database_operation.database_backwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def describe(self):
|
||||
return "Custom state/database change combination"
|
||||
|
||||
|
||||
class RunSQL(Operation):
|
||||
"""
|
||||
Run some raw SQL. A reverse SQL statement may be provided.
|
||||
|
||||
Also accept a list of operations that represent the state change effected
|
||||
by this SQL change, in case it's custom column/table creation/deletion.
|
||||
"""
|
||||
noop = ''
|
||||
|
||||
def __init__(self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False):
|
||||
self.sql = sql
|
||||
self.reverse_sql = reverse_sql
|
||||
self.state_operations = state_operations or []
|
||||
self.hints = hints or {}
|
||||
self.elidable = elidable
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'sql': self.sql,
|
||||
}
|
||||
if self.reverse_sql is not None:
|
||||
kwargs['reverse_sql'] = self.reverse_sql
|
||||
if self.state_operations:
|
||||
kwargs['state_operations'] = self.state_operations
|
||||
if self.hints:
|
||||
kwargs['hints'] = self.hints
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def reversible(self):
|
||||
return self.reverse_sql is not None
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
for state_operation in self.state_operations:
|
||||
state_operation.state_forwards(app_label, state)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
|
||||
self._run_sql(schema_editor, self.sql)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if self.reverse_sql is None:
|
||||
raise NotImplementedError("You cannot reverse this operation")
|
||||
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
|
||||
self._run_sql(schema_editor, self.reverse_sql)
|
||||
|
||||
def describe(self):
|
||||
return "Raw SQL operation"
|
||||
|
||||
def _run_sql(self, schema_editor, sqls):
|
||||
if isinstance(sqls, (list, tuple)):
|
||||
for sql in sqls:
|
||||
params = None
|
||||
if isinstance(sql, (list, tuple)):
|
||||
elements = len(sql)
|
||||
if elements == 2:
|
||||
sql, params = sql
|
||||
else:
|
||||
raise ValueError("Expected a 2-tuple but got %d" % elements)
|
||||
schema_editor.execute(sql, params=params)
|
||||
elif sqls != RunSQL.noop:
|
||||
statements = schema_editor.connection.ops.prepare_sql_script(sqls)
|
||||
for statement in statements:
|
||||
schema_editor.execute(statement, params=None)
|
||||
|
||||
|
||||
class RunPython(Operation):
|
||||
"""
|
||||
Run Python code in a context suitable for doing versioned ORM operations.
|
||||
"""
|
||||
|
||||
reduces_to_sql = False
|
||||
|
||||
def __init__(self, code, reverse_code=None, atomic=None, hints=None, elidable=False):
|
||||
self.atomic = atomic
|
||||
# Forwards code
|
||||
if not callable(code):
|
||||
raise ValueError("RunPython must be supplied with a callable")
|
||||
self.code = code
|
||||
# Reverse code
|
||||
if reverse_code is None:
|
||||
self.reverse_code = None
|
||||
else:
|
||||
if not callable(reverse_code):
|
||||
raise ValueError("RunPython must be supplied with callable arguments")
|
||||
self.reverse_code = reverse_code
|
||||
self.hints = hints or {}
|
||||
self.elidable = elidable
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
'code': self.code,
|
||||
}
|
||||
if self.reverse_code is not None:
|
||||
kwargs['reverse_code'] = self.reverse_code
|
||||
if self.atomic is not None:
|
||||
kwargs['atomic'] = self.atomic
|
||||
if self.hints:
|
||||
kwargs['hints'] = self.hints
|
||||
return (
|
||||
self.__class__.__qualname__,
|
||||
[],
|
||||
kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def reversible(self):
|
||||
return self.reverse_code is not None
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
# RunPython objects have no state effect. To add some, combine this
|
||||
# with SeparateDatabaseAndState.
|
||||
pass
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# RunPython has access to all models. Ensure that all models are
|
||||
# reloaded in case any are delayed.
|
||||
from_state.clear_delayed_apps_cache()
|
||||
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
|
||||
# We now execute the Python code in a context that contains a 'models'
|
||||
# object, representing the versioned models as an app registry.
|
||||
# We could try to override the global cache, but then people will still
|
||||
# use direct imports, so we go with a documentation approach instead.
|
||||
self.code(from_state.apps, schema_editor)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if self.reverse_code is None:
|
||||
raise NotImplementedError("You cannot reverse this operation")
|
||||
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
|
||||
self.reverse_code(from_state.apps, schema_editor)
|
||||
|
||||
def describe(self):
|
||||
return "Raw Python operation"
|
||||
|
||||
@staticmethod
|
||||
def noop(apps, schema_editor):
|
||||
return None
|
||||
69
venv/Lib/site-packages/django/db/migrations/optimizer.py
Normal file
69
venv/Lib/site-packages/django/db/migrations/optimizer.py
Normal file
@@ -0,0 +1,69 @@
|
||||
class MigrationOptimizer:
|
||||
"""
|
||||
Power the optimization process, where you provide a list of Operations
|
||||
and you are returned a list of equal or shorter length - operations
|
||||
are merged into one if possible.
|
||||
|
||||
For example, a CreateModel and an AddField can be optimized into a
|
||||
new CreateModel, and CreateModel and DeleteModel can be optimized into
|
||||
nothing.
|
||||
"""
|
||||
|
||||
def optimize(self, operations, app_label):
|
||||
"""
|
||||
Main optimization entry point. Pass in a list of Operation instances,
|
||||
get out a new list of Operation instances.
|
||||
|
||||
Unfortunately, due to the scope of the optimization (two combinable
|
||||
operations might be separated by several hundred others), this can't be
|
||||
done as a peephole optimization with checks/output implemented on
|
||||
the Operations themselves; instead, the optimizer looks at each
|
||||
individual operation and scans forwards in the list to see if there
|
||||
are any matches, stopping at boundaries - operations which can't
|
||||
be optimized over (RunSQL, operations on the same field/model, etc.)
|
||||
|
||||
The inner loop is run until the starting list is the same as the result
|
||||
list, and then the result is returned. This means that operation
|
||||
optimization must be stable and always return an equal or shorter list.
|
||||
"""
|
||||
# Internal tracking variable for test assertions about # of loops
|
||||
if app_label is None:
|
||||
raise TypeError('app_label must be a str.')
|
||||
self._iterations = 0
|
||||
while True:
|
||||
result = self.optimize_inner(operations, app_label)
|
||||
self._iterations += 1
|
||||
if result == operations:
|
||||
return result
|
||||
operations = result
|
||||
|
||||
def optimize_inner(self, operations, app_label):
|
||||
"""Inner optimization loop."""
|
||||
new_operations = []
|
||||
for i, operation in enumerate(operations):
|
||||
right = True # Should we reduce on the right or on the left.
|
||||
# Compare it to each operation after it
|
||||
for j, other in enumerate(operations[i + 1:]):
|
||||
result = operation.reduce(other, app_label)
|
||||
if isinstance(result, list):
|
||||
in_between = operations[i + 1:i + j + 1]
|
||||
if right:
|
||||
new_operations.extend(in_between)
|
||||
new_operations.extend(result)
|
||||
elif all(op.reduce(other, app_label) is True for op in in_between):
|
||||
# Perform a left reduction if all of the in-between
|
||||
# operations can optimize through other.
|
||||
new_operations.extend(result)
|
||||
new_operations.extend(in_between)
|
||||
else:
|
||||
# Otherwise keep trying.
|
||||
new_operations.append(operation)
|
||||
break
|
||||
new_operations.extend(operations[i + j + 2:])
|
||||
return new_operations
|
||||
elif not result:
|
||||
# Can't perform a right reduction.
|
||||
right = False
|
||||
else:
|
||||
new_operations.append(operation)
|
||||
return new_operations
|
||||
245
venv/Lib/site-packages/django/db/migrations/questioner.py
Normal file
245
venv/Lib/site-packages/django/db/migrations/questioner.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import datetime
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
from django.apps import apps
|
||||
from django.db.models import NOT_PROVIDED
|
||||
from django.utils import timezone
|
||||
|
||||
from .loader import MigrationLoader
|
||||
|
||||
|
||||
class MigrationQuestioner:
|
||||
"""
|
||||
Give the autodetector responses to questions it might have.
|
||||
This base class has a built-in noninteractive mode, but the
|
||||
interactive subclass is what the command-line arguments will use.
|
||||
"""
|
||||
|
||||
def __init__(self, defaults=None, specified_apps=None, dry_run=None):
|
||||
self.defaults = defaults or {}
|
||||
self.specified_apps = specified_apps or set()
|
||||
self.dry_run = dry_run
|
||||
|
||||
def ask_initial(self, app_label):
|
||||
"""Should we create an initial migration for the app?"""
|
||||
# If it was specified on the command line, definitely true
|
||||
if app_label in self.specified_apps:
|
||||
return True
|
||||
# Otherwise, we look to see if it has a migrations module
|
||||
# without any Python files in it, apart from __init__.py.
|
||||
# Apps from the new app template will have these; the Python
|
||||
# file check will ensure we skip South ones.
|
||||
try:
|
||||
app_config = apps.get_app_config(app_label)
|
||||
except LookupError: # It's a fake app.
|
||||
return self.defaults.get("ask_initial", False)
|
||||
migrations_import_path, _ = MigrationLoader.migrations_module(app_config.label)
|
||||
if migrations_import_path is None:
|
||||
# It's an application with migrations disabled.
|
||||
return self.defaults.get("ask_initial", False)
|
||||
try:
|
||||
migrations_module = importlib.import_module(migrations_import_path)
|
||||
except ImportError:
|
||||
return self.defaults.get("ask_initial", False)
|
||||
else:
|
||||
if getattr(migrations_module, "__file__", None):
|
||||
filenames = os.listdir(os.path.dirname(migrations_module.__file__))
|
||||
elif hasattr(migrations_module, "__path__"):
|
||||
if len(migrations_module.__path__) > 1:
|
||||
return False
|
||||
filenames = os.listdir(list(migrations_module.__path__)[0])
|
||||
return not any(x.endswith(".py") for x in filenames if x != "__init__.py")
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
"""Adding a NOT NULL field to a model."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
"""Changing a NULL field to NOT NULL."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
def ask_rename(self, model_name, old_name, new_name, field_instance):
|
||||
"""Was this field really renamed?"""
|
||||
return self.defaults.get("ask_rename", False)
|
||||
|
||||
def ask_rename_model(self, old_model_state, new_model_state):
|
||||
"""Was this model really renamed?"""
|
||||
return self.defaults.get("ask_rename_model", False)
|
||||
|
||||
def ask_merge(self, app_label):
|
||||
"""Should these migrations really be merged?"""
|
||||
return self.defaults.get("ask_merge", False)
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
"""Adding an auto_now_add field to a model."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
|
||||
class InteractiveMigrationQuestioner(MigrationQuestioner):
|
||||
|
||||
def _boolean_input(self, question, default=None):
|
||||
result = input("%s " % question)
|
||||
if not result and default is not None:
|
||||
return default
|
||||
while not result or result[0].lower() not in "yn":
|
||||
result = input("Please answer yes or no: ")
|
||||
return result[0].lower() == "y"
|
||||
|
||||
def _choice_input(self, question, choices):
|
||||
print(question)
|
||||
for i, choice in enumerate(choices):
|
||||
print(" %s) %s" % (i + 1, choice))
|
||||
result = input("Select an option: ")
|
||||
while True:
|
||||
try:
|
||||
value = int(result)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
if 0 < value <= len(choices):
|
||||
return value
|
||||
result = input("Please select a valid option: ")
|
||||
|
||||
def _ask_default(self, default=''):
|
||||
"""
|
||||
Prompt for a default value.
|
||||
|
||||
The ``default`` argument allows providing a custom default value (as a
|
||||
string) which will be shown to the user and used as the return value
|
||||
if the user doesn't provide any other input.
|
||||
"""
|
||||
print('Please enter the default value as valid Python.')
|
||||
if default:
|
||||
print(
|
||||
f"Accept the default '{default}' by pressing 'Enter' or "
|
||||
f"provide another value."
|
||||
)
|
||||
print(
|
||||
'The datetime and django.utils.timezone modules are available, so '
|
||||
'it is possible to provide e.g. timezone.now as a value.'
|
||||
)
|
||||
print("Type 'exit' to exit this prompt")
|
||||
while True:
|
||||
if default:
|
||||
prompt = "[default: {}] >>> ".format(default)
|
||||
else:
|
||||
prompt = ">>> "
|
||||
code = input(prompt)
|
||||
if not code and default:
|
||||
code = default
|
||||
if not code:
|
||||
print("Please enter some code, or 'exit' (without quotes) to exit.")
|
||||
elif code == "exit":
|
||||
sys.exit(1)
|
||||
else:
|
||||
try:
|
||||
return eval(code, {}, {'datetime': datetime, 'timezone': timezone})
|
||||
except (SyntaxError, NameError) as e:
|
||||
print("Invalid input: %s" % e)
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
"""Adding a NOT NULL field to a model."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
f"It is impossible to add a non-nullable field '{field_name}' "
|
||||
f"to {model_name} without specifying a default. This is "
|
||||
f"because the database needs something to populate existing "
|
||||
f"rows.\n"
|
||||
f"Please select a fix:",
|
||||
[
|
||||
("Provide a one-off default now (will be set on all existing "
|
||||
"rows with a null value for this column)"),
|
||||
'Quit and manually define a default value in models.py.',
|
||||
]
|
||||
)
|
||||
if choice == 2:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default()
|
||||
return None
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
"""Changing a NULL field to NOT NULL."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
f"It is impossible to change a nullable field '{field_name}' "
|
||||
f"on {model_name} to non-nullable without providing a "
|
||||
f"default. This is because the database needs something to "
|
||||
f"populate existing rows.\n"
|
||||
f"Please select a fix:",
|
||||
[
|
||||
("Provide a one-off default now (will be set on all existing "
|
||||
"rows with a null value for this column)"),
|
||||
'Ignore for now. Existing rows that contain NULL values '
|
||||
'will have to be handled manually, for example with a '
|
||||
'RunPython or RunSQL operation.',
|
||||
'Quit and manually define a default value in models.py.',
|
||||
]
|
||||
)
|
||||
if choice == 2:
|
||||
return NOT_PROVIDED
|
||||
elif choice == 3:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default()
|
||||
return None
|
||||
|
||||
def ask_rename(self, model_name, old_name, new_name, field_instance):
|
||||
"""Was this field really renamed?"""
|
||||
msg = 'Was %s.%s renamed to %s.%s (a %s)? [y/N]'
|
||||
return self._boolean_input(msg % (model_name, old_name, model_name, new_name,
|
||||
field_instance.__class__.__name__), False)
|
||||
|
||||
def ask_rename_model(self, old_model_state, new_model_state):
|
||||
"""Was this model really renamed?"""
|
||||
msg = 'Was the model %s.%s renamed to %s? [y/N]'
|
||||
return self._boolean_input(msg % (old_model_state.app_label, old_model_state.name,
|
||||
new_model_state.name), False)
|
||||
|
||||
def ask_merge(self, app_label):
|
||||
return self._boolean_input(
|
||||
"\nMerging will only work if the operations printed above do not conflict\n" +
|
||||
"with each other (working on different fields or models)\n" +
|
||||
'Should these migration branches be merged? [y/N]',
|
||||
False,
|
||||
)
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
"""Adding an auto_now_add field to a model."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
f"It is impossible to add the field '{field_name}' with "
|
||||
f"'auto_now_add=True' to {model_name} without providing a "
|
||||
f"default. This is because the database needs something to "
|
||||
f"populate existing rows.\n",
|
||||
[
|
||||
'Provide a one-off default now which will be set on all '
|
||||
'existing rows',
|
||||
'Quit and manually define a default value in models.py.',
|
||||
]
|
||||
)
|
||||
if choice == 2:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default(default='timezone.now')
|
||||
return None
|
||||
|
||||
|
||||
class NonInteractiveMigrationQuestioner(MigrationQuestioner):
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
# We can't ask the user, so act like the user aborted.
|
||||
sys.exit(3)
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
# We can't ask the user, so set as not provided.
|
||||
return NOT_PROVIDED
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
# We can't ask the user, so act like the user aborted.
|
||||
sys.exit(3)
|
||||
96
venv/Lib/site-packages/django/db/migrations/recorder.py
Normal file
96
venv/Lib/site-packages/django/db/migrations/recorder.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from django.apps.registry import Apps
|
||||
from django.db import DatabaseError, models
|
||||
from django.utils.functional import classproperty
|
||||
from django.utils.timezone import now
|
||||
|
||||
from .exceptions import MigrationSchemaMissing
|
||||
|
||||
|
||||
class MigrationRecorder:
|
||||
"""
|
||||
Deal with storing migration records in the database.
|
||||
|
||||
Because this table is actually itself used for dealing with model
|
||||
creation, it's the one thing we can't do normally via migrations.
|
||||
We manually handle table creation/schema updating (using schema backend)
|
||||
and then have a floating model to do queries with.
|
||||
|
||||
If a migration is unapplied its row is removed from the table. Having
|
||||
a row in the table always means a migration is applied.
|
||||
"""
|
||||
_migration_class = None
|
||||
|
||||
@classproperty
|
||||
def Migration(cls):
|
||||
"""
|
||||
Lazy load to avoid AppRegistryNotReady if installed apps import
|
||||
MigrationRecorder.
|
||||
"""
|
||||
if cls._migration_class is None:
|
||||
class Migration(models.Model):
|
||||
app = models.CharField(max_length=255)
|
||||
name = models.CharField(max_length=255)
|
||||
applied = models.DateTimeField(default=now)
|
||||
|
||||
class Meta:
|
||||
apps = Apps()
|
||||
app_label = 'migrations'
|
||||
db_table = 'django_migrations'
|
||||
|
||||
def __str__(self):
|
||||
return 'Migration %s for %s' % (self.name, self.app)
|
||||
|
||||
cls._migration_class = Migration
|
||||
return cls._migration_class
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
@property
|
||||
def migration_qs(self):
|
||||
return self.Migration.objects.using(self.connection.alias)
|
||||
|
||||
def has_table(self):
|
||||
"""Return True if the django_migrations table exists."""
|
||||
with self.connection.cursor() as cursor:
|
||||
tables = self.connection.introspection.table_names(cursor)
|
||||
return self.Migration._meta.db_table in tables
|
||||
|
||||
def ensure_schema(self):
|
||||
"""Ensure the table exists and has the correct schema."""
|
||||
# If the table's there, that's fine - we've never changed its schema
|
||||
# in the codebase.
|
||||
if self.has_table():
|
||||
return
|
||||
# Make the table
|
||||
try:
|
||||
with self.connection.schema_editor() as editor:
|
||||
editor.create_model(self.Migration)
|
||||
except DatabaseError as exc:
|
||||
raise MigrationSchemaMissing("Unable to create the django_migrations table (%s)" % exc)
|
||||
|
||||
def applied_migrations(self):
|
||||
"""
|
||||
Return a dict mapping (app_name, migration_name) to Migration instances
|
||||
for all applied migrations.
|
||||
"""
|
||||
if self.has_table():
|
||||
return {(migration.app, migration.name): migration for migration in self.migration_qs}
|
||||
else:
|
||||
# If the django_migrations table doesn't exist, then no migrations
|
||||
# are applied.
|
||||
return {}
|
||||
|
||||
def record_applied(self, app, name):
|
||||
"""Record that a migration was applied."""
|
||||
self.ensure_schema()
|
||||
self.migration_qs.create(app=app, name=name)
|
||||
|
||||
def record_unapplied(self, app, name):
|
||||
"""Record that a migration was unapplied."""
|
||||
self.ensure_schema()
|
||||
self.migration_qs.filter(app=app, name=name).delete()
|
||||
|
||||
def flush(self):
|
||||
"""Delete all migration records. Useful for testing migrations."""
|
||||
self.migration_qs.all().delete()
|
||||
357
venv/Lib/site-packages/django/db/migrations/serializer.py
Normal file
357
venv/Lib/site-packages/django/db/migrations/serializer.py
Normal file
@@ -0,0 +1,357 @@
|
||||
import builtins
|
||||
import collections.abc
|
||||
import datetime
|
||||
import decimal
|
||||
import enum
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import types
|
||||
import uuid
|
||||
|
||||
from django.conf import SettingsReference
|
||||
from django.db import models
|
||||
from django.db.migrations.operations.base import Operation
|
||||
from django.db.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
|
||||
from django.utils.functional import LazyObject, Promise
|
||||
from django.utils.timezone import utc
|
||||
from django.utils.version import get_docs_version
|
||||
|
||||
|
||||
class BaseSerializer:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def serialize(self):
|
||||
raise NotImplementedError('Subclasses of BaseSerializer must implement the serialize() method.')
|
||||
|
||||
|
||||
class BaseSequenceSerializer(BaseSerializer):
|
||||
def _format(self):
|
||||
raise NotImplementedError('Subclasses of BaseSequenceSerializer must implement the _format() method.')
|
||||
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for item in self.value:
|
||||
item_string, item_imports = serializer_factory(item).serialize()
|
||||
imports.update(item_imports)
|
||||
strings.append(item_string)
|
||||
value = self._format()
|
||||
return value % (", ".join(strings)), imports
|
||||
|
||||
|
||||
class BaseSimpleSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return repr(self.value), set()
|
||||
|
||||
|
||||
class ChoicesSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return serializer_factory(self.value.value).serialize()
|
||||
|
||||
|
||||
class DateTimeSerializer(BaseSerializer):
|
||||
"""For datetime.*, except datetime.datetime."""
|
||||
def serialize(self):
|
||||
return repr(self.value), {'import datetime'}
|
||||
|
||||
|
||||
class DatetimeDatetimeSerializer(BaseSerializer):
|
||||
"""For datetime.datetime."""
|
||||
def serialize(self):
|
||||
if self.value.tzinfo is not None and self.value.tzinfo != utc:
|
||||
self.value = self.value.astimezone(utc)
|
||||
imports = ["import datetime"]
|
||||
if self.value.tzinfo is not None:
|
||||
imports.append("from django.utils.timezone import utc")
|
||||
return repr(self.value).replace('datetime.timezone.utc', 'utc'), set(imports)
|
||||
|
||||
|
||||
class DecimalSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return repr(self.value), {"from decimal import Decimal"}
|
||||
|
||||
|
||||
class DeconstructableSerializer(BaseSerializer):
|
||||
@staticmethod
|
||||
def serialize_deconstructed(path, args, kwargs):
|
||||
name, imports = DeconstructableSerializer._serialize_path(path)
|
||||
strings = []
|
||||
for arg in args:
|
||||
arg_string, arg_imports = serializer_factory(arg).serialize()
|
||||
strings.append(arg_string)
|
||||
imports.update(arg_imports)
|
||||
for kw, arg in sorted(kwargs.items()):
|
||||
arg_string, arg_imports = serializer_factory(arg).serialize()
|
||||
imports.update(arg_imports)
|
||||
strings.append("%s=%s" % (kw, arg_string))
|
||||
return "%s(%s)" % (name, ", ".join(strings)), imports
|
||||
|
||||
@staticmethod
|
||||
def _serialize_path(path):
|
||||
module, name = path.rsplit(".", 1)
|
||||
if module == "django.db.models":
|
||||
imports = {"from django.db import models"}
|
||||
name = "models.%s" % name
|
||||
else:
|
||||
imports = {"import %s" % module}
|
||||
name = path
|
||||
return name, imports
|
||||
|
||||
def serialize(self):
|
||||
return self.serialize_deconstructed(*self.value.deconstruct())
|
||||
|
||||
|
||||
class DictionarySerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for k, v in sorted(self.value.items()):
|
||||
k_string, k_imports = serializer_factory(k).serialize()
|
||||
v_string, v_imports = serializer_factory(v).serialize()
|
||||
imports.update(k_imports)
|
||||
imports.update(v_imports)
|
||||
strings.append((k_string, v_string))
|
||||
return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports
|
||||
|
||||
|
||||
class EnumSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
enum_class = self.value.__class__
|
||||
module = enum_class.__module__
|
||||
return (
|
||||
'%s.%s[%r]' % (module, enum_class.__qualname__, self.value.name),
|
||||
{'import %s' % module},
|
||||
)
|
||||
|
||||
|
||||
class FloatSerializer(BaseSimpleSerializer):
|
||||
def serialize(self):
|
||||
if math.isnan(self.value) or math.isinf(self.value):
|
||||
return 'float("{}")'.format(self.value), set()
|
||||
return super().serialize()
|
||||
|
||||
|
||||
class FrozensetSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
return "frozenset([%s])"
|
||||
|
||||
|
||||
class FunctionTypeSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
if getattr(self.value, "__self__", None) and isinstance(self.value.__self__, type):
|
||||
klass = self.value.__self__
|
||||
module = klass.__module__
|
||||
return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), {"import %s" % module}
|
||||
# Further error checking
|
||||
if self.value.__name__ == '<lambda>':
|
||||
raise ValueError("Cannot serialize function: lambda")
|
||||
if self.value.__module__ is None:
|
||||
raise ValueError("Cannot serialize function %r: No module" % self.value)
|
||||
|
||||
module_name = self.value.__module__
|
||||
|
||||
if '<' not in self.value.__qualname__: # Qualname can include <locals>
|
||||
return '%s.%s' % (module_name, self.value.__qualname__), {'import %s' % self.value.__module__}
|
||||
|
||||
raise ValueError(
|
||||
'Could not find function %s in %s.\n' % (self.value.__name__, module_name)
|
||||
)
|
||||
|
||||
|
||||
class FunctoolsPartialSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
# Serialize functools.partial() arguments
|
||||
func_string, func_imports = serializer_factory(self.value.func).serialize()
|
||||
args_string, args_imports = serializer_factory(self.value.args).serialize()
|
||||
keywords_string, keywords_imports = serializer_factory(self.value.keywords).serialize()
|
||||
# Add any imports needed by arguments
|
||||
imports = {'import functools', *func_imports, *args_imports, *keywords_imports}
|
||||
return (
|
||||
'functools.%s(%s, *%s, **%s)' % (
|
||||
self.value.__class__.__name__,
|
||||
func_string,
|
||||
args_string,
|
||||
keywords_string,
|
||||
),
|
||||
imports,
|
||||
)
|
||||
|
||||
|
||||
class IterableSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for item in self.value:
|
||||
item_string, item_imports = serializer_factory(item).serialize()
|
||||
imports.update(item_imports)
|
||||
strings.append(item_string)
|
||||
# When len(strings)==0, the empty iterable should be serialized as
|
||||
# "()", not "(,)" because (,) is invalid Python syntax.
|
||||
value = "(%s)" if len(strings) != 1 else "(%s,)"
|
||||
return value % (", ".join(strings)), imports
|
||||
|
||||
|
||||
class ModelFieldSerializer(DeconstructableSerializer):
|
||||
def serialize(self):
|
||||
attr_name, path, args, kwargs = self.value.deconstruct()
|
||||
return self.serialize_deconstructed(path, args, kwargs)
|
||||
|
||||
|
||||
class ModelManagerSerializer(DeconstructableSerializer):
|
||||
def serialize(self):
|
||||
as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct()
|
||||
if as_manager:
|
||||
name, imports = self._serialize_path(qs_path)
|
||||
return "%s.as_manager()" % name, imports
|
||||
else:
|
||||
return self.serialize_deconstructed(manager_path, args, kwargs)
|
||||
|
||||
|
||||
class OperationSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
from django.db.migrations.writer import OperationWriter
|
||||
string, imports = OperationWriter(self.value, indentation=0).serialize()
|
||||
# Nested operation, trailing comma is handled in upper OperationWriter._write()
|
||||
return string.rstrip(','), imports
|
||||
|
||||
|
||||
class PathLikeSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return repr(os.fspath(self.value)), {}
|
||||
|
||||
|
||||
class PathSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
# Convert concrete paths to pure paths to avoid issues with migrations
|
||||
# generated on one platform being used on a different platform.
|
||||
prefix = 'Pure' if isinstance(self.value, pathlib.Path) else ''
|
||||
return 'pathlib.%s%r' % (prefix, self.value), {'import pathlib'}
|
||||
|
||||
|
||||
class RegexSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
regex_pattern, pattern_imports = serializer_factory(self.value.pattern).serialize()
|
||||
# Turn off default implicit flags (e.g. re.U) because regexes with the
|
||||
# same implicit and explicit flags aren't equal.
|
||||
flags = self.value.flags ^ re.compile('').flags
|
||||
regex_flags, flag_imports = serializer_factory(flags).serialize()
|
||||
imports = {'import re', *pattern_imports, *flag_imports}
|
||||
args = [regex_pattern]
|
||||
if flags:
|
||||
args.append(regex_flags)
|
||||
return "re.compile(%s)" % ', '.join(args), imports
|
||||
|
||||
|
||||
class SequenceSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
return "[%s]"
|
||||
|
||||
|
||||
class SetSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
# Serialize as a set literal except when value is empty because {}
|
||||
# is an empty dict.
|
||||
return '{%s}' if self.value else 'set(%s)'
|
||||
|
||||
|
||||
class SettingsReferenceSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return "settings.%s" % self.value.setting_name, {"from django.conf import settings"}
|
||||
|
||||
|
||||
class TupleSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
# When len(value)==0, the empty tuple should be serialized as "()",
|
||||
# not "(,)" because (,) is invalid Python syntax.
|
||||
return "(%s)" if len(self.value) != 1 else "(%s,)"
|
||||
|
||||
|
||||
class TypeSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
special_cases = [
|
||||
(models.Model, "models.Model", ['from django.db import models']),
|
||||
(type(None), 'type(None)', []),
|
||||
]
|
||||
for case, string, imports in special_cases:
|
||||
if case is self.value:
|
||||
return string, set(imports)
|
||||
if hasattr(self.value, "__module__"):
|
||||
module = self.value.__module__
|
||||
if module == builtins.__name__:
|
||||
return self.value.__name__, set()
|
||||
else:
|
||||
return "%s.%s" % (module, self.value.__qualname__), {"import %s" % module}
|
||||
|
||||
|
||||
class UUIDSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return "uuid.%s" % repr(self.value), {"import uuid"}
|
||||
|
||||
|
||||
class Serializer:
|
||||
_registry = {
|
||||
# Some of these are order-dependent.
|
||||
frozenset: FrozensetSerializer,
|
||||
list: SequenceSerializer,
|
||||
set: SetSerializer,
|
||||
tuple: TupleSerializer,
|
||||
dict: DictionarySerializer,
|
||||
models.Choices: ChoicesSerializer,
|
||||
enum.Enum: EnumSerializer,
|
||||
datetime.datetime: DatetimeDatetimeSerializer,
|
||||
(datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
|
||||
SettingsReference: SettingsReferenceSerializer,
|
||||
float: FloatSerializer,
|
||||
(bool, int, type(None), bytes, str, range): BaseSimpleSerializer,
|
||||
decimal.Decimal: DecimalSerializer,
|
||||
(functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
|
||||
(types.FunctionType, types.BuiltinFunctionType, types.MethodType): FunctionTypeSerializer,
|
||||
collections.abc.Iterable: IterableSerializer,
|
||||
(COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
|
||||
uuid.UUID: UUIDSerializer,
|
||||
pathlib.PurePath: PathSerializer,
|
||||
os.PathLike: PathLikeSerializer,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register(cls, type_, serializer):
|
||||
if not issubclass(serializer, BaseSerializer):
|
||||
raise ValueError("'%s' must inherit from 'BaseSerializer'." % serializer.__name__)
|
||||
cls._registry[type_] = serializer
|
||||
|
||||
@classmethod
|
||||
def unregister(cls, type_):
|
||||
cls._registry.pop(type_)
|
||||
|
||||
|
||||
def serializer_factory(value):
|
||||
if isinstance(value, Promise):
|
||||
value = str(value)
|
||||
elif isinstance(value, LazyObject):
|
||||
# The unwrapped value is returned as the first item of the arguments
|
||||
# tuple.
|
||||
value = value.__reduce__()[1][0]
|
||||
|
||||
if isinstance(value, models.Field):
|
||||
return ModelFieldSerializer(value)
|
||||
if isinstance(value, models.manager.BaseManager):
|
||||
return ModelManagerSerializer(value)
|
||||
if isinstance(value, Operation):
|
||||
return OperationSerializer(value)
|
||||
if isinstance(value, type):
|
||||
return TypeSerializer(value)
|
||||
# Anything that knows how to deconstruct itself.
|
||||
if hasattr(value, 'deconstruct'):
|
||||
return DeconstructableSerializer(value)
|
||||
for type_, serializer_cls in Serializer._registry.items():
|
||||
if isinstance(value, type_):
|
||||
return serializer_cls(value)
|
||||
raise ValueError(
|
||||
"Cannot serialize: %r\nThere are some values Django cannot serialize into "
|
||||
"migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
|
||||
"topics/migrations/#migration-serializing" % (value, get_docs_version())
|
||||
)
|
||||
904
venv/Lib/site-packages/django/db/migrations/state.py
Normal file
904
venv/Lib/site-packages/django/db/migrations/state.py
Normal file
@@ -0,0 +1,904 @@
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
from django.apps import AppConfig
|
||||
from django.apps.registry import Apps, apps as global_apps
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.db import models
|
||||
from django.db.migrations.utils import field_is_referenced, get_references
|
||||
from django.db.models import NOT_PROVIDED
|
||||
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
|
||||
from django.db.models.options import DEFAULT_NAMES, normalize_together
|
||||
from django.db.models.utils import make_model_tuple
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.module_loading import import_string
|
||||
from django.utils.version import get_docs_version
|
||||
|
||||
from .exceptions import InvalidBasesError
|
||||
from .utils import resolve_relation
|
||||
|
||||
|
||||
def _get_app_label_and_model_name(model, app_label=''):
|
||||
if isinstance(model, str):
|
||||
split = model.split('.', 1)
|
||||
return tuple(split) if len(split) == 2 else (app_label, split[0])
|
||||
else:
|
||||
return model._meta.app_label, model._meta.model_name
|
||||
|
||||
|
||||
def _get_related_models(m):
|
||||
"""Return all models that have a direct relationship to the given model."""
|
||||
related_models = [
|
||||
subclass for subclass in m.__subclasses__()
|
||||
if issubclass(subclass, models.Model)
|
||||
]
|
||||
related_fields_models = set()
|
||||
for f in m._meta.get_fields(include_parents=True, include_hidden=True):
|
||||
if f.is_relation and f.related_model is not None and not isinstance(f.related_model, str):
|
||||
related_fields_models.add(f.model)
|
||||
related_models.append(f.related_model)
|
||||
# Reverse accessors of foreign keys to proxy models are attached to their
|
||||
# concrete proxied model.
|
||||
opts = m._meta
|
||||
if opts.proxy and m in related_fields_models:
|
||||
related_models.append(opts.concrete_model)
|
||||
return related_models
|
||||
|
||||
|
||||
def get_related_models_tuples(model):
|
||||
"""
|
||||
Return a list of typical (app_label, model_name) tuples for all related
|
||||
models for the given model.
|
||||
"""
|
||||
return {
|
||||
(rel_mod._meta.app_label, rel_mod._meta.model_name)
|
||||
for rel_mod in _get_related_models(model)
|
||||
}
|
||||
|
||||
|
||||
def get_related_models_recursive(model):
|
||||
"""
|
||||
Return all models that have a direct or indirect relationship
|
||||
to the given model.
|
||||
|
||||
Relationships are either defined by explicit relational fields, like
|
||||
ForeignKey, ManyToManyField or OneToOneField, or by inheriting from another
|
||||
model (a superclass is related to its subclasses, but not vice versa). Note,
|
||||
however, that a model inheriting from a concrete model is also related to
|
||||
its superclass through the implicit *_ptr OneToOneField on the subclass.
|
||||
"""
|
||||
seen = set()
|
||||
queue = _get_related_models(model)
|
||||
for rel_mod in queue:
|
||||
rel_app_label, rel_model_name = rel_mod._meta.app_label, rel_mod._meta.model_name
|
||||
if (rel_app_label, rel_model_name) in seen:
|
||||
continue
|
||||
seen.add((rel_app_label, rel_model_name))
|
||||
queue.extend(_get_related_models(rel_mod))
|
||||
return seen - {(model._meta.app_label, model._meta.model_name)}
|
||||
|
||||
|
||||
class ProjectState:
|
||||
"""
|
||||
Represent the entire project's overall state. This is the item that is
|
||||
passed around - do it here rather than at the app level so that cross-app
|
||||
FKs/etc. resolve properly.
|
||||
"""
|
||||
|
||||
def __init__(self, models=None, real_apps=None):
|
||||
self.models = models or {}
|
||||
# Apps to include from main registry, usually unmigrated ones
|
||||
if real_apps is None:
|
||||
real_apps = set()
|
||||
else:
|
||||
assert isinstance(real_apps, set)
|
||||
self.real_apps = real_apps
|
||||
self.is_delayed = False
|
||||
# {remote_model_key: {model_key: {field_name: field}}}
|
||||
self._relations = None
|
||||
|
||||
@property
|
||||
def relations(self):
|
||||
if self._relations is None:
|
||||
self.resolve_fields_and_relations()
|
||||
return self._relations
|
||||
|
||||
def add_model(self, model_state):
|
||||
model_key = model_state.app_label, model_state.name_lower
|
||||
self.models[model_key] = model_state
|
||||
if self._relations is not None:
|
||||
self.resolve_model_relations(model_key)
|
||||
if 'apps' in self.__dict__: # hasattr would cache the property
|
||||
self.reload_model(*model_key)
|
||||
|
||||
def remove_model(self, app_label, model_name):
|
||||
model_key = app_label, model_name
|
||||
del self.models[model_key]
|
||||
if self._relations is not None:
|
||||
self._relations.pop(model_key, None)
|
||||
# Call list() since _relations can change size during iteration.
|
||||
for related_model_key, model_relations in list(self._relations.items()):
|
||||
model_relations.pop(model_key, None)
|
||||
if not model_relations:
|
||||
del self._relations[related_model_key]
|
||||
if 'apps' in self.__dict__: # hasattr would cache the property
|
||||
self.apps.unregister_model(*model_key)
|
||||
# Need to do this explicitly since unregister_model() doesn't clear
|
||||
# the cache automatically (#24513)
|
||||
self.apps.clear_cache()
|
||||
|
||||
def rename_model(self, app_label, old_name, new_name):
|
||||
# Add a new model.
|
||||
old_name_lower = old_name.lower()
|
||||
new_name_lower = new_name.lower()
|
||||
renamed_model = self.models[app_label, old_name_lower].clone()
|
||||
renamed_model.name = new_name
|
||||
self.models[app_label, new_name_lower] = renamed_model
|
||||
# Repoint all fields pointing to the old model to the new one.
|
||||
old_model_tuple = (app_label, old_name_lower)
|
||||
new_remote_model = f'{app_label}.{new_name}'
|
||||
to_reload = set()
|
||||
for model_state, name, field, reference in get_references(self, old_model_tuple):
|
||||
changed_field = None
|
||||
if reference.to:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.model = new_remote_model
|
||||
if reference.through:
|
||||
if changed_field is None:
|
||||
changed_field = field.clone()
|
||||
changed_field.remote_field.through = new_remote_model
|
||||
if changed_field:
|
||||
model_state.fields[name] = changed_field
|
||||
to_reload.add((model_state.app_label, model_state.name_lower))
|
||||
if self._relations is not None:
|
||||
old_name_key = app_label, old_name_lower
|
||||
new_name_key = app_label, new_name_lower
|
||||
if old_name_key in self._relations:
|
||||
self._relations[new_name_key] = self._relations.pop(old_name_key)
|
||||
for model_relations in self._relations.values():
|
||||
if old_name_key in model_relations:
|
||||
model_relations[new_name_key] = model_relations.pop(old_name_key)
|
||||
# Reload models related to old model before removing the old model.
|
||||
self.reload_models(to_reload, delay=True)
|
||||
# Remove the old model.
|
||||
self.remove_model(app_label, old_name_lower)
|
||||
self.reload_model(app_label, new_name_lower, delay=True)
|
||||
|
||||
def alter_model_options(self, app_label, model_name, options, option_keys=None):
|
||||
model_state = self.models[app_label, model_name]
|
||||
model_state.options = {**model_state.options, **options}
|
||||
if option_keys:
|
||||
for key in option_keys:
|
||||
if key not in options:
|
||||
model_state.options.pop(key, False)
|
||||
self.reload_model(app_label, model_name, delay=True)
|
||||
|
||||
def alter_model_managers(self, app_label, model_name, managers):
|
||||
model_state = self.models[app_label, model_name]
|
||||
model_state.managers = list(managers)
|
||||
self.reload_model(app_label, model_name, delay=True)
|
||||
|
||||
def _append_option(self, app_label, model_name, option_name, obj):
|
||||
model_state = self.models[app_label, model_name]
|
||||
model_state.options[option_name] = [*model_state.options[option_name], obj]
|
||||
self.reload_model(app_label, model_name, delay=True)
|
||||
|
||||
def _remove_option(self, app_label, model_name, option_name, obj_name):
|
||||
model_state = self.models[app_label, model_name]
|
||||
objs = model_state.options[option_name]
|
||||
model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
|
||||
self.reload_model(app_label, model_name, delay=True)
|
||||
|
||||
def add_index(self, app_label, model_name, index):
|
||||
self._append_option(app_label, model_name, 'indexes', index)
|
||||
|
||||
def remove_index(self, app_label, model_name, index_name):
|
||||
self._remove_option(app_label, model_name, 'indexes', index_name)
|
||||
|
||||
def add_constraint(self, app_label, model_name, constraint):
|
||||
self._append_option(app_label, model_name, 'constraints', constraint)
|
||||
|
||||
def remove_constraint(self, app_label, model_name, constraint_name):
|
||||
self._remove_option(app_label, model_name, 'constraints', constraint_name)
|
||||
|
||||
def add_field(self, app_label, model_name, name, field, preserve_default):
|
||||
# If preserve default is off, don't use the default for future state.
|
||||
if not preserve_default:
|
||||
field = field.clone()
|
||||
field.default = NOT_PROVIDED
|
||||
else:
|
||||
field = field
|
||||
model_key = app_label, model_name
|
||||
self.models[model_key].fields[name] = field
|
||||
if self._relations is not None:
|
||||
self.resolve_model_field_relations(model_key, name, field)
|
||||
# Delay rendering of relationships if it's not a relational field.
|
||||
delay = not field.is_relation
|
||||
self.reload_model(*model_key, delay=delay)
|
||||
|
||||
def remove_field(self, app_label, model_name, name):
|
||||
model_key = app_label, model_name
|
||||
model_state = self.models[model_key]
|
||||
old_field = model_state.fields.pop(name)
|
||||
if self._relations is not None:
|
||||
self.resolve_model_field_relations(model_key, name, old_field)
|
||||
# Delay rendering of relationships if it's not a relational field.
|
||||
delay = not old_field.is_relation
|
||||
self.reload_model(*model_key, delay=delay)
|
||||
|
||||
def alter_field(self, app_label, model_name, name, field, preserve_default):
|
||||
if not preserve_default:
|
||||
field = field.clone()
|
||||
field.default = NOT_PROVIDED
|
||||
else:
|
||||
field = field
|
||||
model_key = app_label, model_name
|
||||
fields = self.models[model_key].fields
|
||||
if self._relations is not None:
|
||||
old_field = fields.pop(name)
|
||||
if old_field.is_relation:
|
||||
self.resolve_model_field_relations(model_key, name, old_field)
|
||||
fields[name] = field
|
||||
if field.is_relation:
|
||||
self.resolve_model_field_relations(model_key, name, field)
|
||||
else:
|
||||
fields[name] = field
|
||||
# TODO: investigate if old relational fields must be reloaded or if
|
||||
# it's sufficient if the new field is (#27737).
|
||||
# Delay rendering of relationships if it's not a relational field and
|
||||
# not referenced by a foreign key.
|
||||
delay = (
|
||||
not field.is_relation and
|
||||
not field_is_referenced(self, model_key, (name, field))
|
||||
)
|
||||
self.reload_model(*model_key, delay=delay)
|
||||
|
||||
def rename_field(self, app_label, model_name, old_name, new_name):
|
||||
model_key = app_label, model_name
|
||||
model_state = self.models[model_key]
|
||||
# Rename the field.
|
||||
fields = model_state.fields
|
||||
try:
|
||||
found = fields.pop(old_name)
|
||||
except KeyError:
|
||||
raise FieldDoesNotExist(
|
||||
f"{app_label}.{model_name} has no field named '{old_name}'"
|
||||
)
|
||||
fields[new_name] = found
|
||||
for field in fields.values():
|
||||
# Fix from_fields to refer to the new field.
|
||||
from_fields = getattr(field, 'from_fields', None)
|
||||
if from_fields:
|
||||
field.from_fields = tuple([
|
||||
new_name if from_field_name == old_name else from_field_name
|
||||
for from_field_name in from_fields
|
||||
])
|
||||
# Fix index/unique_together to refer to the new field.
|
||||
options = model_state.options
|
||||
for option in ('index_together', 'unique_together'):
|
||||
if option in options:
|
||||
options[option] = [
|
||||
[new_name if n == old_name else n for n in together]
|
||||
for together in options[option]
|
||||
]
|
||||
# Fix to_fields to refer to the new field.
|
||||
delay = True
|
||||
references = get_references(self, model_key, (old_name, found))
|
||||
for *_, field, reference in references:
|
||||
delay = False
|
||||
if reference.to:
|
||||
remote_field, to_fields = reference.to
|
||||
if getattr(remote_field, 'field_name', None) == old_name:
|
||||
remote_field.field_name = new_name
|
||||
if to_fields:
|
||||
field.to_fields = tuple([
|
||||
new_name if to_field_name == old_name else to_field_name
|
||||
for to_field_name in to_fields
|
||||
])
|
||||
if self._relations is not None:
|
||||
old_name_lower = old_name.lower()
|
||||
new_name_lower = new_name.lower()
|
||||
for to_model in self._relations.values():
|
||||
if old_name_lower in to_model[model_key]:
|
||||
field = to_model[model_key].pop(old_name_lower)
|
||||
field.name = new_name_lower
|
||||
to_model[model_key][new_name_lower] = field
|
||||
self.reload_model(*model_key, delay=delay)
|
||||
|
||||
def _find_reload_model(self, app_label, model_name, delay=False):
|
||||
if delay:
|
||||
self.is_delayed = True
|
||||
|
||||
related_models = set()
|
||||
|
||||
try:
|
||||
old_model = self.apps.get_model(app_label, model_name)
|
||||
except LookupError:
|
||||
pass
|
||||
else:
|
||||
# Get all relations to and from the old model before reloading,
|
||||
# as _meta.apps may change
|
||||
if delay:
|
||||
related_models = get_related_models_tuples(old_model)
|
||||
else:
|
||||
related_models = get_related_models_recursive(old_model)
|
||||
|
||||
# Get all outgoing references from the model to be rendered
|
||||
model_state = self.models[(app_label, model_name)]
|
||||
# Directly related models are the models pointed to by ForeignKeys,
|
||||
# OneToOneFields, and ManyToManyFields.
|
||||
direct_related_models = set()
|
||||
for field in model_state.fields.values():
|
||||
if field.is_relation:
|
||||
if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT:
|
||||
continue
|
||||
rel_app_label, rel_model_name = _get_app_label_and_model_name(field.related_model, app_label)
|
||||
direct_related_models.add((rel_app_label, rel_model_name.lower()))
|
||||
|
||||
# For all direct related models recursively get all related models.
|
||||
related_models.update(direct_related_models)
|
||||
for rel_app_label, rel_model_name in direct_related_models:
|
||||
try:
|
||||
rel_model = self.apps.get_model(rel_app_label, rel_model_name)
|
||||
except LookupError:
|
||||
pass
|
||||
else:
|
||||
if delay:
|
||||
related_models.update(get_related_models_tuples(rel_model))
|
||||
else:
|
||||
related_models.update(get_related_models_recursive(rel_model))
|
||||
|
||||
# Include the model itself
|
||||
related_models.add((app_label, model_name))
|
||||
|
||||
return related_models
|
||||
|
||||
def reload_model(self, app_label, model_name, delay=False):
|
||||
if 'apps' in self.__dict__: # hasattr would cache the property
|
||||
related_models = self._find_reload_model(app_label, model_name, delay)
|
||||
self._reload(related_models)
|
||||
|
||||
def reload_models(self, models, delay=True):
|
||||
if 'apps' in self.__dict__: # hasattr would cache the property
|
||||
related_models = set()
|
||||
for app_label, model_name in models:
|
||||
related_models.update(self._find_reload_model(app_label, model_name, delay))
|
||||
self._reload(related_models)
|
||||
|
||||
def _reload(self, related_models):
|
||||
# Unregister all related models
|
||||
with self.apps.bulk_update():
|
||||
for rel_app_label, rel_model_name in related_models:
|
||||
self.apps.unregister_model(rel_app_label, rel_model_name)
|
||||
|
||||
states_to_be_rendered = []
|
||||
# Gather all models states of those models that will be rerendered.
|
||||
# This includes:
|
||||
# 1. All related models of unmigrated apps
|
||||
for model_state in self.apps.real_models:
|
||||
if (model_state.app_label, model_state.name_lower) in related_models:
|
||||
states_to_be_rendered.append(model_state)
|
||||
|
||||
# 2. All related models of migrated apps
|
||||
for rel_app_label, rel_model_name in related_models:
|
||||
try:
|
||||
model_state = self.models[rel_app_label, rel_model_name]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
states_to_be_rendered.append(model_state)
|
||||
|
||||
# Render all models
|
||||
self.apps.render_multiple(states_to_be_rendered)
|
||||
|
||||
def update_model_field_relation(
|
||||
self, model, model_key, field_name, field, concretes,
|
||||
):
|
||||
remote_model_key = resolve_relation(model, *model_key)
|
||||
if remote_model_key[0] not in self.real_apps and remote_model_key in concretes:
|
||||
remote_model_key = concretes[remote_model_key]
|
||||
relations_to_remote_model = self._relations[remote_model_key]
|
||||
if field_name in self.models[model_key].fields:
|
||||
# The assert holds because it's a new relation, or an altered
|
||||
# relation, in which case references have been removed by
|
||||
# alter_field().
|
||||
assert field_name not in relations_to_remote_model[model_key]
|
||||
relations_to_remote_model[model_key][field_name] = field
|
||||
else:
|
||||
del relations_to_remote_model[model_key][field_name]
|
||||
if not relations_to_remote_model[model_key]:
|
||||
del relations_to_remote_model[model_key]
|
||||
|
||||
def resolve_model_field_relations(
|
||||
self, model_key, field_name, field, concretes=None,
|
||||
):
|
||||
remote_field = field.remote_field
|
||||
if not remote_field:
|
||||
return
|
||||
if concretes is None:
|
||||
concretes, _ = self._get_concrete_models_mapping_and_proxy_models()
|
||||
|
||||
self.update_model_field_relation(
|
||||
remote_field.model, model_key, field_name, field, concretes,
|
||||
)
|
||||
|
||||
through = getattr(remote_field, 'through', None)
|
||||
if not through:
|
||||
return
|
||||
self.update_model_field_relation(through, model_key, field_name, field, concretes)
|
||||
|
||||
def resolve_model_relations(self, model_key, concretes=None):
|
||||
if concretes is None:
|
||||
concretes, _ = self._get_concrete_models_mapping_and_proxy_models()
|
||||
|
||||
model_state = self.models[model_key]
|
||||
for field_name, field in model_state.fields.items():
|
||||
self.resolve_model_field_relations(model_key, field_name, field, concretes)
|
||||
|
||||
def resolve_fields_and_relations(self):
|
||||
# Resolve fields.
|
||||
for model_state in self.models.values():
|
||||
for field_name, field in model_state.fields.items():
|
||||
field.name = field_name
|
||||
# Resolve relations.
|
||||
# {remote_model_key: {model_key: {field_name: field}}}
|
||||
self._relations = defaultdict(partial(defaultdict, dict))
|
||||
concretes, proxies = self._get_concrete_models_mapping_and_proxy_models()
|
||||
|
||||
for model_key in concretes:
|
||||
self.resolve_model_relations(model_key, concretes)
|
||||
|
||||
for model_key in proxies:
|
||||
self._relations[model_key] = self._relations[concretes[model_key]]
|
||||
|
||||
def get_concrete_model_key(self, model):
|
||||
concrete_models_mapping, _ = self._get_concrete_models_mapping_and_proxy_models()
|
||||
model_key = make_model_tuple(model)
|
||||
return concrete_models_mapping[model_key]
|
||||
|
||||
def _get_concrete_models_mapping_and_proxy_models(self):
|
||||
concrete_models_mapping = {}
|
||||
proxy_models = {}
|
||||
# Split models to proxy and concrete models.
|
||||
for model_key, model_state in self.models.items():
|
||||
if model_state.options.get('proxy'):
|
||||
proxy_models[model_key] = model_state
|
||||
# Find a concrete model for the proxy.
|
||||
concrete_models_mapping[model_key] = self._find_concrete_model_from_proxy(
|
||||
proxy_models, model_state,
|
||||
)
|
||||
else:
|
||||
concrete_models_mapping[model_key] = model_key
|
||||
return concrete_models_mapping, proxy_models
|
||||
|
||||
def _find_concrete_model_from_proxy(self, proxy_models, model_state):
|
||||
for base in model_state.bases:
|
||||
if not (isinstance(base, str) or issubclass(base, models.Model)):
|
||||
continue
|
||||
base_key = make_model_tuple(base)
|
||||
base_state = proxy_models.get(base_key)
|
||||
if not base_state:
|
||||
# Concrete model found, stop looking at bases.
|
||||
return base_key
|
||||
return self._find_concrete_model_from_proxy(proxy_models, base_state)
|
||||
|
||||
def clone(self):
|
||||
"""Return an exact copy of this ProjectState."""
|
||||
new_state = ProjectState(
|
||||
models={k: v.clone() for k, v in self.models.items()},
|
||||
real_apps=self.real_apps,
|
||||
)
|
||||
if 'apps' in self.__dict__:
|
||||
new_state.apps = self.apps.clone()
|
||||
new_state.is_delayed = self.is_delayed
|
||||
return new_state
|
||||
|
||||
def clear_delayed_apps_cache(self):
|
||||
if self.is_delayed and 'apps' in self.__dict__:
|
||||
del self.__dict__['apps']
|
||||
|
||||
@cached_property
|
||||
def apps(self):
|
||||
return StateApps(self.real_apps, self.models)
|
||||
|
||||
@classmethod
|
||||
def from_apps(cls, apps):
|
||||
"""Take an Apps and return a ProjectState matching it."""
|
||||
app_models = {}
|
||||
for model in apps.get_models(include_swapped=True):
|
||||
model_state = ModelState.from_model(model)
|
||||
app_models[(model_state.app_label, model_state.name_lower)] = model_state
|
||||
return cls(app_models)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.models == other.models and self.real_apps == other.real_apps
|
||||
|
||||
|
||||
class AppConfigStub(AppConfig):
|
||||
"""Stub of an AppConfig. Only provides a label and a dict of models."""
|
||||
def __init__(self, label):
|
||||
self.apps = None
|
||||
self.models = {}
|
||||
# App-label and app-name are not the same thing, so technically passing
|
||||
# in the label here is wrong. In practice, migrations don't care about
|
||||
# the app name, but we need something unique, and the label works fine.
|
||||
self.label = label
|
||||
self.name = label
|
||||
|
||||
def import_models(self):
|
||||
self.models = self.apps.all_models[self.label]
|
||||
|
||||
|
||||
class StateApps(Apps):
|
||||
"""
|
||||
Subclass of the global Apps registry class to better handle dynamic model
|
||||
additions and removals.
|
||||
"""
|
||||
def __init__(self, real_apps, models, ignore_swappable=False):
|
||||
# Any apps in self.real_apps should have all their models included
|
||||
# in the render. We don't use the original model instances as there
|
||||
# are some variables that refer to the Apps object.
|
||||
# FKs/M2Ms from real apps are also not included as they just
|
||||
# mess things up with partial states (due to lack of dependencies)
|
||||
self.real_models = []
|
||||
for app_label in real_apps:
|
||||
app = global_apps.get_app_config(app_label)
|
||||
for model in app.get_models():
|
||||
self.real_models.append(ModelState.from_model(model, exclude_rels=True))
|
||||
# Populate the app registry with a stub for each application.
|
||||
app_labels = {model_state.app_label for model_state in models.values()}
|
||||
app_configs = [AppConfigStub(label) for label in sorted([*real_apps, *app_labels])]
|
||||
super().__init__(app_configs)
|
||||
|
||||
# These locks get in the way of copying as implemented in clone(),
|
||||
# which is called whenever Django duplicates a StateApps before
|
||||
# updating it.
|
||||
self._lock = None
|
||||
self.ready_event = None
|
||||
|
||||
self.render_multiple([*models.values(), *self.real_models])
|
||||
|
||||
# There shouldn't be any operations pending at this point.
|
||||
from django.core.checks.model_checks import _check_lazy_references
|
||||
ignore = {make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set()
|
||||
errors = _check_lazy_references(self, ignore=ignore)
|
||||
if errors:
|
||||
raise ValueError("\n".join(error.msg for error in errors))
|
||||
|
||||
@contextmanager
|
||||
def bulk_update(self):
|
||||
# Avoid clearing each model's cache for each change. Instead, clear
|
||||
# all caches when we're finished updating the model instances.
|
||||
ready = self.ready
|
||||
self.ready = False
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.ready = ready
|
||||
self.clear_cache()
|
||||
|
||||
def render_multiple(self, model_states):
|
||||
# We keep trying to render the models in a loop, ignoring invalid
|
||||
# base errors, until the size of the unrendered models doesn't
|
||||
# decrease by at least one, meaning there's a base dependency loop/
|
||||
# missing base.
|
||||
if not model_states:
|
||||
return
|
||||
# Prevent that all model caches are expired for each render.
|
||||
with self.bulk_update():
|
||||
unrendered_models = model_states
|
||||
while unrendered_models:
|
||||
new_unrendered_models = []
|
||||
for model in unrendered_models:
|
||||
try:
|
||||
model.render(self)
|
||||
except InvalidBasesError:
|
||||
new_unrendered_models.append(model)
|
||||
if len(new_unrendered_models) == len(unrendered_models):
|
||||
raise InvalidBasesError(
|
||||
"Cannot resolve bases for %r\nThis can happen if you are inheriting models from an "
|
||||
"app with migrations (e.g. contrib.auth)\n in an app with no migrations; see "
|
||||
"https://docs.djangoproject.com/en/%s/topics/migrations/#dependencies "
|
||||
"for more" % (new_unrendered_models, get_docs_version())
|
||||
)
|
||||
unrendered_models = new_unrendered_models
|
||||
|
||||
def clone(self):
|
||||
"""Return a clone of this registry."""
|
||||
clone = StateApps([], {})
|
||||
clone.all_models = copy.deepcopy(self.all_models)
|
||||
clone.app_configs = copy.deepcopy(self.app_configs)
|
||||
# Set the pointer to the correct app registry.
|
||||
for app_config in clone.app_configs.values():
|
||||
app_config.apps = clone
|
||||
# No need to actually clone them, they'll never change
|
||||
clone.real_models = self.real_models
|
||||
return clone
|
||||
|
||||
def register_model(self, app_label, model):
|
||||
self.all_models[app_label][model._meta.model_name] = model
|
||||
if app_label not in self.app_configs:
|
||||
self.app_configs[app_label] = AppConfigStub(app_label)
|
||||
self.app_configs[app_label].apps = self
|
||||
self.app_configs[app_label].models[model._meta.model_name] = model
|
||||
self.do_pending_operations(model)
|
||||
self.clear_cache()
|
||||
|
||||
def unregister_model(self, app_label, model_name):
|
||||
try:
|
||||
del self.all_models[app_label][model_name]
|
||||
del self.app_configs[app_label].models[model_name]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
class ModelState:
|
||||
"""
|
||||
Represent a Django Model. Don't use the actual Model class as it's not
|
||||
designed to have its options changed - instead, mutate this one and then
|
||||
render it into a Model as required.
|
||||
|
||||
Note that while you are allowed to mutate .fields, you are not allowed
|
||||
to mutate the Field instances inside there themselves - you must instead
|
||||
assign new ones, as these are not detached during a clone.
|
||||
"""
|
||||
|
||||
def __init__(self, app_label, name, fields, options=None, bases=None, managers=None):
|
||||
self.app_label = app_label
|
||||
self.name = name
|
||||
self.fields = dict(fields)
|
||||
self.options = options or {}
|
||||
self.options.setdefault('indexes', [])
|
||||
self.options.setdefault('constraints', [])
|
||||
self.bases = bases or (models.Model,)
|
||||
self.managers = managers or []
|
||||
for name, field in self.fields.items():
|
||||
# Sanity-check that fields are NOT already bound to a model.
|
||||
if hasattr(field, 'model'):
|
||||
raise ValueError(
|
||||
'ModelState.fields cannot be bound to a model - "%s" is.' % name
|
||||
)
|
||||
# Sanity-check that relation fields are NOT referring to a model class.
|
||||
if field.is_relation and hasattr(field.related_model, '_meta'):
|
||||
raise ValueError(
|
||||
'ModelState.fields cannot refer to a model class - "%s.to" does. '
|
||||
'Use a string reference instead.' % name
|
||||
)
|
||||
if field.many_to_many and hasattr(field.remote_field.through, '_meta'):
|
||||
raise ValueError(
|
||||
'ModelState.fields cannot refer to a model class - "%s.through" does. '
|
||||
'Use a string reference instead.' % name
|
||||
)
|
||||
# Sanity-check that indexes have their name set.
|
||||
for index in self.options['indexes']:
|
||||
if not index.name:
|
||||
raise ValueError(
|
||||
"Indexes passed to ModelState require a name attribute. "
|
||||
"%r doesn't have one." % index
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def name_lower(self):
|
||||
return self.name.lower()
|
||||
|
||||
def get_field(self, field_name):
|
||||
field_name = (
|
||||
self.options['order_with_respect_to']
|
||||
if field_name == '_order'
|
||||
else field_name
|
||||
)
|
||||
return self.fields[field_name]
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model, exclude_rels=False):
|
||||
"""Given a model, return a ModelState representing it."""
|
||||
# Deconstruct the fields
|
||||
fields = []
|
||||
for field in model._meta.local_fields:
|
||||
if getattr(field, "remote_field", None) and exclude_rels:
|
||||
continue
|
||||
if isinstance(field, models.OrderWrt):
|
||||
continue
|
||||
name = field.name
|
||||
try:
|
||||
fields.append((name, field.clone()))
|
||||
except TypeError as e:
|
||||
raise TypeError("Couldn't reconstruct field %s on %s: %s" % (
|
||||
name,
|
||||
model._meta.label,
|
||||
e,
|
||||
))
|
||||
if not exclude_rels:
|
||||
for field in model._meta.local_many_to_many:
|
||||
name = field.name
|
||||
try:
|
||||
fields.append((name, field.clone()))
|
||||
except TypeError as e:
|
||||
raise TypeError("Couldn't reconstruct m2m field %s on %s: %s" % (
|
||||
name,
|
||||
model._meta.object_name,
|
||||
e,
|
||||
))
|
||||
# Extract the options
|
||||
options = {}
|
||||
for name in DEFAULT_NAMES:
|
||||
# Ignore some special options
|
||||
if name in ["apps", "app_label"]:
|
||||
continue
|
||||
elif name in model._meta.original_attrs:
|
||||
if name == "unique_together":
|
||||
ut = model._meta.original_attrs["unique_together"]
|
||||
options[name] = set(normalize_together(ut))
|
||||
elif name == "index_together":
|
||||
it = model._meta.original_attrs["index_together"]
|
||||
options[name] = set(normalize_together(it))
|
||||
elif name == "indexes":
|
||||
indexes = [idx.clone() for idx in model._meta.indexes]
|
||||
for index in indexes:
|
||||
if not index.name:
|
||||
index.set_name_with_model(model)
|
||||
options['indexes'] = indexes
|
||||
elif name == 'constraints':
|
||||
options['constraints'] = [con.clone() for con in model._meta.constraints]
|
||||
else:
|
||||
options[name] = model._meta.original_attrs[name]
|
||||
# If we're ignoring relationships, remove all field-listing model
|
||||
# options (that option basically just means "make a stub model")
|
||||
if exclude_rels:
|
||||
for key in ["unique_together", "index_together", "order_with_respect_to"]:
|
||||
if key in options:
|
||||
del options[key]
|
||||
# Private fields are ignored, so remove options that refer to them.
|
||||
elif options.get('order_with_respect_to') in {field.name for field in model._meta.private_fields}:
|
||||
del options['order_with_respect_to']
|
||||
|
||||
def flatten_bases(model):
|
||||
bases = []
|
||||
for base in model.__bases__:
|
||||
if hasattr(base, "_meta") and base._meta.abstract:
|
||||
bases.extend(flatten_bases(base))
|
||||
else:
|
||||
bases.append(base)
|
||||
return bases
|
||||
|
||||
# We can't rely on __mro__ directly because we only want to flatten
|
||||
# abstract models and not the whole tree. However by recursing on
|
||||
# __bases__ we may end up with duplicates and ordering issues, we
|
||||
# therefore discard any duplicates and reorder the bases according
|
||||
# to their index in the MRO.
|
||||
flattened_bases = sorted(set(flatten_bases(model)), key=lambda x: model.__mro__.index(x))
|
||||
|
||||
# Make our record
|
||||
bases = tuple(
|
||||
(
|
||||
base._meta.label_lower
|
||||
if hasattr(base, "_meta") else
|
||||
base
|
||||
)
|
||||
for base in flattened_bases
|
||||
)
|
||||
# Ensure at least one base inherits from models.Model
|
||||
if not any((isinstance(base, str) or issubclass(base, models.Model)) for base in bases):
|
||||
bases = (models.Model,)
|
||||
|
||||
managers = []
|
||||
manager_names = set()
|
||||
default_manager_shim = None
|
||||
for manager in model._meta.managers:
|
||||
if manager.name in manager_names:
|
||||
# Skip overridden managers.
|
||||
continue
|
||||
elif manager.use_in_migrations:
|
||||
# Copy managers usable in migrations.
|
||||
new_manager = copy.copy(manager)
|
||||
new_manager._set_creation_counter()
|
||||
elif manager is model._base_manager or manager is model._default_manager:
|
||||
# Shim custom managers used as default and base managers.
|
||||
new_manager = models.Manager()
|
||||
new_manager.model = manager.model
|
||||
new_manager.name = manager.name
|
||||
if manager is model._default_manager:
|
||||
default_manager_shim = new_manager
|
||||
else:
|
||||
continue
|
||||
manager_names.add(manager.name)
|
||||
managers.append((manager.name, new_manager))
|
||||
|
||||
# Ignore a shimmed default manager called objects if it's the only one.
|
||||
if managers == [('objects', default_manager_shim)]:
|
||||
managers = []
|
||||
|
||||
# Construct the new ModelState
|
||||
return cls(
|
||||
model._meta.app_label,
|
||||
model._meta.object_name,
|
||||
fields,
|
||||
options,
|
||||
bases,
|
||||
managers,
|
||||
)
|
||||
|
||||
def construct_managers(self):
|
||||
"""Deep-clone the managers using deconstruction."""
|
||||
# Sort all managers by their creation counter
|
||||
sorted_managers = sorted(self.managers, key=lambda v: v[1].creation_counter)
|
||||
for mgr_name, manager in sorted_managers:
|
||||
as_manager, manager_path, qs_path, args, kwargs = manager.deconstruct()
|
||||
if as_manager:
|
||||
qs_class = import_string(qs_path)
|
||||
yield mgr_name, qs_class.as_manager()
|
||||
else:
|
||||
manager_class = import_string(manager_path)
|
||||
yield mgr_name, manager_class(*args, **kwargs)
|
||||
|
||||
def clone(self):
|
||||
"""Return an exact copy of this ModelState."""
|
||||
return self.__class__(
|
||||
app_label=self.app_label,
|
||||
name=self.name,
|
||||
fields=dict(self.fields),
|
||||
# Since options are shallow-copied here, operations such as
|
||||
# AddIndex must replace their option (e.g 'indexes') rather
|
||||
# than mutating it.
|
||||
options=dict(self.options),
|
||||
bases=self.bases,
|
||||
managers=list(self.managers),
|
||||
)
|
||||
|
||||
def render(self, apps):
|
||||
"""Create a Model object from our current state into the given apps."""
|
||||
# First, make a Meta object
|
||||
meta_contents = {'app_label': self.app_label, 'apps': apps, **self.options}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
# Then, work out our bases
|
||||
try:
|
||||
bases = tuple(
|
||||
(apps.get_model(base) if isinstance(base, str) else base)
|
||||
for base in self.bases
|
||||
)
|
||||
except LookupError:
|
||||
raise InvalidBasesError("Cannot resolve one or more bases from %r" % (self.bases,))
|
||||
# Clone fields for the body, add other bits.
|
||||
body = {name: field.clone() for name, field in self.fields.items()}
|
||||
body['Meta'] = meta
|
||||
body['__module__'] = "__fake__"
|
||||
|
||||
# Restore managers
|
||||
body.update(self.construct_managers())
|
||||
# Then, make a Model object (apps.register_model is called in __new__)
|
||||
return type(self.name, bases, body)
|
||||
|
||||
def get_index_by_name(self, name):
|
||||
for index in self.options['indexes']:
|
||||
if index.name == name:
|
||||
return index
|
||||
raise ValueError("No index named %s on model %s" % (name, self.name))
|
||||
|
||||
def get_constraint_by_name(self, name):
|
||||
for constraint in self.options['constraints']:
|
||||
if constraint.name == name:
|
||||
return constraint
|
||||
raise ValueError('No constraint named %s on model %s' % (name, self.name))
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
(self.app_label == other.app_label) and
|
||||
(self.name == other.name) and
|
||||
(len(self.fields) == len(other.fields)) and
|
||||
all(
|
||||
k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:]
|
||||
for (k1, f1), (k2, f2) in zip(
|
||||
sorted(self.fields.items()),
|
||||
sorted(other.fields.items()),
|
||||
)
|
||||
) and
|
||||
(self.options == other.options) and
|
||||
(self.bases == other.bases) and
|
||||
(self.managers == other.managers)
|
||||
)
|
||||
118
venv/Lib/site-packages/django/db/migrations/utils.py
Normal file
118
venv/Lib/site-packages/django/db/migrations/utils.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import datetime
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
|
||||
|
||||
FieldReference = namedtuple('FieldReference', 'to through')
|
||||
|
||||
COMPILED_REGEX_TYPE = type(re.compile(''))
|
||||
|
||||
|
||||
class RegexObject:
|
||||
def __init__(self, obj):
|
||||
self.pattern = obj.pattern
|
||||
self.flags = obj.flags
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.pattern == other.pattern and self.flags == other.flags
|
||||
|
||||
|
||||
def get_migration_name_timestamp():
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M")
|
||||
|
||||
|
||||
def resolve_relation(model, app_label=None, model_name=None):
|
||||
"""
|
||||
Turn a model class or model reference string and return a model tuple.
|
||||
|
||||
app_label and model_name are used to resolve the scope of recursive and
|
||||
unscoped model relationship.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
if model == RECURSIVE_RELATIONSHIP_CONSTANT:
|
||||
if app_label is None or model_name is None:
|
||||
raise TypeError(
|
||||
'app_label and model_name must be provided to resolve '
|
||||
'recursive relationships.'
|
||||
)
|
||||
return app_label, model_name
|
||||
if '.' in model:
|
||||
app_label, model_name = model.split('.', 1)
|
||||
return app_label, model_name.lower()
|
||||
if app_label is None:
|
||||
raise TypeError(
|
||||
'app_label must be provided to resolve unscoped model '
|
||||
'relationships.'
|
||||
)
|
||||
return app_label, model.lower()
|
||||
return model._meta.app_label, model._meta.model_name
|
||||
|
||||
|
||||
def field_references(
|
||||
model_tuple,
|
||||
field,
|
||||
reference_model_tuple,
|
||||
reference_field_name=None,
|
||||
reference_field=None,
|
||||
):
|
||||
"""
|
||||
Return either False or a FieldReference if `field` references provided
|
||||
context.
|
||||
|
||||
False positives can be returned if `reference_field_name` is provided
|
||||
without `reference_field` because of the introspection limitation it
|
||||
incurs. This should not be an issue when this function is used to determine
|
||||
whether or not an optimization can take place.
|
||||
"""
|
||||
remote_field = field.remote_field
|
||||
if not remote_field:
|
||||
return False
|
||||
references_to = None
|
||||
references_through = None
|
||||
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
|
||||
to_fields = getattr(field, 'to_fields', None)
|
||||
if (
|
||||
reference_field_name is None or
|
||||
# Unspecified to_field(s).
|
||||
to_fields is None or
|
||||
# Reference to primary key.
|
||||
(None in to_fields and (reference_field is None or reference_field.primary_key)) or
|
||||
# Reference to field.
|
||||
reference_field_name in to_fields
|
||||
):
|
||||
references_to = (remote_field, to_fields)
|
||||
through = getattr(remote_field, 'through', None)
|
||||
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
|
||||
through_fields = remote_field.through_fields
|
||||
if (
|
||||
reference_field_name is None or
|
||||
# Unspecified through_fields.
|
||||
through_fields is None or
|
||||
# Reference to field.
|
||||
reference_field_name in through_fields
|
||||
):
|
||||
references_through = (remote_field, through_fields)
|
||||
if not (references_to or references_through):
|
||||
return False
|
||||
return FieldReference(references_to, references_through)
|
||||
|
||||
|
||||
def get_references(state, model_tuple, field_tuple=()):
|
||||
"""
|
||||
Generator of (model_state, name, field, reference) referencing
|
||||
provided context.
|
||||
|
||||
If field_tuple is provided only references to this particular field of
|
||||
model_tuple will be generated.
|
||||
"""
|
||||
for state_model_tuple, model_state in state.models.items():
|
||||
for name, field in model_state.fields.items():
|
||||
reference = field_references(state_model_tuple, field, model_tuple, *field_tuple)
|
||||
if reference:
|
||||
yield model_state, name, field, reference
|
||||
|
||||
|
||||
def field_is_referenced(state, model_tuple, field_tuple):
|
||||
"""Return whether `field_tuple` is referenced by any state models."""
|
||||
return next(get_references(state, model_tuple, field_tuple), None) is not None
|
||||
300
venv/Lib/site-packages/django/db/migrations/writer.py
Normal file
300
venv/Lib/site-packages/django/db/migrations/writer.py
Normal file
@@ -0,0 +1,300 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from importlib import import_module
|
||||
|
||||
from django import get_version
|
||||
from django.apps import apps
|
||||
# SettingsReference imported for backwards compatibility in Django 2.2.
|
||||
from django.conf import SettingsReference # NOQA
|
||||
from django.db import migrations
|
||||
from django.db.migrations.loader import MigrationLoader
|
||||
from django.db.migrations.serializer import Serializer, serializer_factory
|
||||
from django.utils.inspect import get_func_args
|
||||
from django.utils.module_loading import module_dir
|
||||
from django.utils.timezone import now
|
||||
|
||||
|
||||
class OperationWriter:
|
||||
def __init__(self, operation, indentation=2):
|
||||
self.operation = operation
|
||||
self.buff = []
|
||||
self.indentation = indentation
|
||||
|
||||
def serialize(self):
|
||||
|
||||
def _write(_arg_name, _arg_value):
|
||||
if (_arg_name in self.operation.serialization_expand_args and
|
||||
isinstance(_arg_value, (list, tuple, dict))):
|
||||
if isinstance(_arg_value, dict):
|
||||
self.feed('%s={' % _arg_name)
|
||||
self.indent()
|
||||
for key, value in _arg_value.items():
|
||||
key_string, key_imports = MigrationWriter.serialize(key)
|
||||
arg_string, arg_imports = MigrationWriter.serialize(value)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
self.feed('%s: %s' % (key_string, args[0]))
|
||||
for arg in args[1:-1]:
|
||||
self.feed(arg)
|
||||
self.feed('%s,' % args[-1])
|
||||
else:
|
||||
self.feed('%s: %s,' % (key_string, arg_string))
|
||||
imports.update(key_imports)
|
||||
imports.update(arg_imports)
|
||||
self.unindent()
|
||||
self.feed('},')
|
||||
else:
|
||||
self.feed('%s=[' % _arg_name)
|
||||
self.indent()
|
||||
for item in _arg_value:
|
||||
arg_string, arg_imports = MigrationWriter.serialize(item)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
for arg in args[:-1]:
|
||||
self.feed(arg)
|
||||
self.feed('%s,' % args[-1])
|
||||
else:
|
||||
self.feed('%s,' % arg_string)
|
||||
imports.update(arg_imports)
|
||||
self.unindent()
|
||||
self.feed('],')
|
||||
else:
|
||||
arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
self.feed('%s=%s' % (_arg_name, args[0]))
|
||||
for arg in args[1:-1]:
|
||||
self.feed(arg)
|
||||
self.feed('%s,' % args[-1])
|
||||
else:
|
||||
self.feed('%s=%s,' % (_arg_name, arg_string))
|
||||
imports.update(arg_imports)
|
||||
|
||||
imports = set()
|
||||
name, args, kwargs = self.operation.deconstruct()
|
||||
operation_args = get_func_args(self.operation.__init__)
|
||||
|
||||
# See if this operation is in django.db.migrations. If it is,
|
||||
# We can just use the fact we already have that imported,
|
||||
# otherwise, we need to add an import for the operation class.
|
||||
if getattr(migrations, name, None) == self.operation.__class__:
|
||||
self.feed('migrations.%s(' % name)
|
||||
else:
|
||||
imports.add('import %s' % (self.operation.__class__.__module__))
|
||||
self.feed('%s.%s(' % (self.operation.__class__.__module__, name))
|
||||
|
||||
self.indent()
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
arg_value = arg
|
||||
arg_name = operation_args[i]
|
||||
_write(arg_name, arg_value)
|
||||
|
||||
i = len(args)
|
||||
# Only iterate over remaining arguments
|
||||
for arg_name in operation_args[i:]:
|
||||
if arg_name in kwargs: # Don't sort to maintain signature order
|
||||
arg_value = kwargs[arg_name]
|
||||
_write(arg_name, arg_value)
|
||||
|
||||
self.unindent()
|
||||
self.feed('),')
|
||||
return self.render(), imports
|
||||
|
||||
def indent(self):
|
||||
self.indentation += 1
|
||||
|
||||
def unindent(self):
|
||||
self.indentation -= 1
|
||||
|
||||
def feed(self, line):
|
||||
self.buff.append(' ' * (self.indentation * 4) + line)
|
||||
|
||||
def render(self):
|
||||
return '\n'.join(self.buff)
|
||||
|
||||
|
||||
class MigrationWriter:
|
||||
"""
|
||||
Take a Migration instance and is able to produce the contents
|
||||
of the migration file from it.
|
||||
"""
|
||||
|
||||
def __init__(self, migration, include_header=True):
|
||||
self.migration = migration
|
||||
self.include_header = include_header
|
||||
self.needs_manual_porting = False
|
||||
|
||||
def as_string(self):
|
||||
"""Return a string of the file contents."""
|
||||
items = {
|
||||
"replaces_str": "",
|
||||
"initial_str": "",
|
||||
}
|
||||
|
||||
imports = set()
|
||||
|
||||
# Deconstruct operations
|
||||
operations = []
|
||||
for operation in self.migration.operations:
|
||||
operation_string, operation_imports = OperationWriter(operation).serialize()
|
||||
imports.update(operation_imports)
|
||||
operations.append(operation_string)
|
||||
items["operations"] = "\n".join(operations) + "\n" if operations else ""
|
||||
|
||||
# Format dependencies and write out swappable dependencies right
|
||||
dependencies = []
|
||||
for dependency in self.migration.dependencies:
|
||||
if dependency[0] == "__setting__":
|
||||
dependencies.append(" migrations.swappable_dependency(settings.%s)," % dependency[1])
|
||||
imports.add("from django.conf import settings")
|
||||
else:
|
||||
dependencies.append(" %s," % self.serialize(dependency)[0])
|
||||
items["dependencies"] = "\n".join(dependencies) + "\n" if dependencies else ""
|
||||
|
||||
# Format imports nicely, swapping imports of functions from migration files
|
||||
# for comments
|
||||
migration_imports = set()
|
||||
for line in list(imports):
|
||||
if re.match(r"^import (.*)\.\d+[^\s]*$", line):
|
||||
migration_imports.add(line.split("import")[1].strip())
|
||||
imports.remove(line)
|
||||
self.needs_manual_porting = True
|
||||
|
||||
# django.db.migrations is always used, but models import may not be.
|
||||
# If models import exists, merge it with migrations import.
|
||||
if "from django.db import models" in imports:
|
||||
imports.discard("from django.db import models")
|
||||
imports.add("from django.db import migrations, models")
|
||||
else:
|
||||
imports.add("from django.db import migrations")
|
||||
|
||||
# Sort imports by the package / module to be imported (the part after
|
||||
# "from" in "from ... import ..." or after "import" in "import ...").
|
||||
sorted_imports = sorted(imports, key=lambda i: i.split()[1])
|
||||
items["imports"] = "\n".join(sorted_imports) + "\n" if imports else ""
|
||||
if migration_imports:
|
||||
items["imports"] += (
|
||||
"\n\n# Functions from the following migrations need manual "
|
||||
"copying.\n# Move them and any dependencies into this file, "
|
||||
"then update the\n# RunPython operations to refer to the local "
|
||||
"versions:\n# %s"
|
||||
) % "\n# ".join(sorted(migration_imports))
|
||||
# If there's a replaces, make a string for it
|
||||
if self.migration.replaces:
|
||||
items['replaces_str'] = "\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
|
||||
# Hinting that goes into comment
|
||||
if self.include_header:
|
||||
items['migration_header'] = MIGRATION_HEADER_TEMPLATE % {
|
||||
'version': get_version(),
|
||||
'timestamp': now().strftime("%Y-%m-%d %H:%M"),
|
||||
}
|
||||
else:
|
||||
items['migration_header'] = ""
|
||||
|
||||
if self.migration.initial:
|
||||
items['initial_str'] = "\n initial = True\n"
|
||||
|
||||
return MIGRATION_TEMPLATE % items
|
||||
|
||||
@property
|
||||
def basedir(self):
|
||||
migrations_package_name, _ = MigrationLoader.migrations_module(self.migration.app_label)
|
||||
|
||||
if migrations_package_name is None:
|
||||
raise ValueError(
|
||||
"Django can't create migrations for app '%s' because "
|
||||
"migrations have been disabled via the MIGRATION_MODULES "
|
||||
"setting." % self.migration.app_label
|
||||
)
|
||||
|
||||
# See if we can import the migrations module directly
|
||||
try:
|
||||
migrations_module = import_module(migrations_package_name)
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
return module_dir(migrations_module)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Alright, see if it's a direct submodule of the app
|
||||
app_config = apps.get_app_config(self.migration.app_label)
|
||||
maybe_app_name, _, migrations_package_basename = migrations_package_name.rpartition(".")
|
||||
if app_config.name == maybe_app_name:
|
||||
return os.path.join(app_config.path, migrations_package_basename)
|
||||
|
||||
# In case of using MIGRATION_MODULES setting and the custom package
|
||||
# doesn't exist, create one, starting from an existing package
|
||||
existing_dirs, missing_dirs = migrations_package_name.split("."), []
|
||||
while existing_dirs:
|
||||
missing_dirs.insert(0, existing_dirs.pop(-1))
|
||||
try:
|
||||
base_module = import_module(".".join(existing_dirs))
|
||||
except (ImportError, ValueError):
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
base_dir = module_dir(base_module)
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not locate an appropriate location to create "
|
||||
"migrations package %s. Make sure the toplevel "
|
||||
"package exists and can be imported." %
|
||||
migrations_package_name)
|
||||
|
||||
final_dir = os.path.join(base_dir, *missing_dirs)
|
||||
os.makedirs(final_dir, exist_ok=True)
|
||||
for missing_dir in missing_dirs:
|
||||
base_dir = os.path.join(base_dir, missing_dir)
|
||||
with open(os.path.join(base_dir, "__init__.py"), "w"):
|
||||
pass
|
||||
|
||||
return final_dir
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return "%s.py" % self.migration.name
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return os.path.join(self.basedir, self.filename)
|
||||
|
||||
@classmethod
|
||||
def serialize(cls, value):
|
||||
return serializer_factory(value).serialize()
|
||||
|
||||
@classmethod
|
||||
def register_serializer(cls, type_, serializer):
|
||||
Serializer.register(type_, serializer)
|
||||
|
||||
@classmethod
|
||||
def unregister_serializer(cls, type_):
|
||||
Serializer.unregister(type_)
|
||||
|
||||
|
||||
MIGRATION_HEADER_TEMPLATE = """\
|
||||
# Generated by Django %(version)s on %(timestamp)s
|
||||
|
||||
"""
|
||||
|
||||
|
||||
MIGRATION_TEMPLATE = """\
|
||||
%(migration_header)s%(imports)s
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
%(replaces_str)s%(initial_str)s
|
||||
dependencies = [
|
||||
%(dependencies)s\
|
||||
]
|
||||
|
||||
operations = [
|
||||
%(operations)s\
|
||||
]
|
||||
"""
|
||||
52
venv/Lib/site-packages/django/db/models/__init__.py
Normal file
52
venv/Lib/site-packages/django/db/models/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db.models import signals
|
||||
from django.db.models.aggregates import * # NOQA
|
||||
from django.db.models.aggregates import __all__ as aggregates_all
|
||||
from django.db.models.constraints import * # NOQA
|
||||
from django.db.models.constraints import __all__ as constraints_all
|
||||
from django.db.models.deletion import (
|
||||
CASCADE, DO_NOTHING, PROTECT, RESTRICT, SET, SET_DEFAULT, SET_NULL,
|
||||
ProtectedError, RestrictedError,
|
||||
)
|
||||
from django.db.models.enums import * # NOQA
|
||||
from django.db.models.enums import __all__ as enums_all
|
||||
from django.db.models.expressions import (
|
||||
Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func,
|
||||
OrderBy, OuterRef, RowRange, Subquery, Value, ValueRange, When, Window,
|
||||
WindowFrame,
|
||||
)
|
||||
from django.db.models.fields import * # NOQA
|
||||
from django.db.models.fields import __all__ as fields_all
|
||||
from django.db.models.fields.files import FileField, ImageField
|
||||
from django.db.models.fields.json import JSONField
|
||||
from django.db.models.fields.proxy import OrderWrt
|
||||
from django.db.models.indexes import * # NOQA
|
||||
from django.db.models.indexes import __all__ as indexes_all
|
||||
from django.db.models.lookups import Lookup, Transform
|
||||
from django.db.models.manager import Manager
|
||||
from django.db.models.query import Prefetch, QuerySet, prefetch_related_objects
|
||||
from django.db.models.query_utils import FilteredRelation, Q
|
||||
|
||||
# Imports that would create circular imports if sorted
|
||||
from django.db.models.base import DEFERRED, Model # isort:skip
|
||||
from django.db.models.fields.related import ( # isort:skip
|
||||
ForeignKey, ForeignObject, OneToOneField, ManyToManyField,
|
||||
ForeignObjectRel, ManyToOneRel, ManyToManyRel, OneToOneRel,
|
||||
)
|
||||
|
||||
|
||||
__all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all
|
||||
__all__ += [
|
||||
'ObjectDoesNotExist', 'signals',
|
||||
'CASCADE', 'DO_NOTHING', 'PROTECT', 'RESTRICT', 'SET', 'SET_DEFAULT',
|
||||
'SET_NULL', 'ProtectedError', 'RestrictedError',
|
||||
'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F',
|
||||
'Func', 'OrderBy', 'OuterRef', 'RowRange', 'Subquery', 'Value',
|
||||
'ValueRange', 'When',
|
||||
'Window', 'WindowFrame',
|
||||
'FileField', 'ImageField', 'JSONField', 'OrderWrt', 'Lookup', 'Transform',
|
||||
'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects',
|
||||
'DEFERRED', 'Model', 'FilteredRelation',
|
||||
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
|
||||
'ForeignObjectRel', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel',
|
||||
]
|
||||
165
venv/Lib/site-packages/django/db/models/aggregates.py
Normal file
165
venv/Lib/site-packages/django/db/models/aggregates.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Classes to represent the definitions of aggregate functions.
|
||||
"""
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.expressions import Case, Func, Star, When
|
||||
from django.db.models.fields import IntegerField
|
||||
from django.db.models.functions.comparison import Coalesce
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDurationInputMixin, NumericOutputFieldMixin,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
||||
]
|
||||
|
||||
|
||||
class Aggregate(Func):
|
||||
template = '%(function)s(%(distinct)s%(expressions)s)'
|
||||
contains_aggregate = True
|
||||
name = None
|
||||
filter_template = '%s FILTER (WHERE %%(filter)s)'
|
||||
window_compatible = True
|
||||
allow_distinct = False
|
||||
empty_result_set_value = None
|
||||
|
||||
def __init__(self, *expressions, distinct=False, filter=None, default=None, **extra):
|
||||
if distinct and not self.allow_distinct:
|
||||
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
|
||||
if default is not None and self.empty_result_set_value is not None:
|
||||
raise TypeError(f'{self.__class__.__name__} does not allow default.')
|
||||
self.distinct = distinct
|
||||
self.filter = filter
|
||||
self.default = default
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def get_source_fields(self):
|
||||
# Don't return the filter expression since it's not a source field.
|
||||
return [e._output_field_or_none for e in super().get_source_expressions()]
|
||||
|
||||
def get_source_expressions(self):
|
||||
source_expressions = super().get_source_expressions()
|
||||
if self.filter:
|
||||
return source_expressions + [self.filter]
|
||||
return source_expressions
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.filter = self.filter and exprs.pop()
|
||||
return super().set_source_expressions(exprs)
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
# Aggregates are not allowed in UPDATE queries, so ignore for_save
|
||||
c = super().resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.filter = c.filter and c.filter.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if not summarize:
|
||||
# Call Aggregate.get_source_expressions() to avoid
|
||||
# returning self.filter and including that in this loop.
|
||||
expressions = super(Aggregate, c).get_source_expressions()
|
||||
for index, expr in enumerate(expressions):
|
||||
if expr.contains_aggregate:
|
||||
before_resolved = self.get_source_expressions()[index]
|
||||
name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
|
||||
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
|
||||
if (default := c.default) is None:
|
||||
return c
|
||||
if hasattr(default, 'resolve_expression'):
|
||||
default = default.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.default = None # Reset the default argument before wrapping.
|
||||
return Coalesce(c, default, output_field=c._output_field_or_none)
|
||||
|
||||
@property
|
||||
def default_alias(self):
|
||||
expressions = self.get_source_expressions()
|
||||
if len(expressions) == 1 and hasattr(expressions[0], 'name'):
|
||||
return '%s__%s' % (expressions[0].name, self.name.lower())
|
||||
raise TypeError("Complex expressions require an alias")
|
||||
|
||||
def get_group_by_cols(self, alias=None):
|
||||
return []
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context['distinct'] = 'DISTINCT ' if self.distinct else ''
|
||||
if self.filter:
|
||||
if connection.features.supports_aggregate_filter_clause:
|
||||
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
||||
template = self.filter_template % extra_context.get('template', self.template)
|
||||
sql, params = super().as_sql(
|
||||
compiler, connection, template=template, filter=filter_sql,
|
||||
**extra_context
|
||||
)
|
||||
return sql, params + filter_params
|
||||
else:
|
||||
copy = self.copy()
|
||||
copy.filter = None
|
||||
source_expressions = copy.get_source_expressions()
|
||||
condition = When(self.filter, then=source_expressions[0])
|
||||
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
|
||||
return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def _get_repr_options(self):
|
||||
options = super()._get_repr_options()
|
||||
if self.distinct:
|
||||
options['distinct'] = self.distinct
|
||||
if self.filter:
|
||||
options['filter'] = self.filter
|
||||
return options
|
||||
|
||||
|
||||
class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
|
||||
function = 'AVG'
|
||||
name = 'Avg'
|
||||
allow_distinct = True
|
||||
|
||||
|
||||
class Count(Aggregate):
|
||||
function = 'COUNT'
|
||||
name = 'Count'
|
||||
output_field = IntegerField()
|
||||
allow_distinct = True
|
||||
empty_result_set_value = 0
|
||||
|
||||
def __init__(self, expression, filter=None, **extra):
|
||||
if expression == '*':
|
||||
expression = Star()
|
||||
if isinstance(expression, Star) and filter is not None:
|
||||
raise ValueError('Star cannot be used with filter. Please specify a field.')
|
||||
super().__init__(expression, filter=filter, **extra)
|
||||
|
||||
|
||||
class Max(Aggregate):
|
||||
function = 'MAX'
|
||||
name = 'Max'
|
||||
|
||||
|
||||
class Min(Aggregate):
|
||||
function = 'MIN'
|
||||
name = 'Min'
|
||||
|
||||
|
||||
class StdDev(NumericOutputFieldMixin, Aggregate):
|
||||
name = 'StdDev'
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'}
|
||||
|
||||
|
||||
class Sum(FixDurationInputMixin, Aggregate):
|
||||
function = 'SUM'
|
||||
name = 'Sum'
|
||||
allow_distinct = True
|
||||
|
||||
|
||||
class Variance(NumericOutputFieldMixin, Aggregate):
|
||||
name = 'Variance'
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), 'sample': self.function == 'VAR_SAMP'}
|
||||
2190
venv/Lib/site-packages/django/db/models/base.py
Normal file
2190
venv/Lib/site-packages/django/db/models/base.py
Normal file
File diff suppressed because it is too large
Load Diff
6
venv/Lib/site-packages/django/db/models/constants.py
Normal file
6
venv/Lib/site-packages/django/db/models/constants.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Constants used across the ORM in general.
|
||||
"""
|
||||
|
||||
# Separator used to split filter strings apart.
|
||||
LOOKUP_SEP = '__'
|
||||
255
venv/Lib/site-packages/django/db/models/constraints.py
Normal file
255
venv/Lib/site-packages/django/db/models/constraints.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from enum import Enum
|
||||
|
||||
from django.db.models.expressions import ExpressionList, F
|
||||
from django.db.models.indexes import IndexExpression
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.sql.query import Query
|
||||
|
||||
__all__ = ['CheckConstraint', 'Deferrable', 'UniqueConstraint']
|
||||
|
||||
|
||||
class BaseConstraint:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return False
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
raise NotImplementedError('This method must be implemented by a subclass.')
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
raise NotImplementedError('This method must be implemented by a subclass.')
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
raise NotImplementedError('This method must be implemented by a subclass.')
|
||||
|
||||
def deconstruct(self):
|
||||
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace('django.db.models.constraints', 'django.db.models')
|
||||
return (path, (), {'name': self.name})
|
||||
|
||||
def clone(self):
|
||||
_, args, kwargs = self.deconstruct()
|
||||
return self.__class__(*args, **kwargs)
|
||||
|
||||
|
||||
class CheckConstraint(BaseConstraint):
|
||||
def __init__(self, *, check, name):
|
||||
self.check = check
|
||||
if not getattr(check, 'conditional', False):
|
||||
raise TypeError(
|
||||
'CheckConstraint.check must be a Q instance or boolean '
|
||||
'expression.'
|
||||
)
|
||||
super().__init__(name)
|
||||
|
||||
def _get_check_sql(self, model, schema_editor):
|
||||
query = Query(model=model, alias_cols=False)
|
||||
where = query.build_where(self.check)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
check = self._get_check_sql(model, schema_editor)
|
||||
return schema_editor._check_sql(self.name, check)
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
check = self._get_check_sql(model, schema_editor)
|
||||
return schema_editor._create_check_sql(model, self.name, check)
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
return schema_editor._delete_check_sql(model, self.name)
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s: check=%s name=%s>' % (
|
||||
self.__class__.__qualname__,
|
||||
self.check,
|
||||
repr(self.name),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CheckConstraint):
|
||||
return self.name == other.name and self.check == other.check
|
||||
return super().__eq__(other)
|
||||
|
||||
def deconstruct(self):
|
||||
path, args, kwargs = super().deconstruct()
|
||||
kwargs['check'] = self.check
|
||||
return path, args, kwargs
|
||||
|
||||
|
||||
class Deferrable(Enum):
|
||||
DEFERRED = 'deferred'
|
||||
IMMEDIATE = 'immediate'
|
||||
|
||||
# A similar format was proposed for Python 3.10.
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__qualname__}.{self._name_}'
|
||||
|
||||
|
||||
class UniqueConstraint(BaseConstraint):
|
||||
def __init__(
|
||||
self,
|
||||
*expressions,
|
||||
fields=(),
|
||||
name=None,
|
||||
condition=None,
|
||||
deferrable=None,
|
||||
include=None,
|
||||
opclasses=(),
|
||||
):
|
||||
if not name:
|
||||
raise ValueError('A unique constraint must be named.')
|
||||
if not expressions and not fields:
|
||||
raise ValueError(
|
||||
'At least one field or expression is required to define a '
|
||||
'unique constraint.'
|
||||
)
|
||||
if expressions and fields:
|
||||
raise ValueError(
|
||||
'UniqueConstraint.fields and expressions are mutually exclusive.'
|
||||
)
|
||||
if not isinstance(condition, (type(None), Q)):
|
||||
raise ValueError('UniqueConstraint.condition must be a Q instance.')
|
||||
if condition and deferrable:
|
||||
raise ValueError(
|
||||
'UniqueConstraint with conditions cannot be deferred.'
|
||||
)
|
||||
if include and deferrable:
|
||||
raise ValueError(
|
||||
'UniqueConstraint with include fields cannot be deferred.'
|
||||
)
|
||||
if opclasses and deferrable:
|
||||
raise ValueError(
|
||||
'UniqueConstraint with opclasses cannot be deferred.'
|
||||
)
|
||||
if expressions and deferrable:
|
||||
raise ValueError(
|
||||
'UniqueConstraint with expressions cannot be deferred.'
|
||||
)
|
||||
if expressions and opclasses:
|
||||
raise ValueError(
|
||||
'UniqueConstraint.opclasses cannot be used with expressions. '
|
||||
'Use django.contrib.postgres.indexes.OpClass() instead.'
|
||||
)
|
||||
if not isinstance(deferrable, (type(None), Deferrable)):
|
||||
raise ValueError(
|
||||
'UniqueConstraint.deferrable must be a Deferrable instance.'
|
||||
)
|
||||
if not isinstance(include, (type(None), list, tuple)):
|
||||
raise ValueError('UniqueConstraint.include must be a list or tuple.')
|
||||
if not isinstance(opclasses, (list, tuple)):
|
||||
raise ValueError('UniqueConstraint.opclasses must be a list or tuple.')
|
||||
if opclasses and len(fields) != len(opclasses):
|
||||
raise ValueError(
|
||||
'UniqueConstraint.fields and UniqueConstraint.opclasses must '
|
||||
'have the same number of elements.'
|
||||
)
|
||||
self.fields = tuple(fields)
|
||||
self.condition = condition
|
||||
self.deferrable = deferrable
|
||||
self.include = tuple(include) if include else ()
|
||||
self.opclasses = opclasses
|
||||
self.expressions = tuple(
|
||||
F(expression) if isinstance(expression, str) else expression
|
||||
for expression in expressions
|
||||
)
|
||||
super().__init__(name)
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return bool(self.expressions)
|
||||
|
||||
def _get_condition_sql(self, model, schema_editor):
|
||||
if self.condition is None:
|
||||
return None
|
||||
query = Query(model=model, alias_cols=False)
|
||||
where = query.build_where(self.condition)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def _get_index_expressions(self, model, schema_editor):
|
||||
if not self.expressions:
|
||||
return None
|
||||
index_expressions = []
|
||||
for expression in self.expressions:
|
||||
index_expression = IndexExpression(expression)
|
||||
index_expression.set_wrapper_classes(schema_editor.connection)
|
||||
index_expressions.append(index_expression)
|
||||
return ExpressionList(*index_expressions).resolve_expression(
|
||||
Query(model, alias_cols=False),
|
||||
)
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
fields = [model._meta.get_field(field_name) for field_name in self.fields]
|
||||
include = [model._meta.get_field(field_name).column for field_name in self.include]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._unique_sql(
|
||||
model, fields, self.name, condition=condition,
|
||||
deferrable=self.deferrable, include=include,
|
||||
opclasses=self.opclasses, expressions=expressions,
|
||||
)
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
fields = [model._meta.get_field(field_name) for field_name in self.fields]
|
||||
include = [model._meta.get_field(field_name).column for field_name in self.include]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._create_unique_sql(
|
||||
model, fields, self.name, condition=condition,
|
||||
deferrable=self.deferrable, include=include,
|
||||
opclasses=self.opclasses, expressions=expressions,
|
||||
)
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
include = [model._meta.get_field(field_name).column for field_name in self.include]
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._delete_unique_sql(
|
||||
model, self.name, condition=condition, deferrable=self.deferrable,
|
||||
include=include, opclasses=self.opclasses, expressions=expressions,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s:%s%s%s%s%s%s%s>' % (
|
||||
self.__class__.__qualname__,
|
||||
'' if not self.fields else ' fields=%s' % repr(self.fields),
|
||||
'' if not self.expressions else ' expressions=%s' % repr(self.expressions),
|
||||
' name=%s' % repr(self.name),
|
||||
'' if self.condition is None else ' condition=%s' % self.condition,
|
||||
'' if self.deferrable is None else ' deferrable=%r' % self.deferrable,
|
||||
'' if not self.include else ' include=%s' % repr(self.include),
|
||||
'' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, UniqueConstraint):
|
||||
return (
|
||||
self.name == other.name and
|
||||
self.fields == other.fields and
|
||||
self.condition == other.condition and
|
||||
self.deferrable == other.deferrable and
|
||||
self.include == other.include and
|
||||
self.opclasses == other.opclasses and
|
||||
self.expressions == other.expressions
|
||||
)
|
||||
return super().__eq__(other)
|
||||
|
||||
def deconstruct(self):
|
||||
path, args, kwargs = super().deconstruct()
|
||||
if self.fields:
|
||||
kwargs['fields'] = self.fields
|
||||
if self.condition:
|
||||
kwargs['condition'] = self.condition
|
||||
if self.deferrable:
|
||||
kwargs['deferrable'] = self.deferrable
|
||||
if self.include:
|
||||
kwargs['include'] = self.include
|
||||
if self.opclasses:
|
||||
kwargs['opclasses'] = self.opclasses
|
||||
return path, self.expressions, kwargs
|
||||
449
venv/Lib/site-packages/django/db/models/deletion.py
Normal file
449
venv/Lib/site-packages/django/db/models/deletion.py
Normal file
@@ -0,0 +1,449 @@
|
||||
from collections import Counter, defaultdict
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
|
||||
from django.db import IntegrityError, connections, transaction
|
||||
from django.db.models import query_utils, signals, sql
|
||||
|
||||
|
||||
class ProtectedError(IntegrityError):
|
||||
def __init__(self, msg, protected_objects):
|
||||
self.protected_objects = protected_objects
|
||||
super().__init__(msg, protected_objects)
|
||||
|
||||
|
||||
class RestrictedError(IntegrityError):
|
||||
def __init__(self, msg, restricted_objects):
|
||||
self.restricted_objects = restricted_objects
|
||||
super().__init__(msg, restricted_objects)
|
||||
|
||||
|
||||
def CASCADE(collector, field, sub_objs, using):
|
||||
collector.collect(
|
||||
sub_objs, source=field.remote_field.model, source_attr=field.name,
|
||||
nullable=field.null, fail_on_restricted=False,
|
||||
)
|
||||
if field.null and not connections[using].features.can_defer_constraint_checks:
|
||||
collector.add_field_update(field, None, sub_objs)
|
||||
|
||||
|
||||
def PROTECT(collector, field, sub_objs, using):
|
||||
raise ProtectedError(
|
||||
"Cannot delete some instances of model '%s' because they are "
|
||||
"referenced through a protected foreign key: '%s.%s'" % (
|
||||
field.remote_field.model.__name__, sub_objs[0].__class__.__name__, field.name
|
||||
),
|
||||
sub_objs
|
||||
)
|
||||
|
||||
|
||||
def RESTRICT(collector, field, sub_objs, using):
|
||||
collector.add_restricted_objects(field, sub_objs)
|
||||
collector.add_dependency(field.remote_field.model, field.model)
|
||||
|
||||
|
||||
def SET(value):
|
||||
if callable(value):
|
||||
def set_on_delete(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, value(), sub_objs)
|
||||
else:
|
||||
def set_on_delete(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, value, sub_objs)
|
||||
set_on_delete.deconstruct = lambda: ('django.db.models.SET', (value,), {})
|
||||
return set_on_delete
|
||||
|
||||
|
||||
def SET_NULL(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, None, sub_objs)
|
||||
|
||||
|
||||
def SET_DEFAULT(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, field.get_default(), sub_objs)
|
||||
|
||||
|
||||
def DO_NOTHING(collector, field, sub_objs, using):
|
||||
pass
|
||||
|
||||
|
||||
def get_candidate_relations_to_delete(opts):
|
||||
# The candidate relations are the ones that come from N-1 and 1-1 relations.
|
||||
# N-N (i.e., many-to-many) relations aren't candidates for deletion.
|
||||
return (
|
||||
f for f in opts.get_fields(include_hidden=True)
|
||||
if f.auto_created and not f.concrete and (f.one_to_one or f.one_to_many)
|
||||
)
|
||||
|
||||
|
||||
class Collector:
|
||||
def __init__(self, using):
|
||||
self.using = using
|
||||
# Initially, {model: {instances}}, later values become lists.
|
||||
self.data = defaultdict(set)
|
||||
# {model: {(field, value): {instances}}}
|
||||
self.field_updates = defaultdict(partial(defaultdict, set))
|
||||
# {model: {field: {instances}}}
|
||||
self.restricted_objects = defaultdict(partial(defaultdict, set))
|
||||
# fast_deletes is a list of queryset-likes that can be deleted without
|
||||
# fetching the objects into memory.
|
||||
self.fast_deletes = []
|
||||
|
||||
# Tracks deletion-order dependency for databases without transactions
|
||||
# or ability to defer constraint checks. Only concrete model classes
|
||||
# should be included, as the dependencies exist only between actual
|
||||
# database tables; proxy models are represented here by their concrete
|
||||
# parent.
|
||||
self.dependencies = defaultdict(set) # {model: {models}}
|
||||
|
||||
def add(self, objs, source=None, nullable=False, reverse_dependency=False):
|
||||
"""
|
||||
Add 'objs' to the collection of objects to be deleted. If the call is
|
||||
the result of a cascade, 'source' should be the model that caused it,
|
||||
and 'nullable' should be set to True if the relation can be null.
|
||||
|
||||
Return a list of all objects that were not already collected.
|
||||
"""
|
||||
if not objs:
|
||||
return []
|
||||
new_objs = []
|
||||
model = objs[0].__class__
|
||||
instances = self.data[model]
|
||||
for obj in objs:
|
||||
if obj not in instances:
|
||||
new_objs.append(obj)
|
||||
instances.update(new_objs)
|
||||
# Nullable relationships can be ignored -- they are nulled out before
|
||||
# deleting, and therefore do not affect the order in which objects have
|
||||
# to be deleted.
|
||||
if source is not None and not nullable:
|
||||
self.add_dependency(source, model, reverse_dependency=reverse_dependency)
|
||||
return new_objs
|
||||
|
||||
def add_dependency(self, model, dependency, reverse_dependency=False):
|
||||
if reverse_dependency:
|
||||
model, dependency = dependency, model
|
||||
self.dependencies[model._meta.concrete_model].add(dependency._meta.concrete_model)
|
||||
self.data.setdefault(dependency, self.data.default_factory())
|
||||
|
||||
def add_field_update(self, field, value, objs):
|
||||
"""
|
||||
Schedule a field update. 'objs' must be a homogeneous iterable
|
||||
collection of model instances (e.g. a QuerySet).
|
||||
"""
|
||||
if not objs:
|
||||
return
|
||||
model = objs[0].__class__
|
||||
self.field_updates[model][field, value].update(objs)
|
||||
|
||||
def add_restricted_objects(self, field, objs):
|
||||
if objs:
|
||||
model = objs[0].__class__
|
||||
self.restricted_objects[model][field].update(objs)
|
||||
|
||||
def clear_restricted_objects_from_set(self, model, objs):
|
||||
if model in self.restricted_objects:
|
||||
self.restricted_objects[model] = {
|
||||
field: items - objs
|
||||
for field, items in self.restricted_objects[model].items()
|
||||
}
|
||||
|
||||
def clear_restricted_objects_from_queryset(self, model, qs):
|
||||
if model in self.restricted_objects:
|
||||
objs = set(qs.filter(pk__in=[
|
||||
obj.pk
|
||||
for objs in self.restricted_objects[model].values() for obj in objs
|
||||
]))
|
||||
self.clear_restricted_objects_from_set(model, objs)
|
||||
|
||||
def _has_signal_listeners(self, model):
|
||||
return (
|
||||
signals.pre_delete.has_listeners(model) or
|
||||
signals.post_delete.has_listeners(model)
|
||||
)
|
||||
|
||||
def can_fast_delete(self, objs, from_field=None):
|
||||
"""
|
||||
Determine if the objects in the given queryset-like or single object
|
||||
can be fast-deleted. This can be done if there are no cascades, no
|
||||
parents and no signal listeners for the object class.
|
||||
|
||||
The 'from_field' tells where we are coming from - we need this to
|
||||
determine if the objects are in fact to be deleted. Allow also
|
||||
skipping parent -> child -> parent chain preventing fast delete of
|
||||
the child.
|
||||
"""
|
||||
if from_field and from_field.remote_field.on_delete is not CASCADE:
|
||||
return False
|
||||
if hasattr(objs, '_meta'):
|
||||
model = objs._meta.model
|
||||
elif hasattr(objs, 'model') and hasattr(objs, '_raw_delete'):
|
||||
model = objs.model
|
||||
else:
|
||||
return False
|
||||
if self._has_signal_listeners(model):
|
||||
return False
|
||||
# The use of from_field comes from the need to avoid cascade back to
|
||||
# parent when parent delete is cascading to child.
|
||||
opts = model._meta
|
||||
return (
|
||||
all(link == from_field for link in opts.concrete_model._meta.parents.values()) and
|
||||
# Foreign keys pointing to this model.
|
||||
all(
|
||||
related.field.remote_field.on_delete is DO_NOTHING
|
||||
for related in get_candidate_relations_to_delete(opts)
|
||||
) and (
|
||||
# Something like generic foreign key.
|
||||
not any(hasattr(field, 'bulk_related_objects') for field in opts.private_fields)
|
||||
)
|
||||
)
|
||||
|
||||
def get_del_batches(self, objs, fields):
|
||||
"""
|
||||
Return the objs in suitably sized batches for the used connection.
|
||||
"""
|
||||
field_names = [field.name for field in fields]
|
||||
conn_batch_size = max(
|
||||
connections[self.using].ops.bulk_batch_size(field_names, objs), 1)
|
||||
if len(objs) > conn_batch_size:
|
||||
return [objs[i:i + conn_batch_size]
|
||||
for i in range(0, len(objs), conn_batch_size)]
|
||||
else:
|
||||
return [objs]
|
||||
|
||||
def collect(self, objs, source=None, nullable=False, collect_related=True,
|
||||
source_attr=None, reverse_dependency=False, keep_parents=False,
|
||||
fail_on_restricted=True):
|
||||
"""
|
||||
Add 'objs' to the collection of objects to be deleted as well as all
|
||||
parent instances. 'objs' must be a homogeneous iterable collection of
|
||||
model instances (e.g. a QuerySet). If 'collect_related' is True,
|
||||
related objects will be handled by their respective on_delete handler.
|
||||
|
||||
If the call is the result of a cascade, 'source' should be the model
|
||||
that caused it and 'nullable' should be set to True, if the relation
|
||||
can be null.
|
||||
|
||||
If 'reverse_dependency' is True, 'source' will be deleted before the
|
||||
current model, rather than after. (Needed for cascading to parent
|
||||
models, the one case in which the cascade follows the forwards
|
||||
direction of an FK rather than the reverse direction.)
|
||||
|
||||
If 'keep_parents' is True, data of parent model's will be not deleted.
|
||||
|
||||
If 'fail_on_restricted' is False, error won't be raised even if it's
|
||||
prohibited to delete such objects due to RESTRICT, that defers
|
||||
restricted object checking in recursive calls where the top-level call
|
||||
may need to collect more objects to determine whether restricted ones
|
||||
can be deleted.
|
||||
"""
|
||||
if self.can_fast_delete(objs):
|
||||
self.fast_deletes.append(objs)
|
||||
return
|
||||
new_objs = self.add(objs, source, nullable,
|
||||
reverse_dependency=reverse_dependency)
|
||||
if not new_objs:
|
||||
return
|
||||
|
||||
model = new_objs[0].__class__
|
||||
|
||||
if not keep_parents:
|
||||
# Recursively collect concrete model's parent models, but not their
|
||||
# related objects. These will be found by meta.get_fields()
|
||||
concrete_model = model._meta.concrete_model
|
||||
for ptr in concrete_model._meta.parents.values():
|
||||
if ptr:
|
||||
parent_objs = [getattr(obj, ptr.name) for obj in new_objs]
|
||||
self.collect(parent_objs, source=model,
|
||||
source_attr=ptr.remote_field.related_name,
|
||||
collect_related=False,
|
||||
reverse_dependency=True,
|
||||
fail_on_restricted=False)
|
||||
if not collect_related:
|
||||
return
|
||||
|
||||
if keep_parents:
|
||||
parents = set(model._meta.get_parent_list())
|
||||
model_fast_deletes = defaultdict(list)
|
||||
protected_objects = defaultdict(list)
|
||||
for related in get_candidate_relations_to_delete(model._meta):
|
||||
# Preserve parent reverse relationships if keep_parents=True.
|
||||
if keep_parents and related.model in parents:
|
||||
continue
|
||||
field = related.field
|
||||
if field.remote_field.on_delete == DO_NOTHING:
|
||||
continue
|
||||
related_model = related.related_model
|
||||
if self.can_fast_delete(related_model, from_field=field):
|
||||
model_fast_deletes[related_model].append(field)
|
||||
continue
|
||||
batches = self.get_del_batches(new_objs, [field])
|
||||
for batch in batches:
|
||||
sub_objs = self.related_objects(related_model, [field], batch)
|
||||
# Non-referenced fields can be deferred if no signal receivers
|
||||
# are connected for the related model as they'll never be
|
||||
# exposed to the user. Skip field deferring when some
|
||||
# relationships are select_related as interactions between both
|
||||
# features are hard to get right. This should only happen in
|
||||
# the rare cases where .related_objects is overridden anyway.
|
||||
if not (sub_objs.query.select_related or self._has_signal_listeners(related_model)):
|
||||
referenced_fields = set(chain.from_iterable(
|
||||
(rf.attname for rf in rel.field.foreign_related_fields)
|
||||
for rel in get_candidate_relations_to_delete(related_model._meta)
|
||||
))
|
||||
sub_objs = sub_objs.only(*tuple(referenced_fields))
|
||||
if sub_objs:
|
||||
try:
|
||||
field.remote_field.on_delete(self, field, sub_objs, self.using)
|
||||
except ProtectedError as error:
|
||||
key = "'%s.%s'" % (field.model.__name__, field.name)
|
||||
protected_objects[key] += error.protected_objects
|
||||
if protected_objects:
|
||||
raise ProtectedError(
|
||||
'Cannot delete some instances of model %r because they are '
|
||||
'referenced through protected foreign keys: %s.' % (
|
||||
model.__name__,
|
||||
', '.join(protected_objects),
|
||||
),
|
||||
set(chain.from_iterable(protected_objects.values())),
|
||||
)
|
||||
for related_model, related_fields in model_fast_deletes.items():
|
||||
batches = self.get_del_batches(new_objs, related_fields)
|
||||
for batch in batches:
|
||||
sub_objs = self.related_objects(related_model, related_fields, batch)
|
||||
self.fast_deletes.append(sub_objs)
|
||||
for field in model._meta.private_fields:
|
||||
if hasattr(field, 'bulk_related_objects'):
|
||||
# It's something like generic foreign key.
|
||||
sub_objs = field.bulk_related_objects(new_objs, self.using)
|
||||
self.collect(sub_objs, source=model, nullable=True, fail_on_restricted=False)
|
||||
|
||||
if fail_on_restricted:
|
||||
# Raise an error if collected restricted objects (RESTRICT) aren't
|
||||
# candidates for deletion also collected via CASCADE.
|
||||
for related_model, instances in self.data.items():
|
||||
self.clear_restricted_objects_from_set(related_model, instances)
|
||||
for qs in self.fast_deletes:
|
||||
self.clear_restricted_objects_from_queryset(qs.model, qs)
|
||||
if self.restricted_objects.values():
|
||||
restricted_objects = defaultdict(list)
|
||||
for related_model, fields in self.restricted_objects.items():
|
||||
for field, objs in fields.items():
|
||||
if objs:
|
||||
key = "'%s.%s'" % (related_model.__name__, field.name)
|
||||
restricted_objects[key] += objs
|
||||
if restricted_objects:
|
||||
raise RestrictedError(
|
||||
'Cannot delete some instances of model %r because '
|
||||
'they are referenced through restricted foreign keys: '
|
||||
'%s.' % (
|
||||
model.__name__,
|
||||
', '.join(restricted_objects),
|
||||
),
|
||||
set(chain.from_iterable(restricted_objects.values())),
|
||||
)
|
||||
|
||||
def related_objects(self, related_model, related_fields, objs):
|
||||
"""
|
||||
Get a QuerySet of the related model to objs via related fields.
|
||||
"""
|
||||
predicate = query_utils.Q(
|
||||
*(
|
||||
(f'{related_field.name}__in', objs)
|
||||
for related_field in related_fields
|
||||
),
|
||||
_connector=query_utils.Q.OR,
|
||||
)
|
||||
return related_model._base_manager.using(self.using).filter(predicate)
|
||||
|
||||
def instances_with_model(self):
|
||||
for model, instances in self.data.items():
|
||||
for obj in instances:
|
||||
yield model, obj
|
||||
|
||||
def sort(self):
|
||||
sorted_models = []
|
||||
concrete_models = set()
|
||||
models = list(self.data)
|
||||
while len(sorted_models) < len(models):
|
||||
found = False
|
||||
for model in models:
|
||||
if model in sorted_models:
|
||||
continue
|
||||
dependencies = self.dependencies.get(model._meta.concrete_model)
|
||||
if not (dependencies and dependencies.difference(concrete_models)):
|
||||
sorted_models.append(model)
|
||||
concrete_models.add(model._meta.concrete_model)
|
||||
found = True
|
||||
if not found:
|
||||
return
|
||||
self.data = {model: self.data[model] for model in sorted_models}
|
||||
|
||||
def delete(self):
|
||||
# sort instance collections
|
||||
for model, instances in self.data.items():
|
||||
self.data[model] = sorted(instances, key=attrgetter("pk"))
|
||||
|
||||
# if possible, bring the models in an order suitable for databases that
|
||||
# don't support transactions or cannot defer constraint checks until the
|
||||
# end of a transaction.
|
||||
self.sort()
|
||||
# number of objects deleted for each model label
|
||||
deleted_counter = Counter()
|
||||
|
||||
# Optimize for the case with a single obj and no dependencies
|
||||
if len(self.data) == 1 and len(instances) == 1:
|
||||
instance = list(instances)[0]
|
||||
if self.can_fast_delete(instance):
|
||||
with transaction.mark_for_rollback_on_error(self.using):
|
||||
count = sql.DeleteQuery(model).delete_batch([instance.pk], self.using)
|
||||
setattr(instance, model._meta.pk.attname, None)
|
||||
return count, {model._meta.label: count}
|
||||
|
||||
with transaction.atomic(using=self.using, savepoint=False):
|
||||
# send pre_delete signals
|
||||
for model, obj in self.instances_with_model():
|
||||
if not model._meta.auto_created:
|
||||
signals.pre_delete.send(
|
||||
sender=model, instance=obj, using=self.using
|
||||
)
|
||||
|
||||
# fast deletes
|
||||
for qs in self.fast_deletes:
|
||||
count = qs._raw_delete(using=self.using)
|
||||
if count:
|
||||
deleted_counter[qs.model._meta.label] += count
|
||||
|
||||
# update fields
|
||||
for model, instances_for_fieldvalues in self.field_updates.items():
|
||||
for (field, value), instances in instances_for_fieldvalues.items():
|
||||
query = sql.UpdateQuery(model)
|
||||
query.update_batch([obj.pk for obj in instances],
|
||||
{field.name: value}, self.using)
|
||||
|
||||
# reverse instance collections
|
||||
for instances in self.data.values():
|
||||
instances.reverse()
|
||||
|
||||
# delete instances
|
||||
for model, instances in self.data.items():
|
||||
query = sql.DeleteQuery(model)
|
||||
pk_list = [obj.pk for obj in instances]
|
||||
count = query.delete_batch(pk_list, self.using)
|
||||
if count:
|
||||
deleted_counter[model._meta.label] += count
|
||||
|
||||
if not model._meta.auto_created:
|
||||
for obj in instances:
|
||||
signals.post_delete.send(
|
||||
sender=model, instance=obj, using=self.using
|
||||
)
|
||||
|
||||
# update collected instances
|
||||
for instances_for_fieldvalues in self.field_updates.values():
|
||||
for (field, value), instances in instances_for_fieldvalues.items():
|
||||
for obj in instances:
|
||||
setattr(obj, field.attname, value)
|
||||
for model, instances in self.data.items():
|
||||
for instance in instances:
|
||||
setattr(instance, model._meta.pk.attname, None)
|
||||
return sum(deleted_counter.values()), dict(deleted_counter)
|
||||
91
venv/Lib/site-packages/django/db/models/enums.py
Normal file
91
venv/Lib/site-packages/django/db/models/enums.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import enum
|
||||
from types import DynamicClassAttribute
|
||||
|
||||
from django.utils.functional import Promise
|
||||
|
||||
__all__ = ['Choices', 'IntegerChoices', 'TextChoices']
|
||||
|
||||
|
||||
class ChoicesMeta(enum.EnumMeta):
|
||||
"""A metaclass for creating a enum choices."""
|
||||
|
||||
def __new__(metacls, classname, bases, classdict, **kwds):
|
||||
labels = []
|
||||
for key in classdict._member_names:
|
||||
value = classdict[key]
|
||||
if (
|
||||
isinstance(value, (list, tuple)) and
|
||||
len(value) > 1 and
|
||||
isinstance(value[-1], (Promise, str))
|
||||
):
|
||||
*value, label = value
|
||||
value = tuple(value)
|
||||
else:
|
||||
label = key.replace('_', ' ').title()
|
||||
labels.append(label)
|
||||
# Use dict.__setitem__() to suppress defenses against double
|
||||
# assignment in enum's classdict.
|
||||
dict.__setitem__(classdict, key, value)
|
||||
cls = super().__new__(metacls, classname, bases, classdict, **kwds)
|
||||
for member, label in zip(cls.__members__.values(), labels):
|
||||
member._label_ = label
|
||||
return enum.unique(cls)
|
||||
|
||||
def __contains__(cls, member):
|
||||
if not isinstance(member, enum.Enum):
|
||||
# Allow non-enums to match against member values.
|
||||
return any(x.value == member for x in cls)
|
||||
return super().__contains__(member)
|
||||
|
||||
@property
|
||||
def names(cls):
|
||||
empty = ['__empty__'] if hasattr(cls, '__empty__') else []
|
||||
return empty + [member.name for member in cls]
|
||||
|
||||
@property
|
||||
def choices(cls):
|
||||
empty = [(None, cls.__empty__)] if hasattr(cls, '__empty__') else []
|
||||
return empty + [(member.value, member.label) for member in cls]
|
||||
|
||||
@property
|
||||
def labels(cls):
|
||||
return [label for _, label in cls.choices]
|
||||
|
||||
@property
|
||||
def values(cls):
|
||||
return [value for value, _ in cls.choices]
|
||||
|
||||
|
||||
class Choices(enum.Enum, metaclass=ChoicesMeta):
|
||||
"""Class for creating enumerated choices."""
|
||||
|
||||
@DynamicClassAttribute
|
||||
def label(self):
|
||||
return self._label_
|
||||
|
||||
@property
|
||||
def do_not_call_in_templates(self):
|
||||
return True
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Use value when cast to str, so that Choices set as model instance
|
||||
attributes are rendered as expected in templates and similar contexts.
|
||||
"""
|
||||
return str(self.value)
|
||||
|
||||
# A similar format was proposed for Python 3.10.
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__qualname__}.{self._name_}'
|
||||
|
||||
|
||||
class IntegerChoices(int, Choices):
|
||||
"""Class for creating enumerated integer choices."""
|
||||
pass
|
||||
|
||||
|
||||
class TextChoices(str, Choices):
|
||||
"""Class for creating enumerated string choices."""
|
||||
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name
|
||||
1463
venv/Lib/site-packages/django/db/models/expressions.py
Normal file
1463
venv/Lib/site-packages/django/db/models/expressions.py
Normal file
File diff suppressed because it is too large
Load Diff
2530
venv/Lib/site-packages/django/db/models/fields/__init__.py
Normal file
2530
venv/Lib/site-packages/django/db/models/fields/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
481
venv/Lib/site-packages/django/db/models/fields/files.py
Normal file
481
venv/Lib/site-packages/django/db/models/fields/files.py
Normal file
@@ -0,0 +1,481 @@
|
||||
import datetime
|
||||
import posixpath
|
||||
|
||||
from django import forms
|
||||
from django.core import checks
|
||||
from django.core.files.base import File
|
||||
from django.core.files.images import ImageFile
|
||||
from django.core.files.storage import Storage, default_storage
|
||||
from django.core.files.utils import validate_file_name
|
||||
from django.db.models import signals
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.query_utils import DeferredAttribute
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class FieldFile(File):
|
||||
def __init__(self, instance, field, name):
|
||||
super().__init__(None, name)
|
||||
self.instance = instance
|
||||
self.field = field
|
||||
self.storage = field.storage
|
||||
self._committed = True
|
||||
|
||||
def __eq__(self, other):
|
||||
# Older code may be expecting FileField values to be simple strings.
|
||||
# By overriding the == operator, it can remain backwards compatibility.
|
||||
if hasattr(other, 'name'):
|
||||
return self.name == other.name
|
||||
return self.name == other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
# The standard File contains most of the necessary properties, but
|
||||
# FieldFiles can be instantiated without a name, so that needs to
|
||||
# be checked for here.
|
||||
|
||||
def _require_file(self):
|
||||
if not self:
|
||||
raise ValueError("The '%s' attribute has no file associated with it." % self.field.name)
|
||||
|
||||
def _get_file(self):
|
||||
self._require_file()
|
||||
if getattr(self, '_file', None) is None:
|
||||
self._file = self.storage.open(self.name, 'rb')
|
||||
return self._file
|
||||
|
||||
def _set_file(self, file):
|
||||
self._file = file
|
||||
|
||||
def _del_file(self):
|
||||
del self._file
|
||||
|
||||
file = property(_get_file, _set_file, _del_file)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
self._require_file()
|
||||
return self.storage.path(self.name)
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
self._require_file()
|
||||
return self.storage.url(self.name)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
self._require_file()
|
||||
if not self._committed:
|
||||
return self.file.size
|
||||
return self.storage.size(self.name)
|
||||
|
||||
def open(self, mode='rb'):
|
||||
self._require_file()
|
||||
if getattr(self, '_file', None) is None:
|
||||
self.file = self.storage.open(self.name, mode)
|
||||
else:
|
||||
self.file.open(mode)
|
||||
return self
|
||||
# open() doesn't alter the file's contents, but it does reset the pointer
|
||||
open.alters_data = True
|
||||
|
||||
# In addition to the standard File API, FieldFiles have extra methods
|
||||
# to further manipulate the underlying file, as well as update the
|
||||
# associated model instance.
|
||||
|
||||
def save(self, name, content, save=True):
|
||||
name = self.field.generate_filename(self.instance, name)
|
||||
self.name = self.storage.save(name, content, max_length=self.field.max_length)
|
||||
setattr(self.instance, self.field.attname, self.name)
|
||||
self._committed = True
|
||||
|
||||
# Save the object because it has changed, unless save is False
|
||||
if save:
|
||||
self.instance.save()
|
||||
save.alters_data = True
|
||||
|
||||
def delete(self, save=True):
|
||||
if not self:
|
||||
return
|
||||
# Only close the file if it's already open, which we know by the
|
||||
# presence of self._file
|
||||
if hasattr(self, '_file'):
|
||||
self.close()
|
||||
del self.file
|
||||
|
||||
self.storage.delete(self.name)
|
||||
|
||||
self.name = None
|
||||
setattr(self.instance, self.field.attname, self.name)
|
||||
self._committed = False
|
||||
|
||||
if save:
|
||||
self.instance.save()
|
||||
delete.alters_data = True
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
file = getattr(self, '_file', None)
|
||||
return file is None or file.closed
|
||||
|
||||
def close(self):
|
||||
file = getattr(self, '_file', None)
|
||||
if file is not None:
|
||||
file.close()
|
||||
|
||||
def __getstate__(self):
|
||||
# FieldFile needs access to its associated model field, an instance and
|
||||
# the file's name. Everything else will be restored later, by
|
||||
# FileDescriptor below.
|
||||
return {
|
||||
'name': self.name,
|
||||
'closed': False,
|
||||
'_committed': True,
|
||||
'_file': None,
|
||||
'instance': self.instance,
|
||||
'field': self.field,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self.storage = self.field.storage
|
||||
|
||||
|
||||
class FileDescriptor(DeferredAttribute):
|
||||
"""
|
||||
The descriptor for the file attribute on the model instance. Return a
|
||||
FieldFile when accessed so you can write code like::
|
||||
|
||||
>>> from myapp.models import MyModel
|
||||
>>> instance = MyModel.objects.get(pk=1)
|
||||
>>> instance.file.size
|
||||
|
||||
Assign a file object on assignment so you can do::
|
||||
|
||||
>>> with open('/path/to/hello.world') as f:
|
||||
... instance.file = File(f)
|
||||
"""
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
|
||||
# This is slightly complicated, so worth an explanation.
|
||||
# instance.file`needs to ultimately return some instance of `File`,
|
||||
# probably a subclass. Additionally, this returned object needs to have
|
||||
# the FieldFile API so that users can easily do things like
|
||||
# instance.file.path and have that delegated to the file storage engine.
|
||||
# Easy enough if we're strict about assignment in __set__, but if you
|
||||
# peek below you can see that we're not. So depending on the current
|
||||
# value of the field we have to dynamically construct some sort of
|
||||
# "thing" to return.
|
||||
|
||||
# The instance dict contains whatever was originally assigned
|
||||
# in __set__.
|
||||
file = super().__get__(instance, cls)
|
||||
|
||||
# If this value is a string (instance.file = "path/to/file") or None
|
||||
# then we simply wrap it with the appropriate attribute class according
|
||||
# to the file field. [This is FieldFile for FileFields and
|
||||
# ImageFieldFile for ImageFields; it's also conceivable that user
|
||||
# subclasses might also want to subclass the attribute class]. This
|
||||
# object understands how to convert a path to a file, and also how to
|
||||
# handle None.
|
||||
if isinstance(file, str) or file is None:
|
||||
attr = self.field.attr_class(instance, self.field, file)
|
||||
instance.__dict__[self.field.attname] = attr
|
||||
|
||||
# Other types of files may be assigned as well, but they need to have
|
||||
# the FieldFile interface added to them. Thus, we wrap any other type of
|
||||
# File inside a FieldFile (well, the field's attr_class, which is
|
||||
# usually FieldFile).
|
||||
elif isinstance(file, File) and not isinstance(file, FieldFile):
|
||||
file_copy = self.field.attr_class(instance, self.field, file.name)
|
||||
file_copy.file = file
|
||||
file_copy._committed = False
|
||||
instance.__dict__[self.field.attname] = file_copy
|
||||
|
||||
# Finally, because of the (some would say boneheaded) way pickle works,
|
||||
# the underlying FieldFile might not actually itself have an associated
|
||||
# file. So we need to reset the details of the FieldFile in those cases.
|
||||
elif isinstance(file, FieldFile) and not hasattr(file, 'field'):
|
||||
file.instance = instance
|
||||
file.field = self.field
|
||||
file.storage = self.field.storage
|
||||
|
||||
# Make sure that the instance is correct.
|
||||
elif isinstance(file, FieldFile) and instance is not file.instance:
|
||||
file.instance = instance
|
||||
|
||||
# That was fun, wasn't it?
|
||||
return instance.__dict__[self.field.attname]
|
||||
|
||||
def __set__(self, instance, value):
|
||||
instance.__dict__[self.field.attname] = value
|
||||
|
||||
|
||||
class FileField(Field):
|
||||
|
||||
# The class to wrap instance attributes in. Accessing the file object off
|
||||
# the instance will always return an instance of attr_class.
|
||||
attr_class = FieldFile
|
||||
|
||||
# The descriptor to use for accessing the attribute off of the class.
|
||||
descriptor_class = FileDescriptor
|
||||
|
||||
description = _("File")
|
||||
|
||||
def __init__(self, verbose_name=None, name=None, upload_to='', storage=None, **kwargs):
|
||||
self._primary_key_set_explicitly = 'primary_key' in kwargs
|
||||
|
||||
self.storage = storage or default_storage
|
||||
if callable(self.storage):
|
||||
# Hold a reference to the callable for deconstruct().
|
||||
self._storage_callable = self.storage
|
||||
self.storage = self.storage()
|
||||
if not isinstance(self.storage, Storage):
|
||||
raise TypeError(
|
||||
"%s.storage must be a subclass/instance of %s.%s"
|
||||
% (self.__class__.__qualname__, Storage.__module__, Storage.__qualname__)
|
||||
)
|
||||
self.upload_to = upload_to
|
||||
|
||||
kwargs.setdefault('max_length', 100)
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_primary_key(),
|
||||
*self._check_upload_to(),
|
||||
]
|
||||
|
||||
def _check_primary_key(self):
|
||||
if self._primary_key_set_explicitly:
|
||||
return [
|
||||
checks.Error(
|
||||
"'primary_key' is not a valid argument for a %s." % self.__class__.__name__,
|
||||
obj=self,
|
||||
id='fields.E201',
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _check_upload_to(self):
|
||||
if isinstance(self.upload_to, str) and self.upload_to.startswith('/'):
|
||||
return [
|
||||
checks.Error(
|
||||
"%s's 'upload_to' argument must be a relative path, not an "
|
||||
"absolute path." % self.__class__.__name__,
|
||||
obj=self,
|
||||
id='fields.E202',
|
||||
hint='Remove the leading slash.',
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if kwargs.get("max_length") == 100:
|
||||
del kwargs["max_length"]
|
||||
kwargs['upload_to'] = self.upload_to
|
||||
if self.storage is not default_storage:
|
||||
kwargs['storage'] = getattr(self, '_storage_callable', self.storage)
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_internal_type(self):
|
||||
return "FileField"
|
||||
|
||||
def get_prep_value(self, value):
|
||||
value = super().get_prep_value(value)
|
||||
# Need to convert File objects provided via a form to string for database insertion
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
file = super().pre_save(model_instance, add)
|
||||
if file and not file._committed:
|
||||
# Commit the file to storage prior to saving the model
|
||||
file.save(file.name, file.file, save=False)
|
||||
return file
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
setattr(cls, self.attname, self.descriptor_class(self))
|
||||
|
||||
def generate_filename(self, instance, filename):
|
||||
"""
|
||||
Apply (if callable) or prepend (if a string) upload_to to the filename,
|
||||
then delegate further processing of the name to the storage backend.
|
||||
Until the storage layer, all file paths are expected to be Unix style
|
||||
(with forward slashes).
|
||||
"""
|
||||
if callable(self.upload_to):
|
||||
filename = self.upload_to(instance, filename)
|
||||
else:
|
||||
dirname = datetime.datetime.now().strftime(str(self.upload_to))
|
||||
filename = posixpath.join(dirname, filename)
|
||||
filename = validate_file_name(filename, allow_relative_path=True)
|
||||
return self.storage.generate_filename(filename)
|
||||
|
||||
def save_form_data(self, instance, data):
|
||||
# Important: None means "no change", other false value means "clear"
|
||||
# This subtle distinction (rather than a more explicit marker) is
|
||||
# needed because we need to consume values that are also sane for a
|
||||
# regular (non Model-) Form to find in its cleaned_data dictionary.
|
||||
if data is not None:
|
||||
# This value will be converted to str and stored in the
|
||||
# database, so leaving False as-is is not acceptable.
|
||||
setattr(instance, self.name, data or '')
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(**{
|
||||
'form_class': forms.FileField,
|
||||
'max_length': self.max_length,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
|
||||
class ImageFileDescriptor(FileDescriptor):
|
||||
"""
|
||||
Just like the FileDescriptor, but for ImageFields. The only difference is
|
||||
assigning the width/height to the width_field/height_field, if appropriate.
|
||||
"""
|
||||
def __set__(self, instance, value):
|
||||
previous_file = instance.__dict__.get(self.field.attname)
|
||||
super().__set__(instance, value)
|
||||
|
||||
# To prevent recalculating image dimensions when we are instantiating
|
||||
# an object from the database (bug #11084), only update dimensions if
|
||||
# the field had a value before this assignment. Since the default
|
||||
# value for FileField subclasses is an instance of field.attr_class,
|
||||
# previous_file will only be None when we are called from
|
||||
# Model.__init__(). The ImageField.update_dimension_fields method
|
||||
# hooked up to the post_init signal handles the Model.__init__() cases.
|
||||
# Assignment happening outside of Model.__init__() will trigger the
|
||||
# update right here.
|
||||
if previous_file is not None:
|
||||
self.field.update_dimension_fields(instance, force=True)
|
||||
|
||||
|
||||
class ImageFieldFile(ImageFile, FieldFile):
|
||||
def delete(self, save=True):
|
||||
# Clear the image dimensions cache
|
||||
if hasattr(self, '_dimensions_cache'):
|
||||
del self._dimensions_cache
|
||||
super().delete(save)
|
||||
|
||||
|
||||
class ImageField(FileField):
|
||||
attr_class = ImageFieldFile
|
||||
descriptor_class = ImageFileDescriptor
|
||||
description = _("Image")
|
||||
|
||||
def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs):
|
||||
self.width_field, self.height_field = width_field, height_field
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_image_library_installed(),
|
||||
]
|
||||
|
||||
def _check_image_library_installed(self):
|
||||
try:
|
||||
from PIL import Image # NOQA
|
||||
except ImportError:
|
||||
return [
|
||||
checks.Error(
|
||||
'Cannot use ImageField because Pillow is not installed.',
|
||||
hint=('Get Pillow at https://pypi.org/project/Pillow/ '
|
||||
'or run command "python -m pip install Pillow".'),
|
||||
obj=self,
|
||||
id='fields.E210',
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.width_field:
|
||||
kwargs['width_field'] = self.width_field
|
||||
if self.height_field:
|
||||
kwargs['height_field'] = self.height_field
|
||||
return name, path, args, kwargs
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
# Attach update_dimension_fields so that dimension fields declared
|
||||
# after their corresponding image field don't stay cleared by
|
||||
# Model.__init__, see bug #11196.
|
||||
# Only run post-initialization dimension update on non-abstract models
|
||||
if not cls._meta.abstract:
|
||||
signals.post_init.connect(self.update_dimension_fields, sender=cls)
|
||||
|
||||
def update_dimension_fields(self, instance, force=False, *args, **kwargs):
|
||||
"""
|
||||
Update field's width and height fields, if defined.
|
||||
|
||||
This method is hooked up to model's post_init signal to update
|
||||
dimensions after instantiating a model instance. However, dimensions
|
||||
won't be updated if the dimensions fields are already populated. This
|
||||
avoids unnecessary recalculation when loading an object from the
|
||||
database.
|
||||
|
||||
Dimensions can be forced to update with force=True, which is how
|
||||
ImageFileDescriptor.__set__ calls this method.
|
||||
"""
|
||||
# Nothing to update if the field doesn't have dimension fields or if
|
||||
# the field is deferred.
|
||||
has_dimension_fields = self.width_field or self.height_field
|
||||
if not has_dimension_fields or self.attname not in instance.__dict__:
|
||||
return
|
||||
|
||||
# getattr will call the ImageFileDescriptor's __get__ method, which
|
||||
# coerces the assigned value into an instance of self.attr_class
|
||||
# (ImageFieldFile in this case).
|
||||
file = getattr(instance, self.attname)
|
||||
|
||||
# Nothing to update if we have no file and not being forced to update.
|
||||
if not file and not force:
|
||||
return
|
||||
|
||||
dimension_fields_filled = not(
|
||||
(self.width_field and not getattr(instance, self.width_field)) or
|
||||
(self.height_field and not getattr(instance, self.height_field))
|
||||
)
|
||||
# When both dimension fields have values, we are most likely loading
|
||||
# data from the database or updating an image field that already had
|
||||
# an image stored. In the first case, we don't want to update the
|
||||
# dimension fields because we are already getting their values from the
|
||||
# database. In the second case, we do want to update the dimensions
|
||||
# fields and will skip this return because force will be True since we
|
||||
# were called from ImageFileDescriptor.__set__.
|
||||
if dimension_fields_filled and not force:
|
||||
return
|
||||
|
||||
# file should be an instance of ImageFieldFile or should be None.
|
||||
if file:
|
||||
width = file.width
|
||||
height = file.height
|
||||
else:
|
||||
# No file, so clear dimensions fields.
|
||||
width = None
|
||||
height = None
|
||||
|
||||
# Update the width and height fields.
|
||||
if self.width_field:
|
||||
setattr(instance, self.width_field, width)
|
||||
if self.height_field:
|
||||
setattr(instance, self.height_field, height)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(**{
|
||||
'form_class': forms.ImageField,
|
||||
**kwargs,
|
||||
})
|
||||
538
venv/Lib/site-packages/django/db/models/fields/json.py
Normal file
538
venv/Lib/site-packages/django/db/models/fields/json.py
Normal file
@@ -0,0 +1,538 @@
|
||||
import json
|
||||
|
||||
from django import forms
|
||||
from django.core import checks, exceptions
|
||||
from django.db import NotSupportedError, connections, router
|
||||
from django.db.models import lookups
|
||||
from django.db.models.lookups import PostgresOperatorLookup, Transform
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from . import Field
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
|
||||
__all__ = ['JSONField']
|
||||
|
||||
|
||||
class JSONField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
description = _('A JSON object')
|
||||
default_error_messages = {
|
||||
'invalid': _('Value must be valid JSON.'),
|
||||
}
|
||||
_default_hint = ('dict', '{}')
|
||||
|
||||
def __init__(
|
||||
self, verbose_name=None, name=None, encoder=None, decoder=None,
|
||||
**kwargs,
|
||||
):
|
||||
if encoder and not callable(encoder):
|
||||
raise ValueError('The encoder parameter must be a callable object.')
|
||||
if decoder and not callable(decoder):
|
||||
raise ValueError('The decoder parameter must be a callable object.')
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
databases = kwargs.get('databases') or []
|
||||
errors.extend(self._check_supported(databases))
|
||||
return errors
|
||||
|
||||
def _check_supported(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor and
|
||||
self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not (
|
||||
'supports_json_field' in self.model._meta.required_db_features or
|
||||
connection.features.supports_json_field
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
'%s does not support JSONFields.'
|
||||
% connection.display_name,
|
||||
obj=self.model,
|
||||
id='fields.E180',
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.encoder is not None:
|
||||
kwargs['encoder'] = self.encoder
|
||||
if self.decoder is not None:
|
||||
kwargs['decoder'] = self.decoder
|
||||
return name, path, args, kwargs
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
if value is None:
|
||||
return value
|
||||
# Some backends (SQLite at least) extract non-string values in their
|
||||
# SQL datatypes.
|
||||
if isinstance(expression, KeyTransform) and not isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.loads(value, cls=self.decoder)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
def get_internal_type(self):
|
||||
return 'JSONField'
|
||||
|
||||
def get_prep_value(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
return json.dumps(value, cls=self.encoder)
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
return KeyTransformFactory(name)
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
super().validate(value, model_instance)
|
||||
try:
|
||||
json.dumps(value, cls=self.encoder)
|
||||
except TypeError:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages['invalid'],
|
||||
code='invalid',
|
||||
params={'value': value},
|
||||
)
|
||||
|
||||
def value_to_string(self, obj):
|
||||
return self.value_from_object(obj)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(**{
|
||||
'form_class': forms.JSONField,
|
||||
'encoder': self.encoder,
|
||||
'decoder': self.decoder,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
|
||||
def compile_json_path(key_transforms, include_root=True):
|
||||
path = ['$'] if include_root else []
|
||||
for key_transform in key_transforms:
|
||||
try:
|
||||
num = int(key_transform)
|
||||
except ValueError: # non-integer
|
||||
path.append('.')
|
||||
path.append(json.dumps(key_transform))
|
||||
else:
|
||||
path.append('[%s]' % num)
|
||||
return ''.join(path)
|
||||
|
||||
|
||||
class DataContains(PostgresOperatorLookup):
|
||||
lookup_name = 'contains'
|
||||
postgres_operator = '@>'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
'contains lookup is not supported on this database backend.'
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(lhs_params) + tuple(rhs_params)
|
||||
return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params
|
||||
|
||||
|
||||
class ContainedBy(PostgresOperatorLookup):
|
||||
lookup_name = 'contained_by'
|
||||
postgres_operator = '<@'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
'contained_by lookup is not supported on this database backend.'
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(rhs_params) + tuple(lhs_params)
|
||||
return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params
|
||||
|
||||
|
||||
class HasKeyLookup(PostgresOperatorLookup):
|
||||
logical_operator = None
|
||||
|
||||
def as_sql(self, compiler, connection, template=None):
|
||||
# Process JSON path from the left-hand side.
|
||||
if isinstance(self.lhs, KeyTransform):
|
||||
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection)
|
||||
lhs_json_path = compile_json_path(lhs_key_transforms)
|
||||
else:
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
lhs_json_path = '$'
|
||||
sql = template % lhs
|
||||
# Process JSON path from the right-hand side.
|
||||
rhs = self.rhs
|
||||
rhs_params = []
|
||||
if not isinstance(rhs, (list, tuple)):
|
||||
rhs = [rhs]
|
||||
for key in rhs:
|
||||
if isinstance(key, KeyTransform):
|
||||
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
|
||||
else:
|
||||
rhs_key_transforms = [key]
|
||||
rhs_params.append('%s%s' % (
|
||||
lhs_json_path,
|
||||
compile_json_path(rhs_key_transforms, include_root=False),
|
||||
))
|
||||
# Add condition for each key.
|
||||
if self.logical_operator:
|
||||
sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params))
|
||||
return sql, tuple(lhs_params) + tuple(rhs_params)
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)")
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')")
|
||||
# Add paths directly into SQL because path expressions cannot be passed
|
||||
# as bind variables on Oracle.
|
||||
return sql % tuple(params), []
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
*_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
|
||||
for key in rhs_key_transforms[:-1]:
|
||||
self.lhs = KeyTransform(key, self.lhs)
|
||||
self.rhs = rhs_key_transforms[-1]
|
||||
return super().as_postgresql(compiler, connection)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL')
|
||||
|
||||
|
||||
class HasKey(HasKeyLookup):
|
||||
lookup_name = 'has_key'
|
||||
postgres_operator = '?'
|
||||
prepare_rhs = False
|
||||
|
||||
|
||||
class HasKeys(HasKeyLookup):
|
||||
lookup_name = 'has_keys'
|
||||
postgres_operator = '?&'
|
||||
logical_operator = ' AND '
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return [str(item) for item in self.rhs]
|
||||
|
||||
|
||||
class HasAnyKeys(HasKeys):
|
||||
lookup_name = 'has_any_keys'
|
||||
postgres_operator = '?|'
|
||||
logical_operator = ' OR '
|
||||
|
||||
|
||||
class CaseInsensitiveMixin:
|
||||
"""
|
||||
Mixin to allow case-insensitive comparison of JSON values on MySQL.
|
||||
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
|
||||
Because utf8mb4_bin is a binary collation, comparison of JSON values is
|
||||
case-sensitive.
|
||||
"""
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if connection.vendor == 'mysql':
|
||||
return 'LOWER(%s)' % lhs, lhs_params
|
||||
return lhs, lhs_params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == 'mysql':
|
||||
return 'LOWER(%s)' % rhs, rhs_params
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class JSONExact(lookups.Exact):
|
||||
can_use_none_as_rhs = True
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
# Treat None lookup values as null.
|
||||
if rhs == '%s' and rhs_params == [None]:
|
||||
rhs_params = ['null']
|
||||
if connection.vendor == 'mysql':
|
||||
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
|
||||
rhs = rhs % tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
|
||||
pass
|
||||
|
||||
|
||||
JSONField.register_lookup(DataContains)
|
||||
JSONField.register_lookup(ContainedBy)
|
||||
JSONField.register_lookup(HasKey)
|
||||
JSONField.register_lookup(HasKeys)
|
||||
JSONField.register_lookup(HasAnyKeys)
|
||||
JSONField.register_lookup(JSONExact)
|
||||
JSONField.register_lookup(JSONIContains)
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
postgres_operator = '->'
|
||||
postgres_nested_operator = '#>'
|
||||
|
||||
def __init__(self, key_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.key_name = str(key_name)
|
||||
|
||||
def preprocess_lhs(self, compiler, connection):
|
||||
key_transforms = [self.key_name]
|
||||
previous = self.lhs
|
||||
while isinstance(previous, KeyTransform):
|
||||
key_transforms.insert(0, previous.key_name)
|
||||
previous = previous.lhs
|
||||
lhs, params = compiler.compile(previous)
|
||||
if connection.vendor == 'oracle':
|
||||
# Escape string-formatting.
|
||||
key_transforms = [key.replace('%', '%%') for key in key_transforms]
|
||||
return lhs, params, key_transforms
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return (
|
||||
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" %
|
||||
((lhs, json_path) * 2)
|
||||
), tuple(params) * 2
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
if len(key_transforms) > 1:
|
||||
sql = '(%s %s %%s)' % (lhs, self.postgres_nested_operator)
|
||||
return sql, tuple(params) + (key_transforms,)
|
||||
try:
|
||||
lookup = int(self.key_name)
|
||||
except ValueError:
|
||||
lookup = self.key_name
|
||||
return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
datatype_values = ','.join([
|
||||
repr(datatype) for datatype in connection.ops.jsonfield_datatype_values
|
||||
])
|
||||
return (
|
||||
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
|
||||
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
|
||||
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
|
||||
|
||||
|
||||
class KeyTextTransform(KeyTransform):
|
||||
postgres_operator = '->>'
|
||||
postgres_nested_operator = '#>>'
|
||||
|
||||
|
||||
class KeyTransformTextLookupMixin:
|
||||
"""
|
||||
Mixin for combining with a lookup expecting a text lhs from a JSONField
|
||||
key lookup. On PostgreSQL, make use of the ->> operator instead of casting
|
||||
key values to text and performing the lookup on the resulting
|
||||
representation.
|
||||
"""
|
||||
def __init__(self, key_transform, *args, **kwargs):
|
||||
if not isinstance(key_transform, KeyTransform):
|
||||
raise TypeError(
|
||||
'Transform should be an instance of KeyTransform in order to '
|
||||
'use this lookup.'
|
||||
)
|
||||
key_text_transform = KeyTextTransform(
|
||||
key_transform.key_name, *key_transform.source_expressions,
|
||||
**key_transform.extra,
|
||||
)
|
||||
super().__init__(key_text_transform, *args, **kwargs)
|
||||
|
||||
|
||||
class KeyTransformIsNull(lookups.IsNull):
|
||||
# key__isnull=False is the same as has_key='key'
|
||||
def as_oracle(self, compiler, connection):
|
||||
sql, params = HasKey(
|
||||
self.lhs.lhs,
|
||||
self.lhs.key_name,
|
||||
).as_oracle(compiler, connection)
|
||||
if not self.rhs:
|
||||
return sql, params
|
||||
# Column doesn't have a key or IS NULL.
|
||||
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
|
||||
return '(NOT %s OR %s IS NULL)' % (sql, lhs), tuple(params) + tuple(lhs_params)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
template = 'JSON_TYPE(%s, %%s) IS NULL'
|
||||
if not self.rhs:
|
||||
template = 'JSON_TYPE(%s, %%s) IS NOT NULL'
|
||||
return HasKey(self.lhs.lhs, self.lhs.key_name).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
)
|
||||
|
||||
|
||||
class KeyTransformIn(lookups.In):
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
sql, params = super().resolve_expression_parameter(
|
||||
compiler, connection, sql, param,
|
||||
)
|
||||
if (
|
||||
not hasattr(param, 'as_sql') and
|
||||
not connection.features.has_native_json_field
|
||||
):
|
||||
if connection.vendor == 'oracle':
|
||||
value = json.loads(param)
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
if isinstance(value, (list, dict)):
|
||||
sql = sql % 'JSON_QUERY'
|
||||
else:
|
||||
sql = sql % 'JSON_VALUE'
|
||||
elif connection.vendor == 'mysql' or (
|
||||
connection.vendor == 'sqlite' and
|
||||
params[0] not in connection.ops.jsonfield_datatype_values
|
||||
):
|
||||
sql = "JSON_EXTRACT(%s, '$')"
|
||||
if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
|
||||
sql = 'JSON_UNQUOTE(%s)' % sql
|
||||
return sql, params
|
||||
|
||||
|
||||
class KeyTransformExact(JSONExact):
|
||||
def process_rhs(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
return super(lookups.Exact, self).process_rhs(compiler, connection)
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == 'oracle':
|
||||
func = []
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
for value in rhs_params:
|
||||
value = json.loads(value)
|
||||
if isinstance(value, (list, dict)):
|
||||
func.append(sql % 'JSON_QUERY')
|
||||
else:
|
||||
func.append(sql % 'JSON_VALUE')
|
||||
rhs = rhs % tuple(func)
|
||||
elif connection.vendor == 'sqlite':
|
||||
func = []
|
||||
for value in rhs_params:
|
||||
if value in connection.ops.jsonfield_datatype_values:
|
||||
func.append('%s')
|
||||
else:
|
||||
func.append("JSON_EXTRACT(%s, '$')")
|
||||
rhs = rhs % tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if rhs_params == ['null']:
|
||||
# Field has key and it's NULL.
|
||||
has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name)
|
||||
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
|
||||
is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True)
|
||||
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
|
||||
return (
|
||||
'%s AND %s' % (has_key_sql, is_null_sql),
|
||||
tuple(has_key_params) + tuple(is_null_params),
|
||||
)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformNumericLookupMixin:
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if not connection.features.has_native_json_field:
|
||||
rhs_params = [json.loads(value) for value in rhs_params]
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformIn)
|
||||
KeyTransform.register_lookup(KeyTransformExact)
|
||||
KeyTransform.register_lookup(KeyTransformIExact)
|
||||
KeyTransform.register_lookup(KeyTransformIsNull)
|
||||
KeyTransform.register_lookup(KeyTransformIContains)
|
||||
KeyTransform.register_lookup(KeyTransformStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformRegex)
|
||||
KeyTransform.register_lookup(KeyTransformIRegex)
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformLt)
|
||||
KeyTransform.register_lookup(KeyTransformLte)
|
||||
KeyTransform.register_lookup(KeyTransformGt)
|
||||
KeyTransform.register_lookup(KeyTransformGte)
|
||||
|
||||
|
||||
class KeyTransformFactory:
|
||||
|
||||
def __init__(self, key_name):
|
||||
self.key_name = key_name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return KeyTransform(self.key_name, *args, **kwargs)
|
||||
56
venv/Lib/site-packages/django/db/models/fields/mixins.py
Normal file
56
venv/Lib/site-packages/django/db/models/fields/mixins.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from django.core import checks
|
||||
|
||||
NOT_PROVIDED = object()
|
||||
|
||||
|
||||
class FieldCacheMixin:
|
||||
"""Provide an API for working with the model's fields value cache."""
|
||||
|
||||
def get_cache_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cached_value(self, instance, default=NOT_PROVIDED):
|
||||
cache_name = self.get_cache_name()
|
||||
try:
|
||||
return instance._state.fields_cache[cache_name]
|
||||
except KeyError:
|
||||
if default is NOT_PROVIDED:
|
||||
raise
|
||||
return default
|
||||
|
||||
def is_cached(self, instance):
|
||||
return self.get_cache_name() in instance._state.fields_cache
|
||||
|
||||
def set_cached_value(self, instance, value):
|
||||
instance._state.fields_cache[self.get_cache_name()] = value
|
||||
|
||||
def delete_cached_value(self, instance):
|
||||
del instance._state.fields_cache[self.get_cache_name()]
|
||||
|
||||
|
||||
class CheckFieldDefaultMixin:
|
||||
_default_hint = ('<valid default>', '<invalid default>')
|
||||
|
||||
def _check_default(self):
|
||||
if self.has_default() and self.default is not None and not callable(self.default):
|
||||
return [
|
||||
checks.Warning(
|
||||
"%s default should be a callable instead of an instance "
|
||||
"so that it's not shared between all field instances." % (
|
||||
self.__class__.__name__,
|
||||
),
|
||||
hint=(
|
||||
'Use a callable instead, e.g., use `%s` instead of '
|
||||
'`%s`.' % self._default_hint
|
||||
),
|
||||
obj=self,
|
||||
id='fields.E010',
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
errors.extend(self._check_default())
|
||||
return errors
|
||||
18
venv/Lib/site-packages/django/db/models/fields/proxy.py
Normal file
18
venv/Lib/site-packages/django/db/models/fields/proxy.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Field-like classes that aren't really fields. It's easier to use objects that
|
||||
have the same attributes as fields sometimes (avoids a lot of special casing).
|
||||
"""
|
||||
|
||||
from django.db.models import fields
|
||||
|
||||
|
||||
class OrderWrt(fields.IntegerField):
|
||||
"""
|
||||
A proxy for the _order database field that is used when
|
||||
Meta.order_with_respect_to is specified.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['name'] = '_order'
|
||||
kwargs['editable'] = False
|
||||
super().__init__(*args, **kwargs)
|
||||
1721
venv/Lib/site-packages/django/db/models/fields/related.py
Normal file
1721
venv/Lib/site-packages/django/db/models/fields/related.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,159 @@
|
||||
from django.db.models.lookups import (
|
||||
Exact, GreaterThan, GreaterThanOrEqual, In, IsNull, LessThan,
|
||||
LessThanOrEqual,
|
||||
)
|
||||
|
||||
|
||||
class MultiColSource:
|
||||
contains_aggregate = False
|
||||
|
||||
def __init__(self, alias, targets, sources, field):
|
||||
self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
|
||||
self.output_field = self.field
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({}, {})".format(
|
||||
self.__class__.__name__, self.alias, self.field)
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self.__class__(relabels.get(self.alias, self.alias),
|
||||
self.targets, self.sources, self.field)
|
||||
|
||||
def get_lookup(self, lookup):
|
||||
return self.output_field.get_lookup(lookup)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
|
||||
def get_normalized_value(value, lhs):
|
||||
from django.db.models import Model
|
||||
if isinstance(value, Model):
|
||||
value_list = []
|
||||
sources = lhs.output_field.get_path_info()[-1].target_fields
|
||||
for source in sources:
|
||||
while not isinstance(value, source.model) and source.remote_field:
|
||||
source = source.remote_field.model._meta.get_field(source.remote_field.field_name)
|
||||
try:
|
||||
value_list.append(getattr(value, source.attname))
|
||||
except AttributeError:
|
||||
# A case like Restaurant.objects.filter(place=restaurant_instance),
|
||||
# where place is a OneToOneField and the primary key of Restaurant.
|
||||
return (value.pk,)
|
||||
return tuple(value_list)
|
||||
if not isinstance(value, tuple):
|
||||
return (value,)
|
||||
return value
|
||||
|
||||
|
||||
class RelatedIn(In):
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
|
||||
# We need to run the related field's get_prep_value(). Consider case
|
||||
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
|
||||
# doesn't have validation for non-integers, so we must run validation
|
||||
# using the target field.
|
||||
if hasattr(self.lhs.output_field, 'get_path_info'):
|
||||
# Run the target field's get_prep_value. We can safely assume there is
|
||||
# only one as we don't get to the direct value branch otherwise.
|
||||
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
|
||||
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, MultiColSource):
|
||||
# For multicolumn lookups we need to build a multicolumn where clause.
|
||||
# This clause is either a SubqueryConstraint (for values that need to be compiled to
|
||||
# SQL) or an OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
|
||||
from django.db.models.sql.where import (
|
||||
AND, OR, SubqueryConstraint, WhereNode,
|
||||
)
|
||||
|
||||
root_constraint = WhereNode(connector=OR)
|
||||
if self.rhs_is_direct_value():
|
||||
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
|
||||
for value in values:
|
||||
value_constraint = WhereNode()
|
||||
for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
|
||||
lookup_class = target.get_lookup('exact')
|
||||
lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
|
||||
value_constraint.add(lookup, AND)
|
||||
root_constraint.add(value_constraint, OR)
|
||||
else:
|
||||
root_constraint.add(
|
||||
SubqueryConstraint(
|
||||
self.lhs.alias, [target.column for target in self.lhs.targets],
|
||||
[source.name for source in self.lhs.sources], self.rhs),
|
||||
AND)
|
||||
return root_constraint.as_sql(compiler, connection)
|
||||
else:
|
||||
if (not getattr(self.rhs, 'has_select_fields', True) and
|
||||
not getattr(self.lhs.field.target_field, 'primary_key', False)):
|
||||
self.rhs.clear_select_clause()
|
||||
if (getattr(self.lhs.output_field, 'primary_key', False) and
|
||||
self.lhs.output_field.model == self.rhs.model):
|
||||
# A case like Restaurant.objects.filter(place__in=restaurant_qs),
|
||||
# where place is a OneToOneField and the primary key of
|
||||
# Restaurant.
|
||||
target_field = self.lhs.field.name
|
||||
else:
|
||||
target_field = self.lhs.field.target_field.name
|
||||
self.rhs.add_fields([target_field], True)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedLookupMixin:
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, MultiColSource) and not hasattr(self.rhs, 'resolve_expression'):
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
|
||||
# We need to run the related field's get_prep_value(). Consider case
|
||||
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
|
||||
# doesn't have validation for non-integers, so we must run validation
|
||||
# using the target field.
|
||||
if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_path_info'):
|
||||
# Get the target field. We can safely assume there is only one
|
||||
# as we don't get to the direct value branch otherwise.
|
||||
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
|
||||
self.rhs = target_field.get_prep_value(self.rhs)
|
||||
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, MultiColSource):
|
||||
assert self.rhs_is_direct_value()
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)
|
||||
from django.db.models.sql.where import AND, WhereNode
|
||||
root_constraint = WhereNode()
|
||||
for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
|
||||
lookup_class = target.get_lookup(self.lookup_name)
|
||||
root_constraint.add(
|
||||
lookup_class(target.get_col(self.lhs.alias, source), val), AND)
|
||||
return root_constraint.as_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedExact(RelatedLookupMixin, Exact):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThan(RelatedLookupMixin, LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedIsNull(RelatedLookupMixin, IsNull):
|
||||
pass
|
||||
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
"Rel objects" for related fields.
|
||||
|
||||
"Rel objects" (for lack of a better name) carry information about the relation
|
||||
modeled by a related field and provide some utility functions. They're stored
|
||||
in the ``remote_field`` attribute of the field.
|
||||
|
||||
They also act as reverse fields for the purposes of the Meta API because
|
||||
they're the closest concept currently available.
|
||||
"""
|
||||
|
||||
from django.core import exceptions
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
from . import BLANK_CHOICE_DASH
|
||||
from .mixins import FieldCacheMixin
|
||||
|
||||
|
||||
class ForeignObjectRel(FieldCacheMixin):
|
||||
"""
|
||||
Used by ForeignObject to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
# Field flags
|
||||
auto_created = True
|
||||
concrete = False
|
||||
editable = False
|
||||
is_relation = True
|
||||
|
||||
# Reverse relations are always nullable (Django can't enforce that a
|
||||
# foreign key on the related model points to this model).
|
||||
null = True
|
||||
empty_strings_allowed = False
|
||||
|
||||
def __init__(self, field, to, related_name=None, related_query_name=None,
|
||||
limit_choices_to=None, parent_link=False, on_delete=None):
|
||||
self.field = field
|
||||
self.model = to
|
||||
self.related_name = related_name
|
||||
self.related_query_name = related_query_name
|
||||
self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
|
||||
self.parent_link = parent_link
|
||||
self.on_delete = on_delete
|
||||
|
||||
self.symmetrical = False
|
||||
self.multiple = True
|
||||
|
||||
# Some of the following cached_properties can't be initialized in
|
||||
# __init__ as the field doesn't have its model yet. Calling these methods
|
||||
# before field.contribute_to_class() has been called will result in
|
||||
# AttributeError
|
||||
@cached_property
|
||||
def hidden(self):
|
||||
return self.is_hidden()
|
||||
|
||||
@cached_property
|
||||
def name(self):
|
||||
return self.field.related_query_name()
|
||||
|
||||
@property
|
||||
def remote_field(self):
|
||||
return self.field
|
||||
|
||||
@property
|
||||
def target_field(self):
|
||||
"""
|
||||
When filtering against this relation, return the field on the remote
|
||||
model against which the filtering should happen.
|
||||
"""
|
||||
target_fields = self.get_path_info()[-1].target_fields
|
||||
if len(target_fields) > 1:
|
||||
raise exceptions.FieldError("Can't use target_field for multicolumn relations.")
|
||||
return target_fields[0]
|
||||
|
||||
@cached_property
|
||||
def related_model(self):
|
||||
if not self.field.model:
|
||||
raise AttributeError(
|
||||
"This property can't be accessed before self.field.contribute_to_class has been called.")
|
||||
return self.field.model
|
||||
|
||||
@cached_property
|
||||
def many_to_many(self):
|
||||
return self.field.many_to_many
|
||||
|
||||
@cached_property
|
||||
def many_to_one(self):
|
||||
return self.field.one_to_many
|
||||
|
||||
@cached_property
|
||||
def one_to_many(self):
|
||||
return self.field.many_to_one
|
||||
|
||||
@cached_property
|
||||
def one_to_one(self):
|
||||
return self.field.one_to_one
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
return self.field.get_lookup(lookup_name)
|
||||
|
||||
def get_internal_type(self):
|
||||
return self.field.get_internal_type()
|
||||
|
||||
@property
|
||||
def db_type(self):
|
||||
return self.field.db_type
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s: %s.%s>' % (
|
||||
type(self).__name__,
|
||||
self.related_model._meta.app_label,
|
||||
self.related_model._meta.model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return (
|
||||
self.field,
|
||||
self.model,
|
||||
self.related_name,
|
||||
self.related_query_name,
|
||||
make_hashable(self.limit_choices_to),
|
||||
self.parent_link,
|
||||
self.on_delete,
|
||||
self.symmetrical,
|
||||
self.multiple,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def get_choices(
|
||||
self, include_blank=True, blank_choice=BLANK_CHOICE_DASH,
|
||||
limit_choices_to=None, ordering=(),
|
||||
):
|
||||
"""
|
||||
Return choices with a default blank choices included, for use
|
||||
as <select> choices for this field.
|
||||
|
||||
Analog of django.db.models.fields.Field.get_choices(), provided
|
||||
initially for utilization by RelatedFieldListFilter.
|
||||
"""
|
||||
limit_choices_to = limit_choices_to or self.limit_choices_to
|
||||
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
|
||||
if ordering:
|
||||
qs = qs.order_by(*ordering)
|
||||
return (blank_choice if include_blank else []) + [
|
||||
(x.pk, str(x)) for x in qs
|
||||
]
|
||||
|
||||
def is_hidden(self):
|
||||
"""Should the related object be hidden?"""
|
||||
return bool(self.related_name) and self.related_name[-1] == '+'
|
||||
|
||||
def get_joining_columns(self):
|
||||
return self.field.get_reverse_joining_columns()
|
||||
|
||||
def get_extra_restriction(self, alias, related_alias):
|
||||
return self.field.get_extra_restriction(related_alias, alias)
|
||||
|
||||
def set_field_name(self):
|
||||
"""
|
||||
Set the related field's name, this is not available until later stages
|
||||
of app loading, so set_field_name is called from
|
||||
set_attributes_from_rel()
|
||||
"""
|
||||
# By default foreign object doesn't relate to any remote field (for
|
||||
# example custom multicolumn joins currently have no remote field).
|
||||
self.field_name = None
|
||||
|
||||
def get_accessor_name(self, model=None):
|
||||
# This method encapsulates the logic that decides what name to give an
|
||||
# accessor descriptor that retrieves related many-to-one or
|
||||
# many-to-many objects. It uses the lowercased object_name + "_set",
|
||||
# but this can be overridden with the "related_name" option. Due to
|
||||
# backwards compatibility ModelForms need to be able to provide an
|
||||
# alternate model. See BaseInlineFormSet.get_default_prefix().
|
||||
opts = model._meta if model else self.related_model._meta
|
||||
model = model or self.related_model
|
||||
if self.multiple:
|
||||
# If this is a symmetrical m2m relation on self, there is no reverse accessor.
|
||||
if self.symmetrical and model == self.model:
|
||||
return None
|
||||
if self.related_name:
|
||||
return self.related_name
|
||||
return opts.model_name + ('_set' if self.multiple else '')
|
||||
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
return self.field.get_reverse_path_info(filtered_relation)
|
||||
|
||||
def get_cache_name(self):
|
||||
"""
|
||||
Return the name of the cache key to use for storing an instance of the
|
||||
forward model on the reverse model.
|
||||
"""
|
||||
return self.get_accessor_name()
|
||||
|
||||
|
||||
class ManyToOneRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by the ForeignKey field to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
|
||||
Note: Because we somewhat abuse the Rel objects by using them as reverse
|
||||
fields we get the funny situation where
|
||||
``ManyToOneRel.many_to_one == False`` and
|
||||
``ManyToOneRel.one_to_many == True``. This is unfortunate but the actual
|
||||
ManyToOneRel class is a private API and there is work underway to turn
|
||||
reverse relations into actual fields.
|
||||
"""
|
||||
|
||||
def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
|
||||
limit_choices_to=None, parent_link=False, on_delete=None):
|
||||
super().__init__(
|
||||
field, to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.field_name = field_name
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state.pop('related_model', None)
|
||||
return state
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (self.field_name,)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the Field in the 'to' object to which this relationship is tied.
|
||||
"""
|
||||
field = self.model._meta.get_field(self.field_name)
|
||||
if not field.concrete:
|
||||
raise exceptions.FieldDoesNotExist("No related field named '%s'" % self.field_name)
|
||||
return field
|
||||
|
||||
def set_field_name(self):
|
||||
self.field_name = self.field_name or self.model._meta.pk.name
|
||||
|
||||
|
||||
class OneToOneRel(ManyToOneRel):
|
||||
"""
|
||||
Used by OneToOneField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
|
||||
limit_choices_to=None, parent_link=False, on_delete=None):
|
||||
super().__init__(
|
||||
field, to, field_name,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.multiple = False
|
||||
|
||||
|
||||
class ManyToManyRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by ManyToManyField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(self, field, to, related_name=None, related_query_name=None,
|
||||
limit_choices_to=None, symmetrical=True, through=None,
|
||||
through_fields=None, db_constraint=True):
|
||||
super().__init__(
|
||||
field, to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
)
|
||||
|
||||
if through and not db_constraint:
|
||||
raise ValueError("Can't supply a through model and db_constraint=False")
|
||||
self.through = through
|
||||
|
||||
if through_fields and not through:
|
||||
raise ValueError("Cannot specify through_fields without a through model")
|
||||
self.through_fields = through_fields
|
||||
|
||||
self.symmetrical = symmetrical
|
||||
self.db_constraint = db_constraint
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (
|
||||
self.through,
|
||||
make_hashable(self.through_fields),
|
||||
self.db_constraint,
|
||||
)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the field in the 'to' object to which this relationship is tied.
|
||||
Provided for symmetry with ManyToOneRel.
|
||||
"""
|
||||
opts = self.through._meta
|
||||
if self.through_fields:
|
||||
field = opts.get_field(self.through_fields[0])
|
||||
else:
|
||||
for field in opts.fields:
|
||||
rel = getattr(field, 'remote_field', None)
|
||||
if rel and rel.model == self.model:
|
||||
break
|
||||
return field.foreign_related_fields[0]
|
||||
@@ -0,0 +1,46 @@
|
||||
from .comparison import (
|
||||
Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf,
|
||||
)
|
||||
from .datetime import (
|
||||
Extract, ExtractDay, ExtractHour, ExtractIsoWeekDay, ExtractIsoYear,
|
||||
ExtractMinute, ExtractMonth, ExtractQuarter, ExtractSecond, ExtractWeek,
|
||||
ExtractWeekDay, ExtractYear, Now, Trunc, TruncDate, TruncDay, TruncHour,
|
||||
TruncMinute, TruncMonth, TruncQuarter, TruncSecond, TruncTime, TruncWeek,
|
||||
TruncYear,
|
||||
)
|
||||
from .math import (
|
||||
Abs, ACos, ASin, ATan, ATan2, Ceil, Cos, Cot, Degrees, Exp, Floor, Ln, Log,
|
||||
Mod, Pi, Power, Radians, Random, Round, Sign, Sin, Sqrt, Tan,
|
||||
)
|
||||
from .text import (
|
||||
MD5, SHA1, SHA224, SHA256, SHA384, SHA512, Chr, Concat, ConcatPair, Left,
|
||||
Length, Lower, LPad, LTrim, Ord, Repeat, Replace, Reverse, Right, RPad,
|
||||
RTrim, StrIndex, Substr, Trim, Upper,
|
||||
)
|
||||
from .window import (
|
||||
CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
|
||||
PercentRank, Rank, RowNumber,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# comparison and conversion
|
||||
'Cast', 'Coalesce', 'Collate', 'Greatest', 'JSONObject', 'Least', 'NullIf',
|
||||
# datetime
|
||||
'Extract', 'ExtractDay', 'ExtractHour', 'ExtractMinute', 'ExtractMonth',
|
||||
'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractIsoWeekDay',
|
||||
'ExtractWeekDay', 'ExtractIsoYear', 'ExtractYear', 'Now', 'Trunc',
|
||||
'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'TruncMonth',
|
||||
'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncWeek', 'TruncYear',
|
||||
# math
|
||||
'Abs', 'ACos', 'ASin', 'ATan', 'ATan2', 'Ceil', 'Cos', 'Cot', 'Degrees',
|
||||
'Exp', 'Floor', 'Ln', 'Log', 'Mod', 'Pi', 'Power', 'Radians', 'Random',
|
||||
'Round', 'Sign', 'Sin', 'Sqrt', 'Tan',
|
||||
# text
|
||||
'MD5', 'SHA1', 'SHA224', 'SHA256', 'SHA384', 'SHA512', 'Chr', 'Concat',
|
||||
'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim', 'Ord', 'Repeat',
|
||||
'Replace', 'Reverse', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr',
|
||||
'Trim', 'Upper',
|
||||
# window
|
||||
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
|
||||
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
|
||||
]
|
||||
193
venv/Lib/site-packages/django/db/models/functions/comparison.py
Normal file
193
venv/Lib/site-packages/django/db/models/functions/comparison.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Database functions that do comparisons or type conversions."""
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields.json import JSONField
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
class Cast(Func):
|
||||
"""Coerce an expression to a new field type."""
|
||||
function = 'CAST'
|
||||
template = '%(function)s(%(expressions)s AS %(db_type)s)'
|
||||
|
||||
def __init__(self, expression, output_field):
|
||||
super().__init__(expression, output_field=output_field)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context['db_type'] = self.output_field.cast_db_type(connection)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
db_type = self.output_field.db_type(connection)
|
||||
if db_type in {'datetime', 'time'}:
|
||||
# Use strftime as datetime/time don't keep fractional seconds.
|
||||
template = 'strftime(%%s, %(expressions)s)'
|
||||
sql, params = super().as_sql(compiler, connection, template=template, **extra_context)
|
||||
format_string = '%H:%M:%f' if db_type == 'time' else '%Y-%m-%d %H:%M:%f'
|
||||
params.insert(0, format_string)
|
||||
return sql, params
|
||||
elif db_type == 'date':
|
||||
template = 'date(%(expressions)s)'
|
||||
return super().as_sql(compiler, connection, template=template, **extra_context)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
template = None
|
||||
output_type = self.output_field.get_internal_type()
|
||||
# MySQL doesn't support explicit cast to float.
|
||||
if output_type == 'FloatField':
|
||||
template = '(%(expressions)s + 0.0)'
|
||||
# MariaDB doesn't support explicit cast to JSON.
|
||||
elif output_type == 'JSONField' and connection.mysql_is_mariadb:
|
||||
template = "JSON_EXTRACT(%(expressions)s, '$')"
|
||||
return self.as_sql(compiler, connection, template=template, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# CAST would be valid too, but the :: shortcut syntax is more readable.
|
||||
# 'expressions' is wrapped in parentheses in case it's a complex
|
||||
# expression.
|
||||
return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.output_field.get_internal_type() == 'JSONField':
|
||||
# Oracle doesn't support explicit cast to JSON.
|
||||
template = "JSON_QUERY(%(expressions)s, '$')"
|
||||
return super().as_sql(compiler, connection, template=template, **extra_context)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Coalesce(Func):
|
||||
"""Return, from left to right, the first non-null expression."""
|
||||
function = 'COALESCE'
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError('Coalesce must take at least two expressions')
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
@property
|
||||
def empty_result_set_value(self):
|
||||
for expression in self.get_source_expressions():
|
||||
result = expression.empty_result_set_value
|
||||
if result is NotImplemented or result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
|
||||
# so convert all fields to NCLOB when that type is expected.
|
||||
if self.output_field.get_internal_type() == 'TextField':
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions([
|
||||
Func(expression, function='TO_NCLOB') for expression in self.get_source_expressions()
|
||||
])
|
||||
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Collate(Func):
|
||||
function = 'COLLATE'
|
||||
template = '%(expressions)s %(function)s %(collation)s'
|
||||
# Inspired from https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
||||
collation_re = _lazy_re_compile(r'^[\w\-]+$')
|
||||
|
||||
def __init__(self, expression, collation):
|
||||
if not (collation and self.collation_re.match(collation)):
|
||||
raise ValueError('Invalid collation name: %r.' % collation)
|
||||
self.collation = collation
|
||||
super().__init__(expression)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context.setdefault('collation', connection.ops.quote_name(self.collation))
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Greatest(Func):
|
||||
"""
|
||||
Return the maximum expression.
|
||||
|
||||
If any expression is null the return value is database-specific:
|
||||
On PostgreSQL, the maximum not-null expression is returned.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
|
||||
"""
|
||||
function = 'GREATEST'
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError('Greatest must take at least two expressions')
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MAX function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function='MAX', **extra_context)
|
||||
|
||||
|
||||
class JSONObject(Func):
|
||||
function = 'JSON_OBJECT'
|
||||
output_field = JSONField()
|
||||
|
||||
def __init__(self, **fields):
|
||||
expressions = []
|
||||
for key, value in fields.items():
|
||||
expressions.extend((Value(key), value))
|
||||
super().__init__(*expressions)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.has_json_object_function:
|
||||
raise NotSupportedError(
|
||||
'JSONObject() is not supported on this database backend.'
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function='JSONB_BUILD_OBJECT',
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
class ArgJoiner:
|
||||
def join(self, args):
|
||||
args = [' VALUE '.join(arg) for arg in zip(args[::2], args[1::2])]
|
||||
return ', '.join(args)
|
||||
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
arg_joiner=ArgJoiner(),
|
||||
template='%(function)s(%(expressions)s RETURNING CLOB)',
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Least(Func):
|
||||
"""
|
||||
Return the minimum expression.
|
||||
|
||||
If any expression is null the return value is database-specific:
|
||||
On PostgreSQL, return the minimum not-null expression.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, return null.
|
||||
"""
|
||||
function = 'LEAST'
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError('Least must take at least two expressions')
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MIN function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function='MIN', **extra_context)
|
||||
|
||||
|
||||
class NullIf(Func):
|
||||
function = 'NULLIF'
|
||||
arity = 2
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
expression1 = self.get_source_expressions()[0]
|
||||
if isinstance(expression1, Value) and expression1.value is None:
|
||||
raise ValueError('Oracle does not allow Value(None) for expression1.')
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
339
venv/Lib/site-packages/django/db/models/functions/datetime.py
Normal file
339
venv/Lib/site-packages/django/db/models/functions/datetime.py
Normal file
@@ -0,0 +1,339 @@
|
||||
from datetime import datetime
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.models.expressions import Func
|
||||
from django.db.models.fields import (
|
||||
DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
|
||||
)
|
||||
from django.db.models.lookups import (
|
||||
Transform, YearExact, YearGt, YearGte, YearLt, YearLte,
|
||||
)
|
||||
from django.utils import timezone
|
||||
|
||||
|
||||
class TimezoneMixin:
|
||||
tzinfo = None
|
||||
|
||||
def get_tzname(self):
|
||||
# Timezone conversions must happen to the input datetime *before*
|
||||
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
|
||||
# database as 2016-01-01 01:00:00 +00:00. Any results should be
|
||||
# based on the input datetime not the stored datetime.
|
||||
tzname = None
|
||||
if settings.USE_TZ:
|
||||
if self.tzinfo is None:
|
||||
tzname = timezone.get_current_timezone_name()
|
||||
else:
|
||||
tzname = timezone._get_timezone_name(self.tzinfo)
|
||||
return tzname
|
||||
|
||||
|
||||
class Extract(TimezoneMixin, Transform):
|
||||
lookup_name = None
|
||||
output_field = IntegerField()
|
||||
|
||||
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
|
||||
if self.lookup_name is None:
|
||||
self.lookup_name = lookup_name
|
||||
if self.lookup_name is None:
|
||||
raise ValueError('lookup_name must be provided')
|
||||
self.tzinfo = tzinfo
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
lhs_output_field = self.lhs.output_field
|
||||
if isinstance(lhs_output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError('tzinfo can only be used with DateTimeField.')
|
||||
elif isinstance(lhs_output_field, DateField):
|
||||
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
|
||||
elif isinstance(lhs_output_field, TimeField):
|
||||
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
|
||||
elif isinstance(lhs_output_field, DurationField):
|
||||
if not connection.features.has_native_duration_field:
|
||||
raise ValueError('Extract requires native DurationField database support.')
|
||||
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
|
||||
else:
|
||||
# resolve_expression has already validated the output_field so this
|
||||
# assert should never be hit.
|
||||
assert False, "Tried to Extract from an invalid type."
|
||||
return sql, params
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
field = getattr(copy.lhs, 'output_field', None)
|
||||
if field is None:
|
||||
return copy
|
||||
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
|
||||
raise ValueError(
|
||||
'Extract input expression must be DateField, DateTimeField, '
|
||||
'TimeField, or DurationField.'
|
||||
)
|
||||
# Passing dates to functions expecting datetimes is most likely a mistake.
|
||||
if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
|
||||
raise ValueError(
|
||||
"Cannot extract time component '%s' from DateField '%s'." % (copy.lookup_name, field.name)
|
||||
)
|
||||
if (
|
||||
isinstance(field, DurationField) and
|
||||
copy.lookup_name in ('year', 'iso_year', 'month', 'week', 'week_day', 'iso_week_day', 'quarter')
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot extract component '%s' from DurationField '%s'."
|
||||
% (copy.lookup_name, field.name)
|
||||
)
|
||||
return copy
|
||||
|
||||
|
||||
class ExtractYear(Extract):
|
||||
lookup_name = 'year'
|
||||
|
||||
|
||||
class ExtractIsoYear(Extract):
|
||||
"""Return the ISO-8601 week-numbering year."""
|
||||
lookup_name = 'iso_year'
|
||||
|
||||
|
||||
class ExtractMonth(Extract):
|
||||
lookup_name = 'month'
|
||||
|
||||
|
||||
class ExtractDay(Extract):
|
||||
lookup_name = 'day'
|
||||
|
||||
|
||||
class ExtractWeek(Extract):
|
||||
"""
|
||||
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
|
||||
week.
|
||||
"""
|
||||
lookup_name = 'week'
|
||||
|
||||
|
||||
class ExtractWeekDay(Extract):
|
||||
"""
|
||||
Return Sunday=1 through Saturday=7.
|
||||
|
||||
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
|
||||
"""
|
||||
lookup_name = 'week_day'
|
||||
|
||||
|
||||
class ExtractIsoWeekDay(Extract):
|
||||
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
|
||||
lookup_name = 'iso_week_day'
|
||||
|
||||
|
||||
class ExtractQuarter(Extract):
|
||||
lookup_name = 'quarter'
|
||||
|
||||
|
||||
class ExtractHour(Extract):
|
||||
lookup_name = 'hour'
|
||||
|
||||
|
||||
class ExtractMinute(Extract):
|
||||
lookup_name = 'minute'
|
||||
|
||||
|
||||
class ExtractSecond(Extract):
|
||||
lookup_name = 'second'
|
||||
|
||||
|
||||
DateField.register_lookup(ExtractYear)
|
||||
DateField.register_lookup(ExtractMonth)
|
||||
DateField.register_lookup(ExtractDay)
|
||||
DateField.register_lookup(ExtractWeekDay)
|
||||
DateField.register_lookup(ExtractIsoWeekDay)
|
||||
DateField.register_lookup(ExtractWeek)
|
||||
DateField.register_lookup(ExtractIsoYear)
|
||||
DateField.register_lookup(ExtractQuarter)
|
||||
|
||||
TimeField.register_lookup(ExtractHour)
|
||||
TimeField.register_lookup(ExtractMinute)
|
||||
TimeField.register_lookup(ExtractSecond)
|
||||
|
||||
DateTimeField.register_lookup(ExtractHour)
|
||||
DateTimeField.register_lookup(ExtractMinute)
|
||||
DateTimeField.register_lookup(ExtractSecond)
|
||||
|
||||
ExtractYear.register_lookup(YearExact)
|
||||
ExtractYear.register_lookup(YearGt)
|
||||
ExtractYear.register_lookup(YearGte)
|
||||
ExtractYear.register_lookup(YearLt)
|
||||
ExtractYear.register_lookup(YearLte)
|
||||
|
||||
ExtractIsoYear.register_lookup(YearExact)
|
||||
ExtractIsoYear.register_lookup(YearGt)
|
||||
ExtractIsoYear.register_lookup(YearGte)
|
||||
ExtractIsoYear.register_lookup(YearLt)
|
||||
ExtractIsoYear.register_lookup(YearLte)
|
||||
|
||||
|
||||
class Now(Func):
|
||||
template = 'CURRENT_TIMESTAMP'
|
||||
output_field = DateTimeField()
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
|
||||
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
|
||||
# other databases.
|
||||
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
|
||||
|
||||
|
||||
class TruncBase(TimezoneMixin, Transform):
|
||||
kind = None
|
||||
tzinfo = None
|
||||
|
||||
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
|
||||
# argument.
|
||||
def __init__(self, expression, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra):
|
||||
self.tzinfo = tzinfo
|
||||
self.is_dst = is_dst
|
||||
super().__init__(expression, output_field=output_field, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
inner_sql, inner_params = compiler.compile(self.lhs)
|
||||
tzname = None
|
||||
if isinstance(self.lhs.output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError('tzinfo can only be used with DateTimeField.')
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
|
||||
elif isinstance(self.output_field, DateField):
|
||||
sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname)
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
|
||||
else:
|
||||
raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
|
||||
return sql, inner_params
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
field = copy.lhs.output_field
|
||||
# DateTimeField is a subclass of DateField so this works for both.
|
||||
if not isinstance(field, (DateField, TimeField)):
|
||||
raise TypeError(
|
||||
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
|
||||
)
|
||||
# If self.output_field was None, then accessing the field will trigger
|
||||
# the resolver to assign it to self.lhs.output_field.
|
||||
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
|
||||
raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
|
||||
# Passing dates or times to functions expecting datetimes is most
|
||||
# likely a mistake.
|
||||
class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
|
||||
output_field = class_output_field or copy.output_field
|
||||
has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
|
||||
if type(field) == DateField and (
|
||||
isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
|
||||
raise ValueError("Cannot truncate DateField '%s' to %s." % (
|
||||
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
|
||||
))
|
||||
elif isinstance(field, TimeField) and (
|
||||
isinstance(output_field, DateTimeField) or
|
||||
copy.kind in ('year', 'quarter', 'month', 'week', 'day', 'date')):
|
||||
raise ValueError("Cannot truncate TimeField '%s' to %s." % (
|
||||
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
|
||||
))
|
||||
return copy
|
||||
|
||||
def convert_value(self, value, expression, connection):
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
if not settings.USE_TZ:
|
||||
pass
|
||||
elif value is not None:
|
||||
value = value.replace(tzinfo=None)
|
||||
value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
|
||||
elif not connection.features.has_zoneinfo_database:
|
||||
raise ValueError(
|
||||
'Database returned an invalid datetime value. Are time '
|
||||
'zone definitions for your database installed?'
|
||||
)
|
||||
elif isinstance(value, datetime):
|
||||
if value is None:
|
||||
pass
|
||||
elif isinstance(self.output_field, DateField):
|
||||
value = value.date()
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
value = value.time()
|
||||
return value
|
||||
|
||||
|
||||
class Trunc(TruncBase):
|
||||
|
||||
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
|
||||
# argument.
|
||||
def __init__(self, expression, kind, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra):
|
||||
self.kind = kind
|
||||
super().__init__(
|
||||
expression, output_field=output_field, tzinfo=tzinfo,
|
||||
is_dst=is_dst, **extra
|
||||
)
|
||||
|
||||
|
||||
class TruncYear(TruncBase):
|
||||
kind = 'year'
|
||||
|
||||
|
||||
class TruncQuarter(TruncBase):
|
||||
kind = 'quarter'
|
||||
|
||||
|
||||
class TruncMonth(TruncBase):
|
||||
kind = 'month'
|
||||
|
||||
|
||||
class TruncWeek(TruncBase):
|
||||
"""Truncate to midnight on the Monday of the week."""
|
||||
kind = 'week'
|
||||
|
||||
|
||||
class TruncDay(TruncBase):
|
||||
kind = 'day'
|
||||
|
||||
|
||||
class TruncDate(TruncBase):
|
||||
kind = 'date'
|
||||
lookup_name = 'date'
|
||||
output_field = DateField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to date rather than truncate to date.
|
||||
lhs, lhs_params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
sql = connection.ops.datetime_cast_date_sql(lhs, tzname)
|
||||
return sql, lhs_params
|
||||
|
||||
|
||||
class TruncTime(TruncBase):
|
||||
kind = 'time'
|
||||
lookup_name = 'time'
|
||||
output_field = TimeField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to time rather than truncate to time.
|
||||
lhs, lhs_params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
sql = connection.ops.datetime_cast_time_sql(lhs, tzname)
|
||||
return sql, lhs_params
|
||||
|
||||
|
||||
class TruncHour(TruncBase):
|
||||
kind = 'hour'
|
||||
|
||||
|
||||
class TruncMinute(TruncBase):
|
||||
kind = 'minute'
|
||||
|
||||
|
||||
class TruncSecond(TruncBase):
|
||||
kind = 'second'
|
||||
|
||||
|
||||
DateTimeField.register_lookup(TruncDate)
|
||||
DateTimeField.register_lookup(TruncTime)
|
||||
197
venv/Lib/site-packages/django/db/models/functions/math.py
Normal file
197
venv/Lib/site-packages/django/db/models/functions/math.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import math
|
||||
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
from django.db.models.functions import Cast
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDecimalInputMixin, NumericOutputFieldMixin,
|
||||
)
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
|
||||
class Abs(Transform):
|
||||
function = 'ABS'
|
||||
lookup_name = 'abs'
|
||||
|
||||
|
||||
class ACos(NumericOutputFieldMixin, Transform):
|
||||
function = 'ACOS'
|
||||
lookup_name = 'acos'
|
||||
|
||||
|
||||
class ASin(NumericOutputFieldMixin, Transform):
|
||||
function = 'ASIN'
|
||||
lookup_name = 'asin'
|
||||
|
||||
|
||||
class ATan(NumericOutputFieldMixin, Transform):
|
||||
function = 'ATAN'
|
||||
lookup_name = 'atan'
|
||||
|
||||
|
||||
class ATan2(NumericOutputFieldMixin, Func):
|
||||
function = 'ATAN2'
|
||||
arity = 2
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version >= (5, 0, 0):
|
||||
return self.as_sql(compiler, connection)
|
||||
# This function is usually ATan2(y, x), returning the inverse tangent
|
||||
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
|
||||
# Cast integers to float to avoid inconsistent/buggy behavior if the
|
||||
# arguments are mixed between integer and float or decimal.
|
||||
# https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions([
|
||||
Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField)
|
||||
else expression for expression in self.get_source_expressions()[::-1]
|
||||
])
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Ceil(Transform):
|
||||
function = 'CEILING'
|
||||
lookup_name = 'ceil'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function='CEIL', **extra_context)
|
||||
|
||||
|
||||
class Cos(NumericOutputFieldMixin, Transform):
|
||||
function = 'COS'
|
||||
lookup_name = 'cos'
|
||||
|
||||
|
||||
class Cot(NumericOutputFieldMixin, Transform):
|
||||
function = 'COT'
|
||||
lookup_name = 'cot'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
|
||||
|
||||
|
||||
class Degrees(NumericOutputFieldMixin, Transform):
|
||||
function = 'DEGREES'
|
||||
lookup_name = 'degrees'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection,
|
||||
template='((%%(expressions)s) * 180 / %s)' % math.pi,
|
||||
**extra_context
|
||||
)
|
||||
|
||||
|
||||
class Exp(NumericOutputFieldMixin, Transform):
|
||||
function = 'EXP'
|
||||
lookup_name = 'exp'
|
||||
|
||||
|
||||
class Floor(Transform):
|
||||
function = 'FLOOR'
|
||||
lookup_name = 'floor'
|
||||
|
||||
|
||||
class Ln(NumericOutputFieldMixin, Transform):
|
||||
function = 'LN'
|
||||
lookup_name = 'ln'
|
||||
|
||||
|
||||
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||
function = 'LOG'
|
||||
arity = 2
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if not getattr(connection.ops, 'spatialite', False):
|
||||
return self.as_sql(compiler, connection)
|
||||
# This function is usually Log(b, x) returning the logarithm of x to
|
||||
# the base b, but on SpatiaLite it's Log(x, b).
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(self.get_source_expressions()[::-1])
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||
function = 'MOD'
|
||||
arity = 2
|
||||
|
||||
|
||||
class Pi(NumericOutputFieldMixin, Func):
|
||||
function = 'PI'
|
||||
arity = 0
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
|
||||
|
||||
|
||||
class Power(NumericOutputFieldMixin, Func):
|
||||
function = 'POWER'
|
||||
arity = 2
|
||||
|
||||
|
||||
class Radians(NumericOutputFieldMixin, Transform):
|
||||
function = 'RADIANS'
|
||||
lookup_name = 'radians'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection,
|
||||
template='((%%(expressions)s) * %s / 180)' % math.pi,
|
||||
**extra_context
|
||||
)
|
||||
|
||||
|
||||
class Random(NumericOutputFieldMixin, Func):
|
||||
function = 'RANDOM'
|
||||
arity = 0
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function='RAND', **extra_context)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function='DBMS_RANDOM.VALUE', **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function='RAND', **extra_context)
|
||||
|
||||
def get_group_by_cols(self, alias=None):
|
||||
return []
|
||||
|
||||
|
||||
class Round(FixDecimalInputMixin, Transform):
|
||||
function = 'ROUND'
|
||||
lookup_name = 'round'
|
||||
arity = None # Override Transform's arity=1 to enable passing precision.
|
||||
|
||||
def __init__(self, expression, precision=0, **extra):
|
||||
super().__init__(expression, precision, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
precision = self.get_source_expressions()[1]
|
||||
if isinstance(precision, Value) and precision.value < 0:
|
||||
raise ValueError('SQLite does not support negative precision.')
|
||||
return super().as_sqlite(compiler, connection, **extra_context)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
source = self.get_source_expressions()[0]
|
||||
return source.output_field
|
||||
|
||||
|
||||
class Sign(Transform):
|
||||
function = 'SIGN'
|
||||
lookup_name = 'sign'
|
||||
|
||||
|
||||
class Sin(NumericOutputFieldMixin, Transform):
|
||||
function = 'SIN'
|
||||
lookup_name = 'sin'
|
||||
|
||||
|
||||
class Sqrt(NumericOutputFieldMixin, Transform):
|
||||
function = 'SQRT'
|
||||
lookup_name = 'sqrt'
|
||||
|
||||
|
||||
class Tan(NumericOutputFieldMixin, Transform):
|
||||
function = 'TAN'
|
||||
lookup_name = 'tan'
|
||||
52
venv/Lib/site-packages/django/db/models/functions/mixins.py
Normal file
52
venv/Lib/site-packages/django/db/models/functions/mixins.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import sys
|
||||
|
||||
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
||||
from django.db.models.functions import Cast
|
||||
|
||||
|
||||
class FixDecimalInputMixin:
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
|
||||
# following function signatures:
|
||||
# - LOG(double, double)
|
||||
# - MOD(double, double)
|
||||
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions([
|
||||
Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
|
||||
else expression for expression in self.get_source_expressions()
|
||||
])
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class FixDurationInputMixin:
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||
if self.output_field.get_internal_type() == 'DurationField':
|
||||
sql = 'CAST(%s AS SIGNED)' % sql
|
||||
return sql, params
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.output_field.get_internal_type() == 'DurationField':
|
||||
expression = self.get_source_expressions()[0]
|
||||
options = self._get_repr_options()
|
||||
from django.db.backends.oracle.functions import (
|
||||
IntervalToSeconds, SecondsToInterval,
|
||||
)
|
||||
return compiler.compile(
|
||||
SecondsToInterval(self.__class__(IntervalToSeconds(expression), **options))
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class NumericOutputFieldMixin:
|
||||
|
||||
def _resolve_output_field(self):
|
||||
source_fields = self.get_source_fields()
|
||||
if any(isinstance(s, DecimalField) for s in source_fields):
|
||||
return DecimalField()
|
||||
if any(isinstance(s, IntegerField) for s in source_fields):
|
||||
return FloatField()
|
||||
return super()._resolve_output_field() if source_fields else FloatField()
|
||||
323
venv/Lib/site-packages/django/db/models/functions/text.py
Normal file
323
venv/Lib/site-packages/django/db/models/functions/text.py
Normal file
@@ -0,0 +1,323 @@
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import CharField, IntegerField
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
|
||||
class MySQLSHA2Mixin:
|
||||
def as_mysql(self, compiler, connection, **extra_content):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template='SHA2(%%(expressions)s, %s)' % self.function[3:],
|
||||
**extra_content,
|
||||
)
|
||||
|
||||
|
||||
class OracleHashMixin:
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=(
|
||||
"LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW("
|
||||
"%(expressions)s, 'AL32UTF8'), '%(function)s')))"
|
||||
),
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class PostgreSQLSHAMixin:
|
||||
def as_postgresql(self, compiler, connection, **extra_content):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
|
||||
function=self.function.lower(),
|
||||
**extra_content,
|
||||
)
|
||||
|
||||
|
||||
class Chr(Transform):
|
||||
function = 'CHR'
|
||||
lookup_name = 'chr'
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, function='CHAR',
|
||||
template='%(function)s(%(expressions)s USING utf16)',
|
||||
**extra_context
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection,
|
||||
template='%(function)s(%(expressions)s USING NCHAR_CS)',
|
||||
**extra_context
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function='CHAR', **extra_context)
|
||||
|
||||
|
||||
class ConcatPair(Func):
|
||||
"""
|
||||
Concatenate two arguments together. This is used by `Concat` because not
|
||||
all backend databases support more than two arguments.
|
||||
"""
|
||||
function = 'CONCAT'
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
coalesced = self.coalesce()
|
||||
return super(ConcatPair, coalesced).as_sql(
|
||||
compiler, connection, template='%(expressions)s', arg_joiner=' || ',
|
||||
**extra_context
|
||||
)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
|
||||
return super().as_sql(
|
||||
compiler, connection, function='CONCAT_WS',
|
||||
template="%(function)s('', %(expressions)s)",
|
||||
**extra_context
|
||||
)
|
||||
|
||||
def coalesce(self):
|
||||
# null on either side results in null for expression, wrap with coalesce
|
||||
c = self.copy()
|
||||
c.set_source_expressions([
|
||||
Coalesce(expression, Value('')) for expression in c.get_source_expressions()
|
||||
])
|
||||
return c
|
||||
|
||||
|
||||
class Concat(Func):
|
||||
"""
|
||||
Concatenate text fields together. Backends that result in an entire
|
||||
null expression when any arguments are null will wrap each argument in
|
||||
coalesce functions to ensure a non-null result.
|
||||
"""
|
||||
function = None
|
||||
template = "%(expressions)s"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError('Concat must take at least two expressions')
|
||||
paired = self._paired(expressions)
|
||||
super().__init__(paired, **extra)
|
||||
|
||||
def _paired(self, expressions):
|
||||
# wrap pairs of expressions in successive concat functions
|
||||
# exp = [a, b, c, d]
|
||||
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
|
||||
if len(expressions) == 2:
|
||||
return ConcatPair(*expressions)
|
||||
return ConcatPair(expressions[0], self._paired(expressions[1:]))
|
||||
|
||||
|
||||
class Left(Func):
|
||||
function = 'LEFT'
|
||||
arity = 2
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, length, **extra):
|
||||
"""
|
||||
expression: the name of a field, or an expression returning a string
|
||||
length: the number of characters to return from the start of the string
|
||||
"""
|
||||
if not hasattr(length, 'resolve_expression'):
|
||||
if length < 1:
|
||||
raise ValueError("'length' must be greater than 0.")
|
||||
super().__init__(expression, length, **extra)
|
||||
|
||||
def get_substr(self):
|
||||
return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return self.get_substr().as_oracle(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return self.get_substr().as_sqlite(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Length(Transform):
|
||||
"""Return the number of characters in the expression."""
|
||||
function = 'LENGTH'
|
||||
lookup_name = 'length'
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context)
|
||||
|
||||
|
||||
class Lower(Transform):
|
||||
function = 'LOWER'
|
||||
lookup_name = 'lower'
|
||||
|
||||
|
||||
class LPad(Func):
|
||||
function = 'LPAD'
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, length, fill_text=Value(' '), **extra):
|
||||
if not hasattr(length, 'resolve_expression') and length is not None and length < 0:
|
||||
raise ValueError("'length' must be greater or equal to 0.")
|
||||
super().__init__(expression, length, fill_text, **extra)
|
||||
|
||||
|
||||
class LTrim(Transform):
|
||||
function = 'LTRIM'
|
||||
lookup_name = 'ltrim'
|
||||
|
||||
|
||||
class MD5(OracleHashMixin, Transform):
|
||||
function = 'MD5'
|
||||
lookup_name = 'md5'
|
||||
|
||||
|
||||
class Ord(Transform):
|
||||
function = 'ASCII'
|
||||
lookup_name = 'ord'
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function='ORD', **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function='UNICODE', **extra_context)
|
||||
|
||||
|
||||
class Repeat(Func):
|
||||
function = 'REPEAT'
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, number, **extra):
|
||||
if not hasattr(number, 'resolve_expression') and number is not None and number < 0:
|
||||
raise ValueError("'number' must be greater or equal to 0.")
|
||||
super().__init__(expression, number, **extra)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
expression, number = self.source_expressions
|
||||
length = None if number is None else Length(expression) * number
|
||||
rpad = RPad(expression, length, expression)
|
||||
return rpad.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Replace(Func):
|
||||
function = 'REPLACE'
|
||||
|
||||
def __init__(self, expression, text, replacement=Value(''), **extra):
|
||||
super().__init__(expression, text, replacement, **extra)
|
||||
|
||||
|
||||
class Reverse(Transform):
|
||||
function = 'REVERSE'
|
||||
lookup_name = 'reverse'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
# REVERSE in Oracle is undocumented and doesn't support multi-byte
|
||||
# strings. Use a special subquery instead.
|
||||
return super().as_sql(
|
||||
compiler, connection,
|
||||
template=(
|
||||
'(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM '
|
||||
'(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s '
|
||||
'FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) '
|
||||
'GROUP BY %(expressions)s)'
|
||||
),
|
||||
**extra_context
|
||||
)
|
||||
|
||||
|
||||
class Right(Left):
|
||||
function = 'RIGHT'
|
||||
|
||||
def get_substr(self):
|
||||
return Substr(self.source_expressions[0], self.source_expressions[1] * Value(-1))
|
||||
|
||||
|
||||
class RPad(LPad):
|
||||
function = 'RPAD'
|
||||
|
||||
|
||||
class RTrim(Transform):
|
||||
function = 'RTRIM'
|
||||
lookup_name = 'rtrim'
|
||||
|
||||
|
||||
class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = 'SHA1'
|
||||
lookup_name = 'sha1'
|
||||
|
||||
|
||||
class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
|
||||
function = 'SHA224'
|
||||
lookup_name = 'sha224'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
raise NotSupportedError('SHA224 is not supported on Oracle.')
|
||||
|
||||
|
||||
class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = 'SHA256'
|
||||
lookup_name = 'sha256'
|
||||
|
||||
|
||||
class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = 'SHA384'
|
||||
lookup_name = 'sha384'
|
||||
|
||||
|
||||
class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = 'SHA512'
|
||||
lookup_name = 'sha512'
|
||||
|
||||
|
||||
class StrIndex(Func):
|
||||
"""
|
||||
Return a positive integer corresponding to the 1-indexed position of the
|
||||
first occurrence of a substring inside another string, or 0 if the
|
||||
substring is not found.
|
||||
"""
|
||||
function = 'INSTR'
|
||||
arity = 2
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function='STRPOS', **extra_context)
|
||||
|
||||
|
||||
class Substr(Func):
|
||||
function = 'SUBSTRING'
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, pos, length=None, **extra):
|
||||
"""
|
||||
expression: the name of a field, or an expression returning a string
|
||||
pos: an integer > 0, or an expression returning an integer
|
||||
length: an optional number of characters to return
|
||||
"""
|
||||
if not hasattr(pos, 'resolve_expression'):
|
||||
if pos < 1:
|
||||
raise ValueError("'pos' must be greater than 0")
|
||||
expressions = [expression, pos]
|
||||
if length is not None:
|
||||
expressions.append(length)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
|
||||
|
||||
|
||||
class Trim(Transform):
|
||||
function = 'TRIM'
|
||||
lookup_name = 'trim'
|
||||
|
||||
|
||||
class Upper(Transform):
|
||||
function = 'UPPER'
|
||||
lookup_name = 'upper'
|
||||
108
venv/Lib/site-packages/django/db/models/functions/window.py
Normal file
108
venv/Lib/site-packages/django/db/models/functions/window.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from django.db.models.expressions import Func
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
|
||||
__all__ = [
|
||||
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
|
||||
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
|
||||
]
|
||||
|
||||
|
||||
class CumeDist(Func):
|
||||
function = 'CUME_DIST'
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class DenseRank(Func):
|
||||
function = 'DENSE_RANK'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class FirstValue(Func):
|
||||
arity = 1
|
||||
function = 'FIRST_VALUE'
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class LagLeadFunction(Func):
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, expression, offset=1, default=None, **extra):
|
||||
if expression is None:
|
||||
raise ValueError(
|
||||
'%s requires a non-null source expression.' %
|
||||
self.__class__.__name__
|
||||
)
|
||||
if offset is None or offset <= 0:
|
||||
raise ValueError(
|
||||
'%s requires a positive integer for the offset.' %
|
||||
self.__class__.__name__
|
||||
)
|
||||
args = (expression, offset)
|
||||
if default is not None:
|
||||
args += (default,)
|
||||
super().__init__(*args, **extra)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
sources = self.get_source_expressions()
|
||||
return sources[0].output_field
|
||||
|
||||
|
||||
class Lag(LagLeadFunction):
|
||||
function = 'LAG'
|
||||
|
||||
|
||||
class LastValue(Func):
|
||||
arity = 1
|
||||
function = 'LAST_VALUE'
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Lead(LagLeadFunction):
|
||||
function = 'LEAD'
|
||||
|
||||
|
||||
class NthValue(Func):
|
||||
function = 'NTH_VALUE'
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, expression, nth=1, **extra):
|
||||
if expression is None:
|
||||
raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__)
|
||||
if nth is None or nth <= 0:
|
||||
raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__)
|
||||
super().__init__(expression, nth, **extra)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
sources = self.get_source_expressions()
|
||||
return sources[0].output_field
|
||||
|
||||
|
||||
class Ntile(Func):
|
||||
function = 'NTILE'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, num_buckets=1, **extra):
|
||||
if num_buckets <= 0:
|
||||
raise ValueError('num_buckets must be greater than 0.')
|
||||
super().__init__(num_buckets, **extra)
|
||||
|
||||
|
||||
class PercentRank(Func):
|
||||
function = 'PERCENT_RANK'
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Rank(Func):
|
||||
function = 'RANK'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class RowNumber(Func):
|
||||
function = 'ROW_NUMBER'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
270
venv/Lib/site-packages/django/db/models/indexes.py
Normal file
270
venv/Lib/site-packages/django/db/models/indexes.py
Normal file
@@ -0,0 +1,270 @@
|
||||
from django.db.backends.utils import names_digest, split_identifier
|
||||
from django.db.models.expressions import Col, ExpressionList, F, Func, OrderBy
|
||||
from django.db.models.functions import Collate
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.sql import Query
|
||||
from django.utils.functional import partition
|
||||
|
||||
__all__ = ['Index']
|
||||
|
||||
|
||||
class Index:
|
||||
suffix = 'idx'
|
||||
# The max length of the name of the index (restricted to 30 for
|
||||
# cross-database compatibility with Oracle)
|
||||
max_name_length = 30
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*expressions,
|
||||
fields=(),
|
||||
name=None,
|
||||
db_tablespace=None,
|
||||
opclasses=(),
|
||||
condition=None,
|
||||
include=None,
|
||||
):
|
||||
if opclasses and not name:
|
||||
raise ValueError('An index must be named to use opclasses.')
|
||||
if not isinstance(condition, (type(None), Q)):
|
||||
raise ValueError('Index.condition must be a Q instance.')
|
||||
if condition and not name:
|
||||
raise ValueError('An index must be named to use condition.')
|
||||
if not isinstance(fields, (list, tuple)):
|
||||
raise ValueError('Index.fields must be a list or tuple.')
|
||||
if not isinstance(opclasses, (list, tuple)):
|
||||
raise ValueError('Index.opclasses must be a list or tuple.')
|
||||
if not expressions and not fields:
|
||||
raise ValueError(
|
||||
'At least one field or expression is required to define an '
|
||||
'index.'
|
||||
)
|
||||
if expressions and fields:
|
||||
raise ValueError(
|
||||
'Index.fields and expressions are mutually exclusive.',
|
||||
)
|
||||
if expressions and not name:
|
||||
raise ValueError('An index must be named to use expressions.')
|
||||
if expressions and opclasses:
|
||||
raise ValueError(
|
||||
'Index.opclasses cannot be used with expressions. Use '
|
||||
'django.contrib.postgres.indexes.OpClass() instead.'
|
||||
)
|
||||
if opclasses and len(fields) != len(opclasses):
|
||||
raise ValueError('Index.fields and Index.opclasses must have the same number of elements.')
|
||||
if fields and not all(isinstance(field, str) for field in fields):
|
||||
raise ValueError('Index.fields must contain only strings with field names.')
|
||||
if include and not name:
|
||||
raise ValueError('A covering index must be named.')
|
||||
if not isinstance(include, (type(None), list, tuple)):
|
||||
raise ValueError('Index.include must be a list or tuple.')
|
||||
self.fields = list(fields)
|
||||
# A list of 2-tuple with the field name and ordering ('' or 'DESC').
|
||||
self.fields_orders = [
|
||||
(field_name[1:], 'DESC') if field_name.startswith('-') else (field_name, '')
|
||||
for field_name in self.fields
|
||||
]
|
||||
self.name = name or ''
|
||||
self.db_tablespace = db_tablespace
|
||||
self.opclasses = opclasses
|
||||
self.condition = condition
|
||||
self.include = tuple(include) if include else ()
|
||||
self.expressions = tuple(
|
||||
F(expression) if isinstance(expression, str) else expression
|
||||
for expression in expressions
|
||||
)
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return bool(self.expressions)
|
||||
|
||||
def _get_condition_sql(self, model, schema_editor):
|
||||
if self.condition is None:
|
||||
return None
|
||||
query = Query(model=model, alias_cols=False)
|
||||
where = query.build_where(self.condition)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def create_sql(self, model, schema_editor, using='', **kwargs):
|
||||
include = [model._meta.get_field(field_name).column for field_name in self.include]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
if self.expressions:
|
||||
index_expressions = []
|
||||
for expression in self.expressions:
|
||||
index_expression = IndexExpression(expression)
|
||||
index_expression.set_wrapper_classes(schema_editor.connection)
|
||||
index_expressions.append(index_expression)
|
||||
expressions = ExpressionList(*index_expressions).resolve_expression(
|
||||
Query(model, alias_cols=False),
|
||||
)
|
||||
fields = None
|
||||
col_suffixes = None
|
||||
else:
|
||||
fields = [
|
||||
model._meta.get_field(field_name)
|
||||
for field_name, _ in self.fields_orders
|
||||
]
|
||||
col_suffixes = [order[1] for order in self.fields_orders]
|
||||
expressions = None
|
||||
return schema_editor._create_index_sql(
|
||||
model, fields=fields, name=self.name, using=using,
|
||||
db_tablespace=self.db_tablespace, col_suffixes=col_suffixes,
|
||||
opclasses=self.opclasses, condition=condition, include=include,
|
||||
expressions=expressions, **kwargs,
|
||||
)
|
||||
|
||||
def remove_sql(self, model, schema_editor, **kwargs):
|
||||
return schema_editor._delete_index_sql(model, self.name, **kwargs)
|
||||
|
||||
def deconstruct(self):
|
||||
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace('django.db.models.indexes', 'django.db.models')
|
||||
kwargs = {'name': self.name}
|
||||
if self.fields:
|
||||
kwargs['fields'] = self.fields
|
||||
if self.db_tablespace is not None:
|
||||
kwargs['db_tablespace'] = self.db_tablespace
|
||||
if self.opclasses:
|
||||
kwargs['opclasses'] = self.opclasses
|
||||
if self.condition:
|
||||
kwargs['condition'] = self.condition
|
||||
if self.include:
|
||||
kwargs['include'] = self.include
|
||||
return (path, self.expressions, kwargs)
|
||||
|
||||
def clone(self):
|
||||
"""Create a copy of this Index."""
|
||||
_, args, kwargs = self.deconstruct()
|
||||
return self.__class__(*args, **kwargs)
|
||||
|
||||
def set_name_with_model(self, model):
|
||||
"""
|
||||
Generate a unique name for the index.
|
||||
|
||||
The name is divided into 3 parts - table name (12 chars), field name
|
||||
(8 chars) and unique hash + suffix (10 chars). Each part is made to
|
||||
fit its size by truncating the excess length.
|
||||
"""
|
||||
_, table_name = split_identifier(model._meta.db_table)
|
||||
column_names = [model._meta.get_field(field_name).column for field_name, order in self.fields_orders]
|
||||
column_names_with_order = [
|
||||
(('-%s' if order else '%s') % column_name)
|
||||
for column_name, (field_name, order) in zip(column_names, self.fields_orders)
|
||||
]
|
||||
# The length of the parts of the name is based on the default max
|
||||
# length of 30 characters.
|
||||
hash_data = [table_name] + column_names_with_order + [self.suffix]
|
||||
self.name = '%s_%s_%s' % (
|
||||
table_name[:11],
|
||||
column_names[0][:7],
|
||||
'%s_%s' % (names_digest(*hash_data, length=6), self.suffix),
|
||||
)
|
||||
if len(self.name) > self.max_name_length:
|
||||
raise ValueError(
|
||||
'Index too long for multiple database support. Is self.suffix '
|
||||
'longer than 3 characters?'
|
||||
)
|
||||
if self.name[0] == '_' or self.name[0].isdigit():
|
||||
self.name = 'D%s' % self.name[1:]
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s:%s%s%s%s%s%s%s>' % (
|
||||
self.__class__.__qualname__,
|
||||
'' if not self.fields else ' fields=%s' % repr(self.fields),
|
||||
'' if not self.expressions else ' expressions=%s' % repr(self.expressions),
|
||||
'' if not self.name else ' name=%s' % repr(self.name),
|
||||
''
|
||||
if self.db_tablespace is None
|
||||
else ' db_tablespace=%s' % repr(self.db_tablespace),
|
||||
'' if self.condition is None else ' condition=%s' % self.condition,
|
||||
'' if not self.include else ' include=%s' % repr(self.include),
|
||||
'' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.__class__ == other.__class__:
|
||||
return self.deconstruct() == other.deconstruct()
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class IndexExpression(Func):
|
||||
"""Order and wrap expressions for CREATE INDEX statements."""
|
||||
template = '%(expressions)s'
|
||||
wrapper_classes = (OrderBy, Collate)
|
||||
|
||||
def set_wrapper_classes(self, connection=None):
|
||||
# Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
|
||||
if connection and connection.features.collate_as_index_expression:
|
||||
self.wrapper_classes = tuple([
|
||||
wrapper_cls
|
||||
for wrapper_cls in self.wrapper_classes
|
||||
if wrapper_cls is not Collate
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def register_wrappers(cls, *wrapper_classes):
|
||||
cls.wrapper_classes = wrapper_classes
|
||||
|
||||
def resolve_expression(
|
||||
self,
|
||||
query=None,
|
||||
allow_joins=True,
|
||||
reuse=None,
|
||||
summarize=False,
|
||||
for_save=False,
|
||||
):
|
||||
expressions = list(self.flatten())
|
||||
# Split expressions and wrappers.
|
||||
index_expressions, wrappers = partition(
|
||||
lambda e: isinstance(e, self.wrapper_classes),
|
||||
expressions,
|
||||
)
|
||||
wrapper_types = [type(wrapper) for wrapper in wrappers]
|
||||
if len(wrapper_types) != len(set(wrapper_types)):
|
||||
raise ValueError(
|
||||
"Multiple references to %s can't be used in an indexed "
|
||||
"expression." % ', '.join([
|
||||
wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
|
||||
])
|
||||
)
|
||||
if expressions[1:len(wrappers) + 1] != wrappers:
|
||||
raise ValueError(
|
||||
'%s must be topmost expressions in an indexed expression.'
|
||||
% ', '.join([
|
||||
wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
|
||||
])
|
||||
)
|
||||
# Wrap expressions in parentheses if they are not column references.
|
||||
root_expression = index_expressions[1]
|
||||
resolve_root_expression = root_expression.resolve_expression(
|
||||
query,
|
||||
allow_joins,
|
||||
reuse,
|
||||
summarize,
|
||||
for_save,
|
||||
)
|
||||
if not isinstance(resolve_root_expression, Col):
|
||||
root_expression = Func(root_expression, template='(%(expressions)s)')
|
||||
|
||||
if wrappers:
|
||||
# Order wrappers and set their expressions.
|
||||
wrappers = sorted(
|
||||
wrappers,
|
||||
key=lambda w: self.wrapper_classes.index(type(w)),
|
||||
)
|
||||
wrappers = [wrapper.copy() for wrapper in wrappers]
|
||||
for i, wrapper in enumerate(wrappers[:-1]):
|
||||
wrapper.set_source_expressions([wrappers[i + 1]])
|
||||
# Set the root expression on the deepest wrapper.
|
||||
wrappers[-1].set_source_expressions([root_expression])
|
||||
self.set_source_expressions([wrappers[0]])
|
||||
else:
|
||||
# Use the root expression, if there are no wrappers.
|
||||
self.set_source_expressions([root_expression])
|
||||
return super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
# Casting to numeric is unnecessary.
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
687
venv/Lib/site-packages/django/db/models/lookups.py
Normal file
687
venv/Lib/site-packages/django/db/models/lookups.py
Normal file
@@ -0,0 +1,687 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from django.core.exceptions import EmptyResultSet
|
||||
from django.db.models.expressions import Case, Expression, Func, Value, When
|
||||
from django.db.models.fields import (
|
||||
BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField,
|
||||
)
|
||||
from django.db.models.query_utils import RegisterLookupMixin
|
||||
from django.utils.datastructures import OrderedSet
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
|
||||
class Lookup(Expression):
|
||||
lookup_name = None
|
||||
prepare_rhs = True
|
||||
can_use_none_as_rhs = False
|
||||
|
||||
def __init__(self, lhs, rhs):
|
||||
self.lhs, self.rhs = lhs, rhs
|
||||
self.rhs = self.get_prep_lookup()
|
||||
self.lhs = self.get_prep_lhs()
|
||||
if hasattr(self.lhs, 'get_bilateral_transforms'):
|
||||
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
||||
else:
|
||||
bilateral_transforms = []
|
||||
if bilateral_transforms:
|
||||
# Warn the user as soon as possible if they are trying to apply
|
||||
# a bilateral transformation on a nested QuerySet: that won't work.
|
||||
from django.db.models.sql.query import ( # avoid circular import
|
||||
Query,
|
||||
)
|
||||
if isinstance(rhs, Query):
|
||||
raise NotImplementedError("Bilateral transformations on nested querysets are not implemented.")
|
||||
self.bilateral_transforms = bilateral_transforms
|
||||
|
||||
def apply_bilateral_transforms(self, value):
|
||||
for transform in self.bilateral_transforms:
|
||||
value = transform(value)
|
||||
return value
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})'
|
||||
|
||||
def batch_process_rhs(self, compiler, connection, rhs=None):
|
||||
if rhs is None:
|
||||
rhs = self.rhs
|
||||
if self.bilateral_transforms:
|
||||
sqls, sqls_params = [], []
|
||||
for p in rhs:
|
||||
value = Value(p, output_field=self.lhs.output_field)
|
||||
value = self.apply_bilateral_transforms(value)
|
||||
value = value.resolve_expression(compiler.query)
|
||||
sql, sql_params = compiler.compile(value)
|
||||
sqls.append(sql)
|
||||
sqls_params.extend(sql_params)
|
||||
else:
|
||||
_, params = self.get_db_prep_lookup(rhs, connection)
|
||||
sqls, sqls_params = ['%s'] * len(params), params
|
||||
return sqls, sqls_params
|
||||
|
||||
def get_source_expressions(self):
|
||||
if self.rhs_is_direct_value():
|
||||
return [self.lhs]
|
||||
return [self.lhs, self.rhs]
|
||||
|
||||
def set_source_expressions(self, new_exprs):
|
||||
if len(new_exprs) == 1:
|
||||
self.lhs = new_exprs[0]
|
||||
else:
|
||||
self.lhs, self.rhs = new_exprs
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if not self.prepare_rhs or hasattr(self.rhs, 'resolve_expression'):
|
||||
return self.rhs
|
||||
if hasattr(self.lhs, 'output_field'):
|
||||
if hasattr(self.lhs.output_field, 'get_prep_value'):
|
||||
return self.lhs.output_field.get_prep_value(self.rhs)
|
||||
elif self.rhs_is_direct_value():
|
||||
return Value(self.rhs)
|
||||
return self.rhs
|
||||
|
||||
def get_prep_lhs(self):
|
||||
if hasattr(self.lhs, 'resolve_expression'):
|
||||
return self.lhs
|
||||
return Value(self.lhs)
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
return ('%s', [value])
|
||||
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
lhs = lhs or self.lhs
|
||||
if hasattr(lhs, 'resolve_expression'):
|
||||
lhs = lhs.resolve_expression(compiler.query)
|
||||
sql, params = compiler.compile(lhs)
|
||||
if isinstance(lhs, Lookup):
|
||||
# Wrapped in parentheses to respect operator precedence.
|
||||
sql = f'({sql})'
|
||||
return sql, params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
value = self.rhs
|
||||
if self.bilateral_transforms:
|
||||
if self.rhs_is_direct_value():
|
||||
# Do not call get_db_prep_lookup here as the value will be
|
||||
# transformed before being used for lookup
|
||||
value = Value(value, output_field=self.lhs.output_field)
|
||||
value = self.apply_bilateral_transforms(value)
|
||||
value = value.resolve_expression(compiler.query)
|
||||
if hasattr(value, 'as_sql'):
|
||||
sql, params = compiler.compile(value)
|
||||
# Ensure expression is wrapped in parentheses to respect operator
|
||||
# precedence but avoid double wrapping as it can be misinterpreted
|
||||
# on some backends (e.g. subqueries on SQLite).
|
||||
if sql and sql[0] != '(':
|
||||
sql = '(%s)' % sql
|
||||
return sql, params
|
||||
else:
|
||||
return self.get_db_prep_lookup(value, connection)
|
||||
|
||||
def rhs_is_direct_value(self):
|
||||
return not hasattr(self.rhs, 'as_sql')
|
||||
|
||||
def get_group_by_cols(self, alias=None):
|
||||
cols = []
|
||||
for source in self.get_source_expressions():
|
||||
cols.extend(source.get_group_by_cols())
|
||||
return cols
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
# Oracle doesn't allow EXISTS() and filters to be compared to another
|
||||
# expression unless they're wrapped in a CASE WHEN.
|
||||
wrapped = False
|
||||
exprs = []
|
||||
for expr in (self.lhs, self.rhs):
|
||||
if connection.ops.conditional_expression_supported_in_where_clause(expr):
|
||||
expr = Case(When(expr, then=True), default=False)
|
||||
wrapped = True
|
||||
exprs.append(expr)
|
||||
lookup = type(self)(*exprs) if wrapped else self
|
||||
return lookup.as_sql(compiler, connection)
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return BooleanField()
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return self.__class__, self.lhs, self.rhs
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Lookup):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(make_hashable(self.identity))
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
c.lhs = self.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
c.rhs = self.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
return c
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
# Wrap filters with a CASE WHEN expression if a database backend
|
||||
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
||||
# BY list.
|
||||
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
|
||||
sql = f'CASE WHEN {sql} THEN 1 ELSE 0 END'
|
||||
return sql, params
|
||||
|
||||
|
||||
class Transform(RegisterLookupMixin, Func):
|
||||
"""
|
||||
RegisterLookupMixin() is first so that get_lookup() and get_transform()
|
||||
first examine self and then check output_field.
|
||||
"""
|
||||
bilateral = False
|
||||
arity = 1
|
||||
|
||||
@property
|
||||
def lhs(self):
|
||||
return self.get_source_expressions()[0]
|
||||
|
||||
def get_bilateral_transforms(self):
|
||||
if hasattr(self.lhs, 'get_bilateral_transforms'):
|
||||
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
||||
else:
|
||||
bilateral_transforms = []
|
||||
if self.bilateral:
|
||||
bilateral_transforms.append(self.__class__)
|
||||
return bilateral_transforms
|
||||
|
||||
|
||||
class BuiltinLookup(Lookup):
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
lhs_sql, params = super().process_lhs(compiler, connection, lhs)
|
||||
field_internal_type = self.lhs.output_field.get_internal_type()
|
||||
db_type = self.lhs.output_field.db_type(connection=connection)
|
||||
lhs_sql = connection.ops.field_cast_sql(
|
||||
db_type, field_internal_type) % lhs_sql
|
||||
lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
|
||||
return lhs_sql, list(params)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs_sql, params = self.process_lhs(compiler, connection)
|
||||
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
||||
params.extend(rhs_params)
|
||||
rhs_sql = self.get_rhs_op(connection, rhs_sql)
|
||||
return '%s %s' % (lhs_sql, rhs_sql), params
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return connection.operators[self.lookup_name] % rhs
|
||||
|
||||
|
||||
class FieldGetDbPrepValueMixin:
|
||||
"""
|
||||
Some lookups require Field.get_db_prep_value() to be called on their
|
||||
inputs.
|
||||
"""
|
||||
get_db_prep_lookup_value_is_iterable = False
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
# For relational fields, use the 'target_field' attribute of the
|
||||
# output_field.
|
||||
field = getattr(self.lhs.output_field, 'target_field', None)
|
||||
get_db_prep_value = getattr(field, 'get_db_prep_value', None) or self.lhs.output_field.get_db_prep_value
|
||||
return (
|
||||
'%s',
|
||||
[get_db_prep_value(v, connection, prepared=True) for v in value]
|
||||
if self.get_db_prep_lookup_value_is_iterable else
|
||||
[get_db_prep_value(value, connection, prepared=True)]
|
||||
)
|
||||
|
||||
|
||||
class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
|
||||
"""
|
||||
Some lookups require Field.get_db_prep_value() to be called on each value
|
||||
in an iterable.
|
||||
"""
|
||||
get_db_prep_lookup_value_is_iterable = True
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if hasattr(self.rhs, 'resolve_expression'):
|
||||
return self.rhs
|
||||
prepared_values = []
|
||||
for rhs_value in self.rhs:
|
||||
if hasattr(rhs_value, 'resolve_expression'):
|
||||
# An expression will be handled by the database but can coexist
|
||||
# alongside real values.
|
||||
pass
|
||||
elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
|
||||
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
|
||||
prepared_values.append(rhs_value)
|
||||
return prepared_values
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if self.rhs_is_direct_value():
|
||||
# rhs should be an iterable of values. Use batch_process_rhs()
|
||||
# to prepare/transform those values.
|
||||
return self.batch_process_rhs(compiler, connection)
|
||||
else:
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
params = [param]
|
||||
if hasattr(param, 'resolve_expression'):
|
||||
param = param.resolve_expression(compiler.query)
|
||||
if hasattr(param, 'as_sql'):
|
||||
sql, params = compiler.compile(param)
|
||||
return sql, params
|
||||
|
||||
def batch_process_rhs(self, compiler, connection, rhs=None):
|
||||
pre_processed = super().batch_process_rhs(compiler, connection, rhs)
|
||||
# The params list may contain expressions which compile to a
|
||||
# sql/param pair. Zip them to get sql and param pairs that refer to the
|
||||
# same argument and attempt to replace them with the result of
|
||||
# compiling the param step.
|
||||
sql, params = zip(*(
|
||||
self.resolve_expression_parameter(compiler, connection, sql, param)
|
||||
for sql, param in zip(*pre_processed)
|
||||
))
|
||||
params = itertools.chain.from_iterable(params)
|
||||
return sql, tuple(params)
|
||||
|
||||
|
||||
class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup):
|
||||
"""Lookup defined by operators on PostgreSQL."""
|
||||
postgres_operator = None
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(lhs_params) + tuple(rhs_params)
|
||||
return '%s %s %s' % (lhs, self.postgres_operator, rhs), params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = 'exact'
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
from django.db.models.sql.query import Query
|
||||
if isinstance(self.rhs, Query):
|
||||
if self.rhs.has_limit_one():
|
||||
if not self.rhs.has_select_fields:
|
||||
self.rhs.clear_select_clause()
|
||||
self.rhs.add_fields(['pk'])
|
||||
else:
|
||||
raise ValueError(
|
||||
'The QuerySet value for an exact lookup must be limited to '
|
||||
'one result using slicing.'
|
||||
)
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Avoid comparison against direct rhs if lhs is a boolean value. That
|
||||
# turns "boolfield__exact=True" into "WHERE boolean_field" instead of
|
||||
# "WHERE boolean_field = True" when allowed.
|
||||
if (
|
||||
isinstance(self.rhs, bool) and
|
||||
getattr(self.lhs, 'conditional', False) and
|
||||
connection.ops.conditional_expression_supported_in_where_clause(self.lhs)
|
||||
):
|
||||
lhs_sql, params = self.process_lhs(compiler, connection)
|
||||
template = '%s' if self.rhs else 'NOT %s'
|
||||
return template % lhs_sql, params
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IExact(BuiltinLookup):
|
||||
lookup_name = 'iexact'
|
||||
prepare_rhs = False
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super().process_rhs(qn, connection)
|
||||
if params:
|
||||
params[0] = connection.ops.prep_for_iexact_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = 'gt'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = 'gte'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = 'lt'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = 'lte'
|
||||
|
||||
|
||||
class IntegerFieldFloatRounding:
|
||||
"""
|
||||
Allow floats to work as query values for IntegerField. Without this, the
|
||||
decimal portion of the float would always be discarded.
|
||||
"""
|
||||
def get_prep_lookup(self):
|
||||
if isinstance(self.rhs, float):
|
||||
self.rhs = math.ceil(self.rhs)
|
||||
return super().get_prep_lookup()
|
||||
|
||||
|
||||
@IntegerField.register_lookup
|
||||
class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
@IntegerField.register_lookup
|
||||
class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
|
||||
pass
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
lookup_name = 'in'
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
db_rhs = getattr(self.rhs, '_db', None)
|
||||
if db_rhs is not None and db_rhs != connection.alias:
|
||||
raise ValueError(
|
||||
"Subqueries aren't allowed across different databases. Force "
|
||||
"the inner query to be evaluated using `list(inner_query)`."
|
||||
)
|
||||
|
||||
if self.rhs_is_direct_value():
|
||||
# Remove None from the list as NULL is never equal to anything.
|
||||
try:
|
||||
rhs = OrderedSet(self.rhs)
|
||||
rhs.discard(None)
|
||||
except TypeError: # Unhashable items in self.rhs
|
||||
rhs = [r for r in self.rhs if r is not None]
|
||||
|
||||
if not rhs:
|
||||
raise EmptyResultSet
|
||||
|
||||
# rhs should be an iterable; use batch_process_rhs() to
|
||||
# prepare/transform those values.
|
||||
sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
|
||||
placeholder = '(' + ', '.join(sqls) + ')'
|
||||
return (placeholder, sqls_params)
|
||||
else:
|
||||
from django.db.models.sql.query import ( # avoid circular import
|
||||
Query,
|
||||
)
|
||||
if isinstance(self.rhs, Query):
|
||||
query = self.rhs
|
||||
query.clear_ordering(clear_default=True)
|
||||
if not query.has_select_fields:
|
||||
query.clear_select_clause()
|
||||
query.add_fields(['pk'])
|
||||
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
def get_group_by_cols(self, alias=None):
|
||||
cols = self.lhs.get_group_by_cols()
|
||||
if hasattr(self.rhs, 'get_group_by_cols'):
|
||||
if not getattr(self.rhs, 'has_select_fields', True):
|
||||
self.rhs.clear_select_clause()
|
||||
self.rhs.add_fields(['pk'])
|
||||
cols.extend(self.rhs.get_group_by_cols())
|
||||
return cols
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return 'IN %s' % rhs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
max_in_list_size = connection.ops.max_in_list_size()
|
||||
if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
|
||||
return self.split_parameter_list_as_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
def split_parameter_list_as_sql(self, compiler, connection):
|
||||
# This is a special case for databases which limit the number of
|
||||
# elements which can appear in an 'IN' clause.
|
||||
max_in_list_size = connection.ops.max_in_list_size()
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.batch_process_rhs(compiler, connection)
|
||||
in_clause_elements = ['(']
|
||||
params = []
|
||||
for offset in range(0, len(rhs_params), max_in_list_size):
|
||||
if offset > 0:
|
||||
in_clause_elements.append(' OR ')
|
||||
in_clause_elements.append('%s IN (' % lhs)
|
||||
params.extend(lhs_params)
|
||||
sqls = rhs[offset: offset + max_in_list_size]
|
||||
sqls_params = rhs_params[offset: offset + max_in_list_size]
|
||||
param_group = ', '.join(sqls)
|
||||
in_clause_elements.append(param_group)
|
||||
in_clause_elements.append(')')
|
||||
params.extend(sqls_params)
|
||||
in_clause_elements.append(')')
|
||||
return ''.join(in_clause_elements), params
|
||||
|
||||
|
||||
class PatternLookup(BuiltinLookup):
|
||||
param_pattern = '%%%s%%'
|
||||
prepare_rhs = False
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
# Assume we are in startswith. We need to produce SQL like:
|
||||
# col LIKE %s, ['thevalue%']
|
||||
# For python values we can (and should) do that directly in Python,
|
||||
# but if the value is for example reference to other column, then
|
||||
# we need to add the % pattern match to the lookup by something like
|
||||
# col LIKE othercol || '%%'
|
||||
# So, for Python values we don't need any special pattern, but for
|
||||
# SQL reference values or SQL transformations we need the correct
|
||||
# pattern added.
|
||||
if hasattr(self.rhs, 'as_sql') or self.bilateral_transforms:
|
||||
pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
|
||||
return pattern.format(rhs)
|
||||
else:
|
||||
return super().get_rhs_op(connection, rhs)
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super().process_rhs(qn, connection)
|
||||
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
|
||||
params[0] = self.param_pattern % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Contains(PatternLookup):
|
||||
lookup_name = 'contains'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IContains(Contains):
|
||||
lookup_name = 'icontains'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class StartsWith(PatternLookup):
|
||||
lookup_name = 'startswith'
|
||||
param_pattern = '%s%%'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IStartsWith(StartsWith):
|
||||
lookup_name = 'istartswith'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class EndsWith(PatternLookup):
|
||||
lookup_name = 'endswith'
|
||||
param_pattern = '%%%s'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IEndsWith(EndsWith):
|
||||
lookup_name = 'iendswith'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
lookup_name = 'range'
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IsNull(BuiltinLookup):
|
||||
lookup_name = 'isnull'
|
||||
prepare_rhs = False
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not isinstance(self.rhs, bool):
|
||||
raise ValueError(
|
||||
'The QuerySet value for an isnull lookup must be True or '
|
||||
'False.'
|
||||
)
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
if self.rhs:
|
||||
return "%s IS NULL" % sql, params
|
||||
else:
|
||||
return "%s IS NOT NULL" % sql, params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Regex(BuiltinLookup):
|
||||
lookup_name = 'regex'
|
||||
prepare_rhs = False
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if self.lookup_name in connection.operators:
|
||||
return super().as_sql(compiler, connection)
|
||||
else:
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
sql_template = connection.ops.regex_lookup(self.lookup_name)
|
||||
return sql_template % (lhs, rhs), lhs_params + rhs_params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IRegex(Regex):
|
||||
lookup_name = 'iregex'
|
||||
|
||||
|
||||
class YearLookup(Lookup):
|
||||
def year_lookup_bounds(self, connection, year):
|
||||
from django.db.models.functions import ExtractIsoYear
|
||||
iso_year = isinstance(self.lhs, ExtractIsoYear)
|
||||
output_field = self.lhs.lhs.output_field
|
||||
if isinstance(output_field, DateTimeField):
|
||||
bounds = connection.ops.year_lookup_bounds_for_datetime_field(
|
||||
year, iso_year=iso_year,
|
||||
)
|
||||
else:
|
||||
bounds = connection.ops.year_lookup_bounds_for_date_field(
|
||||
year, iso_year=iso_year,
|
||||
)
|
||||
return bounds
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Avoid the extract operation if the rhs is a direct value to allow
|
||||
# indexes to be used.
|
||||
if self.rhs_is_direct_value():
|
||||
# Skip the extract part by directly using the originating field,
|
||||
# that is self.lhs.lhs.
|
||||
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
|
||||
rhs_sql, _ = self.process_rhs(compiler, connection)
|
||||
rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
|
||||
start, finish = self.year_lookup_bounds(connection, self.rhs)
|
||||
params.extend(self.get_bound_params(start, finish))
|
||||
return '%s %s' % (lhs_sql, rhs_sql), params
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
def get_direct_rhs_sql(self, connection, rhs):
|
||||
return connection.operators[self.lookup_name] % rhs
|
||||
|
||||
def get_bound_params(self, start, finish):
|
||||
raise NotImplementedError(
|
||||
'subclasses of YearLookup must provide a get_bound_params() method'
|
||||
)
|
||||
|
||||
|
||||
class YearExact(YearLookup, Exact):
|
||||
def get_direct_rhs_sql(self, connection, rhs):
|
||||
return 'BETWEEN %s AND %s'
|
||||
|
||||
def get_bound_params(self, start, finish):
|
||||
return (start, finish)
|
||||
|
||||
|
||||
class YearGt(YearLookup, GreaterThan):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (finish,)
|
||||
|
||||
|
||||
class YearGte(YearLookup, GreaterThanOrEqual):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (start,)
|
||||
|
||||
|
||||
class YearLt(YearLookup, LessThan):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (start,)
|
||||
|
||||
|
||||
class YearLte(YearLookup, LessThanOrEqual):
|
||||
def get_bound_params(self, start, finish):
|
||||
return (finish,)
|
||||
|
||||
|
||||
class UUIDTextMixin:
|
||||
"""
|
||||
Strip hyphens from a value when filtering a UUIDField on backends without
|
||||
a native datatype for UUID.
|
||||
"""
|
||||
def process_rhs(self, qn, connection):
|
||||
if not connection.features.has_native_uuid_field:
|
||||
from django.db.models.functions import Replace
|
||||
if self.rhs_is_direct_value():
|
||||
self.rhs = Value(self.rhs)
|
||||
self.rhs = Replace(self.rhs, Value('-'), Value(''), output_field=CharField())
|
||||
rhs, params = super().process_rhs(qn, connection)
|
||||
return rhs, params
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIExact(UUIDTextMixin, IExact):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDContains(UUIDTextMixin, Contains):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIContains(UUIDTextMixin, IContains):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDStartsWith(UUIDTextMixin, StartsWith):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDEndsWith(UUIDTextMixin, EndsWith):
|
||||
pass
|
||||
|
||||
|
||||
@UUIDField.register_lookup
|
||||
class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
|
||||
pass
|
||||
203
venv/Lib/site-packages/django/db/models/manager.py
Normal file
203
venv/Lib/site-packages/django/db/models/manager.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import copy
|
||||
import inspect
|
||||
from importlib import import_module
|
||||
|
||||
from django.db import router
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
|
||||
class BaseManager:
|
||||
# To retain order, track each time a Manager instance is created.
|
||||
creation_counter = 0
|
||||
|
||||
# Set to True for the 'objects' managers that are automatically created.
|
||||
auto_created = False
|
||||
|
||||
#: If set to True the manager will be serialized into migrations and will
|
||||
#: thus be available in e.g. RunPython operations.
|
||||
use_in_migrations = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Capture the arguments to make returning them trivial.
|
||||
obj = super().__new__(cls)
|
||||
obj._constructor_args = (args, kwargs)
|
||||
return obj
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._set_creation_counter()
|
||||
self.model = None
|
||||
self.name = None
|
||||
self._db = None
|
||||
self._hints = {}
|
||||
|
||||
def __str__(self):
|
||||
"""Return "app_label.model_label.manager_name"."""
|
||||
return '%s.%s' % (self.model._meta.label, self.name)
|
||||
|
||||
def __class_getitem__(cls, *args, **kwargs):
|
||||
return cls
|
||||
|
||||
def deconstruct(self):
|
||||
"""
|
||||
Return a 5-tuple of the form (as_manager (True), manager_class,
|
||||
queryset_class, args, kwargs).
|
||||
|
||||
Raise a ValueError if the manager is dynamically generated.
|
||||
"""
|
||||
qs_class = self._queryset_class
|
||||
if getattr(self, '_built_with_as_manager', False):
|
||||
# using MyQuerySet.as_manager()
|
||||
return (
|
||||
True, # as_manager
|
||||
None, # manager_class
|
||||
'%s.%s' % (qs_class.__module__, qs_class.__name__), # qs_class
|
||||
None, # args
|
||||
None, # kwargs
|
||||
)
|
||||
else:
|
||||
module_name = self.__module__
|
||||
name = self.__class__.__name__
|
||||
# Make sure it's actually there and not an inner class
|
||||
module = import_module(module_name)
|
||||
if not hasattr(module, name):
|
||||
raise ValueError(
|
||||
"Could not find manager %s in %s.\n"
|
||||
"Please note that you need to inherit from managers you "
|
||||
"dynamically generated with 'from_queryset()'."
|
||||
% (name, module_name)
|
||||
)
|
||||
return (
|
||||
False, # as_manager
|
||||
'%s.%s' % (module_name, name), # manager_class
|
||||
None, # qs_class
|
||||
self._constructor_args[0], # args
|
||||
self._constructor_args[1], # kwargs
|
||||
)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _get_queryset_methods(cls, queryset_class):
|
||||
def create_method(name, method):
|
||||
def manager_method(self, *args, **kwargs):
|
||||
return getattr(self.get_queryset(), name)(*args, **kwargs)
|
||||
manager_method.__name__ = method.__name__
|
||||
manager_method.__doc__ = method.__doc__
|
||||
return manager_method
|
||||
|
||||
new_methods = {}
|
||||
for name, method in inspect.getmembers(queryset_class, predicate=inspect.isfunction):
|
||||
# Only copy missing methods.
|
||||
if hasattr(cls, name):
|
||||
continue
|
||||
# Only copy public methods or methods with the attribute `queryset_only=False`.
|
||||
queryset_only = getattr(method, 'queryset_only', None)
|
||||
if queryset_only or (queryset_only is None and name.startswith('_')):
|
||||
continue
|
||||
# Copy the method onto the manager.
|
||||
new_methods[name] = create_method(name, method)
|
||||
return new_methods
|
||||
|
||||
@classmethod
|
||||
def from_queryset(cls, queryset_class, class_name=None):
|
||||
if class_name is None:
|
||||
class_name = '%sFrom%s' % (cls.__name__, queryset_class.__name__)
|
||||
return type(class_name, (cls,), {
|
||||
'_queryset_class': queryset_class,
|
||||
**cls._get_queryset_methods(queryset_class),
|
||||
})
|
||||
|
||||
def contribute_to_class(self, cls, name):
|
||||
self.name = self.name or name
|
||||
self.model = cls
|
||||
|
||||
setattr(cls, name, ManagerDescriptor(self))
|
||||
|
||||
cls._meta.add_manager(self)
|
||||
|
||||
def _set_creation_counter(self):
|
||||
"""
|
||||
Set the creation counter value for this instance and increment the
|
||||
class-level copy.
|
||||
"""
|
||||
self.creation_counter = BaseManager.creation_counter
|
||||
BaseManager.creation_counter += 1
|
||||
|
||||
def db_manager(self, using=None, hints=None):
|
||||
obj = copy.copy(self)
|
||||
obj._db = using or self._db
|
||||
obj._hints = hints or self._hints
|
||||
return obj
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
return self._db or router.db_for_read(self.model, **self._hints)
|
||||
|
||||
#######################
|
||||
# PROXIES TO QUERYSET #
|
||||
#######################
|
||||
|
||||
def get_queryset(self):
|
||||
"""
|
||||
Return a new QuerySet object. Subclasses can override this method to
|
||||
customize the behavior of the Manager.
|
||||
"""
|
||||
return self._queryset_class(model=self.model, using=self._db, hints=self._hints)
|
||||
|
||||
def all(self):
|
||||
# We can't proxy this method through the `QuerySet` like we do for the
|
||||
# rest of the `QuerySet` methods. This is because `QuerySet.all()`
|
||||
# works by creating a "copy" of the current queryset and in making said
|
||||
# copy, all the cached `prefetch_related` lookups are lost. See the
|
||||
# implementation of `RelatedManager.get_queryset()` for a better
|
||||
# understanding of how this comes into play.
|
||||
return self.get_queryset()
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, self.__class__) and
|
||||
self._constructor_args == other._constructor_args
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
|
||||
class Manager(BaseManager.from_queryset(QuerySet)):
|
||||
pass
|
||||
|
||||
|
||||
class ManagerDescriptor:
|
||||
|
||||
def __init__(self, manager):
|
||||
self.manager = manager
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is not None:
|
||||
raise AttributeError("Manager isn't accessible via %s instances" % cls.__name__)
|
||||
|
||||
if cls._meta.abstract:
|
||||
raise AttributeError("Manager isn't available; %s is abstract" % (
|
||||
cls._meta.object_name,
|
||||
))
|
||||
|
||||
if cls._meta.swapped:
|
||||
raise AttributeError(
|
||||
"Manager isn't available; '%s' has been swapped for '%s'" % (
|
||||
cls._meta.label,
|
||||
cls._meta.swapped,
|
||||
)
|
||||
)
|
||||
|
||||
return cls._meta.managers_map[self.manager.name]
|
||||
|
||||
|
||||
class EmptyManager(Manager):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().none()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user