Ajoutez des fichiers projet.

This commit is contained in:
Ambulance Clerc
2021-12-18 18:43:17 +01:00
parent 3c4d48ed26
commit 46254605fc
4842 changed files with 732322 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
import django
from django.core.handlers.asgi import ASGIHandler
def get_asgi_application():
"""
The public interface to Django's ASGI support. Return an ASGI 3 callable.
Avoids making django.core.handlers.ASGIHandler a public API, in case the
internal implementation changes or moves in the future.
"""
django.setup(set_prefix=False)
return ASGIHandler()

View File

@@ -0,0 +1,66 @@
"""
Caching framework.
This package defines set of cache backends that all conform to a simple API.
In a nutshell, a cache is a set of values -- which can be any object that
may be pickled -- identified by string keys. For the complete API, see
the abstract BaseCache class in django.core.cache.backends.base.
Client code should use the `cache` variable defined here to access the default
cache backend and look up non-default cache backends in the `caches` dict-like
object.
See docs/topics/cache.txt for information on the public API.
"""
from django.core import signals
from django.core.cache.backends.base import (
BaseCache, CacheKeyWarning, InvalidCacheBackendError, InvalidCacheKey,
)
from django.utils.connection import BaseConnectionHandler, ConnectionProxy
from django.utils.module_loading import import_string
__all__ = [
'cache', 'caches', 'DEFAULT_CACHE_ALIAS', 'InvalidCacheBackendError',
'CacheKeyWarning', 'BaseCache', 'InvalidCacheKey',
]
DEFAULT_CACHE_ALIAS = 'default'
class CacheHandler(BaseConnectionHandler):
settings_name = 'CACHES'
exception_class = InvalidCacheBackendError
def create_connection(self, alias):
params = self.settings[alias].copy()
backend = params.pop('BACKEND')
location = params.pop('LOCATION', '')
try:
backend_cls = import_string(backend)
except ImportError as e:
raise InvalidCacheBackendError(
"Could not find backend '%s': %s" % (backend, e)
) from e
return backend_cls(location, params)
def all(self, initialized_only=False):
return [
self[alias] for alias in self
# If initialized_only is True, return only initialized caches.
if not initialized_only or hasattr(self._connections, alias)
]
caches = CacheHandler()
cache = ConnectionProxy(caches, DEFAULT_CACHE_ALIAS)
def close_caches(**kwargs):
# Some caches need to do a cleanup at the end of a request cycle. If not
# implemented in a particular backend cache.close() is a no-op.
for cache in caches.all(initialized_only=True):
cache.close()
signals.request_finished.connect(close_caches)

View File

@@ -0,0 +1,385 @@
"Base Cache class."
import time
import warnings
from asgiref.sync import sync_to_async
from django.core.exceptions import ImproperlyConfigured
from django.utils.module_loading import import_string
class InvalidCacheBackendError(ImproperlyConfigured):
pass
class CacheKeyWarning(RuntimeWarning):
pass
class InvalidCacheKey(ValueError):
pass
# Stub class to ensure not passing in a `timeout` argument results in
# the default timeout
DEFAULT_TIMEOUT = object()
# Memcached does not accept keys longer than this.
MEMCACHE_MAX_KEY_LENGTH = 250
def default_key_func(key, key_prefix, version):
"""
Default function to generate keys.
Construct the key used by all other methods. By default, prepend
the `key_prefix`. KEY_FUNCTION can be used to specify an alternate
function with custom key making behavior.
"""
return '%s:%s:%s' % (key_prefix, version, key)
def get_key_func(key_func):
"""
Function to decide which key function to use.
Default to ``default_key_func``.
"""
if key_func is not None:
if callable(key_func):
return key_func
else:
return import_string(key_func)
return default_key_func
class BaseCache:
_missing_key = object()
def __init__(self, params):
timeout = params.get('timeout', params.get('TIMEOUT', 300))
if timeout is not None:
try:
timeout = int(timeout)
except (ValueError, TypeError):
timeout = 300
self.default_timeout = timeout
options = params.get('OPTIONS', {})
max_entries = params.get('max_entries', options.get('MAX_ENTRIES', 300))
try:
self._max_entries = int(max_entries)
except (ValueError, TypeError):
self._max_entries = 300
cull_frequency = params.get('cull_frequency', options.get('CULL_FREQUENCY', 3))
try:
self._cull_frequency = int(cull_frequency)
except (ValueError, TypeError):
self._cull_frequency = 3
self.key_prefix = params.get('KEY_PREFIX', '')
self.version = params.get('VERSION', 1)
self.key_func = get_key_func(params.get('KEY_FUNCTION'))
def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
"""
Return the timeout value usable by this backend based upon the provided
timeout.
"""
if timeout == DEFAULT_TIMEOUT:
timeout = self.default_timeout
elif timeout == 0:
# ticket 21147 - avoid time.time() related precision issues
timeout = -1
return None if timeout is None else time.time() + timeout
def make_key(self, key, version=None):
"""
Construct the key used by all other methods. By default, use the
key_func to generate a key (which, by default, prepends the
`key_prefix' and 'version'). A different key function can be provided
at the time of cache construction; alternatively, you can subclass the
cache backend to provide custom key making behavior.
"""
if version is None:
version = self.version
return self.key_func(key, self.key_prefix, version)
def validate_key(self, key):
"""
Warn about keys that would not be portable to the memcached
backend. This encourages (but does not force) writing backend-portable
cache code.
"""
for warning in memcache_key_warnings(key):
warnings.warn(warning, CacheKeyWarning)
def make_and_validate_key(self, key, version=None):
"""Helper to make and validate keys."""
key = self.make_key(key, version=version)
self.validate_key(key)
return key
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
"""
Set a value in the cache if the key does not already exist. If
timeout is given, use that timeout for the key; otherwise use the
default cache timeout.
Return True if the value was stored, False otherwise.
"""
raise NotImplementedError('subclasses of BaseCache must provide an add() method')
async def aadd(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
return await sync_to_async(self.add, thread_sensitive=True)(key, value, timeout, version)
def get(self, key, default=None, version=None):
"""
Fetch a given key from the cache. If the key does not exist, return
default, which itself defaults to None.
"""
raise NotImplementedError('subclasses of BaseCache must provide a get() method')
async def aget(self, key, default=None, version=None):
return await sync_to_async(self.get, thread_sensitive=True)(key, default, version)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
"""
Set a value in the cache. If timeout is given, use that timeout for the
key; otherwise use the default cache timeout.
"""
raise NotImplementedError('subclasses of BaseCache must provide a set() method')
async def aset(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
return await sync_to_async(self.set, thread_sensitive=True)(key, value, timeout, version)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
"""
Update the key's expiry time using timeout. Return True if successful
or False if the key does not exist.
"""
raise NotImplementedError('subclasses of BaseCache must provide a touch() method')
async def atouch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
return await sync_to_async(self.touch, thread_sensitive=True)(key, timeout, version)
def delete(self, key, version=None):
"""
Delete a key from the cache and return whether it succeeded, failing
silently.
"""
raise NotImplementedError('subclasses of BaseCache must provide a delete() method')
async def adelete(self, key, version=None):
return await sync_to_async(self.delete, thread_sensitive=True)(key, version)
def get_many(self, keys, version=None):
"""
Fetch a bunch of keys from the cache. For certain backends (memcached,
pgsql) this can be *much* faster when fetching multiple values.
Return a dict mapping each key in keys to its value. If the given
key is missing, it will be missing from the response dict.
"""
d = {}
for k in keys:
val = self.get(k, self._missing_key, version=version)
if val is not self._missing_key:
d[k] = val
return d
async def aget_many(self, keys, version=None):
"""See get_many()."""
d = {}
for k in keys:
val = await self.aget(k, self._missing_key, version=version)
if val is not self._missing_key:
d[k] = val
return d
def get_or_set(self, key, default, timeout=DEFAULT_TIMEOUT, version=None):
"""
Fetch a given key from the cache. If the key does not exist,
add the key and set it to the default value. The default value can
also be any callable. If timeout is given, use that timeout for the
key; otherwise use the default cache timeout.
Return the value of the key stored or retrieved.
"""
val = self.get(key, self._missing_key, version=version)
if val is self._missing_key:
if callable(default):
default = default()
self.add(key, default, timeout=timeout, version=version)
# Fetch the value again to avoid a race condition if another caller
# added a value between the first get() and the add() above.
return self.get(key, default, version=version)
return val
async def aget_or_set(self, key, default, timeout=DEFAULT_TIMEOUT, version=None):
"""See get_or_set()."""
val = await self.aget(key, self._missing_key, version=version)
if val is self._missing_key:
if callable(default):
default = default()
await self.aadd(key, default, timeout=timeout, version=version)
# Fetch the value again to avoid a race condition if another caller
# added a value between the first aget() and the aadd() above.
return await self.aget(key, default, version=version)
return val
def has_key(self, key, version=None):
"""
Return True if the key is in the cache and has not expired.
"""
return self.get(key, self._missing_key, version=version) is not self._missing_key
async def ahas_key(self, key, version=None):
return (
await self.aget(key, self._missing_key, version=version)
is not self._missing_key
)
def incr(self, key, delta=1, version=None):
"""
Add delta to value in the cache. If the key does not exist, raise a
ValueError exception.
"""
value = self.get(key, self._missing_key, version=version)
if value is self._missing_key:
raise ValueError("Key '%s' not found" % key)
new_value = value + delta
self.set(key, new_value, version=version)
return new_value
async def aincr(self, key, delta=1, version=None):
"""See incr()."""
value = await self.aget(key, self._missing_key, version=version)
if value is self._missing_key:
raise ValueError("Key '%s' not found" % key)
new_value = value + delta
await self.aset(key, new_value, version=version)
return new_value
def decr(self, key, delta=1, version=None):
"""
Subtract delta from value in the cache. If the key does not exist, raise
a ValueError exception.
"""
return self.incr(key, -delta, version=version)
async def adecr(self, key, delta=1, version=None):
return await self.aincr(key, -delta, version=version)
def __contains__(self, key):
"""
Return True if the key is in the cache and has not expired.
"""
# This is a separate method, rather than just a copy of has_key(),
# so that it always has the same functionality as has_key(), even
# if a subclass overrides it.
return self.has_key(key)
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
"""
Set a bunch of values in the cache at once from a dict of key/value
pairs. For certain backends (memcached), this is much more efficient
than calling set() multiple times.
If timeout is given, use that timeout for the key; otherwise use the
default cache timeout.
On backends that support it, return a list of keys that failed
insertion, or an empty list if all keys were inserted successfully.
"""
for key, value in data.items():
self.set(key, value, timeout=timeout, version=version)
return []
async def aset_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
for key, value in data.items():
await self.aset(key, value, timeout=timeout, version=version)
return []
def delete_many(self, keys, version=None):
"""
Delete a bunch of values in the cache at once. For certain backends
(memcached), this is much more efficient than calling delete() multiple
times.
"""
for key in keys:
self.delete(key, version=version)
async def adelete_many(self, keys, version=None):
for key in keys:
await self.adelete(key, version=version)
def clear(self):
"""Remove *all* values from the cache at once."""
raise NotImplementedError('subclasses of BaseCache must provide a clear() method')
async def aclear(self):
return await sync_to_async(self.clear, thread_sensitive=True)()
def incr_version(self, key, delta=1, version=None):
"""
Add delta to the cache version for the supplied key. Return the new
version.
"""
if version is None:
version = self.version
value = self.get(key, self._missing_key, version=version)
if value is self._missing_key:
raise ValueError("Key '%s' not found" % key)
self.set(key, value, version=version + delta)
self.delete(key, version=version)
return version + delta
async def aincr_version(self, key, delta=1, version=None):
"""See incr_version()."""
if version is None:
version = self.version
value = await self.aget(key, self._missing_key, version=version)
if value is self._missing_key:
raise ValueError("Key '%s' not found" % key)
await self.aset(key, value, version=version + delta)
await self.adelete(key, version=version)
return version + delta
def decr_version(self, key, delta=1, version=None):
"""
Subtract delta from the cache version for the supplied key. Return the
new version.
"""
return self.incr_version(key, -delta, version)
async def adecr_version(self, key, delta=1, version=None):
return await self.aincr_version(key, -delta, version)
def close(self, **kwargs):
"""Close the cache connection"""
pass
async def aclose(self, **kwargs):
pass
def memcache_key_warnings(key):
if len(key) > MEMCACHE_MAX_KEY_LENGTH:
yield (
'Cache key will cause errors if used with memcached: %r '
'(longer than %s)' % (key, MEMCACHE_MAX_KEY_LENGTH)
)
for char in key:
if ord(char) < 33 or ord(char) == 127:
yield (
'Cache key contains characters that will cause errors if '
'used with memcached: %r' % key
)
break

View File

@@ -0,0 +1,267 @@
"Database cache backend."
import base64
import pickle
from datetime import datetime
from django.conf import settings
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
from django.db import DatabaseError, connections, models, router, transaction
from django.utils import timezone
class Options:
"""A class that will quack like a Django model _meta class.
This allows cache operations to be controlled by the router
"""
def __init__(self, table):
self.db_table = table
self.app_label = 'django_cache'
self.model_name = 'cacheentry'
self.verbose_name = 'cache entry'
self.verbose_name_plural = 'cache entries'
self.object_name = 'CacheEntry'
self.abstract = False
self.managed = True
self.proxy = False
self.swapped = False
class BaseDatabaseCache(BaseCache):
def __init__(self, table, params):
super().__init__(params)
self._table = table
class CacheEntry:
_meta = Options(table)
self.cache_model_class = CacheEntry
class DatabaseCache(BaseDatabaseCache):
# This class uses cursors provided by the database connection. This means
# it reads expiration values as aware or naive datetimes, depending on the
# value of USE_TZ and whether the database supports time zones. The ORM's
# conversion and adaptation infrastructure is then used to avoid comparing
# aware and naive datetimes accidentally.
pickle_protocol = pickle.HIGHEST_PROTOCOL
def get(self, key, default=None, version=None):
return self.get_many([key], version).get(key, default)
def get_many(self, keys, version=None):
if not keys:
return {}
key_map = {self.make_and_validate_key(key, version=version): key for key in keys}
db = router.db_for_read(self.cache_model_class)
connection = connections[db]
quote_name = connection.ops.quote_name
table = quote_name(self._table)
with connection.cursor() as cursor:
cursor.execute(
'SELECT %s, %s, %s FROM %s WHERE %s IN (%s)' % (
quote_name('cache_key'),
quote_name('value'),
quote_name('expires'),
table,
quote_name('cache_key'),
', '.join(['%s'] * len(key_map)),
),
list(key_map),
)
rows = cursor.fetchall()
result = {}
expired_keys = []
expression = models.Expression(output_field=models.DateTimeField())
converters = (connection.ops.get_db_converters(expression) + expression.get_db_converters(connection))
for key, value, expires in rows:
for converter in converters:
expires = converter(expires, expression, connection)
if expires < timezone.now():
expired_keys.append(key)
else:
value = connection.ops.process_clob(value)
value = pickle.loads(base64.b64decode(value.encode()))
result[key_map.get(key)] = value
self._base_delete_many(expired_keys)
return result
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
self._base_set('set', key, value, timeout)
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._base_set('add', key, value, timeout)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._base_set('touch', key, None, timeout)
def _base_set(self, mode, key, value, timeout=DEFAULT_TIMEOUT):
timeout = self.get_backend_timeout(timeout)
db = router.db_for_write(self.cache_model_class)
connection = connections[db]
quote_name = connection.ops.quote_name
table = quote_name(self._table)
with connection.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM %s" % table)
num = cursor.fetchone()[0]
now = timezone.now()
now = now.replace(microsecond=0)
if timeout is None:
exp = datetime.max
else:
tz = timezone.utc if settings.USE_TZ else None
exp = datetime.fromtimestamp(timeout, tz=tz)
exp = exp.replace(microsecond=0)
if num > self._max_entries:
self._cull(db, cursor, now, num)
pickled = pickle.dumps(value, self.pickle_protocol)
# The DB column is expecting a string, so make sure the value is a
# string, not bytes. Refs #19274.
b64encoded = base64.b64encode(pickled).decode('latin1')
try:
# Note: typecasting for datetimes is needed by some 3rd party
# database backends. All core backends work without typecasting,
# so be careful about changes here - test suite will NOT pick
# regressions.
with transaction.atomic(using=db):
cursor.execute(
'SELECT %s, %s FROM %s WHERE %s = %%s' % (
quote_name('cache_key'),
quote_name('expires'),
table,
quote_name('cache_key'),
),
[key]
)
result = cursor.fetchone()
if result:
current_expires = result[1]
expression = models.Expression(output_field=models.DateTimeField())
for converter in (connection.ops.get_db_converters(expression) +
expression.get_db_converters(connection)):
current_expires = converter(current_expires, expression, connection)
exp = connection.ops.adapt_datetimefield_value(exp)
if result and mode == 'touch':
cursor.execute(
'UPDATE %s SET %s = %%s WHERE %s = %%s' % (
table,
quote_name('expires'),
quote_name('cache_key')
),
[exp, key]
)
elif result and (mode == 'set' or (mode == 'add' and current_expires < now)):
cursor.execute(
'UPDATE %s SET %s = %%s, %s = %%s WHERE %s = %%s' % (
table,
quote_name('value'),
quote_name('expires'),
quote_name('cache_key'),
),
[b64encoded, exp, key]
)
elif mode != 'touch':
cursor.execute(
'INSERT INTO %s (%s, %s, %s) VALUES (%%s, %%s, %%s)' % (
table,
quote_name('cache_key'),
quote_name('value'),
quote_name('expires'),
),
[key, b64encoded, exp]
)
else:
return False # touch failed.
except DatabaseError:
# To be threadsafe, updates/inserts are allowed to fail silently
return False
else:
return True
def delete(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return self._base_delete_many([key])
def delete_many(self, keys, version=None):
keys = [self.make_and_validate_key(key, version=version) for key in keys]
self._base_delete_many(keys)
def _base_delete_many(self, keys):
if not keys:
return False
db = router.db_for_write(self.cache_model_class)
connection = connections[db]
quote_name = connection.ops.quote_name
table = quote_name(self._table)
with connection.cursor() as cursor:
cursor.execute(
'DELETE FROM %s WHERE %s IN (%s)' % (
table,
quote_name('cache_key'),
', '.join(['%s'] * len(keys)),
),
keys,
)
return bool(cursor.rowcount)
def has_key(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
db = router.db_for_read(self.cache_model_class)
connection = connections[db]
quote_name = connection.ops.quote_name
now = timezone.now().replace(microsecond=0, tzinfo=None)
with connection.cursor() as cursor:
cursor.execute(
'SELECT %s FROM %s WHERE %s = %%s and expires > %%s' % (
quote_name('cache_key'),
quote_name(self._table),
quote_name('cache_key'),
),
[key, connection.ops.adapt_datetimefield_value(now)]
)
return cursor.fetchone() is not None
def _cull(self, db, cursor, now, num):
if self._cull_frequency == 0:
self.clear()
else:
connection = connections[db]
table = connection.ops.quote_name(self._table)
cursor.execute("DELETE FROM %s WHERE expires < %%s" % table,
[connection.ops.adapt_datetimefield_value(now)])
deleted_count = cursor.rowcount
remaining_num = num - deleted_count
if remaining_num > self._max_entries:
cull_num = remaining_num // self._cull_frequency
cursor.execute(
connection.ops.cache_key_culling_sql() % table,
[cull_num])
last_cache_key = cursor.fetchone()
if last_cache_key:
cursor.execute(
'DELETE FROM %s WHERE cache_key < %%s' % table,
[last_cache_key[0]],
)
def clear(self):
db = router.db_for_write(self.cache_model_class)
connection = connections[db]
table = connection.ops.quote_name(self._table)
with connection.cursor() as cursor:
cursor.execute('DELETE FROM %s' % table)

View File

@@ -0,0 +1,34 @@
"Dummy cache backend"
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
class DummyCache(BaseCache):
def __init__(self, host, *args, **kwargs):
super().__init__(*args, **kwargs)
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
self.make_and_validate_key(key, version=version)
return True
def get(self, key, default=None, version=None):
self.make_and_validate_key(key, version=version)
return default
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
self.make_and_validate_key(key, version=version)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
self.make_and_validate_key(key, version=version)
return False
def delete(self, key, version=None):
self.make_and_validate_key(key, version=version)
return False
def has_key(self, key, version=None):
self.make_and_validate_key(key, version=version)
return False
def clear(self):
pass

View File

@@ -0,0 +1,163 @@
"File-based cache backend"
import glob
import hashlib
import os
import pickle
import random
import tempfile
import time
import zlib
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
from django.core.files import locks
from django.core.files.move import file_move_safe
class FileBasedCache(BaseCache):
cache_suffix = '.djcache'
pickle_protocol = pickle.HIGHEST_PROTOCOL
def __init__(self, dir, params):
super().__init__(params)
self._dir = os.path.abspath(dir)
self._createdir()
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
if self.has_key(key, version):
return False
self.set(key, value, timeout, version)
return True
def get(self, key, default=None, version=None):
fname = self._key_to_file(key, version)
try:
with open(fname, 'rb') as f:
if not self._is_expired(f):
return pickle.loads(zlib.decompress(f.read()))
except FileNotFoundError:
pass
return default
def _write_content(self, file, timeout, value):
expiry = self.get_backend_timeout(timeout)
file.write(pickle.dumps(expiry, self.pickle_protocol))
file.write(zlib.compress(pickle.dumps(value, self.pickle_protocol)))
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
self._createdir() # Cache dir can be deleted at any time.
fname = self._key_to_file(key, version)
self._cull() # make some room if necessary
fd, tmp_path = tempfile.mkstemp(dir=self._dir)
renamed = False
try:
with open(fd, 'wb') as f:
self._write_content(f, timeout, value)
file_move_safe(tmp_path, fname, allow_overwrite=True)
renamed = True
finally:
if not renamed:
os.remove(tmp_path)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
try:
with open(self._key_to_file(key, version), 'r+b') as f:
try:
locks.lock(f, locks.LOCK_EX)
if self._is_expired(f):
return False
else:
previous_value = pickle.loads(zlib.decompress(f.read()))
f.seek(0)
self._write_content(f, timeout, previous_value)
return True
finally:
locks.unlock(f)
except FileNotFoundError:
return False
def delete(self, key, version=None):
return self._delete(self._key_to_file(key, version))
def _delete(self, fname):
if not fname.startswith(self._dir) or not os.path.exists(fname):
return False
try:
os.remove(fname)
except FileNotFoundError:
# The file may have been removed by another process.
return False
return True
def has_key(self, key, version=None):
fname = self._key_to_file(key, version)
if os.path.exists(fname):
with open(fname, 'rb') as f:
return not self._is_expired(f)
return False
def _cull(self):
"""
Remove random cache entries if max_entries is reached at a ratio
of num_entries / cull_frequency. A value of 0 for CULL_FREQUENCY means
that the entire cache will be purged.
"""
filelist = self._list_cache_files()
num_entries = len(filelist)
if num_entries < self._max_entries:
return # return early if no culling is required
if self._cull_frequency == 0:
return self.clear() # Clear the cache when CULL_FREQUENCY = 0
# Delete a random selection of entries
filelist = random.sample(filelist,
int(num_entries / self._cull_frequency))
for fname in filelist:
self._delete(fname)
def _createdir(self):
# Set the umask because os.makedirs() doesn't apply the "mode" argument
# to intermediate-level directories.
old_umask = os.umask(0o077)
try:
os.makedirs(self._dir, 0o700, exist_ok=True)
finally:
os.umask(old_umask)
def _key_to_file(self, key, version=None):
"""
Convert a key into a cache file path. Basically this is the
root cache path joined with the md5sum of the key and a suffix.
"""
key = self.make_and_validate_key(key, version=version)
return os.path.join(self._dir, ''.join(
[hashlib.md5(key.encode()).hexdigest(), self.cache_suffix]))
def clear(self):
"""
Remove all the cache files.
"""
for fname in self._list_cache_files():
self._delete(fname)
def _is_expired(self, f):
"""
Take an open cache file `f` and delete it if it's expired.
"""
try:
exp = pickle.load(f)
except EOFError:
exp = 0 # An empty file is considered expired.
if exp is not None and exp < time.time():
f.close() # On Windows a file has to be closed before deleting
self._delete(f.name)
return True
return False
def _list_cache_files(self):
"""
Get a list of paths to all the cache files. These are all the files
in the root cache dir that end on the cache_suffix.
"""
return [
os.path.join(self._dir, fname)
for fname in glob.glob1(self._dir, '*%s' % self.cache_suffix)
]

View File

@@ -0,0 +1,117 @@
"Thread-safe in-memory cache backend."
import pickle
import time
from collections import OrderedDict
from threading import Lock
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
# Global in-memory store of cache data. Keyed by name, to provide
# multiple named local memory caches.
_caches = {}
_expire_info = {}
_locks = {}
class LocMemCache(BaseCache):
pickle_protocol = pickle.HIGHEST_PROTOCOL
def __init__(self, name, params):
super().__init__(params)
self._cache = _caches.setdefault(name, OrderedDict())
self._expire_info = _expire_info.setdefault(name, {})
self._lock = _locks.setdefault(name, Lock())
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
pickled = pickle.dumps(value, self.pickle_protocol)
with self._lock:
if self._has_expired(key):
self._set(key, pickled, timeout)
return True
return False
def get(self, key, default=None, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
self._delete(key)
return default
pickled = self._cache[key]
self._cache.move_to_end(key, last=False)
return pickle.loads(pickled)
def _set(self, key, value, timeout=DEFAULT_TIMEOUT):
if len(self._cache) >= self._max_entries:
self._cull()
self._cache[key] = value
self._cache.move_to_end(key, last=False)
self._expire_info[key] = self.get_backend_timeout(timeout)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
pickled = pickle.dumps(value, self.pickle_protocol)
with self._lock:
self._set(key, pickled, timeout)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
return False
self._expire_info[key] = self.get_backend_timeout(timeout)
return True
def incr(self, key, delta=1, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
self._delete(key)
raise ValueError("Key '%s' not found" % key)
pickled = self._cache[key]
value = pickle.loads(pickled)
new_value = value + delta
pickled = pickle.dumps(new_value, self.pickle_protocol)
self._cache[key] = pickled
self._cache.move_to_end(key, last=False)
return new_value
def has_key(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
self._delete(key)
return False
return True
def _has_expired(self, key):
exp = self._expire_info.get(key, -1)
return exp is not None and exp <= time.time()
def _cull(self):
if self._cull_frequency == 0:
self._cache.clear()
self._expire_info.clear()
else:
count = len(self._cache) // self._cull_frequency
for i in range(count):
key, _ = self._cache.popitem()
del self._expire_info[key]
def _delete(self, key):
try:
del self._cache[key]
del self._expire_info[key]
except KeyError:
return False
return True
def delete(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
return self._delete(key)
def clear(self):
with self._lock:
self._cache.clear()
self._expire_info.clear()

View File

@@ -0,0 +1,212 @@
"Memcached cache backend"
import pickle
import re
import time
import warnings
from django.core.cache.backends.base import (
DEFAULT_TIMEOUT, BaseCache, InvalidCacheKey, memcache_key_warnings,
)
from django.utils.deprecation import RemovedInDjango41Warning
from django.utils.functional import cached_property
class BaseMemcachedCache(BaseCache):
def __init__(self, server, params, library, value_not_found_exception):
super().__init__(params)
if isinstance(server, str):
self._servers = re.split('[;,]', server)
else:
self._servers = server
# Exception type raised by the underlying client library for a
# nonexistent key.
self.LibraryValueNotFoundException = value_not_found_exception
self._lib = library
self._class = library.Client
self._options = params.get('OPTIONS') or {}
@property
def client_servers(self):
return self._servers
@cached_property
def _cache(self):
"""
Implement transparent thread-safe access to a memcached client.
"""
return self._class(self.client_servers, **self._options)
def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
"""
Memcached deals with long (> 30 days) timeouts in a special
way. Call this function to obtain a safe value for your timeout.
"""
if timeout == DEFAULT_TIMEOUT:
timeout = self.default_timeout
if timeout is None:
# Using 0 in memcache sets a non-expiring timeout.
return 0
elif int(timeout) == 0:
# Other cache backends treat 0 as set-and-expire. To achieve this
# in memcache backends, a negative timeout must be passed.
timeout = -1
if timeout > 2592000: # 60*60*24*30, 30 days
# See https://github.com/memcached/memcached/wiki/Programming#expiration
# "Expiration times can be set from 0, meaning "never expire", to
# 30 days. Any time higher than 30 days is interpreted as a Unix
# timestamp date. If you want to expire an object on January 1st of
# next year, this is how you do that."
#
# This means that we have to switch to absolute timestamps.
timeout += int(time.time())
return int(timeout)
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.add(key, value, self.get_backend_timeout(timeout))
def get(self, key, default=None, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.get(key, default)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
if not self._cache.set(key, value, self.get_backend_timeout(timeout)):
# make sure the key doesn't keep its old value in case of failure to set (memcached's 1MB limit)
self._cache.delete(key)
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return bool(self._cache.touch(key, self.get_backend_timeout(timeout)))
def delete(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return bool(self._cache.delete(key))
def get_many(self, keys, version=None):
key_map = {self.make_and_validate_key(key, version=version): key for key in keys}
ret = self._cache.get_multi(key_map.keys())
return {key_map[k]: v for k, v in ret.items()}
def close(self, **kwargs):
# Many clients don't clean up connections properly.
self._cache.disconnect_all()
def incr(self, key, delta=1, version=None):
key = self.make_and_validate_key(key, version=version)
try:
# Memcached doesn't support negative delta.
if delta < 0:
val = self._cache.decr(key, -delta)
else:
val = self._cache.incr(key, delta)
# Normalize an exception raised by the underlying client library to
# ValueError in the event of a nonexistent key when calling
# incr()/decr().
except self.LibraryValueNotFoundException:
val = None
if val is None:
raise ValueError("Key '%s' not found" % key)
return val
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
safe_data = {}
original_keys = {}
for key, value in data.items():
safe_key = self.make_and_validate_key(key, version=version)
safe_data[safe_key] = value
original_keys[safe_key] = key
failed_keys = self._cache.set_multi(safe_data, self.get_backend_timeout(timeout))
return [original_keys[k] for k in failed_keys]
def delete_many(self, keys, version=None):
keys = [self.make_and_validate_key(key, version=version) for key in keys]
self._cache.delete_multi(keys)
def clear(self):
self._cache.flush_all()
def validate_key(self, key):
for warning in memcache_key_warnings(key):
raise InvalidCacheKey(warning)
class MemcachedCache(BaseMemcachedCache):
"An implementation of a cache binding using python-memcached"
# python-memcached doesn't support default values in get().
# https://github.com/linsomniac/python-memcached/issues/159
_missing_key = None
def __init__(self, server, params):
warnings.warn(
'MemcachedCache is deprecated in favor of PyMemcacheCache and '
'PyLibMCCache.',
RemovedInDjango41Warning, stacklevel=2,
)
# python-memcached ≥ 1.45 returns None for a nonexistent key in
# incr/decr(), python-memcached < 1.45 raises ValueError.
import memcache
super().__init__(server, params, library=memcache, value_not_found_exception=ValueError)
self._options = {'pickleProtocol': pickle.HIGHEST_PROTOCOL, **self._options}
def get(self, key, default=None, version=None):
key = self.make_and_validate_key(key, version=version)
val = self._cache.get(key)
# python-memcached doesn't support default values in get().
# https://github.com/linsomniac/python-memcached/issues/159
# Remove this method if that issue is fixed.
if val is None:
return default
return val
def delete(self, key, version=None):
# python-memcached's delete() returns True when key doesn't exist.
# https://github.com/linsomniac/python-memcached/issues/170
# Call _deletetouch() without the NOT_FOUND in expected results.
key = self.make_and_validate_key(key, version=version)
return bool(self._cache._deletetouch([b'DELETED'], 'delete', key))
class PyLibMCCache(BaseMemcachedCache):
"An implementation of a cache binding using pylibmc"
def __init__(self, server, params):
import pylibmc
super().__init__(server, params, library=pylibmc, value_not_found_exception=pylibmc.NotFound)
@property
def client_servers(self):
output = []
for server in self._servers:
output.append(server[5:] if server.startswith('unix:') else server)
return output
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
if timeout == 0:
return self._cache.delete(key)
return self._cache.touch(key, self.get_backend_timeout(timeout))
def close(self, **kwargs):
# libmemcached manages its own connections. Don't call disconnect_all()
# as it resets the failover state and creates unnecessary reconnects.
pass
class PyMemcacheCache(BaseMemcachedCache):
"""An implementation of a cache binding using pymemcache."""
def __init__(self, server, params):
import pymemcache.serde
super().__init__(server, params, library=pymemcache, value_not_found_exception=KeyError)
self._class = self._lib.HashClient
self._options = {
'allow_unicode_keys': True,
'default_noreply': False,
'serde': pymemcache.serde.pickle_serde,
**self._options,
}

View File

@@ -0,0 +1,224 @@
"""Redis cache backend."""
import random
import re
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
from django.core.serializers.base import PickleSerializer
from django.utils.functional import cached_property
from django.utils.module_loading import import_string
class RedisSerializer(PickleSerializer):
def dumps(self, obj):
if isinstance(obj, int):
return obj
return super().dumps(obj)
def loads(self, data):
try:
return int(data)
except ValueError:
return super().loads(data)
class RedisCacheClient:
def __init__(
self,
servers,
serializer=None,
db=None,
pool_class=None,
parser_class=None,
):
import redis
self._lib = redis
self._servers = servers
self._pools = {}
self._client = self._lib.Redis
if isinstance(pool_class, str):
pool_class = import_string(pool_class)
self._pool_class = pool_class or self._lib.ConnectionPool
if isinstance(serializer, str):
serializer = import_string(serializer)
if callable(serializer):
serializer = serializer()
self._serializer = serializer or RedisSerializer()
if isinstance(parser_class, str):
parser_class = import_string(parser_class)
parser_class = parser_class or self._lib.connection.DefaultParser
self._pool_options = {'parser_class': parser_class, 'db': db}
def _get_connection_pool_index(self, write):
# Write to the first server. Read from other servers if there are more,
# otherwise read from the first server.
if write or len(self._servers) == 1:
return 0
return random.randint(1, len(self._servers) - 1)
def _get_connection_pool(self, write):
index = self._get_connection_pool_index(write)
if index not in self._pools:
self._pools[index] = self._pool_class.from_url(
self._servers[index], **self._pool_options,
)
return self._pools[index]
def get_client(self, key=None, *, write=False):
# key is used so that the method signature remains the same and custom
# cache client can be implemented which might require the key to select
# the server, e.g. sharding.
pool = self._get_connection_pool(write)
return self._client(connection_pool=pool)
def add(self, key, value, timeout):
client = self.get_client(key, write=True)
value = self._serializer.dumps(value)
if timeout == 0:
if ret := bool(client.set(key, value, nx=True)):
client.delete(key)
return ret
else:
return bool(client.set(key, value, ex=timeout, nx=True))
def get(self, key, default):
client = self.get_client(key)
value = client.get(key)
return default if value is None else self._serializer.loads(value)
def set(self, key, value, timeout):
client = self.get_client(key, write=True)
value = self._serializer.dumps(value)
if timeout == 0:
client.delete(key)
else:
client.set(key, value, ex=timeout)
def touch(self, key, timeout):
client = self.get_client(key, write=True)
if timeout is None:
return bool(client.persist(key))
else:
return bool(client.expire(key, timeout))
def delete(self, key):
client = self.get_client(key, write=True)
return bool(client.delete(key))
def get_many(self, keys):
client = self.get_client(None)
ret = client.mget(keys)
return {
k: self._serializer.loads(v) for k, v in zip(keys, ret) if v is not None
}
def has_key(self, key):
client = self.get_client(key)
return bool(client.exists(key))
def incr(self, key, delta):
client = self.get_client(key)
if not client.exists(key):
raise ValueError("Key '%s' not found." % key)
return client.incr(key, delta)
def set_many(self, data, timeout):
client = self.get_client(None, write=True)
pipeline = client.pipeline()
pipeline.mset({k: self._serializer.dumps(v) for k, v in data.items()})
if timeout is not None:
# Setting timeout for each key as redis does not support timeout
# with mset().
for key in data:
pipeline.expire(key, timeout)
pipeline.execute()
def delete_many(self, keys):
client = self.get_client(None, write=True)
client.delete(*keys)
def clear(self):
client = self.get_client(None, write=True)
return bool(client.flushdb())
class RedisCache(BaseCache):
def __init__(self, server, params):
super().__init__(params)
if isinstance(server, str):
self._servers = re.split('[;,]', server)
else:
self._servers = server
self._class = RedisCacheClient
self._options = params.get('OPTIONS', {})
@cached_property
def _cache(self):
return self._class(self._servers, **self._options)
def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
if timeout == DEFAULT_TIMEOUT:
timeout = self.default_timeout
# The key will be made persistent if None used as a timeout.
# Non-positive values will cause the key to be deleted.
return None if timeout is None else max(0, int(timeout))
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.add(key, value, self.get_backend_timeout(timeout))
def get(self, key, default=None, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.get(key, default)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
self._cache.set(key, value, self.get_backend_timeout(timeout))
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.touch(key, self.get_backend_timeout(timeout))
def delete(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.delete(key)
def get_many(self, keys, version=None):
key_map = {self.make_and_validate_key(key, version=version): key for key in keys}
ret = self._cache.get_many(key_map.keys())
return {key_map[k]: v for k, v in ret.items()}
def has_key(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.has_key(key)
def incr(self, key, delta=1, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.incr(key, delta)
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
safe_data = {}
for key, value in data.items():
key = self.make_and_validate_key(key, version=version)
safe_data[key] = value
self._cache.set_many(safe_data, self.get_backend_timeout(timeout))
return []
def delete_many(self, keys, version=None):
safe_keys = []
for key in keys:
key = self.make_and_validate_key(key, version=version)
safe_keys.append(key)
self._cache.delete_many(safe_keys)
def clear(self):
return self._cache.clear()

View File

@@ -0,0 +1,12 @@
import hashlib
TEMPLATE_FRAGMENT_KEY_TEMPLATE = 'template.cache.%s.%s'
def make_template_fragment_key(fragment_name, vary_on=None):
hasher = hashlib.md5()
if vary_on is not None:
for arg in vary_on:
hasher.update(str(arg).encode())
hasher.update(b':')
return TEMPLATE_FRAGMENT_KEY_TEMPLATE % (fragment_name, hasher.hexdigest())

View File

@@ -0,0 +1,27 @@
from .messages import (
CRITICAL, DEBUG, ERROR, INFO, WARNING, CheckMessage, Critical, Debug,
Error, Info, Warning,
)
from .registry import Tags, register, run_checks, tag_exists
# Import these to force registration of checks
import django.core.checks.async_checks # NOQA isort:skip
import django.core.checks.caches # NOQA isort:skip
import django.core.checks.compatibility.django_4_0 # NOQA isort:skip
import django.core.checks.database # NOQA isort:skip
import django.core.checks.files # NOQA isort:skip
import django.core.checks.model_checks # NOQA isort:skip
import django.core.checks.security.base # NOQA isort:skip
import django.core.checks.security.csrf # NOQA isort:skip
import django.core.checks.security.sessions # NOQA isort:skip
import django.core.checks.templates # NOQA isort:skip
import django.core.checks.translation # NOQA isort:skip
import django.core.checks.urls # NOQA isort:skip
__all__ = [
'CheckMessage',
'Debug', 'Info', 'Warning', 'Error', 'Critical',
'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL',
'register', 'run_checks', 'tag_exists', 'Tags',
]

View File

@@ -0,0 +1,16 @@
import os
from . import Error, Tags, register
E001 = Error(
'You should not set the DJANGO_ALLOW_ASYNC_UNSAFE environment variable in '
'deployment. This disables async safety protection.',
id='async.E001',
)
@register(Tags.async_support, deploy=True)
def check_async_unsafe(app_configs, **kwargs):
if os.environ.get('DJANGO_ALLOW_ASYNC_UNSAFE'):
return [E001]
return []

View File

@@ -0,0 +1,72 @@
import pathlib
from django.conf import settings
from django.core.cache import DEFAULT_CACHE_ALIAS, caches
from django.core.cache.backends.filebased import FileBasedCache
from . import Error, Tags, Warning, register
E001 = Error(
"You must define a '%s' cache in your CACHES setting." % DEFAULT_CACHE_ALIAS,
id='caches.E001',
)
@register(Tags.caches)
def check_default_cache_is_configured(app_configs, **kwargs):
if DEFAULT_CACHE_ALIAS not in settings.CACHES:
return [E001]
return []
@register(Tags.caches, deploy=True)
def check_cache_location_not_exposed(app_configs, **kwargs):
errors = []
for name in ('MEDIA_ROOT', 'STATIC_ROOT', 'STATICFILES_DIRS'):
setting = getattr(settings, name, None)
if not setting:
continue
if name == 'STATICFILES_DIRS':
paths = set()
for staticfiles_dir in setting:
if isinstance(staticfiles_dir, (list, tuple)):
_, staticfiles_dir = staticfiles_dir
paths.add(pathlib.Path(staticfiles_dir).resolve())
else:
paths = {pathlib.Path(setting).resolve()}
for alias in settings.CACHES:
cache = caches[alias]
if not isinstance(cache, FileBasedCache):
continue
cache_path = pathlib.Path(cache._dir).resolve()
if any(path == cache_path for path in paths):
relation = 'matches'
elif any(path in cache_path.parents for path in paths):
relation = 'is inside'
elif any(cache_path in path.parents for path in paths):
relation = 'contains'
else:
continue
errors.append(Warning(
f"Your '{alias}' cache configuration might expose your cache "
f"or lead to corruption of your data because its LOCATION "
f"{relation} {name}.",
id='caches.W002',
))
return errors
@register(Tags.caches)
def check_file_based_cache_is_absolute(app_configs, **kwargs):
errors = []
for alias, config in settings.CACHES.items():
cache = caches[alias]
if not isinstance(cache, FileBasedCache):
continue
if not pathlib.Path(config['LOCATION']).is_absolute():
errors.append(Warning(
f"Your '{alias}' cache LOCATION path is relative. Use an "
f"absolute path instead.",
id='caches.W003',
))
return errors

View File

@@ -0,0 +1,18 @@
from django.conf import settings
from .. import Error, Tags, register
@register(Tags.compatibility)
def check_csrf_trusted_origins(app_configs, **kwargs):
errors = []
for origin in settings.CSRF_TRUSTED_ORIGINS:
if '://' not in origin:
errors.append(Error(
'As of Django 4.0, the values in the CSRF_TRUSTED_ORIGINS '
'setting must start with a scheme (usually http:// or '
'https://) but found %s. See the release notes for details.'
% origin,
id='4_0.E001',
))
return errors

View File

@@ -0,0 +1,14 @@
from django.db import connections
from . import Tags, register
@register(Tags.database)
def check_database_backends(databases=None, **kwargs):
if databases is None:
return []
issues = []
for alias in databases:
conn = connections[alias]
issues.extend(conn.validation.check(**kwargs))
return issues

View File

@@ -0,0 +1,19 @@
from pathlib import Path
from django.conf import settings
from . import Error, Tags, register
@register(Tags.files)
def check_setting_file_upload_temp_dir(app_configs, **kwargs):
setting = getattr(settings, 'FILE_UPLOAD_TEMP_DIR', None)
if setting and not Path(setting).is_dir():
return [
Error(
f"The FILE_UPLOAD_TEMP_DIR setting refers to the nonexistent "
f"directory '{setting}'.",
id="files.E001",
),
]
return []

View File

@@ -0,0 +1,76 @@
# Levels
DEBUG = 10
INFO = 20
WARNING = 30
ERROR = 40
CRITICAL = 50
class CheckMessage:
def __init__(self, level, msg, hint=None, obj=None, id=None):
if not isinstance(level, int):
raise TypeError('The first argument should be level.')
self.level = level
self.msg = msg
self.hint = hint
self.obj = obj
self.id = id
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
all(getattr(self, attr) == getattr(other, attr)
for attr in ['level', 'msg', 'hint', 'obj', 'id'])
)
def __str__(self):
from django.db import models
if self.obj is None:
obj = "?"
elif isinstance(self.obj, models.base.ModelBase):
# We need to hardcode ModelBase and Field cases because its __str__
# method doesn't return "applabel.modellabel" and cannot be changed.
obj = self.obj._meta.label
else:
obj = str(self.obj)
id = "(%s) " % self.id if self.id else ""
hint = "\n\tHINT: %s" % self.hint if self.hint else ''
return "%s: %s%s%s" % (obj, id, self.msg, hint)
def __repr__(self):
return "<%s: level=%r, msg=%r, hint=%r, obj=%r, id=%r>" % \
(self.__class__.__name__, self.level, self.msg, self.hint, self.obj, self.id)
def is_serious(self, level=ERROR):
return self.level >= level
def is_silenced(self):
from django.conf import settings
return self.id in settings.SILENCED_SYSTEM_CHECKS
class Debug(CheckMessage):
def __init__(self, *args, **kwargs):
super().__init__(DEBUG, *args, **kwargs)
class Info(CheckMessage):
def __init__(self, *args, **kwargs):
super().__init__(INFO, *args, **kwargs)
class Warning(CheckMessage):
def __init__(self, *args, **kwargs):
super().__init__(WARNING, *args, **kwargs)
class Error(CheckMessage):
def __init__(self, *args, **kwargs):
super().__init__(ERROR, *args, **kwargs)
class Critical(CheckMessage):
def __init__(self, *args, **kwargs):
super().__init__(CRITICAL, *args, **kwargs)

View File

@@ -0,0 +1,210 @@
import inspect
import types
from collections import defaultdict
from itertools import chain
from django.apps import apps
from django.conf import settings
from django.core.checks import Error, Tags, Warning, register
@register(Tags.models)
def check_all_models(app_configs=None, **kwargs):
db_table_models = defaultdict(list)
indexes = defaultdict(list)
constraints = defaultdict(list)
errors = []
if app_configs is None:
models = apps.get_models()
else:
models = chain.from_iterable(app_config.get_models() for app_config in app_configs)
for model in models:
if model._meta.managed and not model._meta.proxy:
db_table_models[model._meta.db_table].append(model._meta.label)
if not inspect.ismethod(model.check):
errors.append(
Error(
"The '%s.check()' class method is currently overridden by %r."
% (model.__name__, model.check),
obj=model,
id='models.E020'
)
)
else:
errors.extend(model.check(**kwargs))
for model_index in model._meta.indexes:
indexes[model_index.name].append(model._meta.label)
for model_constraint in model._meta.constraints:
constraints[model_constraint.name].append(model._meta.label)
if settings.DATABASE_ROUTERS:
error_class, error_id = Warning, 'models.W035'
error_hint = (
'You have configured settings.DATABASE_ROUTERS. Verify that %s '
'are correctly routed to separate databases.'
)
else:
error_class, error_id = Error, 'models.E028'
error_hint = None
for db_table, model_labels in db_table_models.items():
if len(model_labels) != 1:
model_labels_str = ', '.join(model_labels)
errors.append(
error_class(
"db_table '%s' is used by multiple models: %s."
% (db_table, model_labels_str),
obj=db_table,
hint=(error_hint % model_labels_str) if error_hint else None,
id=error_id,
)
)
for index_name, model_labels in indexes.items():
if len(model_labels) > 1:
model_labels = set(model_labels)
errors.append(
Error(
"index name '%s' is not unique %s %s." % (
index_name,
'for model' if len(model_labels) == 1 else 'among models:',
', '.join(sorted(model_labels)),
),
id='models.E029' if len(model_labels) == 1 else 'models.E030',
),
)
for constraint_name, model_labels in constraints.items():
if len(model_labels) > 1:
model_labels = set(model_labels)
errors.append(
Error(
"constraint name '%s' is not unique %s %s." % (
constraint_name,
'for model' if len(model_labels) == 1 else 'among models:',
', '.join(sorted(model_labels)),
),
id='models.E031' if len(model_labels) == 1 else 'models.E032',
),
)
return errors
def _check_lazy_references(apps, ignore=None):
"""
Ensure all lazy (i.e. string) model references have been resolved.
Lazy references are used in various places throughout Django, primarily in
related fields and model signals. Identify those common cases and provide
more helpful error messages for them.
The ignore parameter is used by StateApps to exclude swappable models from
this check.
"""
pending_models = set(apps._pending_operations) - (ignore or set())
# Short circuit if there aren't any errors.
if not pending_models:
return []
from django.db.models import signals
model_signals = {
signal: name for name, signal in vars(signals).items()
if isinstance(signal, signals.ModelSignal)
}
def extract_operation(obj):
"""
Take a callable found in Apps._pending_operations and identify the
original callable passed to Apps.lazy_model_operation(). If that
callable was a partial, return the inner, non-partial function and
any arguments and keyword arguments that were supplied with it.
obj is a callback defined locally in Apps.lazy_model_operation() and
annotated there with a `func` attribute so as to imitate a partial.
"""
operation, args, keywords = obj, [], {}
while hasattr(operation, 'func'):
args.extend(getattr(operation, 'args', []))
keywords.update(getattr(operation, 'keywords', {}))
operation = operation.func
return operation, args, keywords
def app_model_error(model_key):
try:
apps.get_app_config(model_key[0])
model_error = "app '%s' doesn't provide model '%s'" % model_key
except LookupError:
model_error = "app '%s' isn't installed" % model_key[0]
return model_error
# Here are several functions which return CheckMessage instances for the
# most common usages of lazy operations throughout Django. These functions
# take the model that was being waited on as an (app_label, modelname)
# pair, the original lazy function, and its positional and keyword args as
# determined by extract_operation().
def field_error(model_key, func, args, keywords):
error_msg = (
"The field %(field)s was declared with a lazy reference "
"to '%(model)s', but %(model_error)s."
)
params = {
'model': '.'.join(model_key),
'field': keywords['field'],
'model_error': app_model_error(model_key),
}
return Error(error_msg % params, obj=keywords['field'], id='fields.E307')
def signal_connect_error(model_key, func, args, keywords):
error_msg = (
"%(receiver)s was connected to the '%(signal)s' signal with a "
"lazy reference to the sender '%(model)s', but %(model_error)s."
)
receiver = args[0]
# The receiver is either a function or an instance of class
# defining a `__call__` method.
if isinstance(receiver, types.FunctionType):
description = "The function '%s'" % receiver.__name__
elif isinstance(receiver, types.MethodType):
description = "Bound method '%s.%s'" % (receiver.__self__.__class__.__name__, receiver.__name__)
else:
description = "An instance of class '%s'" % receiver.__class__.__name__
signal_name = model_signals.get(func.__self__, 'unknown')
params = {
'model': '.'.join(model_key),
'receiver': description,
'signal': signal_name,
'model_error': app_model_error(model_key),
}
return Error(error_msg % params, obj=receiver.__module__, id='signals.E001')
def default_error(model_key, func, args, keywords):
error_msg = "%(op)s contains a lazy reference to %(model)s, but %(model_error)s."
params = {
'op': func,
'model': '.'.join(model_key),
'model_error': app_model_error(model_key),
}
return Error(error_msg % params, obj=func, id='models.E022')
# Maps common uses of lazy operations to corresponding error functions
# defined above. If a key maps to None, no error will be produced.
# default_error() will be used for usages that don't appear in this dict.
known_lazy = {
('django.db.models.fields.related', 'resolve_related_class'): field_error,
('django.db.models.fields.related', 'set_managed'): None,
('django.dispatch.dispatcher', 'connect'): signal_connect_error,
}
def build_error(model_key, func, args, keywords):
key = (func.__module__, func.__name__)
error_fn = known_lazy.get(key, default_error)
return error_fn(model_key, func, args, keywords) if error_fn else None
return sorted(filter(None, (
build_error(model_key, *extract_operation(func))
for model_key in pending_models
for func in apps._pending_operations[model_key]
)), key=lambda error: error.msg)
@register(Tags.models)
def check_lazy_references(app_configs=None, **kwargs):
return _check_lazy_references(apps)

View File

@@ -0,0 +1,105 @@
from itertools import chain
from django.utils.inspect import func_accepts_kwargs
from django.utils.itercompat import is_iterable
class Tags:
"""
Built-in tags for internal checks.
"""
admin = 'admin'
async_support = 'async_support'
caches = 'caches'
compatibility = 'compatibility'
database = 'database'
files = 'files'
models = 'models'
security = 'security'
signals = 'signals'
sites = 'sites'
staticfiles = 'staticfiles'
templates = 'templates'
translation = 'translation'
urls = 'urls'
class CheckRegistry:
def __init__(self):
self.registered_checks = set()
self.deployment_checks = set()
def register(self, check=None, *tags, **kwargs):
"""
Can be used as a function or a decorator. Register given function
`f` labeled with given `tags`. The function should receive **kwargs
and return list of Errors and Warnings.
Example::
registry = CheckRegistry()
@registry.register('mytag', 'anothertag')
def my_check(app_configs, **kwargs):
# ... perform checks and collect `errors` ...
return errors
# or
registry.register(my_check, 'mytag', 'anothertag')
"""
def inner(check):
if not func_accepts_kwargs(check):
raise TypeError(
'Check functions must accept keyword arguments (**kwargs).'
)
check.tags = tags
checks = self.deployment_checks if kwargs.get('deploy') else self.registered_checks
checks.add(check)
return check
if callable(check):
return inner(check)
else:
if check:
tags += (check,)
return inner
def run_checks(self, app_configs=None, tags=None, include_deployment_checks=False, databases=None):
"""
Run all registered checks and return list of Errors and Warnings.
"""
errors = []
checks = self.get_checks(include_deployment_checks)
if tags is not None:
checks = [check for check in checks if not set(check.tags).isdisjoint(tags)]
for check in checks:
new_errors = check(app_configs=app_configs, databases=databases)
if not is_iterable(new_errors):
raise TypeError(
'The function %r did not return a list. All functions '
'registered with the checks registry must return a list.'
% check,
)
errors.extend(new_errors)
return errors
def tag_exists(self, tag, include_deployment_checks=False):
return tag in self.tags_available(include_deployment_checks)
def tags_available(self, deployment_checks=False):
return set(chain.from_iterable(
check.tags for check in self.get_checks(deployment_checks)
))
def get_checks(self, include_deployment_checks=False):
checks = list(self.registered_checks)
if include_deployment_checks:
checks.extend(self.deployment_checks)
return checks
registry = CheckRegistry()
register = registry.register
run_checks = registry.run_checks
tag_exists = registry.tag_exists

View File

@@ -0,0 +1,257 @@
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from .. import Error, Tags, Warning, register
CROSS_ORIGIN_OPENER_POLICY_VALUES = {
'same-origin', 'same-origin-allow-popups', 'unsafe-none',
}
REFERRER_POLICY_VALUES = {
'no-referrer', 'no-referrer-when-downgrade', 'origin',
'origin-when-cross-origin', 'same-origin', 'strict-origin',
'strict-origin-when-cross-origin', 'unsafe-url',
}
SECRET_KEY_INSECURE_PREFIX = 'django-insecure-'
SECRET_KEY_MIN_LENGTH = 50
SECRET_KEY_MIN_UNIQUE_CHARACTERS = 5
W001 = Warning(
"You do not have 'django.middleware.security.SecurityMiddleware' "
"in your MIDDLEWARE so the SECURE_HSTS_SECONDS, "
"SECURE_CONTENT_TYPE_NOSNIFF, SECURE_REFERRER_POLICY, "
"SECURE_CROSS_ORIGIN_OPENER_POLICY, and SECURE_SSL_REDIRECT settings will "
"have no effect.",
id='security.W001',
)
W002 = Warning(
"You do not have "
"'django.middleware.clickjacking.XFrameOptionsMiddleware' in your "
"MIDDLEWARE, so your pages will not be served with an "
"'x-frame-options' header. Unless there is a good reason for your "
"site to be served in a frame, you should consider enabling this "
"header to help prevent clickjacking attacks.",
id='security.W002',
)
W004 = Warning(
"You have not set a value for the SECURE_HSTS_SECONDS setting. "
"If your entire site is served only over SSL, you may want to consider "
"setting a value and enabling HTTP Strict Transport Security. "
"Be sure to read the documentation first; enabling HSTS carelessly "
"can cause serious, irreversible problems.",
id='security.W004',
)
W005 = Warning(
"You have not set the SECURE_HSTS_INCLUDE_SUBDOMAINS setting to True. "
"Without this, your site is potentially vulnerable to attack "
"via an insecure connection to a subdomain. Only set this to True if "
"you are certain that all subdomains of your domain should be served "
"exclusively via SSL.",
id='security.W005',
)
W006 = Warning(
"Your SECURE_CONTENT_TYPE_NOSNIFF setting is not set to True, "
"so your pages will not be served with an "
"'X-Content-Type-Options: nosniff' header. "
"You should consider enabling this header to prevent the "
"browser from identifying content types incorrectly.",
id='security.W006',
)
W008 = Warning(
"Your SECURE_SSL_REDIRECT setting is not set to True. "
"Unless your site should be available over both SSL and non-SSL "
"connections, you may want to either set this setting True "
"or configure a load balancer or reverse-proxy server "
"to redirect all connections to HTTPS.",
id='security.W008',
)
W009 = Warning(
"Your SECRET_KEY has less than %(min_length)s characters, less than "
"%(min_unique_chars)s unique characters, or it's prefixed with "
"'%(insecure_prefix)s' indicating that it was generated automatically by "
"Django. Please generate a long and random SECRET_KEY, otherwise many of "
"Django's security-critical features will be vulnerable to attack." % {
'min_length': SECRET_KEY_MIN_LENGTH,
'min_unique_chars': SECRET_KEY_MIN_UNIQUE_CHARACTERS,
'insecure_prefix': SECRET_KEY_INSECURE_PREFIX,
},
id='security.W009',
)
W018 = Warning(
"You should not have DEBUG set to True in deployment.",
id='security.W018',
)
W019 = Warning(
"You have "
"'django.middleware.clickjacking.XFrameOptionsMiddleware' in your "
"MIDDLEWARE, but X_FRAME_OPTIONS is not set to 'DENY'. "
"Unless there is a good reason for your site to serve other parts of "
"itself in a frame, you should change it to 'DENY'.",
id='security.W019',
)
W020 = Warning(
"ALLOWED_HOSTS must not be empty in deployment.",
id='security.W020',
)
W021 = Warning(
"You have not set the SECURE_HSTS_PRELOAD setting to True. Without this, "
"your site cannot be submitted to the browser preload list.",
id='security.W021',
)
W022 = Warning(
'You have not set the SECURE_REFERRER_POLICY setting. Without this, your '
'site will not send a Referrer-Policy header. You should consider '
'enabling this header to protect user privacy.',
id='security.W022',
)
E023 = Error(
'You have set the SECURE_REFERRER_POLICY setting to an invalid value.',
hint='Valid values are: {}.'.format(', '.join(sorted(REFERRER_POLICY_VALUES))),
id='security.E023',
)
E024 = Error(
'You have set the SECURE_CROSS_ORIGIN_OPENER_POLICY setting to an invalid '
'value.',
hint='Valid values are: {}.'.format(
', '.join(sorted(CROSS_ORIGIN_OPENER_POLICY_VALUES)),
),
id='security.E024',
)
def _security_middleware():
return 'django.middleware.security.SecurityMiddleware' in settings.MIDDLEWARE
def _xframe_middleware():
return 'django.middleware.clickjacking.XFrameOptionsMiddleware' in settings.MIDDLEWARE
@register(Tags.security, deploy=True)
def check_security_middleware(app_configs, **kwargs):
passed_check = _security_middleware()
return [] if passed_check else [W001]
@register(Tags.security, deploy=True)
def check_xframe_options_middleware(app_configs, **kwargs):
passed_check = _xframe_middleware()
return [] if passed_check else [W002]
@register(Tags.security, deploy=True)
def check_sts(app_configs, **kwargs):
passed_check = not _security_middleware() or settings.SECURE_HSTS_SECONDS
return [] if passed_check else [W004]
@register(Tags.security, deploy=True)
def check_sts_include_subdomains(app_configs, **kwargs):
passed_check = (
not _security_middleware() or
not settings.SECURE_HSTS_SECONDS or
settings.SECURE_HSTS_INCLUDE_SUBDOMAINS is True
)
return [] if passed_check else [W005]
@register(Tags.security, deploy=True)
def check_sts_preload(app_configs, **kwargs):
passed_check = (
not _security_middleware() or
not settings.SECURE_HSTS_SECONDS or
settings.SECURE_HSTS_PRELOAD is True
)
return [] if passed_check else [W021]
@register(Tags.security, deploy=True)
def check_content_type_nosniff(app_configs, **kwargs):
passed_check = (
not _security_middleware() or
settings.SECURE_CONTENT_TYPE_NOSNIFF is True
)
return [] if passed_check else [W006]
@register(Tags.security, deploy=True)
def check_ssl_redirect(app_configs, **kwargs):
passed_check = (
not _security_middleware() or
settings.SECURE_SSL_REDIRECT is True
)
return [] if passed_check else [W008]
@register(Tags.security, deploy=True)
def check_secret_key(app_configs, **kwargs):
try:
secret_key = settings.SECRET_KEY
except (ImproperlyConfigured, AttributeError):
passed_check = False
else:
passed_check = (
len(set(secret_key)) >= SECRET_KEY_MIN_UNIQUE_CHARACTERS and
len(secret_key) >= SECRET_KEY_MIN_LENGTH and
not secret_key.startswith(SECRET_KEY_INSECURE_PREFIX)
)
return [] if passed_check else [W009]
@register(Tags.security, deploy=True)
def check_debug(app_configs, **kwargs):
passed_check = not settings.DEBUG
return [] if passed_check else [W018]
@register(Tags.security, deploy=True)
def check_xframe_deny(app_configs, **kwargs):
passed_check = (
not _xframe_middleware() or
settings.X_FRAME_OPTIONS == 'DENY'
)
return [] if passed_check else [W019]
@register(Tags.security, deploy=True)
def check_allowed_hosts(app_configs, **kwargs):
return [] if settings.ALLOWED_HOSTS else [W020]
@register(Tags.security, deploy=True)
def check_referrer_policy(app_configs, **kwargs):
if _security_middleware():
if settings.SECURE_REFERRER_POLICY is None:
return [W022]
# Support a comma-separated string or iterable of values to allow fallback.
if isinstance(settings.SECURE_REFERRER_POLICY, str):
values = {v.strip() for v in settings.SECURE_REFERRER_POLICY.split(',')}
else:
values = set(settings.SECURE_REFERRER_POLICY)
if not values <= REFERRER_POLICY_VALUES:
return [E023]
return []
@register(Tags.security, deploy=True)
def check_cross_origin_opener_policy(app_configs, **kwargs):
if (
_security_middleware() and
settings.SECURE_CROSS_ORIGIN_OPENER_POLICY is not None and
settings.SECURE_CROSS_ORIGIN_OPENER_POLICY not in CROSS_ORIGIN_OPENER_POLICY_VALUES
):
return [E024]
return []

View File

@@ -0,0 +1,67 @@
import inspect
from django.conf import settings
from .. import Error, Tags, Warning, register
W003 = Warning(
"You don't appear to be using Django's built-in "
"cross-site request forgery protection via the middleware "
"('django.middleware.csrf.CsrfViewMiddleware' is not in your "
"MIDDLEWARE). Enabling the middleware is the safest approach "
"to ensure you don't leave any holes.",
id='security.W003',
)
W016 = Warning(
"You have 'django.middleware.csrf.CsrfViewMiddleware' in your "
"MIDDLEWARE, but you have not set CSRF_COOKIE_SECURE to True. "
"Using a secure-only CSRF cookie makes it more difficult for network "
"traffic sniffers to steal the CSRF token.",
id='security.W016',
)
def _csrf_middleware():
return 'django.middleware.csrf.CsrfViewMiddleware' in settings.MIDDLEWARE
@register(Tags.security, deploy=True)
def check_csrf_middleware(app_configs, **kwargs):
passed_check = _csrf_middleware()
return [] if passed_check else [W003]
@register(Tags.security, deploy=True)
def check_csrf_cookie_secure(app_configs, **kwargs):
passed_check = (
settings.CSRF_USE_SESSIONS or
not _csrf_middleware() or
settings.CSRF_COOKIE_SECURE
)
return [] if passed_check else [W016]
@register(Tags.security)
def check_csrf_failure_view(app_configs, **kwargs):
from django.middleware.csrf import _get_failure_view
errors = []
try:
view = _get_failure_view()
except ImportError:
msg = (
"The CSRF failure view '%s' could not be imported." %
settings.CSRF_FAILURE_VIEW
)
errors.append(Error(msg, id='security.E102'))
else:
try:
inspect.signature(view).bind(None, reason=None)
except TypeError:
msg = (
"The CSRF failure view '%s' does not take the correct number of arguments." %
settings.CSRF_FAILURE_VIEW
)
errors.append(Error(msg, id='security.E101'))
return errors

View File

@@ -0,0 +1,97 @@
from django.conf import settings
from .. import Tags, Warning, register
def add_session_cookie_message(message):
return message + (
" Using a secure-only session cookie makes it more difficult for "
"network traffic sniffers to hijack user sessions."
)
W010 = Warning(
add_session_cookie_message(
"You have 'django.contrib.sessions' in your INSTALLED_APPS, "
"but you have not set SESSION_COOKIE_SECURE to True."
),
id='security.W010',
)
W011 = Warning(
add_session_cookie_message(
"You have 'django.contrib.sessions.middleware.SessionMiddleware' "
"in your MIDDLEWARE, but you have not set "
"SESSION_COOKIE_SECURE to True."
),
id='security.W011',
)
W012 = Warning(
add_session_cookie_message("SESSION_COOKIE_SECURE is not set to True."),
id='security.W012',
)
def add_httponly_message(message):
return message + (
" Using an HttpOnly session cookie makes it more difficult for "
"cross-site scripting attacks to hijack user sessions."
)
W013 = Warning(
add_httponly_message(
"You have 'django.contrib.sessions' in your INSTALLED_APPS, "
"but you have not set SESSION_COOKIE_HTTPONLY to True.",
),
id='security.W013',
)
W014 = Warning(
add_httponly_message(
"You have 'django.contrib.sessions.middleware.SessionMiddleware' "
"in your MIDDLEWARE, but you have not set "
"SESSION_COOKIE_HTTPONLY to True."
),
id='security.W014',
)
W015 = Warning(
add_httponly_message("SESSION_COOKIE_HTTPONLY is not set to True."),
id='security.W015',
)
@register(Tags.security, deploy=True)
def check_session_cookie_secure(app_configs, **kwargs):
errors = []
if not settings.SESSION_COOKIE_SECURE:
if _session_app():
errors.append(W010)
if _session_middleware():
errors.append(W011)
if len(errors) > 1:
errors = [W012]
return errors
@register(Tags.security, deploy=True)
def check_session_cookie_httponly(app_configs, **kwargs):
errors = []
if not settings.SESSION_COOKIE_HTTPONLY:
if _session_app():
errors.append(W013)
if _session_middleware():
errors.append(W014)
if len(errors) > 1:
errors = [W015]
return errors
def _session_middleware():
return 'django.contrib.sessions.middleware.SessionMiddleware' in settings.MIDDLEWARE
def _session_app():
return "django.contrib.sessions" in settings.INSTALLED_APPS

View File

@@ -0,0 +1,35 @@
import copy
from django.conf import settings
from . import Error, Tags, register
E001 = Error(
"You have 'APP_DIRS': True in your TEMPLATES but also specify 'loaders' "
"in OPTIONS. Either remove APP_DIRS or remove the 'loaders' option.",
id='templates.E001',
)
E002 = Error(
"'string_if_invalid' in TEMPLATES OPTIONS must be a string but got: {} ({}).",
id="templates.E002",
)
@register(Tags.templates)
def check_setting_app_dirs_loaders(app_configs, **kwargs):
return [E001] if any(
conf.get('APP_DIRS') and 'loaders' in conf.get('OPTIONS', {})
for conf in settings.TEMPLATES
) else []
@register(Tags.templates)
def check_string_if_invalid_is_string(app_configs, **kwargs):
errors = []
for conf in settings.TEMPLATES:
string_if_invalid = conf.get('OPTIONS', {}).get('string_if_invalid', '')
if not isinstance(string_if_invalid, str):
error = copy.copy(E002)
error.msg = error.msg.format(string_if_invalid, type(string_if_invalid).__name__)
errors.append(error)
return errors

View File

@@ -0,0 +1,64 @@
from django.conf import settings
from django.utils.translation import get_supported_language_variant
from django.utils.translation.trans_real import language_code_re
from . import Error, Tags, register
E001 = Error(
'You have provided an invalid value for the LANGUAGE_CODE setting: {!r}.',
id='translation.E001',
)
E002 = Error(
'You have provided an invalid language code in the LANGUAGES setting: {!r}.',
id='translation.E002',
)
E003 = Error(
'You have provided an invalid language code in the LANGUAGES_BIDI setting: {!r}.',
id='translation.E003',
)
E004 = Error(
'You have provided a value for the LANGUAGE_CODE setting that is not in '
'the LANGUAGES setting.',
id='translation.E004',
)
@register(Tags.translation)
def check_setting_language_code(app_configs, **kwargs):
"""Error if LANGUAGE_CODE setting is invalid."""
tag = settings.LANGUAGE_CODE
if not isinstance(tag, str) or not language_code_re.match(tag):
return [Error(E001.msg.format(tag), id=E001.id)]
return []
@register(Tags.translation)
def check_setting_languages(app_configs, **kwargs):
"""Error if LANGUAGES setting is invalid."""
return [
Error(E002.msg.format(tag), id=E002.id)
for tag, _ in settings.LANGUAGES if not isinstance(tag, str) or not language_code_re.match(tag)
]
@register(Tags.translation)
def check_setting_languages_bidi(app_configs, **kwargs):
"""Error if LANGUAGES_BIDI setting is invalid."""
return [
Error(E003.msg.format(tag), id=E003.id)
for tag in settings.LANGUAGES_BIDI if not isinstance(tag, str) or not language_code_re.match(tag)
]
@register(Tags.translation)
def check_language_settings_consistent(app_configs, **kwargs):
"""Error if language settings are not consistent with each other."""
try:
get_supported_language_variant(settings.LANGUAGE_CODE)
except LookupError:
return [E004]
else:
return []

View File

@@ -0,0 +1,110 @@
from collections import Counter
from django.conf import settings
from . import Error, Tags, Warning, register
@register(Tags.urls)
def check_url_config(app_configs, **kwargs):
if getattr(settings, 'ROOT_URLCONF', None):
from django.urls import get_resolver
resolver = get_resolver()
return check_resolver(resolver)
return []
def check_resolver(resolver):
"""
Recursively check the resolver.
"""
check_method = getattr(resolver, 'check', None)
if check_method is not None:
return check_method()
elif not hasattr(resolver, 'resolve'):
return get_warning_for_invalid_pattern(resolver)
else:
return []
@register(Tags.urls)
def check_url_namespaces_unique(app_configs, **kwargs):
"""
Warn if URL namespaces used in applications aren't unique.
"""
if not getattr(settings, 'ROOT_URLCONF', None):
return []
from django.urls import get_resolver
resolver = get_resolver()
all_namespaces = _load_all_namespaces(resolver)
counter = Counter(all_namespaces)
non_unique_namespaces = [n for n, count in counter.items() if count > 1]
errors = []
for namespace in non_unique_namespaces:
errors.append(Warning(
"URL namespace '{}' isn't unique. You may not be able to reverse "
"all URLs in this namespace".format(namespace),
id="urls.W005",
))
return errors
def _load_all_namespaces(resolver, parents=()):
"""
Recursively load all namespaces from URL patterns.
"""
url_patterns = getattr(resolver, 'url_patterns', [])
namespaces = [
':'.join(parents + (url.namespace,)) for url in url_patterns
if getattr(url, 'namespace', None) is not None
]
for pattern in url_patterns:
namespace = getattr(pattern, 'namespace', None)
current = parents
if namespace is not None:
current += (namespace,)
namespaces.extend(_load_all_namespaces(pattern, current))
return namespaces
def get_warning_for_invalid_pattern(pattern):
"""
Return a list containing a warning that the pattern is invalid.
describe_pattern() cannot be used here, because we cannot rely on the
urlpattern having regex or name attributes.
"""
if isinstance(pattern, str):
hint = (
"Try removing the string '{}'. The list of urlpatterns should not "
"have a prefix string as the first element.".format(pattern)
)
elif isinstance(pattern, tuple):
hint = "Try using path() instead of a tuple."
else:
hint = None
return [Error(
"Your URL pattern {!r} is invalid. Ensure that urlpatterns is a list "
"of path() and/or re_path() instances.".format(pattern),
hint=hint,
id="urls.E004",
)]
@register(Tags.urls)
def check_url_settings(app_configs, **kwargs):
errors = []
for name in ('STATIC_URL', 'MEDIA_URL'):
value = getattr(settings, name)
if value and not value.endswith('/'):
errors.append(E006(name))
return errors
def E006(name):
return Error(
'The {} setting must end with a slash.'.format(name),
id='urls.E006',
)

View File

@@ -0,0 +1,217 @@
"""
Global Django exception and warning classes.
"""
import operator
from django.utils.hashable import make_hashable
class FieldDoesNotExist(Exception):
"""The requested model field does not exist"""
pass
class AppRegistryNotReady(Exception):
"""The django.apps registry is not populated yet"""
pass
class ObjectDoesNotExist(Exception):
"""The requested object does not exist"""
silent_variable_failure = True
class MultipleObjectsReturned(Exception):
"""The query returned multiple objects when only one was expected."""
pass
class SuspiciousOperation(Exception):
"""The user did something suspicious"""
class SuspiciousMultipartForm(SuspiciousOperation):
"""Suspect MIME request in multipart form data"""
pass
class SuspiciousFileOperation(SuspiciousOperation):
"""A Suspicious filesystem operation was attempted"""
pass
class DisallowedHost(SuspiciousOperation):
"""HTTP_HOST header contains invalid value"""
pass
class DisallowedRedirect(SuspiciousOperation):
"""Redirect to scheme not in allowed list"""
pass
class TooManyFieldsSent(SuspiciousOperation):
"""
The number of fields in a GET or POST request exceeded
settings.DATA_UPLOAD_MAX_NUMBER_FIELDS.
"""
pass
class RequestDataTooBig(SuspiciousOperation):
"""
The size of the request (excluding any file uploads) exceeded
settings.DATA_UPLOAD_MAX_MEMORY_SIZE.
"""
pass
class RequestAborted(Exception):
"""The request was closed before it was completed, or timed out."""
pass
class BadRequest(Exception):
"""The request is malformed and cannot be processed."""
pass
class PermissionDenied(Exception):
"""The user did not have permission to do that"""
pass
class ViewDoesNotExist(Exception):
"""The requested view does not exist"""
pass
class MiddlewareNotUsed(Exception):
"""This middleware is not used in this server configuration"""
pass
class ImproperlyConfigured(Exception):
"""Django is somehow improperly configured"""
pass
class FieldError(Exception):
"""Some kind of problem with a model field."""
pass
NON_FIELD_ERRORS = '__all__'
class ValidationError(Exception):
"""An error while validating data."""
def __init__(self, message, code=None, params=None):
"""
The `message` argument can be a single error, a list of errors, or a
dictionary that maps field names to lists of errors. What we define as
an "error" can be either a simple string or an instance of
ValidationError with its message attribute set, and what we define as
list or dictionary can be an actual `list` or `dict` or an instance
of ValidationError with its `error_list` or `error_dict` attribute set.
"""
super().__init__(message, code, params)
if isinstance(message, ValidationError):
if hasattr(message, 'error_dict'):
message = message.error_dict
elif not hasattr(message, 'message'):
message = message.error_list
else:
message, code, params = message.message, message.code, message.params
if isinstance(message, dict):
self.error_dict = {}
for field, messages in message.items():
if not isinstance(messages, ValidationError):
messages = ValidationError(messages)
self.error_dict[field] = messages.error_list
elif isinstance(message, list):
self.error_list = []
for message in message:
# Normalize plain strings to instances of ValidationError.
if not isinstance(message, ValidationError):
message = ValidationError(message)
if hasattr(message, 'error_dict'):
self.error_list.extend(sum(message.error_dict.values(), []))
else:
self.error_list.extend(message.error_list)
else:
self.message = message
self.code = code
self.params = params
self.error_list = [self]
@property
def message_dict(self):
# Trigger an AttributeError if this ValidationError
# doesn't have an error_dict.
getattr(self, 'error_dict')
return dict(self)
@property
def messages(self):
if hasattr(self, 'error_dict'):
return sum(dict(self).values(), [])
return list(self)
def update_error_dict(self, error_dict):
if hasattr(self, 'error_dict'):
for field, error_list in self.error_dict.items():
error_dict.setdefault(field, []).extend(error_list)
else:
error_dict.setdefault(NON_FIELD_ERRORS, []).extend(self.error_list)
return error_dict
def __iter__(self):
if hasattr(self, 'error_dict'):
for field, errors in self.error_dict.items():
yield field, list(ValidationError(errors))
else:
for error in self.error_list:
message = error.message
if error.params:
message %= error.params
yield str(message)
def __str__(self):
if hasattr(self, 'error_dict'):
return repr(dict(self))
return repr(list(self))
def __repr__(self):
return 'ValidationError(%s)' % self
def __eq__(self, other):
if not isinstance(other, ValidationError):
return NotImplemented
return hash(self) == hash(other)
def __hash__(self):
if hasattr(self, 'message'):
return hash((
self.message,
self.code,
make_hashable(self.params),
))
if hasattr(self, 'error_dict'):
return hash(make_hashable(self.error_dict))
return hash(tuple(sorted(self.error_list, key=operator.attrgetter('message'))))
class EmptyResultSet(Exception):
"""A database query predicate is impossible."""
pass
class SynchronousOnlyOperation(Exception):
"""The user tried to call a sync-only function from an async context."""
pass

View File

@@ -0,0 +1,3 @@
from django.core.files.base import File
__all__ = ['File']

View File

@@ -0,0 +1,160 @@
import os
from io import BytesIO, StringIO, UnsupportedOperation
from django.core.files.utils import FileProxyMixin
from django.utils.functional import cached_property
class File(FileProxyMixin):
DEFAULT_CHUNK_SIZE = 64 * 2 ** 10
def __init__(self, file, name=None):
self.file = file
if name is None:
name = getattr(file, 'name', None)
self.name = name
if hasattr(file, 'mode'):
self.mode = file.mode
def __str__(self):
return self.name or ''
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self or "None")
def __bool__(self):
return bool(self.name)
def __len__(self):
return self.size
@cached_property
def size(self):
if hasattr(self.file, 'size'):
return self.file.size
if hasattr(self.file, 'name'):
try:
return os.path.getsize(self.file.name)
except (OSError, TypeError):
pass
if hasattr(self.file, 'tell') and hasattr(self.file, 'seek'):
pos = self.file.tell()
self.file.seek(0, os.SEEK_END)
size = self.file.tell()
self.file.seek(pos)
return size
raise AttributeError("Unable to determine the file's size.")
def chunks(self, chunk_size=None):
"""
Read the file and yield chunks of ``chunk_size`` bytes (defaults to
``File.DEFAULT_CHUNK_SIZE``).
"""
chunk_size = chunk_size or self.DEFAULT_CHUNK_SIZE
try:
self.seek(0)
except (AttributeError, UnsupportedOperation):
pass
while True:
data = self.read(chunk_size)
if not data:
break
yield data
def multiple_chunks(self, chunk_size=None):
"""
Return ``True`` if you can expect multiple chunks.
NB: If a particular file representation is in memory, subclasses should
always return ``False`` -- there's no good reason to read from memory in
chunks.
"""
return self.size > (chunk_size or self.DEFAULT_CHUNK_SIZE)
def __iter__(self):
# Iterate over this file-like object by newlines
buffer_ = None
for chunk in self.chunks():
for line in chunk.splitlines(True):
if buffer_:
if endswith_cr(buffer_) and not equals_lf(line):
# Line split after a \r newline; yield buffer_.
yield buffer_
# Continue with line.
else:
# Line either split without a newline (line
# continues after buffer_) or with \r\n
# newline (line == b'\n').
line = buffer_ + line
# buffer_ handled, clear it.
buffer_ = None
# If this is the end of a \n or \r\n line, yield.
if endswith_lf(line):
yield line
else:
buffer_ = line
if buffer_ is not None:
yield buffer_
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, tb):
self.close()
def open(self, mode=None):
if not self.closed:
self.seek(0)
elif self.name and os.path.exists(self.name):
self.file = open(self.name, mode or self.mode)
else:
raise ValueError("The file cannot be reopened.")
return self
def close(self):
self.file.close()
class ContentFile(File):
"""
A File-like object that takes just raw content, rather than an actual file.
"""
def __init__(self, content, name=None):
stream_class = StringIO if isinstance(content, str) else BytesIO
super().__init__(stream_class(content), name=name)
self.size = len(content)
def __str__(self):
return 'Raw content'
def __bool__(self):
return True
def open(self, mode=None):
self.seek(0)
return self
def close(self):
pass
def write(self, data):
self.__dict__.pop('size', None) # Clear the computed size.
return self.file.write(data)
def endswith_cr(line):
"""Return True if line (a text or bytestring) ends with '\r'."""
return line.endswith('\r' if isinstance(line, str) else b'\r')
def endswith_lf(line):
"""Return True if line (a text or bytestring) ends with '\n'."""
return line.endswith('\n' if isinstance(line, str) else b'\n')
def equals_lf(line):
"""Return True if line (a text or bytestring) equals '\n'."""
return line == ('\n' if isinstance(line, str) else b'\n')

View File

@@ -0,0 +1,87 @@
"""
Utility functions for handling images.
Requires Pillow as you might imagine.
"""
import struct
import zlib
from django.core.files import File
class ImageFile(File):
"""
A mixin for use alongside django.core.files.base.File, which provides
additional features for dealing with images.
"""
@property
def width(self):
return self._get_image_dimensions()[0]
@property
def height(self):
return self._get_image_dimensions()[1]
def _get_image_dimensions(self):
if not hasattr(self, '_dimensions_cache'):
close = self.closed
self.open()
self._dimensions_cache = get_image_dimensions(self, close=close)
return self._dimensions_cache
def get_image_dimensions(file_or_path, close=False):
"""
Return the (width, height) of an image, given an open file or a path. Set
'close' to True to close the file at the end if it is initially in an open
state.
"""
from PIL import ImageFile as PillowImageFile
p = PillowImageFile.Parser()
if hasattr(file_or_path, 'read'):
file = file_or_path
file_pos = file.tell()
file.seek(0)
else:
try:
file = open(file_or_path, 'rb')
except OSError:
return (None, None)
close = True
try:
# Most of the time Pillow only needs a small chunk to parse the image
# and get the dimensions, but with some TIFF files Pillow needs to
# parse the whole file.
chunk_size = 1024
while 1:
data = file.read(chunk_size)
if not data:
break
try:
p.feed(data)
except zlib.error as e:
# ignore zlib complaining on truncated stream, just feed more
# data to parser (ticket #19457).
if e.args[0].startswith("Error -5"):
pass
else:
raise
except struct.error:
# Ignore PIL failing on a too short buffer when reads return
# less bytes than expected. Skip and feed more data to the
# parser (ticket #24544).
pass
except RuntimeError:
# e.g. "RuntimeError: could not create decoder object" for
# WebP files. A different chunk_size may work.
pass
if p.image:
return p.image.size
chunk_size *= 2
return (None, None)
finally:
if close:
file.close()
else:
file.seek(file_pos)

View File

@@ -0,0 +1,118 @@
"""
Portable file locking utilities.
Based partially on an example by Jonathan Feignberg in the Python
Cookbook [1] (licensed under the Python Software License) and a ctypes port by
Anatoly Techtonik for Roundup [2] (license [3]).
[1] http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/65203
[2] https://sourceforge.net/p/roundup/code/ci/default/tree/roundup/backends/portalocker.py
[3] https://sourceforge.net/p/roundup/code/ci/default/tree/COPYING.txt
Example Usage::
>>> from django.core.files import locks
>>> with open('./file', 'wb') as f:
... locks.lock(f, locks.LOCK_EX)
... f.write('Django')
"""
import os
__all__ = ('LOCK_EX', 'LOCK_SH', 'LOCK_NB', 'lock', 'unlock')
def _fd(f):
"""Get a filedescriptor from something which could be a file or an fd."""
return f.fileno() if hasattr(f, 'fileno') else f
if os.name == 'nt':
import msvcrt
from ctypes import (
POINTER, Structure, Union, byref, c_int64, c_ulong, c_void_p, sizeof,
windll,
)
from ctypes.wintypes import BOOL, DWORD, HANDLE
LOCK_SH = 0 # the default
LOCK_NB = 0x1 # LOCKFILE_FAIL_IMMEDIATELY
LOCK_EX = 0x2 # LOCKFILE_EXCLUSIVE_LOCK
# --- Adapted from the pyserial project ---
# detect size of ULONG_PTR
if sizeof(c_ulong) != sizeof(c_void_p):
ULONG_PTR = c_int64
else:
ULONG_PTR = c_ulong
PVOID = c_void_p
# --- Union inside Structure by stackoverflow:3480240 ---
class _OFFSET(Structure):
_fields_ = [
('Offset', DWORD),
('OffsetHigh', DWORD)]
class _OFFSET_UNION(Union):
_anonymous_ = ['_offset']
_fields_ = [
('_offset', _OFFSET),
('Pointer', PVOID)]
class OVERLAPPED(Structure):
_anonymous_ = ['_offset_union']
_fields_ = [
('Internal', ULONG_PTR),
('InternalHigh', ULONG_PTR),
('_offset_union', _OFFSET_UNION),
('hEvent', HANDLE)]
LPOVERLAPPED = POINTER(OVERLAPPED)
# --- Define function prototypes for extra safety ---
LockFileEx = windll.kernel32.LockFileEx
LockFileEx.restype = BOOL
LockFileEx.argtypes = [HANDLE, DWORD, DWORD, DWORD, DWORD, LPOVERLAPPED]
UnlockFileEx = windll.kernel32.UnlockFileEx
UnlockFileEx.restype = BOOL
UnlockFileEx.argtypes = [HANDLE, DWORD, DWORD, DWORD, LPOVERLAPPED]
def lock(f, flags):
hfile = msvcrt.get_osfhandle(_fd(f))
overlapped = OVERLAPPED()
ret = LockFileEx(hfile, flags, 0, 0, 0xFFFF0000, byref(overlapped))
return bool(ret)
def unlock(f):
hfile = msvcrt.get_osfhandle(_fd(f))
overlapped = OVERLAPPED()
ret = UnlockFileEx(hfile, 0, 0, 0xFFFF0000, byref(overlapped))
return bool(ret)
else:
try:
import fcntl
LOCK_SH = fcntl.LOCK_SH # shared lock
LOCK_NB = fcntl.LOCK_NB # non-blocking
LOCK_EX = fcntl.LOCK_EX
except (ImportError, AttributeError):
# File locking is not supported.
LOCK_EX = LOCK_SH = LOCK_NB = 0
# Dummy functions that don't do anything.
def lock(f, flags):
# File is not locked
return False
def unlock(f):
# File is unlocked
return True
else:
def lock(f, flags):
try:
fcntl.flock(_fd(f), flags)
return True
except BlockingIOError:
return False
def unlock(f):
fcntl.flock(_fd(f), fcntl.LOCK_UN)
return True

View File

@@ -0,0 +1,87 @@
"""
Move a file in the safest way possible::
>>> from django.core.files.move import file_move_safe
>>> file_move_safe("/tmp/old_file", "/tmp/new_file")
"""
import errno
import os
from shutil import copystat
from django.core.files import locks
__all__ = ['file_move_safe']
def _samefile(src, dst):
# Macintosh, Unix.
if hasattr(os.path, 'samefile'):
try:
return os.path.samefile(src, dst)
except OSError:
return False
# All other platforms: check for same pathname.
return (os.path.normcase(os.path.abspath(src)) ==
os.path.normcase(os.path.abspath(dst)))
def file_move_safe(old_file_name, new_file_name, chunk_size=1024 * 64, allow_overwrite=False):
"""
Move a file from one location to another in the safest way possible.
First, try ``os.rename``, which is simple but will break across filesystems.
If that fails, stream manually from one file to another in pure Python.
If the destination file exists and ``allow_overwrite`` is ``False``, raise
``FileExistsError``.
"""
# There's no reason to move if we don't have to.
if _samefile(old_file_name, new_file_name):
return
try:
if not allow_overwrite and os.access(new_file_name, os.F_OK):
raise FileExistsError('Destination file %s exists and allow_overwrite is False.' % new_file_name)
os.rename(old_file_name, new_file_name)
return
except OSError:
# OSError happens with os.rename() if moving to another filesystem or
# when moving opened files on certain operating systems.
pass
# first open the old file, so that it won't go away
with open(old_file_name, 'rb') as old_file:
# now open the new file, not forgetting allow_overwrite
fd = os.open(new_file_name, (os.O_WRONLY | os.O_CREAT | getattr(os, 'O_BINARY', 0) |
(os.O_EXCL if not allow_overwrite else 0)))
try:
locks.lock(fd, locks.LOCK_EX)
current_chunk = None
while current_chunk != b'':
current_chunk = old_file.read(chunk_size)
os.write(fd, current_chunk)
finally:
locks.unlock(fd)
os.close(fd)
try:
copystat(old_file_name, new_file_name)
except PermissionError as e:
# Certain filesystems (e.g. CIFS) fail to copy the file's metadata if
# the type of the destination filesystem isn't the same as the source
# filesystem; ignore that.
if e.errno != errno.EPERM:
raise
try:
os.remove(old_file_name)
except PermissionError as e:
# Certain operating systems (Cygwin and Windows)
# fail when deleting opened files, ignore it. (For the
# systems where this happens, temporary files will be auto-deleted
# on close anyway.)
if getattr(e, 'winerror', 0) != 32:
raise

View File

@@ -0,0 +1,373 @@
import os
import pathlib
from datetime import datetime
from urllib.parse import urljoin
from django.conf import settings
from django.core.exceptions import SuspiciousFileOperation
from django.core.files import File, locks
from django.core.files.move import file_move_safe
from django.core.files.utils import validate_file_name
from django.core.signals import setting_changed
from django.utils import timezone
from django.utils._os import safe_join
from django.utils.crypto import get_random_string
from django.utils.deconstruct import deconstructible
from django.utils.encoding import filepath_to_uri
from django.utils.functional import LazyObject, cached_property
from django.utils.module_loading import import_string
from django.utils.text import get_valid_filename
__all__ = (
'Storage', 'FileSystemStorage', 'DefaultStorage', 'default_storage',
'get_storage_class',
)
class Storage:
"""
A base storage class, providing some default behaviors that all other
storage systems can inherit or override, as necessary.
"""
# The following methods represent a public interface to private methods.
# These shouldn't be overridden by subclasses unless absolutely necessary.
def open(self, name, mode='rb'):
"""Retrieve the specified file from storage."""
return self._open(name, mode)
def save(self, name, content, max_length=None):
"""
Save new content to the file specified by name. The content should be
a proper File object or any Python file-like object, ready to be read
from the beginning.
"""
# Get the proper name for the file, as it will actually be saved.
if name is None:
name = content.name
if not hasattr(content, 'chunks'):
content = File(content, name)
name = self.get_available_name(name, max_length=max_length)
return self._save(name, content)
# These methods are part of the public API, with default implementations.
def get_valid_name(self, name):
"""
Return a filename, based on the provided filename, that's suitable for
use in the target storage system.
"""
return get_valid_filename(name)
def get_alternative_name(self, file_root, file_ext):
"""
Return an alternative filename, by adding an underscore and a random 7
character alphanumeric string (before the file extension, if one
exists) to the filename.
"""
return '%s_%s%s' % (file_root, get_random_string(7), file_ext)
def get_available_name(self, name, max_length=None):
"""
Return a filename that's free on the target storage system and
available for new content to be written to.
"""
dir_name, file_name = os.path.split(name)
if '..' in pathlib.PurePath(dir_name).parts:
raise SuspiciousFileOperation("Detected path traversal attempt in '%s'" % dir_name)
validate_file_name(file_name)
file_root, file_ext = os.path.splitext(file_name)
# If the filename already exists, generate an alternative filename
# until it doesn't exist.
# Truncate original name if required, so the new filename does not
# exceed the max_length.
while self.exists(name) or (max_length and len(name) > max_length):
# file_ext includes the dot.
name = os.path.join(dir_name, self.get_alternative_name(file_root, file_ext))
if max_length is None:
continue
# Truncate file_root if max_length exceeded.
truncation = len(name) - max_length
if truncation > 0:
file_root = file_root[:-truncation]
# Entire file_root was truncated in attempt to find an available filename.
if not file_root:
raise SuspiciousFileOperation(
'Storage can not find an available filename for "%s". '
'Please make sure that the corresponding file field '
'allows sufficient "max_length".' % name
)
name = os.path.join(dir_name, self.get_alternative_name(file_root, file_ext))
return name
def generate_filename(self, filename):
"""
Validate the filename by calling get_valid_name() and return a filename
to be passed to the save() method.
"""
# `filename` may include a path as returned by FileField.upload_to.
dirname, filename = os.path.split(filename)
if '..' in pathlib.PurePath(dirname).parts:
raise SuspiciousFileOperation("Detected path traversal attempt in '%s'" % dirname)
return os.path.normpath(os.path.join(dirname, self.get_valid_name(filename)))
def path(self, name):
"""
Return a local filesystem path where the file can be retrieved using
Python's built-in open() function. Storage systems that can't be
accessed using open() should *not* implement this method.
"""
raise NotImplementedError("This backend doesn't support absolute paths.")
# The following methods form the public API for storage systems, but with
# no default implementations. Subclasses must implement *all* of these.
def delete(self, name):
"""
Delete the specified file from the storage system.
"""
raise NotImplementedError('subclasses of Storage must provide a delete() method')
def exists(self, name):
"""
Return True if a file referenced by the given name already exists in the
storage system, or False if the name is available for a new file.
"""
raise NotImplementedError('subclasses of Storage must provide an exists() method')
def listdir(self, path):
"""
List the contents of the specified path. Return a 2-tuple of lists:
the first item being directories, the second item being files.
"""
raise NotImplementedError('subclasses of Storage must provide a listdir() method')
def size(self, name):
"""
Return the total size, in bytes, of the file specified by name.
"""
raise NotImplementedError('subclasses of Storage must provide a size() method')
def url(self, name):
"""
Return an absolute URL where the file's contents can be accessed
directly by a web browser.
"""
raise NotImplementedError('subclasses of Storage must provide a url() method')
def get_accessed_time(self, name):
"""
Return the last accessed time (as a datetime) of the file specified by
name. The datetime will be timezone-aware if USE_TZ=True.
"""
raise NotImplementedError('subclasses of Storage must provide a get_accessed_time() method')
def get_created_time(self, name):
"""
Return the creation time (as a datetime) of the file specified by name.
The datetime will be timezone-aware if USE_TZ=True.
"""
raise NotImplementedError('subclasses of Storage must provide a get_created_time() method')
def get_modified_time(self, name):
"""
Return the last modified time (as a datetime) of the file specified by
name. The datetime will be timezone-aware if USE_TZ=True.
"""
raise NotImplementedError('subclasses of Storage must provide a get_modified_time() method')
@deconstructible
class FileSystemStorage(Storage):
"""
Standard filesystem storage
"""
# The combination of O_CREAT and O_EXCL makes os.open() raise OSError if
# the file already exists before it's opened.
OS_OPEN_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_EXCL | getattr(os, 'O_BINARY', 0)
def __init__(self, location=None, base_url=None, file_permissions_mode=None,
directory_permissions_mode=None):
self._location = location
self._base_url = base_url
self._file_permissions_mode = file_permissions_mode
self._directory_permissions_mode = directory_permissions_mode
setting_changed.connect(self._clear_cached_properties)
def _clear_cached_properties(self, setting, **kwargs):
"""Reset setting based property values."""
if setting == 'MEDIA_ROOT':
self.__dict__.pop('base_location', None)
self.__dict__.pop('location', None)
elif setting == 'MEDIA_URL':
self.__dict__.pop('base_url', None)
elif setting == 'FILE_UPLOAD_PERMISSIONS':
self.__dict__.pop('file_permissions_mode', None)
elif setting == 'FILE_UPLOAD_DIRECTORY_PERMISSIONS':
self.__dict__.pop('directory_permissions_mode', None)
def _value_or_setting(self, value, setting):
return setting if value is None else value
@cached_property
def base_location(self):
return self._value_or_setting(self._location, settings.MEDIA_ROOT)
@cached_property
def location(self):
return os.path.abspath(self.base_location)
@cached_property
def base_url(self):
if self._base_url is not None and not self._base_url.endswith('/'):
self._base_url += '/'
return self._value_or_setting(self._base_url, settings.MEDIA_URL)
@cached_property
def file_permissions_mode(self):
return self._value_or_setting(self._file_permissions_mode, settings.FILE_UPLOAD_PERMISSIONS)
@cached_property
def directory_permissions_mode(self):
return self._value_or_setting(self._directory_permissions_mode, settings.FILE_UPLOAD_DIRECTORY_PERMISSIONS)
def _open(self, name, mode='rb'):
return File(open(self.path(name), mode))
def _save(self, name, content):
full_path = self.path(name)
# Create any intermediate directories that do not exist.
directory = os.path.dirname(full_path)
try:
if self.directory_permissions_mode is not None:
# Set the umask because os.makedirs() doesn't apply the "mode"
# argument to intermediate-level directories.
old_umask = os.umask(0o777 & ~self.directory_permissions_mode)
try:
os.makedirs(directory, self.directory_permissions_mode, exist_ok=True)
finally:
os.umask(old_umask)
else:
os.makedirs(directory, exist_ok=True)
except FileExistsError:
raise FileExistsError('%s exists and is not a directory.' % directory)
# There's a potential race condition between get_available_name and
# saving the file; it's possible that two threads might return the
# same name, at which point all sorts of fun happens. So we need to
# try to create the file, but if it already exists we have to go back
# to get_available_name() and try again.
while True:
try:
# This file has a file path that we can move.
if hasattr(content, 'temporary_file_path'):
file_move_safe(content.temporary_file_path(), full_path)
# This is a normal uploadedfile that we can stream.
else:
# The current umask value is masked out by os.open!
fd = os.open(full_path, self.OS_OPEN_FLAGS, 0o666)
_file = None
try:
locks.lock(fd, locks.LOCK_EX)
for chunk in content.chunks():
if _file is None:
mode = 'wb' if isinstance(chunk, bytes) else 'wt'
_file = os.fdopen(fd, mode)
_file.write(chunk)
finally:
locks.unlock(fd)
if _file is not None:
_file.close()
else:
os.close(fd)
except FileExistsError:
# A new name is needed if the file exists.
name = self.get_available_name(name)
full_path = self.path(name)
else:
# OK, the file save worked. Break out of the loop.
break
if self.file_permissions_mode is not None:
os.chmod(full_path, self.file_permissions_mode)
# Store filenames with forward slashes, even on Windows.
return str(name).replace('\\', '/')
def delete(self, name):
if not name:
raise ValueError('The name must be given to delete().')
name = self.path(name)
# If the file or directory exists, delete it from the filesystem.
try:
if os.path.isdir(name):
os.rmdir(name)
else:
os.remove(name)
except FileNotFoundError:
# FileNotFoundError is raised if the file or directory was removed
# concurrently.
pass
def exists(self, name):
return os.path.lexists(self.path(name))
def listdir(self, path):
path = self.path(path)
directories, files = [], []
with os.scandir(path) as entries:
for entry in entries:
if entry.is_dir():
directories.append(entry.name)
else:
files.append(entry.name)
return directories, files
def path(self, name):
return safe_join(self.location, name)
def size(self, name):
return os.path.getsize(self.path(name))
def url(self, name):
if self.base_url is None:
raise ValueError("This file is not accessible via a URL.")
url = filepath_to_uri(name)
if url is not None:
url = url.lstrip('/')
return urljoin(self.base_url, url)
def _datetime_from_timestamp(self, ts):
"""
If timezone support is enabled, make an aware datetime object in UTC;
otherwise make a naive one in the local timezone.
"""
tz = timezone.utc if settings.USE_TZ else None
return datetime.fromtimestamp(ts, tz=tz)
def get_accessed_time(self, name):
return self._datetime_from_timestamp(os.path.getatime(self.path(name)))
def get_created_time(self, name):
return self._datetime_from_timestamp(os.path.getctime(self.path(name)))
def get_modified_time(self, name):
return self._datetime_from_timestamp(os.path.getmtime(self.path(name)))
def get_storage_class(import_path=None):
return import_string(import_path or settings.DEFAULT_FILE_STORAGE)
class DefaultStorage(LazyObject):
def _setup(self):
self._wrapped = get_storage_class()()
default_storage = DefaultStorage()

View File

@@ -0,0 +1,74 @@
"""
The temp module provides a NamedTemporaryFile that can be reopened in the same
process on any platform. Most platforms use the standard Python
tempfile.NamedTemporaryFile class, but Windows users are given a custom class.
This is needed because the Python implementation of NamedTemporaryFile uses the
O_TEMPORARY flag under Windows, which prevents the file from being reopened
if the same flag is not provided [1][2]. Note that this does not address the
more general issue of opening a file for writing and reading in multiple
processes in a manner that works across platforms.
The custom version of NamedTemporaryFile doesn't support the same keyword
arguments available in tempfile.NamedTemporaryFile.
1: https://mail.python.org/pipermail/python-list/2005-December/336957.html
2: https://bugs.python.org/issue14243
"""
import os
import tempfile
from django.core.files.utils import FileProxyMixin
__all__ = ('NamedTemporaryFile', 'gettempdir',)
if os.name == 'nt':
class TemporaryFile(FileProxyMixin):
"""
Temporary file object constructor that supports reopening of the
temporary file in Windows.
Unlike tempfile.NamedTemporaryFile from the standard library,
__init__() doesn't support the 'delete', 'buffering', 'encoding', or
'newline' keyword arguments.
"""
def __init__(self, mode='w+b', bufsize=-1, suffix='', prefix='', dir=None):
fd, name = tempfile.mkstemp(suffix=suffix, prefix=prefix, dir=dir)
self.name = name
self.file = os.fdopen(fd, mode, bufsize)
self.close_called = False
# Because close can be called during shutdown
# we need to cache os.unlink and access it
# as self.unlink only
unlink = os.unlink
def close(self):
if not self.close_called:
self.close_called = True
try:
self.file.close()
except OSError:
pass
try:
self.unlink(self.name)
except OSError:
pass
def __del__(self):
self.close()
def __enter__(self):
self.file.__enter__()
return self
def __exit__(self, exc, value, tb):
self.file.__exit__(exc, value, tb)
NamedTemporaryFile = TemporaryFile
else:
NamedTemporaryFile = tempfile.NamedTemporaryFile
gettempdir = tempfile.gettempdir

View File

@@ -0,0 +1,120 @@
"""
Classes representing uploaded files.
"""
import os
from io import BytesIO
from django.conf import settings
from django.core.files import temp as tempfile
from django.core.files.base import File
from django.core.files.utils import validate_file_name
__all__ = ('UploadedFile', 'TemporaryUploadedFile', 'InMemoryUploadedFile',
'SimpleUploadedFile')
class UploadedFile(File):
"""
An abstract uploaded file (``TemporaryUploadedFile`` and
``InMemoryUploadedFile`` are the built-in concrete subclasses).
An ``UploadedFile`` object behaves somewhat like a file object and
represents some file data that the user submitted with a form.
"""
def __init__(self, file=None, name=None, content_type=None, size=None, charset=None, content_type_extra=None):
super().__init__(file, name)
self.size = size
self.content_type = content_type
self.charset = charset
self.content_type_extra = content_type_extra
def __repr__(self):
return "<%s: %s (%s)>" % (self.__class__.__name__, self.name, self.content_type)
def _get_name(self):
return self._name
def _set_name(self, name):
# Sanitize the file name so that it can't be dangerous.
if name is not None:
# Just use the basename of the file -- anything else is dangerous.
name = os.path.basename(name)
# File names longer than 255 characters can cause problems on older OSes.
if len(name) > 255:
name, ext = os.path.splitext(name)
ext = ext[:255]
name = name[:255 - len(ext)] + ext
name = validate_file_name(name)
self._name = name
name = property(_get_name, _set_name)
class TemporaryUploadedFile(UploadedFile):
"""
A file uploaded to a temporary location (i.e. stream-to-disk).
"""
def __init__(self, name, content_type, size, charset, content_type_extra=None):
_, ext = os.path.splitext(name)
file = tempfile.NamedTemporaryFile(suffix='.upload' + ext, dir=settings.FILE_UPLOAD_TEMP_DIR)
super().__init__(file, name, content_type, size, charset, content_type_extra)
def temporary_file_path(self):
"""Return the full path of this file."""
return self.file.name
def close(self):
try:
return self.file.close()
except FileNotFoundError:
# The file was moved or deleted before the tempfile could unlink
# it. Still sets self.file.close_called and calls
# self.file.file.close() before the exception.
pass
class InMemoryUploadedFile(UploadedFile):
"""
A file uploaded into memory (i.e. stream-to-memory).
"""
def __init__(self, file, field_name, name, content_type, size, charset, content_type_extra=None):
super().__init__(file, name, content_type, size, charset, content_type_extra)
self.field_name = field_name
def open(self, mode=None):
self.file.seek(0)
return self
def chunks(self, chunk_size=None):
self.file.seek(0)
yield self.read()
def multiple_chunks(self, chunk_size=None):
# Since it's in memory, we'll never have multiple chunks.
return False
class SimpleUploadedFile(InMemoryUploadedFile):
"""
A simple representation of a file, which just has content, size, and a name.
"""
def __init__(self, name, content, content_type='text/plain'):
content = content or b''
super().__init__(BytesIO(content), None, name, content_type, len(content), None, None)
@classmethod
def from_dict(cls, file_dict):
"""
Create a SimpleUploadedFile object from a dictionary with keys:
- filename
- content-type
- content
"""
return cls(file_dict['filename'],
file_dict['content'],
file_dict.get('content-type', 'text/plain'))

View File

@@ -0,0 +1,221 @@
"""
Base file upload handler classes, and the built-in concrete subclasses
"""
import os
from io import BytesIO
from django.conf import settings
from django.core.files.uploadedfile import (
InMemoryUploadedFile, TemporaryUploadedFile,
)
from django.utils.module_loading import import_string
__all__ = [
'UploadFileException', 'StopUpload', 'SkipFile', 'FileUploadHandler',
'TemporaryFileUploadHandler', 'MemoryFileUploadHandler', 'load_handler',
'StopFutureHandlers'
]
class UploadFileException(Exception):
"""
Any error having to do with uploading files.
"""
pass
class StopUpload(UploadFileException):
"""
This exception is raised when an upload must abort.
"""
def __init__(self, connection_reset=False):
"""
If ``connection_reset`` is ``True``, Django knows will halt the upload
without consuming the rest of the upload. This will cause the browser to
show a "connection reset" error.
"""
self.connection_reset = connection_reset
def __str__(self):
if self.connection_reset:
return 'StopUpload: Halt current upload.'
else:
return 'StopUpload: Consume request data, then halt.'
class SkipFile(UploadFileException):
"""
This exception is raised by an upload handler that wants to skip a given file.
"""
pass
class StopFutureHandlers(UploadFileException):
"""
Upload handlers that have handled a file and do not want future handlers to
run should raise this exception instead of returning None.
"""
pass
class FileUploadHandler:
"""
Base class for streaming upload handlers.
"""
chunk_size = 64 * 2 ** 10 # : The default chunk size is 64 KB.
def __init__(self, request=None):
self.file_name = None
self.content_type = None
self.content_length = None
self.charset = None
self.content_type_extra = None
self.request = request
def handle_raw_input(self, input_data, META, content_length, boundary, encoding=None):
"""
Handle the raw input from the client.
Parameters:
:input_data:
An object that supports reading via .read().
:META:
``request.META``.
:content_length:
The (integer) value of the Content-Length header from the
client.
:boundary: The boundary from the Content-Type header. Be sure to
prepend two '--'.
"""
pass
def new_file(self, field_name, file_name, content_type, content_length, charset=None, content_type_extra=None):
"""
Signal that a new file has been started.
Warning: As with any data from the client, you should not trust
content_length (and sometimes won't even get it).
"""
self.field_name = field_name
self.file_name = file_name
self.content_type = content_type
self.content_length = content_length
self.charset = charset
self.content_type_extra = content_type_extra
def receive_data_chunk(self, raw_data, start):
"""
Receive data from the streamed upload parser. ``start`` is the position
in the file of the chunk.
"""
raise NotImplementedError('subclasses of FileUploadHandler must provide a receive_data_chunk() method')
def file_complete(self, file_size):
"""
Signal that a file has completed. File size corresponds to the actual
size accumulated by all the chunks.
Subclasses should return a valid ``UploadedFile`` object.
"""
raise NotImplementedError('subclasses of FileUploadHandler must provide a file_complete() method')
def upload_complete(self):
"""
Signal that the upload is complete. Subclasses should perform cleanup
that is necessary for this handler.
"""
pass
def upload_interrupted(self):
"""
Signal that the upload was interrupted. Subclasses should perform
cleanup that is necessary for this handler.
"""
pass
class TemporaryFileUploadHandler(FileUploadHandler):
"""
Upload handler that streams data into a temporary file.
"""
def new_file(self, *args, **kwargs):
"""
Create the file object to append to as data is coming in.
"""
super().new_file(*args, **kwargs)
self.file = TemporaryUploadedFile(self.file_name, self.content_type, 0, self.charset, self.content_type_extra)
def receive_data_chunk(self, raw_data, start):
self.file.write(raw_data)
def file_complete(self, file_size):
self.file.seek(0)
self.file.size = file_size
return self.file
def upload_interrupted(self):
if hasattr(self, 'file'):
temp_location = self.file.temporary_file_path()
try:
self.file.close()
os.remove(temp_location)
except FileNotFoundError:
pass
class MemoryFileUploadHandler(FileUploadHandler):
"""
File upload handler to stream uploads into memory (used for small files).
"""
def handle_raw_input(self, input_data, META, content_length, boundary, encoding=None):
"""
Use the content_length to signal whether or not this handler should be
used.
"""
# Check the content-length header to see if we should
# If the post is too large, we cannot use the Memory handler.
self.activated = content_length <= settings.FILE_UPLOAD_MAX_MEMORY_SIZE
def new_file(self, *args, **kwargs):
super().new_file(*args, **kwargs)
if self.activated:
self.file = BytesIO()
raise StopFutureHandlers()
def receive_data_chunk(self, raw_data, start):
"""Add the data to the BytesIO file."""
if self.activated:
self.file.write(raw_data)
else:
return raw_data
def file_complete(self, file_size):
"""Return a file object if this handler is activated."""
if not self.activated:
return
self.file.seek(0)
return InMemoryUploadedFile(
file=self.file,
field_name=self.field_name,
name=self.file_name,
content_type=self.content_type,
size=file_size,
charset=self.charset,
content_type_extra=self.content_type_extra
)
def load_handler(path, *args, **kwargs):
"""
Given a path to a handler, return an instance of that handler.
E.g.::
>>> from django.http import HttpRequest
>>> request = HttpRequest()
>>> load_handler('django.core.files.uploadhandler.TemporaryFileUploadHandler', request)
<TemporaryFileUploadHandler object at 0x...>
"""
return import_string(path)(*args, **kwargs)

View File

@@ -0,0 +1,78 @@
import os
import pathlib
from django.core.exceptions import SuspiciousFileOperation
def validate_file_name(name, allow_relative_path=False):
# Remove potentially dangerous names
if os.path.basename(name) in {'', '.', '..'}:
raise SuspiciousFileOperation("Could not derive file name from '%s'" % name)
if allow_relative_path:
# Use PurePosixPath() because this branch is checked only in
# FileField.generate_filename() where all file paths are expected to be
# Unix style (with forward slashes).
path = pathlib.PurePosixPath(name)
if path.is_absolute() or '..' in path.parts:
raise SuspiciousFileOperation(
"Detected path traversal attempt in '%s'" % name
)
elif name != os.path.basename(name):
raise SuspiciousFileOperation("File name '%s' includes path elements" % name)
return name
class FileProxyMixin:
"""
A mixin class used to forward file methods to an underlaying file
object. The internal file object has to be called "file"::
class FileProxy(FileProxyMixin):
def __init__(self, file):
self.file = file
"""
encoding = property(lambda self: self.file.encoding)
fileno = property(lambda self: self.file.fileno)
flush = property(lambda self: self.file.flush)
isatty = property(lambda self: self.file.isatty)
newlines = property(lambda self: self.file.newlines)
read = property(lambda self: self.file.read)
readinto = property(lambda self: self.file.readinto)
readline = property(lambda self: self.file.readline)
readlines = property(lambda self: self.file.readlines)
seek = property(lambda self: self.file.seek)
tell = property(lambda self: self.file.tell)
truncate = property(lambda self: self.file.truncate)
write = property(lambda self: self.file.write)
writelines = property(lambda self: self.file.writelines)
@property
def closed(self):
return not self.file or self.file.closed
def readable(self):
if self.closed:
return False
if hasattr(self.file, 'readable'):
return self.file.readable()
return True
def writable(self):
if self.closed:
return False
if hasattr(self.file, 'writable'):
return self.file.writable()
return 'w' in getattr(self.file, 'mode', '')
def seekable(self):
if self.closed:
return False
if hasattr(self.file, 'seekable'):
return self.file.seekable()
return True
def __iter__(self):
return iter(self.file)

View File

@@ -0,0 +1,295 @@
import logging
import sys
import tempfile
import traceback
from asgiref.sync import ThreadSensitiveContext, sync_to_async
from django.conf import settings
from django.core import signals
from django.core.exceptions import RequestAborted, RequestDataTooBig
from django.core.handlers import base
from django.http import (
FileResponse, HttpRequest, HttpResponse, HttpResponseBadRequest,
HttpResponseServerError, QueryDict, parse_cookie,
)
from django.urls import set_script_prefix
from django.utils.functional import cached_property
logger = logging.getLogger('django.request')
class ASGIRequest(HttpRequest):
"""
Custom request subclass that decodes from an ASGI-standard request dict
and wraps request body handling.
"""
# Number of seconds until a Request gives up on trying to read a request
# body and aborts.
body_receive_timeout = 60
def __init__(self, scope, body_file):
self.scope = scope
self._post_parse_error = False
self._read_started = False
self.resolver_match = None
self.script_name = self.scope.get('root_path', '')
if self.script_name and scope['path'].startswith(self.script_name):
# TODO: Better is-prefix checking, slash handling?
self.path_info = scope['path'][len(self.script_name):]
else:
self.path_info = scope['path']
# The Django path is different from ASGI scope path args, it should
# combine with script name.
if self.script_name:
self.path = '%s/%s' % (
self.script_name.rstrip('/'),
self.path_info.replace('/', '', 1),
)
else:
self.path = scope['path']
# HTTP basics.
self.method = self.scope['method'].upper()
# Ensure query string is encoded correctly.
query_string = self.scope.get('query_string', '')
if isinstance(query_string, bytes):
query_string = query_string.decode()
self.META = {
'REQUEST_METHOD': self.method,
'QUERY_STRING': query_string,
'SCRIPT_NAME': self.script_name,
'PATH_INFO': self.path_info,
# WSGI-expecting code will need these for a while
'wsgi.multithread': True,
'wsgi.multiprocess': True,
}
if self.scope.get('client'):
self.META['REMOTE_ADDR'] = self.scope['client'][0]
self.META['REMOTE_HOST'] = self.META['REMOTE_ADDR']
self.META['REMOTE_PORT'] = self.scope['client'][1]
if self.scope.get('server'):
self.META['SERVER_NAME'] = self.scope['server'][0]
self.META['SERVER_PORT'] = str(self.scope['server'][1])
else:
self.META['SERVER_NAME'] = 'unknown'
self.META['SERVER_PORT'] = '0'
# Headers go into META.
for name, value in self.scope.get('headers', []):
name = name.decode('latin1')
if name == 'content-length':
corrected_name = 'CONTENT_LENGTH'
elif name == 'content-type':
corrected_name = 'CONTENT_TYPE'
else:
corrected_name = 'HTTP_%s' % name.upper().replace('-', '_')
# HTTP/2 say only ASCII chars are allowed in headers, but decode
# latin1 just in case.
value = value.decode('latin1')
if corrected_name in self.META:
value = self.META[corrected_name] + ',' + value
self.META[corrected_name] = value
# Pull out request encoding, if provided.
self._set_content_type_params(self.META)
# Directly assign the body file to be our stream.
self._stream = body_file
# Other bits.
self.resolver_match = None
@cached_property
def GET(self):
return QueryDict(self.META['QUERY_STRING'])
def _get_scheme(self):
return self.scope.get('scheme') or super()._get_scheme()
def _get_post(self):
if not hasattr(self, '_post'):
self._load_post_and_files()
return self._post
def _set_post(self, post):
self._post = post
def _get_files(self):
if not hasattr(self, '_files'):
self._load_post_and_files()
return self._files
POST = property(_get_post, _set_post)
FILES = property(_get_files)
@cached_property
def COOKIES(self):
return parse_cookie(self.META.get('HTTP_COOKIE', ''))
class ASGIHandler(base.BaseHandler):
"""Handler for ASGI requests."""
request_class = ASGIRequest
# Size to chunk response bodies into for multiple response messages.
chunk_size = 2 ** 16
def __init__(self):
super().__init__()
self.load_middleware(is_async=True)
async def __call__(self, scope, receive, send):
"""
Async entrypoint - parses the request and hands off to get_response.
"""
# Serve only HTTP connections.
# FIXME: Allow to override this.
if scope['type'] != 'http':
raise ValueError(
'Django can only handle ASGI/HTTP connections, not %s.'
% scope['type']
)
async with ThreadSensitiveContext():
await self.handle(scope, receive, send)
async def handle(self, scope, receive, send):
"""
Handles the ASGI request. Called via the __call__ method.
"""
# Receive the HTTP request body as a stream object.
try:
body_file = await self.read_body(receive)
except RequestAborted:
return
# Request is complete and can be served.
set_script_prefix(self.get_script_prefix(scope))
await sync_to_async(signals.request_started.send, thread_sensitive=True)(sender=self.__class__, scope=scope)
# Get the request and check for basic issues.
request, error_response = self.create_request(scope, body_file)
if request is None:
await self.send_response(error_response, send)
return
# Get the response, using the async mode of BaseHandler.
response = await self.get_response_async(request)
response._handler_class = self.__class__
# Increase chunk size on file responses (ASGI servers handles low-level
# chunking).
if isinstance(response, FileResponse):
response.block_size = self.chunk_size
# Send the response.
await self.send_response(response, send)
async def read_body(self, receive):
"""Reads an HTTP body from an ASGI connection."""
# Use the tempfile that auto rolls-over to a disk file as it fills up.
body_file = tempfile.SpooledTemporaryFile(max_size=settings.FILE_UPLOAD_MAX_MEMORY_SIZE, mode='w+b')
while True:
message = await receive()
if message['type'] == 'http.disconnect':
# Early client disconnect.
raise RequestAborted()
# Add a body chunk from the message, if provided.
if 'body' in message:
body_file.write(message['body'])
# Quit out if that's the end.
if not message.get('more_body', False):
break
body_file.seek(0)
return body_file
def create_request(self, scope, body_file):
"""
Create the Request object and returns either (request, None) or
(None, response) if there is an error response.
"""
try:
return self.request_class(scope, body_file), None
except UnicodeDecodeError:
logger.warning(
'Bad Request (UnicodeDecodeError)',
exc_info=sys.exc_info(),
extra={'status_code': 400},
)
return None, HttpResponseBadRequest()
except RequestDataTooBig:
return None, HttpResponse('413 Payload too large', status=413)
def handle_uncaught_exception(self, request, resolver, exc_info):
"""Last-chance handler for exceptions."""
# There's no WSGI server to catch the exception further up
# if this fails, so translate it into a plain text response.
try:
return super().handle_uncaught_exception(request, resolver, exc_info)
except Exception:
return HttpResponseServerError(
traceback.format_exc() if settings.DEBUG else 'Internal Server Error',
content_type='text/plain',
)
async def send_response(self, response, send):
"""Encode and send a response out over ASGI."""
# Collect cookies into headers. Have to preserve header case as there
# are some non-RFC compliant clients that require e.g. Content-Type.
response_headers = []
for header, value in response.items():
if isinstance(header, str):
header = header.encode('ascii')
if isinstance(value, str):
value = value.encode('latin1')
response_headers.append((bytes(header), bytes(value)))
for c in response.cookies.values():
response_headers.append(
(b'Set-Cookie', c.output(header='').encode('ascii').strip())
)
# Initial response message.
await send({
'type': 'http.response.start',
'status': response.status_code,
'headers': response_headers,
})
# Streaming responses need to be pinned to their iterator.
if response.streaming:
# Access `__iter__` and not `streaming_content` directly in case
# it has been overridden in a subclass.
for part in response:
for chunk, _ in self.chunk_bytes(part):
await send({
'type': 'http.response.body',
'body': chunk,
# Ignore "more" as there may be more parts; instead,
# use an empty final closing message with False.
'more_body': True,
})
# Final closing message.
await send({'type': 'http.response.body'})
# Other responses just need chunking.
else:
# Yield chunks of response.
for chunk, last in self.chunk_bytes(response.content):
await send({
'type': 'http.response.body',
'body': chunk,
'more_body': not last,
})
await sync_to_async(response.close, thread_sensitive=True)()
@classmethod
def chunk_bytes(cls, data):
"""
Chunks some data up so it can be sent in reasonable size messages.
Yields (chunk, last_chunk) tuples.
"""
position = 0
if not data:
yield data, True
return
while position < len(data):
yield (
data[position:position + cls.chunk_size],
(position + cls.chunk_size) >= len(data),
)
position += cls.chunk_size
def get_script_prefix(self, scope):
"""
Return the script prefix to use from either the scope or a setting.
"""
if settings.FORCE_SCRIPT_NAME:
return settings.FORCE_SCRIPT_NAME
return scope.get('root_path', '') or ''

View File

@@ -0,0 +1,350 @@
import asyncio
import logging
import types
from asgiref.sync import async_to_sync, sync_to_async
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured, MiddlewareNotUsed
from django.core.signals import request_finished
from django.db import connections, transaction
from django.urls import get_resolver, set_urlconf
from django.utils.log import log_response
from django.utils.module_loading import import_string
from .exception import convert_exception_to_response
logger = logging.getLogger('django.request')
class BaseHandler:
_view_middleware = None
_template_response_middleware = None
_exception_middleware = None
_middleware_chain = None
def load_middleware(self, is_async=False):
"""
Populate middleware lists from settings.MIDDLEWARE.
Must be called after the environment is fixed (see __call__ in subclasses).
"""
self._view_middleware = []
self._template_response_middleware = []
self._exception_middleware = []
get_response = self._get_response_async if is_async else self._get_response
handler = convert_exception_to_response(get_response)
handler_is_async = is_async
for middleware_path in reversed(settings.MIDDLEWARE):
middleware = import_string(middleware_path)
middleware_can_sync = getattr(middleware, 'sync_capable', True)
middleware_can_async = getattr(middleware, 'async_capable', False)
if not middleware_can_sync and not middleware_can_async:
raise RuntimeError(
'Middleware %s must have at least one of '
'sync_capable/async_capable set to True.' % middleware_path
)
elif not handler_is_async and middleware_can_sync:
middleware_is_async = False
else:
middleware_is_async = middleware_can_async
try:
# Adapt handler, if needed.
adapted_handler = self.adapt_method_mode(
middleware_is_async, handler, handler_is_async,
debug=settings.DEBUG, name='middleware %s' % middleware_path,
)
mw_instance = middleware(adapted_handler)
except MiddlewareNotUsed as exc:
if settings.DEBUG:
if str(exc):
logger.debug('MiddlewareNotUsed(%r): %s', middleware_path, exc)
else:
logger.debug('MiddlewareNotUsed: %r', middleware_path)
continue
else:
handler = adapted_handler
if mw_instance is None:
raise ImproperlyConfigured(
'Middleware factory %s returned None.' % middleware_path
)
if hasattr(mw_instance, 'process_view'):
self._view_middleware.insert(
0,
self.adapt_method_mode(is_async, mw_instance.process_view),
)
if hasattr(mw_instance, 'process_template_response'):
self._template_response_middleware.append(
self.adapt_method_mode(is_async, mw_instance.process_template_response),
)
if hasattr(mw_instance, 'process_exception'):
# The exception-handling stack is still always synchronous for
# now, so adapt that way.
self._exception_middleware.append(
self.adapt_method_mode(False, mw_instance.process_exception),
)
handler = convert_exception_to_response(mw_instance)
handler_is_async = middleware_is_async
# Adapt the top of the stack, if needed.
handler = self.adapt_method_mode(is_async, handler, handler_is_async)
# We only assign to this when initialization is complete as it is used
# as a flag for initialization being complete.
self._middleware_chain = handler
def adapt_method_mode(
self, is_async, method, method_is_async=None, debug=False, name=None,
):
"""
Adapt a method to be in the correct "mode":
- If is_async is False:
- Synchronous methods are left alone
- Asynchronous methods are wrapped with async_to_sync
- If is_async is True:
- Synchronous methods are wrapped with sync_to_async()
- Asynchronous methods are left alone
"""
if method_is_async is None:
method_is_async = asyncio.iscoroutinefunction(method)
if debug and not name:
name = name or 'method %s()' % method.__qualname__
if is_async:
if not method_is_async:
if debug:
logger.debug('Synchronous %s adapted.', name)
return sync_to_async(method, thread_sensitive=True)
elif method_is_async:
if debug:
logger.debug('Asynchronous %s adapted.', name)
return async_to_sync(method)
return method
def get_response(self, request):
"""Return an HttpResponse object for the given HttpRequest."""
# Setup default url resolver for this thread
set_urlconf(settings.ROOT_URLCONF)
response = self._middleware_chain(request)
response._resource_closers.append(request.close)
if response.status_code >= 400:
log_response(
'%s: %s', response.reason_phrase, request.path,
response=response,
request=request,
)
return response
async def get_response_async(self, request):
"""
Asynchronous version of get_response.
Funneling everything, including WSGI, into a single async
get_response() is too slow. Avoid the context switch by using
a separate async response path.
"""
# Setup default url resolver for this thread.
set_urlconf(settings.ROOT_URLCONF)
response = await self._middleware_chain(request)
response._resource_closers.append(request.close)
if response.status_code >= 400:
await sync_to_async(log_response, thread_sensitive=False)(
'%s: %s', response.reason_phrase, request.path,
response=response,
request=request,
)
return response
def _get_response(self, request):
"""
Resolve and call the view, then apply view, exception, and
template_response middleware. This method is everything that happens
inside the request/response middleware.
"""
response = None
callback, callback_args, callback_kwargs = self.resolve_request(request)
# Apply view middleware
for middleware_method in self._view_middleware:
response = middleware_method(request, callback, callback_args, callback_kwargs)
if response:
break
if response is None:
wrapped_callback = self.make_view_atomic(callback)
# If it is an asynchronous view, run it in a subthread.
if asyncio.iscoroutinefunction(wrapped_callback):
wrapped_callback = async_to_sync(wrapped_callback)
try:
response = wrapped_callback(request, *callback_args, **callback_kwargs)
except Exception as e:
response = self.process_exception_by_middleware(e, request)
if response is None:
raise
# Complain if the view returned None (a common error).
self.check_response(response, callback)
# If the response supports deferred rendering, apply template
# response middleware and then render the response
if hasattr(response, 'render') and callable(response.render):
for middleware_method in self._template_response_middleware:
response = middleware_method(request, response)
# Complain if the template response middleware returned None (a common error).
self.check_response(
response,
middleware_method,
name='%s.process_template_response' % (
middleware_method.__self__.__class__.__name__,
)
)
try:
response = response.render()
except Exception as e:
response = self.process_exception_by_middleware(e, request)
if response is None:
raise
return response
async def _get_response_async(self, request):
"""
Resolve and call the view, then apply view, exception, and
template_response middleware. This method is everything that happens
inside the request/response middleware.
"""
response = None
callback, callback_args, callback_kwargs = self.resolve_request(request)
# Apply view middleware.
for middleware_method in self._view_middleware:
response = await middleware_method(request, callback, callback_args, callback_kwargs)
if response:
break
if response is None:
wrapped_callback = self.make_view_atomic(callback)
# If it is a synchronous view, run it in a subthread
if not asyncio.iscoroutinefunction(wrapped_callback):
wrapped_callback = sync_to_async(wrapped_callback, thread_sensitive=True)
try:
response = await wrapped_callback(request, *callback_args, **callback_kwargs)
except Exception as e:
response = await sync_to_async(
self.process_exception_by_middleware,
thread_sensitive=True,
)(e, request)
if response is None:
raise
# Complain if the view returned None or an uncalled coroutine.
self.check_response(response, callback)
# If the response supports deferred rendering, apply template
# response middleware and then render the response
if hasattr(response, 'render') and callable(response.render):
for middleware_method in self._template_response_middleware:
response = await middleware_method(request, response)
# Complain if the template response middleware returned None or
# an uncalled coroutine.
self.check_response(
response,
middleware_method,
name='%s.process_template_response' % (
middleware_method.__self__.__class__.__name__,
)
)
try:
if asyncio.iscoroutinefunction(response.render):
response = await response.render()
else:
response = await sync_to_async(response.render, thread_sensitive=True)()
except Exception as e:
response = await sync_to_async(
self.process_exception_by_middleware,
thread_sensitive=True,
)(e, request)
if response is None:
raise
# Make sure the response is not a coroutine
if asyncio.iscoroutine(response):
raise RuntimeError('Response is still a coroutine.')
return response
def resolve_request(self, request):
"""
Retrieve/set the urlconf for the request. Return the view resolved,
with its args and kwargs.
"""
# Work out the resolver.
if hasattr(request, 'urlconf'):
urlconf = request.urlconf
set_urlconf(urlconf)
resolver = get_resolver(urlconf)
else:
resolver = get_resolver()
# Resolve the view, and assign the match object back to the request.
resolver_match = resolver.resolve(request.path_info)
request.resolver_match = resolver_match
return resolver_match
def check_response(self, response, callback, name=None):
"""
Raise an error if the view returned None or an uncalled coroutine.
"""
if not(response is None or asyncio.iscoroutine(response)):
return
if not name:
if isinstance(callback, types.FunctionType): # FBV
name = 'The view %s.%s' % (callback.__module__, callback.__name__)
else: # CBV
name = 'The view %s.%s.__call__' % (
callback.__module__,
callback.__class__.__name__,
)
if response is None:
raise ValueError(
"%s didn't return an HttpResponse object. It returned None "
"instead." % name
)
elif asyncio.iscoroutine(response):
raise ValueError(
"%s didn't return an HttpResponse object. It returned an "
"unawaited coroutine instead. You may need to add an 'await' "
"into your view." % name
)
# Other utility methods.
def make_view_atomic(self, view):
non_atomic_requests = getattr(view, '_non_atomic_requests', set())
for db in connections.all():
if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests:
if asyncio.iscoroutinefunction(view):
raise RuntimeError(
'You cannot use ATOMIC_REQUESTS with async views.'
)
view = transaction.atomic(using=db.alias)(view)
return view
def process_exception_by_middleware(self, exception, request):
"""
Pass the exception to the exception middleware. If no middleware
return a response for this exception, return None.
"""
for middleware_method in self._exception_middleware:
response = middleware_method(request, exception)
if response:
return response
return None
def reset_urlconf(sender, **kwargs):
"""Reset the URLconf after each request is finished."""
set_urlconf(None)
request_finished.connect(reset_urlconf)

View File

@@ -0,0 +1,149 @@
import asyncio
import logging
import sys
from functools import wraps
from asgiref.sync import sync_to_async
from django.conf import settings
from django.core import signals
from django.core.exceptions import (
BadRequest, PermissionDenied, RequestDataTooBig, SuspiciousOperation,
TooManyFieldsSent,
)
from django.http import Http404
from django.http.multipartparser import MultiPartParserError
from django.urls import get_resolver, get_urlconf
from django.utils.log import log_response
from django.views import debug
def convert_exception_to_response(get_response):
"""
Wrap the given get_response callable in exception-to-response conversion.
All exceptions will be converted. All known 4xx exceptions (Http404,
PermissionDenied, MultiPartParserError, SuspiciousOperation) will be
converted to the appropriate response, and all other exceptions will be
converted to 500 responses.
This decorator is automatically applied to all middleware to ensure that
no middleware leaks an exception and that the next middleware in the stack
can rely on getting a response instead of an exception.
"""
if asyncio.iscoroutinefunction(get_response):
@wraps(get_response)
async def inner(request):
try:
response = await get_response(request)
except Exception as exc:
response = await sync_to_async(response_for_exception, thread_sensitive=False)(request, exc)
return response
return inner
else:
@wraps(get_response)
def inner(request):
try:
response = get_response(request)
except Exception as exc:
response = response_for_exception(request, exc)
return response
return inner
def response_for_exception(request, exc):
if isinstance(exc, Http404):
if settings.DEBUG:
response = debug.technical_404_response(request, exc)
else:
response = get_exception_response(request, get_resolver(get_urlconf()), 404, exc)
elif isinstance(exc, PermissionDenied):
response = get_exception_response(request, get_resolver(get_urlconf()), 403, exc)
log_response(
'Forbidden (Permission denied): %s', request.path,
response=response,
request=request,
exc_info=sys.exc_info(),
)
elif isinstance(exc, MultiPartParserError):
response = get_exception_response(request, get_resolver(get_urlconf()), 400, exc)
log_response(
'Bad request (Unable to parse request body): %s', request.path,
response=response,
request=request,
exc_info=sys.exc_info(),
)
elif isinstance(exc, BadRequest):
if settings.DEBUG:
response = debug.technical_500_response(request, *sys.exc_info(), status_code=400)
else:
response = get_exception_response(request, get_resolver(get_urlconf()), 400, exc)
log_response(
'%s: %s', str(exc), request.path,
response=response,
request=request,
exc_info=sys.exc_info(),
)
elif isinstance(exc, SuspiciousOperation):
if isinstance(exc, (RequestDataTooBig, TooManyFieldsSent)):
# POST data can't be accessed again, otherwise the original
# exception would be raised.
request._mark_post_parse_error()
# The request logger receives events for any problematic request
# The security logger receives events for all SuspiciousOperations
security_logger = logging.getLogger('django.security.%s' % exc.__class__.__name__)
security_logger.error(
str(exc),
extra={'status_code': 400, 'request': request},
)
if settings.DEBUG:
response = debug.technical_500_response(request, *sys.exc_info(), status_code=400)
else:
response = get_exception_response(request, get_resolver(get_urlconf()), 400, exc)
else:
signals.got_request_exception.send(sender=None, request=request)
response = handle_uncaught_exception(request, get_resolver(get_urlconf()), sys.exc_info())
log_response(
'%s: %s', response.reason_phrase, request.path,
response=response,
request=request,
exc_info=sys.exc_info(),
)
# Force a TemplateResponse to be rendered.
if not getattr(response, 'is_rendered', True) and callable(getattr(response, 'render', None)):
response = response.render()
return response
def get_exception_response(request, resolver, status_code, exception):
try:
callback = resolver.resolve_error_handler(status_code)
response = callback(request, exception=exception)
except Exception:
signals.got_request_exception.send(sender=None, request=request)
response = handle_uncaught_exception(request, resolver, sys.exc_info())
return response
def handle_uncaught_exception(request, resolver, exc_info):
"""
Processing for any otherwise uncaught exceptions (those that will
generate HTTP 500 responses).
"""
if settings.DEBUG_PROPAGATE_EXCEPTIONS:
raise
if settings.DEBUG:
return debug.technical_500_response(request, *exc_info)
# Return an HttpResponse that displays a friendly error message.
callback = resolver.resolve_error_handler(500)
return callback(request)

View File

@@ -0,0 +1,210 @@
from io import BytesIO
from django.conf import settings
from django.core import signals
from django.core.handlers import base
from django.http import HttpRequest, QueryDict, parse_cookie
from django.urls import set_script_prefix
from django.utils.encoding import repercent_broken_unicode
from django.utils.functional import cached_property
from django.utils.regex_helper import _lazy_re_compile
_slashes_re = _lazy_re_compile(br'/+')
class LimitedStream:
"""Wrap another stream to disallow reading it past a number of bytes."""
def __init__(self, stream, limit, buf_size=64 * 1024 * 1024):
self.stream = stream
self.remaining = limit
self.buffer = b''
self.buf_size = buf_size
def _read_limited(self, size=None):
if size is None or size > self.remaining:
size = self.remaining
if size == 0:
return b''
result = self.stream.read(size)
self.remaining -= len(result)
return result
def read(self, size=None):
if size is None:
result = self.buffer + self._read_limited()
self.buffer = b''
elif size < len(self.buffer):
result = self.buffer[:size]
self.buffer = self.buffer[size:]
else: # size >= len(self.buffer)
result = self.buffer + self._read_limited(size - len(self.buffer))
self.buffer = b''
return result
def readline(self, size=None):
while b'\n' not in self.buffer and \
(size is None or len(self.buffer) < size):
if size:
# since size is not None here, len(self.buffer) < size
chunk = self._read_limited(size - len(self.buffer))
else:
chunk = self._read_limited()
if not chunk:
break
self.buffer += chunk
sio = BytesIO(self.buffer)
if size:
line = sio.readline(size)
else:
line = sio.readline()
self.buffer = sio.read()
return line
class WSGIRequest(HttpRequest):
def __init__(self, environ):
script_name = get_script_name(environ)
# If PATH_INFO is empty (e.g. accessing the SCRIPT_NAME URL without a
# trailing slash), operate as if '/' was requested.
path_info = get_path_info(environ) or '/'
self.environ = environ
self.path_info = path_info
# be careful to only replace the first slash in the path because of
# http://test/something and http://test//something being different as
# stated in https://www.ietf.org/rfc/rfc2396.txt
self.path = '%s/%s' % (script_name.rstrip('/'),
path_info.replace('/', '', 1))
self.META = environ
self.META['PATH_INFO'] = path_info
self.META['SCRIPT_NAME'] = script_name
self.method = environ['REQUEST_METHOD'].upper()
# Set content_type, content_params, and encoding.
self._set_content_type_params(environ)
try:
content_length = int(environ.get('CONTENT_LENGTH'))
except (ValueError, TypeError):
content_length = 0
self._stream = LimitedStream(self.environ['wsgi.input'], content_length)
self._read_started = False
self.resolver_match = None
def _get_scheme(self):
return self.environ.get('wsgi.url_scheme')
@cached_property
def GET(self):
# The WSGI spec says 'QUERY_STRING' may be absent.
raw_query_string = get_bytes_from_wsgi(self.environ, 'QUERY_STRING', '')
return QueryDict(raw_query_string, encoding=self._encoding)
def _get_post(self):
if not hasattr(self, '_post'):
self._load_post_and_files()
return self._post
def _set_post(self, post):
self._post = post
@cached_property
def COOKIES(self):
raw_cookie = get_str_from_wsgi(self.environ, 'HTTP_COOKIE', '')
return parse_cookie(raw_cookie)
@property
def FILES(self):
if not hasattr(self, '_files'):
self._load_post_and_files()
return self._files
POST = property(_get_post, _set_post)
class WSGIHandler(base.BaseHandler):
request_class = WSGIRequest
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.load_middleware()
def __call__(self, environ, start_response):
set_script_prefix(get_script_name(environ))
signals.request_started.send(sender=self.__class__, environ=environ)
request = self.request_class(environ)
response = self.get_response(request)
response._handler_class = self.__class__
status = '%d %s' % (response.status_code, response.reason_phrase)
response_headers = [
*response.items(),
*(('Set-Cookie', c.output(header='')) for c in response.cookies.values()),
]
start_response(status, response_headers)
if getattr(response, 'file_to_stream', None) is not None and environ.get('wsgi.file_wrapper'):
# If `wsgi.file_wrapper` is used the WSGI server does not call
# .close on the response, but on the file wrapper. Patch it to use
# response.close instead which takes care of closing all files.
response.file_to_stream.close = response.close
response = environ['wsgi.file_wrapper'](response.file_to_stream, response.block_size)
return response
def get_path_info(environ):
"""Return the HTTP request's PATH_INFO as a string."""
path_info = get_bytes_from_wsgi(environ, 'PATH_INFO', '/')
return repercent_broken_unicode(path_info).decode()
def get_script_name(environ):
"""
Return the equivalent of the HTTP request's SCRIPT_NAME environment
variable. If Apache mod_rewrite is used, return what would have been
the script name prior to any rewriting (so it's the script name as seen
from the client's perspective), unless the FORCE_SCRIPT_NAME setting is
set (to anything).
"""
if settings.FORCE_SCRIPT_NAME is not None:
return settings.FORCE_SCRIPT_NAME
# If Apache's mod_rewrite had a whack at the URL, Apache set either
# SCRIPT_URL or REDIRECT_URL to the full resource URL before applying any
# rewrites. Unfortunately not every web server (lighttpd!) passes this
# information through all the time, so FORCE_SCRIPT_NAME, above, is still
# needed.
script_url = get_bytes_from_wsgi(environ, 'SCRIPT_URL', '') or get_bytes_from_wsgi(environ, 'REDIRECT_URL', '')
if script_url:
if b'//' in script_url:
# mod_wsgi squashes multiple successive slashes in PATH_INFO,
# do the same with script_url before manipulating paths (#17133).
script_url = _slashes_re.sub(b'/', script_url)
path_info = get_bytes_from_wsgi(environ, 'PATH_INFO', '')
script_name = script_url[:-len(path_info)] if path_info else script_url
else:
script_name = get_bytes_from_wsgi(environ, 'SCRIPT_NAME', '')
return script_name.decode()
def get_bytes_from_wsgi(environ, key, default):
"""
Get a value from the WSGI environ dictionary as bytes.
key and default should be strings.
"""
value = environ.get(key, default)
# Non-ASCII values in the WSGI environ are arbitrarily decoded with
# ISO-8859-1. This is wrong for Django websites where UTF-8 is the default.
# Re-encode to recover the original bytestring.
return value.encode('iso-8859-1')
def get_str_from_wsgi(environ, key, default):
"""
Get a value from the WSGI environ dictionary as str.
key and default should be str objects.
"""
value = get_bytes_from_wsgi(environ, key, default)
return value.decode(errors='replace')

View File

@@ -0,0 +1,121 @@
"""
Tools for sending email.
"""
from django.conf import settings
# Imported for backwards compatibility and for the sake
# of a cleaner namespace. These symbols used to be in
# django/core/mail.py before the introduction of email
# backends and the subsequent reorganization (See #10355)
from django.core.mail.message import (
DEFAULT_ATTACHMENT_MIME_TYPE, BadHeaderError, EmailMessage,
EmailMultiAlternatives, SafeMIMEMultipart, SafeMIMEText,
forbid_multi_line_headers, make_msgid,
)
from django.core.mail.utils import DNS_NAME, CachedDnsName
from django.utils.module_loading import import_string
__all__ = [
'CachedDnsName', 'DNS_NAME', 'EmailMessage', 'EmailMultiAlternatives',
'SafeMIMEText', 'SafeMIMEMultipart', 'DEFAULT_ATTACHMENT_MIME_TYPE',
'make_msgid', 'BadHeaderError', 'forbid_multi_line_headers',
'get_connection', 'send_mail', 'send_mass_mail', 'mail_admins',
'mail_managers',
]
def get_connection(backend=None, fail_silently=False, **kwds):
"""Load an email backend and return an instance of it.
If backend is None (default), use settings.EMAIL_BACKEND.
Both fail_silently and other keyword arguments are used in the
constructor of the backend.
"""
klass = import_string(backend or settings.EMAIL_BACKEND)
return klass(fail_silently=fail_silently, **kwds)
def send_mail(subject, message, from_email, recipient_list,
fail_silently=False, auth_user=None, auth_password=None,
connection=None, html_message=None):
"""
Easy wrapper for sending a single message to a recipient list. All members
of the recipient list will see the other recipients in the 'To' field.
If from_email is None, use the DEFAULT_FROM_EMAIL setting.
If auth_user is None, use the EMAIL_HOST_USER setting.
If auth_password is None, use the EMAIL_HOST_PASSWORD setting.
Note: The API for this method is frozen. New code wanting to extend the
functionality should use the EmailMessage class directly.
"""
connection = connection or get_connection(
username=auth_user,
password=auth_password,
fail_silently=fail_silently,
)
mail = EmailMultiAlternatives(subject, message, from_email, recipient_list, connection=connection)
if html_message:
mail.attach_alternative(html_message, 'text/html')
return mail.send()
def send_mass_mail(datatuple, fail_silently=False, auth_user=None,
auth_password=None, connection=None):
"""
Given a datatuple of (subject, message, from_email, recipient_list), send
each message to each recipient list. Return the number of emails sent.
If from_email is None, use the DEFAULT_FROM_EMAIL setting.
If auth_user and auth_password are set, use them to log in.
If auth_user is None, use the EMAIL_HOST_USER setting.
If auth_password is None, use the EMAIL_HOST_PASSWORD setting.
Note: The API for this method is frozen. New code wanting to extend the
functionality should use the EmailMessage class directly.
"""
connection = connection or get_connection(
username=auth_user,
password=auth_password,
fail_silently=fail_silently,
)
messages = [
EmailMessage(subject, message, sender, recipient, connection=connection)
for subject, message, sender, recipient in datatuple
]
return connection.send_messages(messages)
def mail_admins(subject, message, fail_silently=False, connection=None,
html_message=None):
"""Send a message to the admins, as defined by the ADMINS setting."""
if not settings.ADMINS:
return
if not all(isinstance(a, (list, tuple)) and len(a) == 2 for a in settings.ADMINS):
raise ValueError('The ADMINS setting must be a list of 2-tuples.')
mail = EmailMultiAlternatives(
'%s%s' % (settings.EMAIL_SUBJECT_PREFIX, subject), message,
settings.SERVER_EMAIL, [a[1] for a in settings.ADMINS],
connection=connection,
)
if html_message:
mail.attach_alternative(html_message, 'text/html')
mail.send(fail_silently=fail_silently)
def mail_managers(subject, message, fail_silently=False, connection=None,
html_message=None):
"""Send a message to the managers, as defined by the MANAGERS setting."""
if not settings.MANAGERS:
return
if not all(isinstance(a, (list, tuple)) and len(a) == 2 for a in settings.MANAGERS):
raise ValueError('The MANAGERS setting must be a list of 2-tuples.')
mail = EmailMultiAlternatives(
'%s%s' % (settings.EMAIL_SUBJECT_PREFIX, subject), message,
settings.SERVER_EMAIL, [a[1] for a in settings.MANAGERS],
connection=connection,
)
if html_message:
mail.attach_alternative(html_message, 'text/html')
mail.send(fail_silently=fail_silently)

View File

@@ -0,0 +1 @@
# Mail backends shipped with Django.

View File

@@ -0,0 +1,59 @@
"""Base email backend class."""
class BaseEmailBackend:
"""
Base class for email backend implementations.
Subclasses must at least overwrite send_messages().
open() and close() can be called indirectly by using a backend object as a
context manager:
with backend as connection:
# do something with connection
pass
"""
def __init__(self, fail_silently=False, **kwargs):
self.fail_silently = fail_silently
def open(self):
"""
Open a network connection.
This method can be overwritten by backend implementations to
open a network connection.
It's up to the backend implementation to track the status of
a network connection if it's needed by the backend.
This method can be called by applications to force a single
network connection to be used when sending mails. See the
send_messages() method of the SMTP backend for a reference
implementation.
The default implementation does nothing.
"""
pass
def close(self):
"""Close a network connection."""
pass
def __enter__(self):
try:
self.open()
except Exception:
self.close()
raise
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def send_messages(self, email_messages):
"""
Send one or more EmailMessage objects and return the number of email
messages sent.
"""
raise NotImplementedError('subclasses of BaseEmailBackend must override send_messages() method')

View File

@@ -0,0 +1,42 @@
"""
Email backend that writes messages to console instead of sending them.
"""
import sys
import threading
from django.core.mail.backends.base import BaseEmailBackend
class EmailBackend(BaseEmailBackend):
def __init__(self, *args, **kwargs):
self.stream = kwargs.pop('stream', sys.stdout)
self._lock = threading.RLock()
super().__init__(*args, **kwargs)
def write_message(self, message):
msg = message.message()
msg_data = msg.as_bytes()
charset = msg.get_charset().get_output_charset() if msg.get_charset() else 'utf-8'
msg_data = msg_data.decode(charset)
self.stream.write('%s\n' % msg_data)
self.stream.write('-' * 79)
self.stream.write('\n')
def send_messages(self, email_messages):
"""Write all messages to the stream in a thread-safe way."""
if not email_messages:
return
msg_count = 0
with self._lock:
try:
stream_created = self.open()
for message in email_messages:
self.write_message(message)
self.stream.flush() # flush after each message
msg_count += 1
if stream_created:
self.close()
except Exception:
if not self.fail_silently:
raise
return msg_count

View File

@@ -0,0 +1,10 @@
"""
Dummy email backend that does nothing.
"""
from django.core.mail.backends.base import BaseEmailBackend
class EmailBackend(BaseEmailBackend):
def send_messages(self, email_messages):
return len(list(email_messages))

View File

@@ -0,0 +1,64 @@
"""Email backend that writes messages to a file."""
import datetime
import os
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.mail.backends.console import (
EmailBackend as ConsoleEmailBackend,
)
class EmailBackend(ConsoleEmailBackend):
def __init__(self, *args, file_path=None, **kwargs):
self._fname = None
if file_path is not None:
self.file_path = file_path
else:
self.file_path = getattr(settings, 'EMAIL_FILE_PATH', None)
self.file_path = os.path.abspath(self.file_path)
try:
os.makedirs(self.file_path, exist_ok=True)
except FileExistsError:
raise ImproperlyConfigured(
'Path for saving email messages exists, but is not a directory: %s' % self.file_path
)
except OSError as err:
raise ImproperlyConfigured(
'Could not create directory for saving email messages: %s (%s)' % (self.file_path, err)
)
# Make sure that self.file_path is writable.
if not os.access(self.file_path, os.W_OK):
raise ImproperlyConfigured('Could not write to directory: %s' % self.file_path)
# Finally, call super().
# Since we're using the console-based backend as a base,
# force the stream to be None, so we don't default to stdout
kwargs['stream'] = None
super().__init__(*args, **kwargs)
def write_message(self, message):
self.stream.write(message.message().as_bytes() + b'\n')
self.stream.write(b'-' * 79)
self.stream.write(b'\n')
def _get_filename(self):
"""Return a unique file name."""
if self._fname is None:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
fname = "%s-%s.log" % (timestamp, abs(id(self)))
self._fname = os.path.join(self.file_path, fname)
return self._fname
def open(self):
if self.stream is None:
self.stream = open(self._get_filename(), 'ab')
return True
return False
def close(self):
try:
if self.stream is not None:
self.stream.close()
finally:
self.stream = None

View File

@@ -0,0 +1,30 @@
"""
Backend for test environment.
"""
from django.core import mail
from django.core.mail.backends.base import BaseEmailBackend
class EmailBackend(BaseEmailBackend):
"""
An email backend for use during test sessions.
The test connection stores email messages in a dummy outbox,
rather than sending them out on the wire.
The dummy outbox is accessible through the outbox instance attribute.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not hasattr(mail, 'outbox'):
mail.outbox = []
def send_messages(self, messages):
"""Redirect messages to the dummy outbox"""
msg_count = 0
for message in messages: # .message() triggers header validation
message.message()
mail.outbox.append(message)
msg_count += 1
return msg_count

View File

@@ -0,0 +1,130 @@
"""SMTP email backend class."""
import smtplib
import ssl
import threading
from django.conf import settings
from django.core.mail.backends.base import BaseEmailBackend
from django.core.mail.message import sanitize_address
from django.core.mail.utils import DNS_NAME
class EmailBackend(BaseEmailBackend):
"""
A wrapper that manages the SMTP network connection.
"""
def __init__(self, host=None, port=None, username=None, password=None,
use_tls=None, fail_silently=False, use_ssl=None, timeout=None,
ssl_keyfile=None, ssl_certfile=None,
**kwargs):
super().__init__(fail_silently=fail_silently)
self.host = host or settings.EMAIL_HOST
self.port = port or settings.EMAIL_PORT
self.username = settings.EMAIL_HOST_USER if username is None else username
self.password = settings.EMAIL_HOST_PASSWORD if password is None else password
self.use_tls = settings.EMAIL_USE_TLS if use_tls is None else use_tls
self.use_ssl = settings.EMAIL_USE_SSL if use_ssl is None else use_ssl
self.timeout = settings.EMAIL_TIMEOUT if timeout is None else timeout
self.ssl_keyfile = settings.EMAIL_SSL_KEYFILE if ssl_keyfile is None else ssl_keyfile
self.ssl_certfile = settings.EMAIL_SSL_CERTFILE if ssl_certfile is None else ssl_certfile
if self.use_ssl and self.use_tls:
raise ValueError(
"EMAIL_USE_TLS/EMAIL_USE_SSL are mutually exclusive, so only set "
"one of those settings to True.")
self.connection = None
self._lock = threading.RLock()
@property
def connection_class(self):
return smtplib.SMTP_SSL if self.use_ssl else smtplib.SMTP
def open(self):
"""
Ensure an open connection to the email server. Return whether or not a
new connection was required (True or False) or None if an exception
passed silently.
"""
if self.connection:
# Nothing to do if the connection is already open.
return False
# If local_hostname is not specified, socket.getfqdn() gets used.
# For performance, we use the cached FQDN for local_hostname.
connection_params = {'local_hostname': DNS_NAME.get_fqdn()}
if self.timeout is not None:
connection_params['timeout'] = self.timeout
if self.use_ssl:
connection_params.update({
'keyfile': self.ssl_keyfile,
'certfile': self.ssl_certfile,
})
try:
self.connection = self.connection_class(self.host, self.port, **connection_params)
# TLS/SSL are mutually exclusive, so only attempt TLS over
# non-secure connections.
if not self.use_ssl and self.use_tls:
self.connection.starttls(keyfile=self.ssl_keyfile, certfile=self.ssl_certfile)
if self.username and self.password:
self.connection.login(self.username, self.password)
return True
except OSError:
if not self.fail_silently:
raise
def close(self):
"""Close the connection to the email server."""
if self.connection is None:
return
try:
try:
self.connection.quit()
except (ssl.SSLError, smtplib.SMTPServerDisconnected):
# This happens when calling quit() on a TLS connection
# sometimes, or when the connection was already disconnected
# by the server.
self.connection.close()
except smtplib.SMTPException:
if self.fail_silently:
return
raise
finally:
self.connection = None
def send_messages(self, email_messages):
"""
Send one or more EmailMessage objects and return the number of email
messages sent.
"""
if not email_messages:
return 0
with self._lock:
new_conn_created = self.open()
if not self.connection or new_conn_created is None:
# We failed silently on open().
# Trying to send would be pointless.
return 0
num_sent = 0
for message in email_messages:
sent = self._send(message)
if sent:
num_sent += 1
if new_conn_created:
self.close()
return num_sent
def _send(self, email_message):
"""A helper method that does the actual sending."""
if not email_message.recipients():
return False
encoding = email_message.encoding or settings.DEFAULT_CHARSET
from_email = sanitize_address(email_message.from_email, encoding)
recipients = [sanitize_address(addr, encoding) for addr in email_message.recipients()]
message = email_message.message()
try:
self.connection.sendmail(from_email, recipients, message.as_bytes(linesep='\r\n'))
except smtplib.SMTPException:
if not self.fail_silently:
raise
return False
return True

View File

@@ -0,0 +1,451 @@
import mimetypes
from email import (
charset as Charset, encoders as Encoders, generator, message_from_string,
)
from email.errors import HeaderParseError
from email.header import Header
from email.headerregistry import Address, parser
from email.message import Message
from email.mime.base import MIMEBase
from email.mime.message import MIMEMessage
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formataddr, formatdate, getaddresses, make_msgid
from io import BytesIO, StringIO
from pathlib import Path
from django.conf import settings
from django.core.mail.utils import DNS_NAME
from django.utils.encoding import force_str, punycode
# Don't BASE64-encode UTF-8 messages so that we avoid unwanted attention from
# some spam filters.
utf8_charset = Charset.Charset('utf-8')
utf8_charset.body_encoding = None # Python defaults to BASE64
utf8_charset_qp = Charset.Charset('utf-8')
utf8_charset_qp.body_encoding = Charset.QP
# Default MIME type to use on attachments (if it is not explicitly given
# and cannot be guessed).
DEFAULT_ATTACHMENT_MIME_TYPE = 'application/octet-stream'
RFC5322_EMAIL_LINE_LENGTH_LIMIT = 998
class BadHeaderError(ValueError):
pass
# Header names that contain structured address data (RFC #5322)
ADDRESS_HEADERS = {
'from',
'sender',
'reply-to',
'to',
'cc',
'bcc',
'resent-from',
'resent-sender',
'resent-to',
'resent-cc',
'resent-bcc',
}
def forbid_multi_line_headers(name, val, encoding):
"""Forbid multi-line headers to prevent header injection."""
encoding = encoding or settings.DEFAULT_CHARSET
val = str(val) # val may be lazy
if '\n' in val or '\r' in val:
raise BadHeaderError("Header values can't contain newlines (got %r for header %r)" % (val, name))
try:
val.encode('ascii')
except UnicodeEncodeError:
if name.lower() in ADDRESS_HEADERS:
val = ', '.join(sanitize_address(addr, encoding) for addr in getaddresses((val,)))
else:
val = Header(val, encoding).encode()
else:
if name.lower() == 'subject':
val = Header(val).encode()
return name, val
def sanitize_address(addr, encoding):
"""
Format a pair of (name, address) or an email address string.
"""
address = None
if not isinstance(addr, tuple):
addr = force_str(addr)
try:
token, rest = parser.get_mailbox(addr)
except (HeaderParseError, ValueError, IndexError):
raise ValueError('Invalid address "%s"' % addr)
else:
if rest:
# The entire email address must be parsed.
raise ValueError(
'Invalid address; only %s could be parsed from "%s"'
% (token, addr)
)
nm = token.display_name or ''
localpart = token.local_part
domain = token.domain or ''
else:
nm, address = addr
localpart, domain = address.rsplit('@', 1)
address_parts = nm + localpart + domain
if '\n' in address_parts or '\r' in address_parts:
raise ValueError('Invalid address; address parts cannot contain newlines.')
# Avoid UTF-8 encode, if it's possible.
try:
nm.encode('ascii')
nm = Header(nm).encode()
except UnicodeEncodeError:
nm = Header(nm, encoding).encode()
try:
localpart.encode('ascii')
except UnicodeEncodeError:
localpart = Header(localpart, encoding).encode()
domain = punycode(domain)
parsed_address = Address(username=localpart, domain=domain)
return formataddr((nm, parsed_address.addr_spec))
class MIMEMixin:
def as_string(self, unixfrom=False, linesep='\n'):
"""Return the entire formatted message as a string.
Optional `unixfrom' when True, means include the Unix From_ envelope
header.
This overrides the default as_string() implementation to not mangle
lines that begin with 'From '. See bug #13433 for details.
"""
fp = StringIO()
g = generator.Generator(fp, mangle_from_=False)
g.flatten(self, unixfrom=unixfrom, linesep=linesep)
return fp.getvalue()
def as_bytes(self, unixfrom=False, linesep='\n'):
"""Return the entire formatted message as bytes.
Optional `unixfrom' when True, means include the Unix From_ envelope
header.
This overrides the default as_bytes() implementation to not mangle
lines that begin with 'From '. See bug #13433 for details.
"""
fp = BytesIO()
g = generator.BytesGenerator(fp, mangle_from_=False)
g.flatten(self, unixfrom=unixfrom, linesep=linesep)
return fp.getvalue()
class SafeMIMEMessage(MIMEMixin, MIMEMessage):
def __setitem__(self, name, val):
# message/rfc822 attachments must be ASCII
name, val = forbid_multi_line_headers(name, val, 'ascii')
MIMEMessage.__setitem__(self, name, val)
class SafeMIMEText(MIMEMixin, MIMEText):
def __init__(self, _text, _subtype='plain', _charset=None):
self.encoding = _charset
MIMEText.__init__(self, _text, _subtype=_subtype, _charset=_charset)
def __setitem__(self, name, val):
name, val = forbid_multi_line_headers(name, val, self.encoding)
MIMEText.__setitem__(self, name, val)
def set_payload(self, payload, charset=None):
if charset == 'utf-8' and not isinstance(charset, Charset.Charset):
has_long_lines = any(
len(line.encode()) > RFC5322_EMAIL_LINE_LENGTH_LIMIT
for line in payload.splitlines()
)
# Quoted-Printable encoding has the side effect of shortening long
# lines, if any (#22561).
charset = utf8_charset_qp if has_long_lines else utf8_charset
MIMEText.set_payload(self, payload, charset=charset)
class SafeMIMEMultipart(MIMEMixin, MIMEMultipart):
def __init__(self, _subtype='mixed', boundary=None, _subparts=None, encoding=None, **_params):
self.encoding = encoding
MIMEMultipart.__init__(self, _subtype, boundary, _subparts, **_params)
def __setitem__(self, name, val):
name, val = forbid_multi_line_headers(name, val, self.encoding)
MIMEMultipart.__setitem__(self, name, val)
class EmailMessage:
"""A container for email information."""
content_subtype = 'plain'
mixed_subtype = 'mixed'
encoding = None # None => use settings default
def __init__(self, subject='', body='', from_email=None, to=None, bcc=None,
connection=None, attachments=None, headers=None, cc=None,
reply_to=None):
"""
Initialize a single email message (which can be sent to multiple
recipients).
"""
if to:
if isinstance(to, str):
raise TypeError('"to" argument must be a list or tuple')
self.to = list(to)
else:
self.to = []
if cc:
if isinstance(cc, str):
raise TypeError('"cc" argument must be a list or tuple')
self.cc = list(cc)
else:
self.cc = []
if bcc:
if isinstance(bcc, str):
raise TypeError('"bcc" argument must be a list or tuple')
self.bcc = list(bcc)
else:
self.bcc = []
if reply_to:
if isinstance(reply_to, str):
raise TypeError('"reply_to" argument must be a list or tuple')
self.reply_to = list(reply_to)
else:
self.reply_to = []
self.from_email = from_email or settings.DEFAULT_FROM_EMAIL
self.subject = subject
self.body = body or ''
self.attachments = []
if attachments:
for attachment in attachments:
if isinstance(attachment, MIMEBase):
self.attach(attachment)
else:
self.attach(*attachment)
self.extra_headers = headers or {}
self.connection = connection
def get_connection(self, fail_silently=False):
from django.core.mail import get_connection
if not self.connection:
self.connection = get_connection(fail_silently=fail_silently)
return self.connection
def message(self):
encoding = self.encoding or settings.DEFAULT_CHARSET
msg = SafeMIMEText(self.body, self.content_subtype, encoding)
msg = self._create_message(msg)
msg['Subject'] = self.subject
msg['From'] = self.extra_headers.get('From', self.from_email)
self._set_list_header_if_not_empty(msg, 'To', self.to)
self._set_list_header_if_not_empty(msg, 'Cc', self.cc)
self._set_list_header_if_not_empty(msg, 'Reply-To', self.reply_to)
# Email header names are case-insensitive (RFC 2045), so we have to
# accommodate that when doing comparisons.
header_names = [key.lower() for key in self.extra_headers]
if 'date' not in header_names:
# formatdate() uses stdlib methods to format the date, which use
# the stdlib/OS concept of a timezone, however, Django sets the
# TZ environment variable based on the TIME_ZONE setting which
# will get picked up by formatdate().
msg['Date'] = formatdate(localtime=settings.EMAIL_USE_LOCALTIME)
if 'message-id' not in header_names:
# Use cached DNS_NAME for performance
msg['Message-ID'] = make_msgid(domain=DNS_NAME)
for name, value in self.extra_headers.items():
if name.lower() != 'from': # From is already handled
msg[name] = value
return msg
def recipients(self):
"""
Return a list of all recipients of the email (includes direct
addressees as well as Cc and Bcc entries).
"""
return [email for email in (self.to + self.cc + self.bcc) if email]
def send(self, fail_silently=False):
"""Send the email message."""
if not self.recipients():
# Don't bother creating the network connection if there's nobody to
# send to.
return 0
return self.get_connection(fail_silently).send_messages([self])
def attach(self, filename=None, content=None, mimetype=None):
"""
Attach a file with the given filename and content. The filename can
be omitted and the mimetype is guessed, if not provided.
If the first parameter is a MIMEBase subclass, insert it directly
into the resulting message attachments.
For a text/* mimetype (guessed or specified), when a bytes object is
specified as content, decode it as UTF-8. If that fails, set the
mimetype to DEFAULT_ATTACHMENT_MIME_TYPE and don't decode the content.
"""
if isinstance(filename, MIMEBase):
if content is not None or mimetype is not None:
raise ValueError(
'content and mimetype must not be given when a MIMEBase '
'instance is provided.'
)
self.attachments.append(filename)
elif content is None:
raise ValueError('content must be provided.')
else:
mimetype = mimetype or mimetypes.guess_type(filename)[0] or DEFAULT_ATTACHMENT_MIME_TYPE
basetype, subtype = mimetype.split('/', 1)
if basetype == 'text':
if isinstance(content, bytes):
try:
content = content.decode()
except UnicodeDecodeError:
# If mimetype suggests the file is text but it's
# actually binary, read() raises a UnicodeDecodeError.
mimetype = DEFAULT_ATTACHMENT_MIME_TYPE
self.attachments.append((filename, content, mimetype))
def attach_file(self, path, mimetype=None):
"""
Attach a file from the filesystem.
Set the mimetype to DEFAULT_ATTACHMENT_MIME_TYPE if it isn't specified
and cannot be guessed.
For a text/* mimetype (guessed or specified), decode the file's content
as UTF-8. If that fails, set the mimetype to
DEFAULT_ATTACHMENT_MIME_TYPE and don't decode the content.
"""
path = Path(path)
with path.open('rb') as file:
content = file.read()
self.attach(path.name, content, mimetype)
def _create_message(self, msg):
return self._create_attachments(msg)
def _create_attachments(self, msg):
if self.attachments:
encoding = self.encoding or settings.DEFAULT_CHARSET
body_msg = msg
msg = SafeMIMEMultipart(_subtype=self.mixed_subtype, encoding=encoding)
if self.body or body_msg.is_multipart():
msg.attach(body_msg)
for attachment in self.attachments:
if isinstance(attachment, MIMEBase):
msg.attach(attachment)
else:
msg.attach(self._create_attachment(*attachment))
return msg
def _create_mime_attachment(self, content, mimetype):
"""
Convert the content, mimetype pair into a MIME attachment object.
If the mimetype is message/rfc822, content may be an
email.Message or EmailMessage object, as well as a str.
"""
basetype, subtype = mimetype.split('/', 1)
if basetype == 'text':
encoding = self.encoding or settings.DEFAULT_CHARSET
attachment = SafeMIMEText(content, subtype, encoding)
elif basetype == 'message' and subtype == 'rfc822':
# Bug #18967: per RFC2046 s5.2.1, message/rfc822 attachments
# must not be base64 encoded.
if isinstance(content, EmailMessage):
# convert content into an email.Message first
content = content.message()
elif not isinstance(content, Message):
# For compatibility with existing code, parse the message
# into an email.Message object if it is not one already.
content = message_from_string(force_str(content))
attachment = SafeMIMEMessage(content, subtype)
else:
# Encode non-text attachments with base64.
attachment = MIMEBase(basetype, subtype)
attachment.set_payload(content)
Encoders.encode_base64(attachment)
return attachment
def _create_attachment(self, filename, content, mimetype=None):
"""
Convert the filename, content, mimetype triple into a MIME attachment
object.
"""
attachment = self._create_mime_attachment(content, mimetype)
if filename:
try:
filename.encode('ascii')
except UnicodeEncodeError:
filename = ('utf-8', '', filename)
attachment.add_header('Content-Disposition', 'attachment', filename=filename)
return attachment
def _set_list_header_if_not_empty(self, msg, header, values):
"""
Set msg's header, either from self.extra_headers, if present, or from
the values argument.
"""
if values:
try:
value = self.extra_headers[header]
except KeyError:
value = ', '.join(str(v) for v in values)
msg[header] = value
class EmailMultiAlternatives(EmailMessage):
"""
A version of EmailMessage that makes it easy to send multipart/alternative
messages. For example, including text and HTML versions of the text is
made easier.
"""
alternative_subtype = 'alternative'
def __init__(self, subject='', body='', from_email=None, to=None, bcc=None,
connection=None, attachments=None, headers=None, alternatives=None,
cc=None, reply_to=None):
"""
Initialize a single email message (which can be sent to multiple
recipients).
"""
super().__init__(
subject, body, from_email, to, bcc, connection, attachments,
headers, cc, reply_to,
)
self.alternatives = alternatives or []
def attach_alternative(self, content, mimetype):
"""Attach an alternative content representation."""
if content is None or mimetype is None:
raise ValueError('Both content and mimetype must be provided.')
self.alternatives.append((content, mimetype))
def _create_message(self, msg):
return self._create_attachments(self._create_alternatives(msg))
def _create_alternatives(self, msg):
encoding = self.encoding or settings.DEFAULT_CHARSET
if self.alternatives:
body_msg = msg
msg = SafeMIMEMultipart(_subtype=self.alternative_subtype, encoding=encoding)
if self.body:
msg.attach(body_msg)
for alternative in self.alternatives:
msg.attach(self._create_mime_attachment(*alternative))
return msg

View File

@@ -0,0 +1,22 @@
"""
Email message and email sending related helper functions.
"""
import socket
from django.utils.encoding import punycode
# Cache the hostname, but do it lazily: socket.getfqdn() can take a couple of
# seconds, which slows down the restart of the server.
class CachedDnsName:
def __str__(self):
return self.get_fqdn()
def get_fqdn(self):
if not hasattr(self, '_fqdn'):
self._fqdn = punycode(socket.getfqdn())
return self._fqdn
DNS_NAME = CachedDnsName()

View File

@@ -0,0 +1,425 @@
import functools
import os
import pkgutil
import sys
from argparse import (
_AppendConstAction, _CountAction, _StoreConstAction, _SubParsersAction,
)
from collections import defaultdict
from difflib import get_close_matches
from importlib import import_module
import django
from django.apps import apps
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.management.base import (
BaseCommand, CommandError, CommandParser, handle_default_options,
)
from django.core.management.color import color_style
from django.utils import autoreload
def find_commands(management_dir):
"""
Given a path to a management directory, return a list of all the command
names that are available.
"""
command_dir = os.path.join(management_dir, 'commands')
return [name for _, name, is_pkg in pkgutil.iter_modules([command_dir])
if not is_pkg and not name.startswith('_')]
def load_command_class(app_name, name):
"""
Given a command name and an application name, return the Command
class instance. Allow all errors raised by the import process
(ImportError, AttributeError) to propagate.
"""
module = import_module('%s.management.commands.%s' % (app_name, name))
return module.Command()
@functools.lru_cache(maxsize=None)
def get_commands():
"""
Return a dictionary mapping command names to their callback applications.
Look for a management.commands package in django.core, and in each
installed application -- if a commands package exists, register all
commands in that package.
Core commands are always included. If a settings module has been
specified, also include user-defined commands.
The dictionary is in the format {command_name: app_name}. Key-value
pairs from this dictionary can then be used in calls to
load_command_class(app_name, command_name)
If a specific version of a command must be loaded (e.g., with the
startapp command), the instantiated module can be placed in the
dictionary in place of the application name.
The dictionary is cached on the first call and reused on subsequent
calls.
"""
commands = {name: 'django.core' for name in find_commands(__path__[0])}
if not settings.configured:
return commands
for app_config in reversed(list(apps.get_app_configs())):
path = os.path.join(app_config.path, 'management')
commands.update({name: app_config.name for name in find_commands(path)})
return commands
def call_command(command_name, *args, **options):
"""
Call the given command, with the given options and args/kwargs.
This is the primary API you should use for calling specific commands.
`command_name` may be a string or a command object. Using a string is
preferred unless the command object is required for further processing or
testing.
Some examples:
call_command('migrate')
call_command('shell', plain=True)
call_command('sqlmigrate', 'myapp')
from django.core.management.commands import flush
cmd = flush.Command()
call_command(cmd, verbosity=0, interactive=False)
# Do something with cmd ...
"""
if isinstance(command_name, BaseCommand):
# Command object passed in.
command = command_name
command_name = command.__class__.__module__.split('.')[-1]
else:
# Load the command object by name.
try:
app_name = get_commands()[command_name]
except KeyError:
raise CommandError("Unknown command: %r" % command_name)
if isinstance(app_name, BaseCommand):
# If the command is already loaded, use it directly.
command = app_name
else:
command = load_command_class(app_name, command_name)
# Simulate argument parsing to get the option defaults (see #10080 for details).
parser = command.create_parser('', command_name)
# Use the `dest` option name from the parser option
opt_mapping = {
min(s_opt.option_strings).lstrip('-').replace('-', '_'): s_opt.dest
for s_opt in parser._actions if s_opt.option_strings
}
arg_options = {opt_mapping.get(key, key): value for key, value in options.items()}
parse_args = []
for arg in args:
if isinstance(arg, (list, tuple)):
parse_args += map(str, arg)
else:
parse_args.append(str(arg))
def get_actions(parser):
# Parser actions and actions from sub-parser choices.
for opt in parser._actions:
if isinstance(opt, _SubParsersAction):
for sub_opt in opt.choices.values():
yield from get_actions(sub_opt)
else:
yield opt
parser_actions = list(get_actions(parser))
mutually_exclusive_required_options = {
opt
for group in parser._mutually_exclusive_groups
for opt in group._group_actions if group.required
}
# Any required arguments which are passed in via **options must be passed
# to parse_args().
for opt in parser_actions:
if (
opt.dest in options and
(opt.required or opt in mutually_exclusive_required_options)
):
opt_dest_count = sum(v == opt.dest for v in opt_mapping.values())
if opt_dest_count > 1:
raise TypeError(
f'Cannot pass the dest {opt.dest!r} that matches multiple '
f'arguments via **options.'
)
parse_args.append(min(opt.option_strings))
if isinstance(opt, (_AppendConstAction, _CountAction, _StoreConstAction)):
continue
value = arg_options[opt.dest]
if isinstance(value, (list, tuple)):
parse_args += map(str, value)
else:
parse_args.append(str(value))
defaults = parser.parse_args(args=parse_args)
defaults = dict(defaults._get_kwargs(), **arg_options)
# Raise an error if any unknown options were passed.
stealth_options = set(command.base_stealth_options + command.stealth_options)
dest_parameters = {action.dest for action in parser_actions}
valid_options = (dest_parameters | stealth_options).union(opt_mapping)
unknown_options = set(options) - valid_options
if unknown_options:
raise TypeError(
"Unknown option(s) for %s command: %s. "
"Valid options are: %s." % (
command_name,
', '.join(sorted(unknown_options)),
', '.join(sorted(valid_options)),
)
)
# Move positional args out of options to mimic legacy optparse
args = defaults.pop('args', ())
if 'skip_checks' not in options:
defaults['skip_checks'] = True
return command.execute(*args, **defaults)
class ManagementUtility:
"""
Encapsulate the logic of the django-admin and manage.py utilities.
"""
def __init__(self, argv=None):
self.argv = argv or sys.argv[:]
self.prog_name = os.path.basename(self.argv[0])
if self.prog_name == '__main__.py':
self.prog_name = 'python -m django'
self.settings_exception = None
def main_help_text(self, commands_only=False):
"""Return the script's main help text, as a string."""
if commands_only:
usage = sorted(get_commands())
else:
usage = [
"",
"Type '%s help <subcommand>' for help on a specific subcommand." % self.prog_name,
"",
"Available subcommands:",
]
commands_dict = defaultdict(lambda: [])
for name, app in get_commands().items():
if app == 'django.core':
app = 'django'
else:
app = app.rpartition('.')[-1]
commands_dict[app].append(name)
style = color_style()
for app in sorted(commands_dict):
usage.append("")
usage.append(style.NOTICE("[%s]" % app))
for name in sorted(commands_dict[app]):
usage.append(" %s" % name)
# Output an extra note if settings are not properly configured
if self.settings_exception is not None:
usage.append(style.NOTICE(
"Note that only Django core commands are listed "
"as settings are not properly configured (error: %s)."
% self.settings_exception))
return '\n'.join(usage)
def fetch_command(self, subcommand):
"""
Try to fetch the given subcommand, printing a message with the
appropriate command called from the command line (usually
"django-admin" or "manage.py") if it can't be found.
"""
# Get commands outside of try block to prevent swallowing exceptions
commands = get_commands()
try:
app_name = commands[subcommand]
except KeyError:
if os.environ.get('DJANGO_SETTINGS_MODULE'):
# If `subcommand` is missing due to misconfigured settings, the
# following line will retrigger an ImproperlyConfigured exception
# (get_commands() swallows the original one) so the user is
# informed about it.
settings.INSTALLED_APPS
elif not settings.configured:
sys.stderr.write("No Django settings specified.\n")
possible_matches = get_close_matches(subcommand, commands)
sys.stderr.write('Unknown command: %r' % subcommand)
if possible_matches:
sys.stderr.write('. Did you mean %s?' % possible_matches[0])
sys.stderr.write("\nType '%s help' for usage.\n" % self.prog_name)
sys.exit(1)
if isinstance(app_name, BaseCommand):
# If the command is already loaded, use it directly.
klass = app_name
else:
klass = load_command_class(app_name, subcommand)
return klass
def autocomplete(self):
"""
Output completion suggestions for BASH.
The output of this function is passed to BASH's `COMREPLY` variable and
treated as completion suggestions. `COMREPLY` expects a space
separated string as the result.
The `COMP_WORDS` and `COMP_CWORD` BASH environment variables are used
to get information about the cli input. Please refer to the BASH
man-page for more information about this variables.
Subcommand options are saved as pairs. A pair consists of
the long option string (e.g. '--exclude') and a boolean
value indicating if the option requires arguments. When printing to
stdout, an equal sign is appended to options which require arguments.
Note: If debugging this function, it is recommended to write the debug
output in a separate file. Otherwise the debug output will be treated
and formatted as potential completion suggestions.
"""
# Don't complete if user hasn't sourced bash_completion file.
if 'DJANGO_AUTO_COMPLETE' not in os.environ:
return
cwords = os.environ['COMP_WORDS'].split()[1:]
cword = int(os.environ['COMP_CWORD'])
try:
curr = cwords[cword - 1]
except IndexError:
curr = ''
subcommands = [*get_commands(), 'help']
options = [('--help', False)]
# subcommand
if cword == 1:
print(' '.join(sorted(filter(lambda x: x.startswith(curr), subcommands))))
# subcommand options
# special case: the 'help' subcommand has no options
elif cwords[0] in subcommands and cwords[0] != 'help':
subcommand_cls = self.fetch_command(cwords[0])
# special case: add the names of installed apps to options
if cwords[0] in ('dumpdata', 'sqlmigrate', 'sqlsequencereset', 'test'):
try:
app_configs = apps.get_app_configs()
# Get the last part of the dotted path as the app name.
options.extend((app_config.label, 0) for app_config in app_configs)
except ImportError:
# Fail silently if DJANGO_SETTINGS_MODULE isn't set. The
# user will find out once they execute the command.
pass
parser = subcommand_cls.create_parser('', cwords[0])
options.extend(
(min(s_opt.option_strings), s_opt.nargs != 0)
for s_opt in parser._actions if s_opt.option_strings
)
# filter out previously specified options from available options
prev_opts = {x.split('=')[0] for x in cwords[1:cword - 1]}
options = (opt for opt in options if opt[0] not in prev_opts)
# filter options by current input
options = sorted((k, v) for k, v in options if k.startswith(curr))
for opt_label, require_arg in options:
# append '=' to options which require args
if require_arg:
opt_label += '='
print(opt_label)
# Exit code of the bash completion function is never passed back to
# the user, so it's safe to always exit with 0.
# For more details see #25420.
sys.exit(0)
def execute(self):
"""
Given the command-line arguments, figure out which subcommand is being
run, create a parser appropriate to that command, and run it.
"""
try:
subcommand = self.argv[1]
except IndexError:
subcommand = 'help' # Display help if no arguments were given.
# Preprocess options to extract --settings and --pythonpath.
# These options could affect the commands that are available, so they
# must be processed early.
parser = CommandParser(
prog=self.prog_name,
usage='%(prog)s subcommand [options] [args]',
add_help=False,
allow_abbrev=False,
)
parser.add_argument('--settings')
parser.add_argument('--pythonpath')
parser.add_argument('args', nargs='*') # catch-all
try:
options, args = parser.parse_known_args(self.argv[2:])
handle_default_options(options)
except CommandError:
pass # Ignore any option errors at this point.
try:
settings.INSTALLED_APPS
except ImproperlyConfigured as exc:
self.settings_exception = exc
except ImportError as exc:
self.settings_exception = exc
if settings.configured:
# Start the auto-reloading dev server even if the code is broken.
# The hardcoded condition is a code smell but we can't rely on a
# flag on the command class because we haven't located it yet.
if subcommand == 'runserver' and '--noreload' not in self.argv:
try:
autoreload.check_errors(django.setup)()
except Exception:
# The exception will be raised later in the child process
# started by the autoreloader. Pretend it didn't happen by
# loading an empty list of applications.
apps.all_models = defaultdict(dict)
apps.app_configs = {}
apps.apps_ready = apps.models_ready = apps.ready = True
# Remove options not compatible with the built-in runserver
# (e.g. options for the contrib.staticfiles' runserver).
# Changes here require manually testing as described in
# #27522.
_parser = self.fetch_command('runserver').create_parser('django', 'runserver')
_options, _args = _parser.parse_known_args(self.argv[2:])
for _arg in _args:
self.argv.remove(_arg)
# In all other cases, django.setup() is required to succeed.
else:
django.setup()
self.autocomplete()
if subcommand == 'help':
if '--commands' in args:
sys.stdout.write(self.main_help_text(commands_only=True) + '\n')
elif not options.args:
sys.stdout.write(self.main_help_text() + '\n')
else:
self.fetch_command(options.args[0]).print_help(self.prog_name, options.args[0])
# Special-cases: We want 'django-admin --version' and
# 'django-admin --help' to work, for backwards compatibility.
elif subcommand == 'version' or self.argv[1:] == ['--version']:
sys.stdout.write(django.get_version() + '\n')
elif self.argv[1:] in (['--help'], ['-h']):
sys.stdout.write(self.main_help_text() + '\n')
else:
self.fetch_command(subcommand).run_from_argv(self.argv)
def execute_from_command_line(argv=None):
"""Run a ManagementUtility."""
utility = ManagementUtility(argv)
utility.execute()

View File

@@ -0,0 +1,600 @@
"""
Base classes for writing management commands (named commands which can
be executed through ``django-admin`` or ``manage.py``).
"""
import argparse
import os
import sys
import warnings
from argparse import ArgumentParser, HelpFormatter
from io import TextIOBase
import django
from django.core import checks
from django.core.exceptions import ImproperlyConfigured
from django.core.management.color import color_style, no_style
from django.db import DEFAULT_DB_ALIAS, connections
from django.utils.deprecation import RemovedInDjango41Warning
ALL_CHECKS = '__all__'
class CommandError(Exception):
"""
Exception class indicating a problem while executing a management
command.
If this exception is raised during the execution of a management
command, it will be caught and turned into a nicely-printed error
message to the appropriate output stream (i.e., stderr); as a
result, raising this exception (with a sensible description of the
error) is the preferred way to indicate that something has gone
wrong in the execution of a command.
"""
def __init__(self, *args, returncode=1, **kwargs):
self.returncode = returncode
super().__init__(*args, **kwargs)
class SystemCheckError(CommandError):
"""
The system check framework detected unrecoverable errors.
"""
pass
class CommandParser(ArgumentParser):
"""
Customized ArgumentParser class to improve some error messages and prevent
SystemExit in several occasions, as SystemExit is unacceptable when a
command is called programmatically.
"""
def __init__(self, *, missing_args_message=None, called_from_command_line=None, **kwargs):
self.missing_args_message = missing_args_message
self.called_from_command_line = called_from_command_line
super().__init__(**kwargs)
def parse_args(self, args=None, namespace=None):
# Catch missing argument for a better error message
if (self.missing_args_message and
not (args or any(not arg.startswith('-') for arg in args))):
self.error(self.missing_args_message)
return super().parse_args(args, namespace)
def error(self, message):
if self.called_from_command_line:
super().error(message)
else:
raise CommandError("Error: %s" % message)
def handle_default_options(options):
"""
Include any default options that all commands should accept here
so that ManagementUtility can handle them before searching for
user commands.
"""
if options.settings:
os.environ['DJANGO_SETTINGS_MODULE'] = options.settings
if options.pythonpath:
sys.path.insert(0, options.pythonpath)
def no_translations(handle_func):
"""Decorator that forces a command to run with translations deactivated."""
def wrapped(*args, **kwargs):
from django.utils import translation
saved_locale = translation.get_language()
translation.deactivate_all()
try:
res = handle_func(*args, **kwargs)
finally:
if saved_locale is not None:
translation.activate(saved_locale)
return res
return wrapped
class DjangoHelpFormatter(HelpFormatter):
"""
Customized formatter so that command-specific arguments appear in the
--help output before arguments common to all commands.
"""
show_last = {
'--version', '--verbosity', '--traceback', '--settings', '--pythonpath',
'--no-color', '--force-color', '--skip-checks',
}
def _reordered_actions(self, actions):
return sorted(
actions,
key=lambda a: set(a.option_strings) & self.show_last != set()
)
def add_usage(self, usage, actions, *args, **kwargs):
super().add_usage(usage, self._reordered_actions(actions), *args, **kwargs)
def add_arguments(self, actions):
super().add_arguments(self._reordered_actions(actions))
class OutputWrapper(TextIOBase):
"""
Wrapper around stdout/stderr
"""
@property
def style_func(self):
return self._style_func
@style_func.setter
def style_func(self, style_func):
if style_func and self.isatty():
self._style_func = style_func
else:
self._style_func = lambda x: x
def __init__(self, out, ending='\n'):
self._out = out
self.style_func = None
self.ending = ending
def __getattr__(self, name):
return getattr(self._out, name)
def flush(self):
if hasattr(self._out, 'flush'):
self._out.flush()
def isatty(self):
return hasattr(self._out, 'isatty') and self._out.isatty()
def write(self, msg='', style_func=None, ending=None):
ending = self.ending if ending is None else ending
if ending and not msg.endswith(ending):
msg += ending
style_func = style_func or self.style_func
self._out.write(style_func(msg))
class BaseCommand:
"""
The base class from which all management commands ultimately
derive.
Use this class if you want access to all of the mechanisms which
parse the command-line arguments and work out what code to call in
response; if you don't need to change any of that behavior,
consider using one of the subclasses defined in this file.
If you are interested in overriding/customizing various aspects of
the command-parsing and -execution behavior, the normal flow works
as follows:
1. ``django-admin`` or ``manage.py`` loads the command class
and calls its ``run_from_argv()`` method.
2. The ``run_from_argv()`` method calls ``create_parser()`` to get
an ``ArgumentParser`` for the arguments, parses them, performs
any environment changes requested by options like
``pythonpath``, and then calls the ``execute()`` method,
passing the parsed arguments.
3. The ``execute()`` method attempts to carry out the command by
calling the ``handle()`` method with the parsed arguments; any
output produced by ``handle()`` will be printed to standard
output and, if the command is intended to produce a block of
SQL statements, will be wrapped in ``BEGIN`` and ``COMMIT``.
4. If ``handle()`` or ``execute()`` raised any exception (e.g.
``CommandError``), ``run_from_argv()`` will instead print an error
message to ``stderr``.
Thus, the ``handle()`` method is typically the starting point for
subclasses; many built-in commands and command types either place
all of their logic in ``handle()``, or perform some additional
parsing work in ``handle()`` and then delegate from it to more
specialized methods as needed.
Several attributes affect behavior at various steps along the way:
``help``
A short description of the command, which will be printed in
help messages.
``output_transaction``
A boolean indicating whether the command outputs SQL
statements; if ``True``, the output will automatically be
wrapped with ``BEGIN;`` and ``COMMIT;``. Default value is
``False``.
``requires_migrations_checks``
A boolean; if ``True``, the command prints a warning if the set of
migrations on disk don't match the migrations in the database.
``requires_system_checks``
A list or tuple of tags, e.g. [Tags.staticfiles, Tags.models]. System
checks registered in the chosen tags will be checked for errors prior
to executing the command. The value '__all__' can be used to specify
that all system checks should be performed. Default value is '__all__'.
To validate an individual application's models
rather than all applications' models, call
``self.check(app_configs)`` from ``handle()``, where ``app_configs``
is the list of application's configuration provided by the
app registry.
``stealth_options``
A tuple of any options the command uses which aren't defined by the
argument parser.
"""
# Metadata about this command.
help = ''
# Configuration shortcuts that alter various logic.
_called_from_command_line = False
output_transaction = False # Whether to wrap the output in a "BEGIN; COMMIT;"
requires_migrations_checks = False
requires_system_checks = '__all__'
# Arguments, common to all commands, which aren't defined by the argument
# parser.
base_stealth_options = ('stderr', 'stdout')
# Command-specific options not defined by the argument parser.
stealth_options = ()
suppressed_base_arguments = set()
def __init__(self, stdout=None, stderr=None, no_color=False, force_color=False):
self.stdout = OutputWrapper(stdout or sys.stdout)
self.stderr = OutputWrapper(stderr or sys.stderr)
if no_color and force_color:
raise CommandError("'no_color' and 'force_color' can't be used together.")
if no_color:
self.style = no_style()
else:
self.style = color_style(force_color)
self.stderr.style_func = self.style.ERROR
if self.requires_system_checks in [False, True]:
warnings.warn(
"Using a boolean value for requires_system_checks is "
"deprecated. Use '__all__' instead of True, and [] (an empty "
"list) instead of False.",
RemovedInDjango41Warning,
)
self.requires_system_checks = ALL_CHECKS if self.requires_system_checks else []
if (
not isinstance(self.requires_system_checks, (list, tuple)) and
self.requires_system_checks != ALL_CHECKS
):
raise TypeError('requires_system_checks must be a list or tuple.')
def get_version(self):
"""
Return the Django version, which should be correct for all built-in
Django commands. User-supplied commands can override this method to
return their own version.
"""
return django.get_version()
def create_parser(self, prog_name, subcommand, **kwargs):
"""
Create and return the ``ArgumentParser`` which will be used to
parse the arguments to this command.
"""
parser = CommandParser(
prog='%s %s' % (os.path.basename(prog_name), subcommand),
description=self.help or None,
formatter_class=DjangoHelpFormatter,
missing_args_message=getattr(self, 'missing_args_message', None),
called_from_command_line=getattr(self, '_called_from_command_line', None),
**kwargs
)
self.add_base_argument(
parser, '--version', action='version', version=self.get_version(),
help="Show program's version number and exit.",
)
self.add_base_argument(
parser, '-v', '--verbosity', default=1,
type=int, choices=[0, 1, 2, 3],
help='Verbosity level; 0=minimal output, 1=normal output, 2=verbose output, 3=very verbose output',
)
self.add_base_argument(
parser, '--settings',
help=(
'The Python path to a settings module, e.g. '
'"myproject.settings.main". If this isn\'t provided, the '
'DJANGO_SETTINGS_MODULE environment variable will be used.'
),
)
self.add_base_argument(
parser, '--pythonpath',
help='A directory to add to the Python path, e.g. "/home/djangoprojects/myproject".',
)
self.add_base_argument(
parser, '--traceback', action='store_true',
help='Raise on CommandError exceptions.',
)
self.add_base_argument(
parser, '--no-color', action='store_true',
help="Don't colorize the command output.",
)
self.add_base_argument(
parser, '--force-color', action='store_true',
help='Force colorization of the command output.',
)
if self.requires_system_checks:
parser.add_argument(
'--skip-checks', action='store_true',
help='Skip system checks.',
)
self.add_arguments(parser)
return parser
def add_arguments(self, parser):
"""
Entry point for subclassed commands to add custom arguments.
"""
pass
def add_base_argument(self, parser, *args, **kwargs):
"""
Call the parser's add_argument() method, suppressing the help text
according to BaseCommand.suppressed_base_arguments.
"""
for arg in args:
if arg in self.suppressed_base_arguments:
kwargs['help'] = argparse.SUPPRESS
break
parser.add_argument(*args, **kwargs)
def print_help(self, prog_name, subcommand):
"""
Print the help message for this command, derived from
``self.usage()``.
"""
parser = self.create_parser(prog_name, subcommand)
parser.print_help()
def run_from_argv(self, argv):
"""
Set up any environment changes requested (e.g., Python path
and Django settings), then run this command. If the
command raises a ``CommandError``, intercept it and print it sensibly
to stderr. If the ``--traceback`` option is present or the raised
``Exception`` is not ``CommandError``, raise it.
"""
self._called_from_command_line = True
parser = self.create_parser(argv[0], argv[1])
options = parser.parse_args(argv[2:])
cmd_options = vars(options)
# Move positional args out of options to mimic legacy optparse
args = cmd_options.pop('args', ())
handle_default_options(options)
try:
self.execute(*args, **cmd_options)
except CommandError as e:
if options.traceback:
raise
# SystemCheckError takes care of its own formatting.
if isinstance(e, SystemCheckError):
self.stderr.write(str(e), lambda x: x)
else:
self.stderr.write('%s: %s' % (e.__class__.__name__, e))
sys.exit(e.returncode)
finally:
try:
connections.close_all()
except ImproperlyConfigured:
# Ignore if connections aren't setup at this point (e.g. no
# configured settings).
pass
def execute(self, *args, **options):
"""
Try to execute this command, performing system checks if needed (as
controlled by the ``requires_system_checks`` attribute, except if
force-skipped).
"""
if options['force_color'] and options['no_color']:
raise CommandError("The --no-color and --force-color options can't be used together.")
if options['force_color']:
self.style = color_style(force_color=True)
elif options['no_color']:
self.style = no_style()
self.stderr.style_func = None
if options.get('stdout'):
self.stdout = OutputWrapper(options['stdout'])
if options.get('stderr'):
self.stderr = OutputWrapper(options['stderr'])
if self.requires_system_checks and not options['skip_checks']:
if self.requires_system_checks == ALL_CHECKS:
self.check()
else:
self.check(tags=self.requires_system_checks)
if self.requires_migrations_checks:
self.check_migrations()
output = self.handle(*args, **options)
if output:
if self.output_transaction:
connection = connections[options.get('database', DEFAULT_DB_ALIAS)]
output = '%s\n%s\n%s' % (
self.style.SQL_KEYWORD(connection.ops.start_transaction_sql()),
output,
self.style.SQL_KEYWORD(connection.ops.end_transaction_sql()),
)
self.stdout.write(output)
return output
def check(self, app_configs=None, tags=None, display_num_errors=False,
include_deployment_checks=False, fail_level=checks.ERROR,
databases=None):
"""
Use the system check framework to validate entire Django project.
Raise CommandError for any serious message (error or critical errors).
If there are only light messages (like warnings), print them to stderr
and don't raise an exception.
"""
all_issues = checks.run_checks(
app_configs=app_configs,
tags=tags,
include_deployment_checks=include_deployment_checks,
databases=databases,
)
header, body, footer = "", "", ""
visible_issue_count = 0 # excludes silenced warnings
if all_issues:
debugs = [e for e in all_issues if e.level < checks.INFO and not e.is_silenced()]
infos = [e for e in all_issues if checks.INFO <= e.level < checks.WARNING and not e.is_silenced()]
warnings = [e for e in all_issues if checks.WARNING <= e.level < checks.ERROR and not e.is_silenced()]
errors = [e for e in all_issues if checks.ERROR <= e.level < checks.CRITICAL and not e.is_silenced()]
criticals = [e for e in all_issues if checks.CRITICAL <= e.level and not e.is_silenced()]
sorted_issues = [
(criticals, 'CRITICALS'),
(errors, 'ERRORS'),
(warnings, 'WARNINGS'),
(infos, 'INFOS'),
(debugs, 'DEBUGS'),
]
for issues, group_name in sorted_issues:
if issues:
visible_issue_count += len(issues)
formatted = (
self.style.ERROR(str(e))
if e.is_serious()
else self.style.WARNING(str(e))
for e in issues)
formatted = "\n".join(sorted(formatted))
body += '\n%s:\n%s\n' % (group_name, formatted)
if visible_issue_count:
header = "System check identified some issues:\n"
if display_num_errors:
if visible_issue_count:
footer += '\n'
footer += "System check identified %s (%s silenced)." % (
"no issues" if visible_issue_count == 0 else
"1 issue" if visible_issue_count == 1 else
"%s issues" % visible_issue_count,
len(all_issues) - visible_issue_count,
)
if any(e.is_serious(fail_level) and not e.is_silenced() for e in all_issues):
msg = self.style.ERROR("SystemCheckError: %s" % header) + body + footer
raise SystemCheckError(msg)
else:
msg = header + body + footer
if msg:
if visible_issue_count:
self.stderr.write(msg, lambda x: x)
else:
self.stdout.write(msg)
def check_migrations(self):
"""
Print a warning if the set of migrations on disk don't match the
migrations in the database.
"""
from django.db.migrations.executor import MigrationExecutor
try:
executor = MigrationExecutor(connections[DEFAULT_DB_ALIAS])
except ImproperlyConfigured:
# No databases are configured (or the dummy one)
return
plan = executor.migration_plan(executor.loader.graph.leaf_nodes())
if plan:
apps_waiting_migration = sorted({migration.app_label for migration, backwards in plan})
self.stdout.write(
self.style.NOTICE(
"\nYou have %(unapplied_migration_count)s unapplied migration(s). "
"Your project may not work properly until you apply the "
"migrations for app(s): %(apps_waiting_migration)s." % {
"unapplied_migration_count": len(plan),
"apps_waiting_migration": ", ".join(apps_waiting_migration),
}
)
)
self.stdout.write(self.style.NOTICE("Run 'python manage.py migrate' to apply them."))
def handle(self, *args, **options):
"""
The actual logic of the command. Subclasses must implement
this method.
"""
raise NotImplementedError('subclasses of BaseCommand must provide a handle() method')
class AppCommand(BaseCommand):
"""
A management command which takes one or more installed application labels
as arguments, and does something with each of them.
Rather than implementing ``handle()``, subclasses must implement
``handle_app_config()``, which will be called once for each application.
"""
missing_args_message = "Enter at least one application label."
def add_arguments(self, parser):
parser.add_argument('args', metavar='app_label', nargs='+', help='One or more application label.')
def handle(self, *app_labels, **options):
from django.apps import apps
try:
app_configs = [apps.get_app_config(app_label) for app_label in app_labels]
except (LookupError, ImportError) as e:
raise CommandError("%s. Are you sure your INSTALLED_APPS setting is correct?" % e)
output = []
for app_config in app_configs:
app_output = self.handle_app_config(app_config, **options)
if app_output:
output.append(app_output)
return '\n'.join(output)
def handle_app_config(self, app_config, **options):
"""
Perform the command's actions for app_config, an AppConfig instance
corresponding to an application label given on the command line.
"""
raise NotImplementedError(
"Subclasses of AppCommand must provide"
"a handle_app_config() method.")
class LabelCommand(BaseCommand):
"""
A management command which takes one or more arbitrary arguments
(labels) on the command line, and does something with each of
them.
Rather than implementing ``handle()``, subclasses must implement
``handle_label()``, which will be called once for each label.
If the arguments should be names of installed applications, use
``AppCommand`` instead.
"""
label = 'label'
missing_args_message = "Enter at least one %s." % label
def add_arguments(self, parser):
parser.add_argument('args', metavar=self.label, nargs='+')
def handle(self, *labels, **options):
output = []
for label in labels:
label_output = self.handle_label(label, **options)
if label_output:
output.append(label_output)
return '\n'.join(output)
def handle_label(self, label, **options):
"""
Perform the command's actions for ``label``, which will be the
string as given on the command line.
"""
raise NotImplementedError('subclasses of LabelCommand must provide a handle_label() method')

View File

@@ -0,0 +1,107 @@
"""
Sets up the terminal color scheme.
"""
import functools
import os
import sys
from django.utils import termcolors
try:
import colorama
colorama.init()
except (ImportError, OSError):
HAS_COLORAMA = False
else:
HAS_COLORAMA = True
def supports_color():
"""
Return True if the running system's terminal supports color,
and False otherwise.
"""
def vt_codes_enabled_in_windows_registry():
"""
Check the Windows Registry to see if VT code handling has been enabled
by default, see https://superuser.com/a/1300251/447564.
"""
try:
# winreg is only available on Windows.
import winreg
except ImportError:
return False
else:
reg_key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, 'Console')
try:
reg_key_value, _ = winreg.QueryValueEx(reg_key, 'VirtualTerminalLevel')
except FileNotFoundError:
return False
else:
return reg_key_value == 1
# isatty is not always implemented, #6223.
is_a_tty = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
return is_a_tty and (
sys.platform != 'win32' or
HAS_COLORAMA or
'ANSICON' in os.environ or
# Windows Terminal supports VT codes.
'WT_SESSION' in os.environ or
# Microsoft Visual Studio Code's built-in terminal supports colors.
os.environ.get('TERM_PROGRAM') == 'vscode' or
vt_codes_enabled_in_windows_registry()
)
class Style:
pass
def make_style(config_string=''):
"""
Create a Style object from the given config_string.
If config_string is empty django.utils.termcolors.DEFAULT_PALETTE is used.
"""
style = Style()
color_settings = termcolors.parse_color_setting(config_string)
# The nocolor palette has all available roles.
# Use that palette as the basis for populating
# the palette as defined in the environment.
for role in termcolors.PALETTES[termcolors.NOCOLOR_PALETTE]:
if color_settings:
format = color_settings.get(role, {})
style_func = termcolors.make_style(**format)
else:
def style_func(x):
return x
setattr(style, role, style_func)
# For backwards compatibility,
# set style for ERROR_OUTPUT == ERROR
style.ERROR_OUTPUT = style.ERROR
return style
@functools.lru_cache(maxsize=None)
def no_style():
"""
Return a Style object with no color scheme.
"""
return make_style('nocolor')
def color_style(force_color=False):
"""
Return a Style object from the Django color scheme.
"""
if not force_color and not supports_color():
return no_style()
return make_style(os.environ.get('DJANGO_COLORS', ''))

View File

@@ -0,0 +1,70 @@
from django.apps import apps
from django.core import checks
from django.core.checks.registry import registry
from django.core.management.base import BaseCommand, CommandError
class Command(BaseCommand):
help = "Checks the entire Django project for potential problems."
requires_system_checks = []
def add_arguments(self, parser):
parser.add_argument('args', metavar='app_label', nargs='*')
parser.add_argument(
'--tag', '-t', action='append', dest='tags',
help='Run only checks labeled with given tag.',
)
parser.add_argument(
'--list-tags', action='store_true',
help='List available tags.',
)
parser.add_argument(
'--deploy', action='store_true',
help='Check deployment settings.',
)
parser.add_argument(
'--fail-level',
default='ERROR',
choices=['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'],
help=(
'Message level that will cause the command to exit with a '
'non-zero status. Default is ERROR.'
),
)
parser.add_argument(
'--database', action='append', dest='databases',
help='Run database related checks against these aliases.',
)
def handle(self, *app_labels, **options):
include_deployment_checks = options['deploy']
if options['list_tags']:
self.stdout.write('\n'.join(sorted(registry.tags_available(include_deployment_checks))))
return
if app_labels:
app_configs = [apps.get_app_config(app_label) for app_label in app_labels]
else:
app_configs = None
tags = options['tags']
if tags:
try:
invalid_tag = next(
tag for tag in tags if not checks.tag_exists(tag, include_deployment_checks)
)
except StopIteration:
# no invalid tags
pass
else:
raise CommandError('There is no system check with the "%s" tag.' % invalid_tag)
self.check(
app_configs=app_configs,
tags=tags,
display_num_errors=True,
include_deployment_checks=include_deployment_checks,
fail_level=getattr(checks, options['fail_level']),
databases=options['databases'],
)

View File

@@ -0,0 +1,168 @@
import codecs
import concurrent.futures
import glob
import os
from pathlib import Path
from django.core.management.base import BaseCommand, CommandError
from django.core.management.utils import (
find_command, is_ignored_path, popen_wrapper,
)
def has_bom(fn):
with fn.open('rb') as f:
sample = f.read(4)
return sample.startswith((codecs.BOM_UTF8, codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE))
def is_writable(path):
# Known side effect: updating file access/modified time to current time if
# it is writable.
try:
with open(path, 'a'):
os.utime(path, None)
except OSError:
return False
return True
class Command(BaseCommand):
help = 'Compiles .po files to .mo files for use with builtin gettext support.'
requires_system_checks = []
program = 'msgfmt'
program_options = ['--check-format']
def add_arguments(self, parser):
parser.add_argument(
'--locale', '-l', action='append', default=[],
help='Locale(s) to process (e.g. de_AT). Default is to process all. '
'Can be used multiple times.',
)
parser.add_argument(
'--exclude', '-x', action='append', default=[],
help='Locales to exclude. Default is none. Can be used multiple times.',
)
parser.add_argument(
'--use-fuzzy', '-f', dest='fuzzy', action='store_true',
help='Use fuzzy translations.',
)
parser.add_argument(
'--ignore', '-i', action='append', dest='ignore_patterns',
default=[], metavar='PATTERN',
help='Ignore directories matching this glob-style pattern. '
'Use multiple times to ignore more.',
)
def handle(self, **options):
locale = options['locale']
exclude = options['exclude']
ignore_patterns = set(options['ignore_patterns'])
self.verbosity = options['verbosity']
if options['fuzzy']:
self.program_options = self.program_options + ['-f']
if find_command(self.program) is None:
raise CommandError("Can't find %s. Make sure you have GNU gettext "
"tools 0.15 or newer installed." % self.program)
basedirs = [os.path.join('conf', 'locale'), 'locale']
if os.environ.get('DJANGO_SETTINGS_MODULE'):
from django.conf import settings
basedirs.extend(settings.LOCALE_PATHS)
# Walk entire tree, looking for locale directories
for dirpath, dirnames, filenames in os.walk('.', topdown=True):
for dirname in dirnames:
if is_ignored_path(os.path.normpath(os.path.join(dirpath, dirname)), ignore_patterns):
dirnames.remove(dirname)
elif dirname == 'locale':
basedirs.append(os.path.join(dirpath, dirname))
# Gather existing directories.
basedirs = set(map(os.path.abspath, filter(os.path.isdir, basedirs)))
if not basedirs:
raise CommandError("This script should be run from the Django Git "
"checkout or your project or app tree, or with "
"the settings module specified.")
# Build locale list
all_locales = []
for basedir in basedirs:
locale_dirs = filter(os.path.isdir, glob.glob('%s/*' % basedir))
all_locales.extend(map(os.path.basename, locale_dirs))
# Account for excluded locales
locales = locale or all_locales
locales = set(locales).difference(exclude)
self.has_errors = False
for basedir in basedirs:
if locales:
dirs = [os.path.join(basedir, locale, 'LC_MESSAGES') for locale in locales]
else:
dirs = [basedir]
locations = []
for ldir in dirs:
for dirpath, dirnames, filenames in os.walk(ldir):
locations.extend((dirpath, f) for f in filenames if f.endswith('.po'))
if locations:
self.compile_messages(locations)
if self.has_errors:
raise CommandError('compilemessages generated one or more errors.')
def compile_messages(self, locations):
"""
Locations is a list of tuples: [(directory, file), ...]
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for i, (dirpath, f) in enumerate(locations):
po_path = Path(dirpath) / f
mo_path = po_path.with_suffix('.mo')
try:
if mo_path.stat().st_mtime >= po_path.stat().st_mtime:
if self.verbosity > 0:
self.stdout.write(
'File “%s” is already compiled and up to date.'
% po_path
)
continue
except FileNotFoundError:
pass
if self.verbosity > 0:
self.stdout.write('processing file %s in %s' % (f, dirpath))
if has_bom(po_path):
self.stderr.write(
'The %s file has a BOM (Byte Order Mark). Django only '
'supports .po files encoded in UTF-8 and without any BOM.' % po_path
)
self.has_errors = True
continue
# Check writability on first location
if i == 0 and not is_writable(mo_path):
self.stderr.write(
'The po files under %s are in a seemingly not writable location. '
'mo files will not be updated/created.' % dirpath
)
self.has_errors = True
return
args = [self.program, *self.program_options, '-o', mo_path, po_path]
futures.append(executor.submit(popen_wrapper, args))
for future in concurrent.futures.as_completed(futures):
output, errors, status = future.result()
if status:
if self.verbosity > 0:
if errors:
self.stderr.write("Execution of %s failed: %s" % (self.program, errors))
else:
self.stderr.write("Execution of %s failed" % self.program)
self.has_errors = True

View File

@@ -0,0 +1,107 @@
from django.conf import settings
from django.core.cache import caches
from django.core.cache.backends.db import BaseDatabaseCache
from django.core.management.base import BaseCommand, CommandError
from django.db import (
DEFAULT_DB_ALIAS, DatabaseError, connections, models, router, transaction,
)
class Command(BaseCommand):
help = "Creates the tables needed to use the SQL cache backend."
requires_system_checks = []
def add_arguments(self, parser):
parser.add_argument(
'args', metavar='table_name', nargs='*',
help='Optional table names. Otherwise, settings.CACHES is used to find cache tables.',
)
parser.add_argument(
'--database',
default=DEFAULT_DB_ALIAS,
help='Nominates a database onto which the cache tables will be '
'installed. Defaults to the "default" database.',
)
parser.add_argument(
'--dry-run', action='store_true',
help='Does not create the table, just prints the SQL that would be run.',
)
def handle(self, *tablenames, **options):
db = options['database']
self.verbosity = options['verbosity']
dry_run = options['dry_run']
if tablenames:
# Legacy behavior, tablename specified as argument
for tablename in tablenames:
self.create_table(db, tablename, dry_run)
else:
for cache_alias in settings.CACHES:
cache = caches[cache_alias]
if isinstance(cache, BaseDatabaseCache):
self.create_table(db, cache._table, dry_run)
def create_table(self, database, tablename, dry_run):
cache = BaseDatabaseCache(tablename, {})
if not router.allow_migrate_model(database, cache.cache_model_class):
return
connection = connections[database]
if tablename in connection.introspection.table_names():
if self.verbosity > 0:
self.stdout.write("Cache table '%s' already exists." % tablename)
return
fields = (
# "key" is a reserved word in MySQL, so use "cache_key" instead.
models.CharField(name='cache_key', max_length=255, unique=True, primary_key=True),
models.TextField(name='value'),
models.DateTimeField(name='expires', db_index=True),
)
table_output = []
index_output = []
qn = connection.ops.quote_name
for f in fields:
field_output = [
qn(f.name),
f.db_type(connection=connection),
'%sNULL' % ('NOT ' if not f.null else ''),
]
if f.primary_key:
field_output.append("PRIMARY KEY")
elif f.unique:
field_output.append("UNIQUE")
if f.db_index:
unique = "UNIQUE " if f.unique else ""
index_output.append(
"CREATE %sINDEX %s ON %s (%s);" %
(unique, qn('%s_%s' % (tablename, f.name)), qn(tablename), qn(f.name))
)
table_output.append(" ".join(field_output))
full_statement = ["CREATE TABLE %s (" % qn(tablename)]
for i, line in enumerate(table_output):
full_statement.append(' %s%s' % (line, ',' if i < len(table_output) - 1 else ''))
full_statement.append(');')
full_statement = "\n".join(full_statement)
if dry_run:
self.stdout.write(full_statement)
for statement in index_output:
self.stdout.write(statement)
return
with transaction.atomic(using=database, savepoint=connection.features.can_rollback_ddl):
with connection.cursor() as curs:
try:
curs.execute(full_statement)
except DatabaseError as e:
raise CommandError(
"Cache table '%s' could not be created.\nThe error was: %s." %
(tablename, e))
for statement in index_output:
curs.execute(statement)
if self.verbosity > 1:
self.stdout.write("Cache table '%s' created." % tablename)

View File

@@ -0,0 +1,43 @@
import subprocess
from django.core.management.base import BaseCommand, CommandError
from django.db import DEFAULT_DB_ALIAS, connections
class Command(BaseCommand):
help = (
"Runs the command-line client for specified database, or the "
"default database if none is provided."
)
requires_system_checks = []
def add_arguments(self, parser):
parser.add_argument(
'--database', default=DEFAULT_DB_ALIAS,
help='Nominates a database onto which to open a shell. Defaults to the "default" database.',
)
parameters = parser.add_argument_group('parameters', prefix_chars='--')
parameters.add_argument('parameters', nargs='*')
def handle(self, **options):
connection = connections[options['database']]
try:
connection.client.runshell(options['parameters'])
except FileNotFoundError:
# Note that we're assuming the FileNotFoundError relates to the
# command missing. It could be raised for some other reason, in
# which case this error message would be inaccurate. Still, this
# message catches the common case.
raise CommandError(
'You appear not to have the %r program installed or on your path.' %
connection.client.executable_name
)
except subprocess.CalledProcessError as e:
raise CommandError(
'"%s" returned non-zero exit status %s.' % (
' '.join(e.cmd),
e.returncode,
),
returncode=e.returncode,
)

View File

@@ -0,0 +1,79 @@
from django.core.management.base import BaseCommand
def module_to_dict(module, omittable=lambda k: k.startswith('_') or not k.isupper()):
"""Convert a module namespace to a Python dictionary."""
return {k: repr(getattr(module, k)) for k in dir(module) if not omittable(k)}
class Command(BaseCommand):
help = """Displays differences between the current settings.py and Django's
default settings."""
requires_system_checks = []
def add_arguments(self, parser):
parser.add_argument(
'--all', action='store_true',
help=(
'Display all settings, regardless of their value. In "hash" '
'mode, default values are prefixed by "###".'
),
)
parser.add_argument(
'--default', metavar='MODULE',
help=(
"The settings module to compare the current settings against. Leave empty to "
"compare against Django's default settings."
),
)
parser.add_argument(
'--output', default='hash', choices=('hash', 'unified'),
help=(
"Selects the output format. 'hash' mode displays each changed "
"setting, with the settings that don't appear in the defaults "
"followed by ###. 'unified' mode prefixes the default setting "
"with a minus sign, followed by the changed setting prefixed "
"with a plus sign."
),
)
def handle(self, **options):
from django.conf import Settings, global_settings, settings
# Because settings are imported lazily, we need to explicitly load them.
if not settings.configured:
settings._setup()
user_settings = module_to_dict(settings._wrapped)
default = options['default']
default_settings = module_to_dict(Settings(default) if default else global_settings)
output_func = {
'hash': self.output_hash,
'unified': self.output_unified,
}[options['output']]
return '\n'.join(output_func(user_settings, default_settings, **options))
def output_hash(self, user_settings, default_settings, **options):
# Inspired by Postfix's "postconf -n".
output = []
for key in sorted(user_settings):
if key not in default_settings:
output.append("%s = %s ###" % (key, user_settings[key]))
elif user_settings[key] != default_settings[key]:
output.append("%s = %s" % (key, user_settings[key]))
elif options['all']:
output.append("### %s = %s" % (key, user_settings[key]))
return output
def output_unified(self, user_settings, default_settings, **options):
output = []
for key in sorted(user_settings):
if key not in default_settings:
output.append(self.style.SUCCESS("+ %s = %s" % (key, user_settings[key])))
elif user_settings[key] != default_settings[key]:
output.append(self.style.ERROR("- %s = %s" % (key, default_settings[key])))
output.append(self.style.SUCCESS("+ %s = %s" % (key, user_settings[key])))
elif options['all']:
output.append(" %s = %s" % (key, user_settings[key]))
return output

View File

@@ -0,0 +1,245 @@
import gzip
import os
import warnings
from django.apps import apps
from django.core import serializers
from django.core.management.base import BaseCommand, CommandError
from django.core.management.utils import parse_apps_and_model_labels
from django.db import DEFAULT_DB_ALIAS, router
try:
import bz2
has_bz2 = True
except ImportError:
has_bz2 = False
try:
import lzma
has_lzma = True
except ImportError:
has_lzma = False
class ProxyModelWarning(Warning):
pass
class Command(BaseCommand):
help = (
"Output the contents of the database as a fixture of the given format "
"(using each model's default manager unless --all is specified)."
)
def add_arguments(self, parser):
parser.add_argument(
'args', metavar='app_label[.ModelName]', nargs='*',
help='Restricts dumped data to the specified app_label or app_label.ModelName.',
)
parser.add_argument(
'--format', default='json',
help='Specifies the output serialization format for fixtures.',
)
parser.add_argument(
'--indent', type=int,
help='Specifies the indent level to use when pretty-printing output.',
)
parser.add_argument(
'--database',
default=DEFAULT_DB_ALIAS,
help='Nominates a specific database to dump fixtures from. '
'Defaults to the "default" database.',
)
parser.add_argument(
'-e', '--exclude', action='append', default=[],
help='An app_label or app_label.ModelName to exclude '
'(use multiple --exclude to exclude multiple apps/models).',
)
parser.add_argument(
'--natural-foreign', action='store_true', dest='use_natural_foreign_keys',
help='Use natural foreign keys if they are available.',
)
parser.add_argument(
'--natural-primary', action='store_true', dest='use_natural_primary_keys',
help='Use natural primary keys if they are available.',
)
parser.add_argument(
'-a', '--all', action='store_true', dest='use_base_manager',
help="Use Django's base manager to dump all models stored in the database, "
"including those that would otherwise be filtered or modified by a custom manager.",
)
parser.add_argument(
'--pks', dest='primary_keys',
help="Only dump objects with given primary keys. Accepts a comma-separated "
"list of keys. This option only works when you specify one model.",
)
parser.add_argument(
'-o', '--output',
help='Specifies file to which the output is written.'
)
def handle(self, *app_labels, **options):
format = options['format']
indent = options['indent']
using = options['database']
excludes = options['exclude']
output = options['output']
show_traceback = options['traceback']
use_natural_foreign_keys = options['use_natural_foreign_keys']
use_natural_primary_keys = options['use_natural_primary_keys']
use_base_manager = options['use_base_manager']
pks = options['primary_keys']
if pks:
primary_keys = [pk.strip() for pk in pks.split(',')]
else:
primary_keys = []
excluded_models, excluded_apps = parse_apps_and_model_labels(excludes)
if not app_labels:
if primary_keys:
raise CommandError("You can only use --pks option with one model")
app_list = dict.fromkeys(
app_config for app_config in apps.get_app_configs()
if app_config.models_module is not None and app_config not in excluded_apps
)
else:
if len(app_labels) > 1 and primary_keys:
raise CommandError("You can only use --pks option with one model")
app_list = {}
for label in app_labels:
try:
app_label, model_label = label.split('.')
try:
app_config = apps.get_app_config(app_label)
except LookupError as e:
raise CommandError(str(e))
if app_config.models_module is None or app_config in excluded_apps:
continue
try:
model = app_config.get_model(model_label)
except LookupError:
raise CommandError("Unknown model: %s.%s" % (app_label, model_label))
app_list_value = app_list.setdefault(app_config, [])
# We may have previously seen an "all-models" request for
# this app (no model qualifier was given). In this case
# there is no need adding specific models to the list.
if app_list_value is not None and model not in app_list_value:
app_list_value.append(model)
except ValueError:
if primary_keys:
raise CommandError("You can only use --pks option with one model")
# This is just an app - no model qualifier
app_label = label
try:
app_config = apps.get_app_config(app_label)
except LookupError as e:
raise CommandError(str(e))
if app_config.models_module is None or app_config in excluded_apps:
continue
app_list[app_config] = None
# Check that the serialization format exists; this is a shortcut to
# avoid collating all the objects and _then_ failing.
if format not in serializers.get_public_serializer_formats():
try:
serializers.get_serializer(format)
except serializers.SerializerDoesNotExist:
pass
raise CommandError("Unknown serialization format: %s" % format)
def get_objects(count_only=False):
"""
Collate the objects to be serialized. If count_only is True, just
count the number of objects to be serialized.
"""
if use_natural_foreign_keys:
models = serializers.sort_dependencies(app_list.items(), allow_cycles=True)
else:
# There is no need to sort dependencies when natural foreign
# keys are not used.
models = []
for (app_config, model_list) in app_list.items():
if model_list is None:
models.extend(app_config.get_models())
else:
models.extend(model_list)
for model in models:
if model in excluded_models:
continue
if model._meta.proxy and model._meta.proxy_for_model not in models:
warnings.warn(
"%s is a proxy model and won't be serialized." % model._meta.label,
category=ProxyModelWarning,
)
if not model._meta.proxy and router.allow_migrate_model(using, model):
if use_base_manager:
objects = model._base_manager
else:
objects = model._default_manager
queryset = objects.using(using).order_by(model._meta.pk.name)
if primary_keys:
queryset = queryset.filter(pk__in=primary_keys)
if count_only:
yield queryset.order_by().count()
else:
yield from queryset.iterator()
try:
self.stdout.ending = None
progress_output = None
object_count = 0
# If dumpdata is outputting to stdout, there is no way to display progress
if output and self.stdout.isatty() and options['verbosity'] > 0:
progress_output = self.stdout
object_count = sum(get_objects(count_only=True))
if output:
file_root, file_ext = os.path.splitext(output)
compression_formats = {
'.bz2': (open, {}, file_root),
'.gz': (gzip.open, {}, output),
'.lzma': (open, {}, file_root),
'.xz': (open, {}, file_root),
'.zip': (open, {}, file_root),
}
if has_bz2:
compression_formats['.bz2'] = (bz2.open, {}, output)
if has_lzma:
compression_formats['.lzma'] = (
lzma.open, {'format': lzma.FORMAT_ALONE}, output
)
compression_formats['.xz'] = (lzma.open, {}, output)
try:
open_method, kwargs, file_path = compression_formats[file_ext]
except KeyError:
open_method, kwargs, file_path = (open, {}, output)
if file_path != output:
file_name = os.path.basename(file_path)
warnings.warn(
f"Unsupported file extension ({file_ext}). "
f"Fixtures saved in '{file_name}'.",
RuntimeWarning,
)
stream = open_method(file_path, 'wt', **kwargs)
else:
stream = None
try:
serializers.serialize(
format, get_objects(), indent=indent,
use_natural_foreign_keys=use_natural_foreign_keys,
use_natural_primary_keys=use_natural_primary_keys,
stream=stream or self.stdout, progress_output=progress_output,
object_count=object_count,
)
finally:
if stream:
stream.close()
except Exception as e:
if show_traceback:
raise
raise CommandError("Unable to serialize database: %s" % e)

View File

@@ -0,0 +1,82 @@
from importlib import import_module
from django.apps import apps
from django.core.management.base import BaseCommand, CommandError
from django.core.management.color import no_style
from django.core.management.sql import emit_post_migrate_signal, sql_flush
from django.db import DEFAULT_DB_ALIAS, connections
class Command(BaseCommand):
help = (
'Removes ALL DATA from the database, including data added during '
'migrations. Does not achieve a "fresh install" state.'
)
stealth_options = ('reset_sequences', 'allow_cascade', 'inhibit_post_migrate')
def add_arguments(self, parser):
parser.add_argument(
'--noinput', '--no-input', action='store_false', dest='interactive',
help='Tells Django to NOT prompt the user for input of any kind.',
)
parser.add_argument(
'--database', default=DEFAULT_DB_ALIAS,
help='Nominates a database to flush. Defaults to the "default" database.',
)
def handle(self, **options):
database = options['database']
connection = connections[database]
verbosity = options['verbosity']
interactive = options['interactive']
# The following are stealth options used by Django's internals.
reset_sequences = options.get('reset_sequences', True)
allow_cascade = options.get('allow_cascade', False)
inhibit_post_migrate = options.get('inhibit_post_migrate', False)
self.style = no_style()
# Import the 'management' module within each installed app, to register
# dispatcher events.
for app_config in apps.get_app_configs():
try:
import_module('.management', app_config.name)
except ImportError:
pass
sql_list = sql_flush(self.style, connection,
reset_sequences=reset_sequences,
allow_cascade=allow_cascade)
if interactive:
confirm = input("""You have requested a flush of the database.
This will IRREVERSIBLY DESTROY all data currently in the "%s" database,
and return each table to an empty state.
Are you sure you want to do this?
Type 'yes' to continue, or 'no' to cancel: """ % connection.settings_dict['NAME'])
else:
confirm = 'yes'
if confirm == 'yes':
try:
connection.ops.execute_sql_flush(sql_list)
except Exception as exc:
raise CommandError(
"Database %s couldn't be flushed. Possible reasons:\n"
" * The database isn't running or isn't configured correctly.\n"
" * At least one of the expected database tables doesn't exist.\n"
" * The SQL was invalid.\n"
"Hint: Look at the output of 'django-admin sqlflush'. "
"That's the SQL this command wasn't able to run." % (
connection.settings_dict['NAME'],
)
) from exc
# Empty sql_list may signify an empty database and post_migrate would then crash
if sql_list and not inhibit_post_migrate:
# Emit the post migrate signal. This allows individual applications to
# respond as if the database had been migrated from scratch.
emit_post_migrate_signal(verbosity, interactive, database)
else:
self.stdout.write('Flush cancelled.')

View File

@@ -0,0 +1,299 @@
import keyword
import re
from django.core.management.base import BaseCommand, CommandError
from django.db import DEFAULT_DB_ALIAS, connections
from django.db.models.constants import LOOKUP_SEP
class Command(BaseCommand):
help = "Introspects the database tables in the given database and outputs a Django model module."
requires_system_checks = []
stealth_options = ('table_name_filter',)
db_module = 'django.db'
def add_arguments(self, parser):
parser.add_argument(
'table', nargs='*', type=str,
help='Selects what tables or views should be introspected.',
)
parser.add_argument(
'--database', default=DEFAULT_DB_ALIAS,
help='Nominates a database to introspect. Defaults to using the "default" database.',
)
parser.add_argument(
'--include-partitions', action='store_true', help='Also output models for partition tables.',
)
parser.add_argument(
'--include-views', action='store_true', help='Also output models for database views.',
)
def handle(self, **options):
try:
for line in self.handle_inspection(options):
self.stdout.write(line)
except NotImplementedError:
raise CommandError("Database inspection isn't supported for the currently selected database backend.")
def handle_inspection(self, options):
connection = connections[options['database']]
# 'table_name_filter' is a stealth option
table_name_filter = options.get('table_name_filter')
def table2model(table_name):
return re.sub(r'[^a-zA-Z0-9]', '', table_name.title())
with connection.cursor() as cursor:
yield "# This is an auto-generated Django model module."
yield "# You'll have to do the following manually to clean this up:"
yield "# * Rearrange models' order"
yield "# * Make sure each model has one field with primary_key=True"
yield "# * Make sure each ForeignKey and OneToOneField has `on_delete` set to the desired behavior"
yield (
"# * Remove `managed = False` lines if you wish to allow "
"Django to create, modify, and delete the table"
)
yield "# Feel free to rename the models, but don't rename db_table values or field names."
yield 'from %s import models' % self.db_module
known_models = []
table_info = connection.introspection.get_table_list(cursor)
# Determine types of tables and/or views to be introspected.
types = {'t'}
if options['include_partitions']:
types.add('p')
if options['include_views']:
types.add('v')
for table_name in (options['table'] or sorted(info.name for info in table_info if info.type in types)):
if table_name_filter is not None and callable(table_name_filter):
if not table_name_filter(table_name):
continue
try:
try:
relations = connection.introspection.get_relations(cursor, table_name)
except NotImplementedError:
relations = {}
try:
constraints = connection.introspection.get_constraints(cursor, table_name)
except NotImplementedError:
constraints = {}
primary_key_column = connection.introspection.get_primary_key_column(cursor, table_name)
unique_columns = [
c['columns'][0] for c in constraints.values()
if c['unique'] and len(c['columns']) == 1
]
table_description = connection.introspection.get_table_description(cursor, table_name)
except Exception as e:
yield "# Unable to inspect table '%s'" % table_name
yield "# The error was: %s" % e
continue
yield ''
yield ''
yield 'class %s(models.Model):' % table2model(table_name)
known_models.append(table2model(table_name))
used_column_names = [] # Holds column names used in the table so far
column_to_field_name = {} # Maps column names to names of model fields
for row in table_description:
comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = {} # Holds Field parameters such as 'db_column'.
column_name = row.name
is_relation = column_name in relations
att_name, params, notes = self.normalize_col_name(
column_name, used_column_names, is_relation)
extra_params.update(params)
comment_notes.extend(notes)
used_column_names.append(att_name)
column_to_field_name[column_name] = att_name
# Add primary_key and unique, if necessary.
if column_name == primary_key_column:
extra_params['primary_key'] = True
elif column_name in unique_columns:
extra_params['unique'] = True
if is_relation:
if extra_params.pop('unique', False) or extra_params.get('primary_key'):
rel_type = 'OneToOneField'
else:
rel_type = 'ForeignKey'
rel_to = (
"self" if relations[column_name][1] == table_name
else table2model(relations[column_name][1])
)
if rel_to in known_models:
field_type = '%s(%s' % (rel_type, rel_to)
else:
field_type = "%s('%s'" % (rel_type, rel_to)
else:
# Calling `get_field_type` to get the field type string and any
# additional parameters and notes.
field_type, field_params, field_notes = self.get_field_type(connection, table_name, row)
extra_params.update(field_params)
comment_notes.extend(field_notes)
field_type += '('
# Don't output 'id = meta.AutoField(primary_key=True)', because
# that's assumed if it doesn't exist.
if att_name == 'id' and extra_params == {'primary_key': True}:
if field_type == 'AutoField(':
continue
elif field_type == connection.features.introspected_field_types['AutoField'] + '(':
comment_notes.append('AutoField?')
# Add 'null' and 'blank', if the 'null_ok' flag was present in the
# table description.
if row.null_ok: # If it's NULL...
extra_params['blank'] = True
extra_params['null'] = True
field_desc = '%s = %s%s' % (
att_name,
# Custom fields will have a dotted path
'' if '.' in field_type else 'models.',
field_type,
)
if field_type.startswith(('ForeignKey(', 'OneToOneField(')):
field_desc += ', models.DO_NOTHING'
if extra_params:
if not field_desc.endswith('('):
field_desc += ', '
field_desc += ', '.join('%s=%r' % (k, v) for k, v in extra_params.items())
field_desc += ')'
if comment_notes:
field_desc += ' # ' + ' '.join(comment_notes)
yield ' %s' % field_desc
is_view = any(info.name == table_name and info.type == 'v' for info in table_info)
is_partition = any(info.name == table_name and info.type == 'p' for info in table_info)
yield from self.get_meta(table_name, constraints, column_to_field_name, is_view, is_partition)
def normalize_col_name(self, col_name, used_column_names, is_relation):
"""
Modify the column name to make it Python-compatible as a field name
"""
field_params = {}
field_notes = []
new_name = col_name.lower()
if new_name != col_name:
field_notes.append('Field name made lowercase.')
if is_relation:
if new_name.endswith('_id'):
new_name = new_name[:-3]
else:
field_params['db_column'] = col_name
new_name, num_repl = re.subn(r'\W', '_', new_name)
if num_repl > 0:
field_notes.append('Field renamed to remove unsuitable characters.')
if new_name.find(LOOKUP_SEP) >= 0:
while new_name.find(LOOKUP_SEP) >= 0:
new_name = new_name.replace(LOOKUP_SEP, '_')
if col_name.lower().find(LOOKUP_SEP) >= 0:
# Only add the comment if the double underscore was in the original name
field_notes.append("Field renamed because it contained more than one '_' in a row.")
if new_name.startswith('_'):
new_name = 'field%s' % new_name
field_notes.append("Field renamed because it started with '_'.")
if new_name.endswith('_'):
new_name = '%sfield' % new_name
field_notes.append("Field renamed because it ended with '_'.")
if keyword.iskeyword(new_name):
new_name += '_field'
field_notes.append('Field renamed because it was a Python reserved word.')
if new_name[0].isdigit():
new_name = 'number_%s' % new_name
field_notes.append("Field renamed because it wasn't a valid Python identifier.")
if new_name in used_column_names:
num = 0
while '%s_%d' % (new_name, num) in used_column_names:
num += 1
new_name = '%s_%d' % (new_name, num)
field_notes.append('Field renamed because of name conflict.')
if col_name != new_name and field_notes:
field_params['db_column'] = col_name
return new_name, field_params, field_notes
def get_field_type(self, connection, table_name, row):
"""
Given the database connection, the table name, and the cursor row
description, this routine will return the given field type name, as
well as any additional keyword parameters and notes for the field.
"""
field_params = {}
field_notes = []
try:
field_type = connection.introspection.get_field_type(row.type_code, row)
except KeyError:
field_type = 'TextField'
field_notes.append('This field type is a guess.')
# Add max_length for all CharFields.
if field_type == 'CharField' and row.internal_size:
field_params['max_length'] = int(row.internal_size)
if field_type in {'CharField', 'TextField'} and row.collation:
field_params['db_collation'] = row.collation
if field_type == 'DecimalField':
if row.precision is None or row.scale is None:
field_notes.append(
'max_digits and decimal_places have been guessed, as this '
'database handles decimal fields as float')
field_params['max_digits'] = row.precision if row.precision is not None else 10
field_params['decimal_places'] = row.scale if row.scale is not None else 5
else:
field_params['max_digits'] = row.precision
field_params['decimal_places'] = row.scale
return field_type, field_params, field_notes
def get_meta(self, table_name, constraints, column_to_field_name, is_view, is_partition):
"""
Return a sequence comprising the lines of code necessary
to construct the inner Meta class for the model corresponding
to the given database table name.
"""
unique_together = []
has_unsupported_constraint = False
for params in constraints.values():
if params['unique']:
columns = params['columns']
if None in columns:
has_unsupported_constraint = True
columns = [x for x in columns if x is not None]
if len(columns) > 1:
unique_together.append(str(tuple(column_to_field_name[c] for c in columns)))
if is_view:
managed_comment = " # Created from a view. Don't remove."
elif is_partition:
managed_comment = " # Created from a partition. Don't remove."
else:
managed_comment = ''
meta = ['']
if has_unsupported_constraint:
meta.append(' # A unique constraint could not be introspected.')
meta += [
' class Meta:',
' managed = False%s' % managed_comment,
' db_table = %r' % table_name
]
if unique_together:
tup = '(' + ', '.join(unique_together) + ',)'
meta += [" unique_together = %s" % tup]
return meta

View File

@@ -0,0 +1,384 @@
import functools
import glob
import gzip
import os
import sys
import warnings
import zipfile
from itertools import product
from django.apps import apps
from django.conf import settings
from django.core import serializers
from django.core.exceptions import ImproperlyConfigured
from django.core.management.base import BaseCommand, CommandError
from django.core.management.color import no_style
from django.core.management.utils import parse_apps_and_model_labels
from django.db import (
DEFAULT_DB_ALIAS, DatabaseError, IntegrityError, connections, router,
transaction,
)
from django.utils.functional import cached_property
try:
import bz2
has_bz2 = True
except ImportError:
has_bz2 = False
try:
import lzma
has_lzma = True
except ImportError:
has_lzma = False
READ_STDIN = '-'
class Command(BaseCommand):
help = 'Installs the named fixture(s) in the database.'
missing_args_message = (
"No database fixture specified. Please provide the path of at least "
"one fixture in the command line."
)
def add_arguments(self, parser):
parser.add_argument('args', metavar='fixture', nargs='+', help='Fixture labels.')
parser.add_argument(
'--database', default=DEFAULT_DB_ALIAS,
help='Nominates a specific database to load fixtures into. Defaults to the "default" database.',
)
parser.add_argument(
'--app', dest='app_label',
help='Only look for fixtures in the specified app.',
)
parser.add_argument(
'--ignorenonexistent', '-i', action='store_true', dest='ignore',
help='Ignores entries in the serialized data for fields that do not '
'currently exist on the model.',
)
parser.add_argument(
'-e', '--exclude', action='append', default=[],
help='An app_label or app_label.ModelName to exclude. Can be used multiple times.',
)
parser.add_argument(
'--format',
help='Format of serialized data when reading from stdin.',
)
def handle(self, *fixture_labels, **options):
self.ignore = options['ignore']
self.using = options['database']
self.app_label = options['app_label']
self.verbosity = options['verbosity']
self.excluded_models, self.excluded_apps = parse_apps_and_model_labels(options['exclude'])
self.format = options['format']
with transaction.atomic(using=self.using):
self.loaddata(fixture_labels)
# Close the DB connection -- unless we're still in a transaction. This
# is required as a workaround for an edge case in MySQL: if the same
# connection is used to create tables, load data, and query, the query
# can return incorrect results. See Django #7572, MySQL #37735.
if transaction.get_autocommit(self.using):
connections[self.using].close()
@cached_property
def compression_formats(self):
"""A dict mapping format names to (open function, mode arg) tuples."""
# Forcing binary mode may be revisited after dropping Python 2 support (see #22399)
compression_formats = {
None: (open, 'rb'),
'gz': (gzip.GzipFile, 'rb'),
'zip': (SingleZipReader, 'r'),
'stdin': (lambda *args: sys.stdin, None),
}
if has_bz2:
compression_formats['bz2'] = (bz2.BZ2File, 'r')
if has_lzma:
compression_formats['lzma'] = (lzma.LZMAFile, 'r')
compression_formats['xz'] = (lzma.LZMAFile, 'r')
return compression_formats
def reset_sequences(self, connection, models):
"""Reset database sequences for the given connection and models."""
sequence_sql = connection.ops.sequence_reset_sql(no_style(), models)
if sequence_sql:
if self.verbosity >= 2:
self.stdout.write('Resetting sequences')
with connection.cursor() as cursor:
for line in sequence_sql:
cursor.execute(line)
def loaddata(self, fixture_labels):
connection = connections[self.using]
# Keep a count of the installed objects and fixtures
self.fixture_count = 0
self.loaded_object_count = 0
self.fixture_object_count = 0
self.models = set()
self.serialization_formats = serializers.get_public_serializer_formats()
# Django's test suite repeatedly tries to load initial_data fixtures
# from apps that don't have any fixtures. Because disabling constraint
# checks can be expensive on some database (especially MSSQL), bail
# out early if no fixtures are found.
for fixture_label in fixture_labels:
if self.find_fixtures(fixture_label):
break
else:
return
self.objs_with_deferred_fields = []
with connection.constraint_checks_disabled():
for fixture_label in fixture_labels:
self.load_label(fixture_label)
for obj in self.objs_with_deferred_fields:
obj.save_deferred_fields(using=self.using)
# Since we disabled constraint checks, we must manually check for
# any invalid keys that might have been added
table_names = [model._meta.db_table for model in self.models]
try:
connection.check_constraints(table_names=table_names)
except Exception as e:
e.args = ("Problem installing fixtures: %s" % e,)
raise
# If we found even one object in a fixture, we need to reset the
# database sequences.
if self.loaded_object_count > 0:
self.reset_sequences(connection, self.models)
if self.verbosity >= 1:
if self.fixture_object_count == self.loaded_object_count:
self.stdout.write(
"Installed %d object(s) from %d fixture(s)"
% (self.loaded_object_count, self.fixture_count)
)
else:
self.stdout.write(
"Installed %d object(s) (of %d) from %d fixture(s)"
% (self.loaded_object_count, self.fixture_object_count, self.fixture_count)
)
def save_obj(self, obj):
"""Save an object if permitted."""
if (
obj.object._meta.app_config in self.excluded_apps or
type(obj.object) in self.excluded_models
):
return False
saved = False
if router.allow_migrate_model(self.using, obj.object.__class__):
saved = True
self.models.add(obj.object.__class__)
try:
obj.save(using=self.using)
# psycopg2 raises ValueError if data contains NUL chars.
except (DatabaseError, IntegrityError, ValueError) as e:
e.args = ('Could not load %(object_label)s(pk=%(pk)s): %(error_msg)s' % {
'object_label': obj.object._meta.label,
'pk': obj.object.pk,
'error_msg': e,
},)
raise
if obj.deferred_fields:
self.objs_with_deferred_fields.append(obj)
return saved
def load_label(self, fixture_label):
"""Load fixtures files for a given label."""
show_progress = self.verbosity >= 3
for fixture_file, fixture_dir, fixture_name in self.find_fixtures(fixture_label):
_, ser_fmt, cmp_fmt = self.parse_name(os.path.basename(fixture_file))
open_method, mode = self.compression_formats[cmp_fmt]
fixture = open_method(fixture_file, mode)
self.fixture_count += 1
objects_in_fixture = 0
loaded_objects_in_fixture = 0
if self.verbosity >= 2:
self.stdout.write(
"Installing %s fixture '%s' from %s."
% (ser_fmt, fixture_name, humanize(fixture_dir))
)
try:
objects = serializers.deserialize(
ser_fmt, fixture, using=self.using, ignorenonexistent=self.ignore,
handle_forward_references=True,
)
for obj in objects:
objects_in_fixture += 1
if self.save_obj(obj):
loaded_objects_in_fixture += 1
if show_progress:
self.stdout.write(
'\rProcessed %i object(s).' % loaded_objects_in_fixture,
ending=''
)
except Exception as e:
if not isinstance(e, CommandError):
e.args = ("Problem installing fixture '%s': %s" % (fixture_file, e),)
raise
finally:
fixture.close()
if objects_in_fixture and show_progress:
self.stdout.write() # Add a newline after progress indicator.
self.loaded_object_count += loaded_objects_in_fixture
self.fixture_object_count += objects_in_fixture
# Warn if the fixture we loaded contains 0 objects.
if objects_in_fixture == 0:
warnings.warn(
"No fixture data found for '%s'. (File format may be "
"invalid.)" % fixture_name,
RuntimeWarning
)
def get_fixture_name_and_dirs(self, fixture_name):
dirname, basename = os.path.split(fixture_name)
if os.path.isabs(fixture_name):
fixture_dirs = [dirname]
else:
fixture_dirs = self.fixture_dirs
if os.path.sep in os.path.normpath(fixture_name):
fixture_dirs = [os.path.join(dir_, dirname) for dir_ in fixture_dirs]
return basename, fixture_dirs
def get_targets(self, fixture_name, ser_fmt, cmp_fmt):
databases = [self.using, None]
cmp_fmts = self.compression_formats if cmp_fmt is None else [cmp_fmt]
ser_fmts = self.serialization_formats if ser_fmt is None else [ser_fmt]
return {
'%s.%s' % (
fixture_name,
'.'.join([ext for ext in combo if ext]),
) for combo in product(databases, ser_fmts, cmp_fmts)
}
def find_fixture_files_in_dir(self, fixture_dir, fixture_name, targets):
fixture_files_in_dir = []
path = os.path.join(fixture_dir, fixture_name)
for candidate in glob.iglob(glob.escape(path) + '*'):
if os.path.basename(candidate) in targets:
# Save the fixture_dir and fixture_name for future error
# messages.
fixture_files_in_dir.append((candidate, fixture_dir, fixture_name))
return fixture_files_in_dir
@functools.lru_cache(maxsize=None)
def find_fixtures(self, fixture_label):
"""Find fixture files for a given label."""
if fixture_label == READ_STDIN:
return [(READ_STDIN, None, READ_STDIN)]
fixture_name, ser_fmt, cmp_fmt = self.parse_name(fixture_label)
if self.verbosity >= 2:
self.stdout.write("Loading '%s' fixtures..." % fixture_name)
fixture_name, fixture_dirs = self.get_fixture_name_and_dirs(fixture_name)
targets = self.get_targets(fixture_name, ser_fmt, cmp_fmt)
fixture_files = []
for fixture_dir in fixture_dirs:
if self.verbosity >= 2:
self.stdout.write("Checking %s for fixtures..." % humanize(fixture_dir))
fixture_files_in_dir = self.find_fixture_files_in_dir(
fixture_dir, fixture_name, targets,
)
if self.verbosity >= 2 and not fixture_files_in_dir:
self.stdout.write("No fixture '%s' in %s." %
(fixture_name, humanize(fixture_dir)))
# Check kept for backwards-compatibility; it isn't clear why
# duplicates are only allowed in different directories.
if len(fixture_files_in_dir) > 1:
raise CommandError(
"Multiple fixtures named '%s' in %s. Aborting." %
(fixture_name, humanize(fixture_dir)))
fixture_files.extend(fixture_files_in_dir)
if not fixture_files:
raise CommandError("No fixture named '%s' found." % fixture_name)
return fixture_files
@cached_property
def fixture_dirs(self):
"""
Return a list of fixture directories.
The list contains the 'fixtures' subdirectory of each installed
application, if it exists, the directories in FIXTURE_DIRS, and the
current directory.
"""
dirs = []
fixture_dirs = settings.FIXTURE_DIRS
if len(fixture_dirs) != len(set(fixture_dirs)):
raise ImproperlyConfigured("settings.FIXTURE_DIRS contains duplicates.")
for app_config in apps.get_app_configs():
app_label = app_config.label
app_dir = os.path.join(app_config.path, 'fixtures')
if app_dir in fixture_dirs:
raise ImproperlyConfigured(
"'%s' is a default fixture directory for the '%s' app "
"and cannot be listed in settings.FIXTURE_DIRS." % (app_dir, app_label)
)
if self.app_label and app_label != self.app_label:
continue
if os.path.isdir(app_dir):
dirs.append(app_dir)
dirs.extend(fixture_dirs)
dirs.append('')
return [os.path.realpath(d) for d in dirs]
def parse_name(self, fixture_name):
"""
Split fixture name in name, serialization format, compression format.
"""
if fixture_name == READ_STDIN:
if not self.format:
raise CommandError('--format must be specified when reading from stdin.')
return READ_STDIN, self.format, 'stdin'
parts = fixture_name.rsplit('.', 2)
if len(parts) > 1 and parts[-1] in self.compression_formats:
cmp_fmt = parts[-1]
parts = parts[:-1]
else:
cmp_fmt = None
if len(parts) > 1:
if parts[-1] in self.serialization_formats:
ser_fmt = parts[-1]
parts = parts[:-1]
else:
raise CommandError(
"Problem installing fixture '%s': %s is not a known "
"serialization format." % ('.'.join(parts[:-1]), parts[-1]))
else:
ser_fmt = None
name = '.'.join(parts)
return name, ser_fmt, cmp_fmt
class SingleZipReader(zipfile.ZipFile):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if len(self.namelist()) != 1:
raise ValueError("Zip-compressed fixtures must contain one file.")
def read(self):
return zipfile.ZipFile.read(self, self.namelist()[0])
def humanize(dirname):
return "'%s'" % dirname if dirname else 'absolute path'

View File

@@ -0,0 +1,675 @@
import glob
import os
import re
import sys
from functools import total_ordering
from itertools import dropwhile
import django
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.files.temp import NamedTemporaryFile
from django.core.management.base import BaseCommand, CommandError
from django.core.management.utils import (
find_command, handle_extensions, is_ignored_path, popen_wrapper,
)
from django.utils.encoding import DEFAULT_LOCALE_ENCODING
from django.utils.functional import cached_property
from django.utils.jslex import prepare_js_for_gettext
from django.utils.regex_helper import _lazy_re_compile
from django.utils.text import get_text_list
from django.utils.translation import templatize
plural_forms_re = _lazy_re_compile(r'^(?P<value>"Plural-Forms.+?\\n")\s*$', re.MULTILINE | re.DOTALL)
STATUS_OK = 0
NO_LOCALE_DIR = object()
def check_programs(*programs):
for program in programs:
if find_command(program) is None:
raise CommandError(
"Can't find %s. Make sure you have GNU gettext tools 0.15 or "
"newer installed." % program
)
@total_ordering
class TranslatableFile:
def __init__(self, dirpath, file_name, locale_dir):
self.file = file_name
self.dirpath = dirpath
self.locale_dir = locale_dir
def __repr__(self):
return "<%s: %s>" % (
self.__class__.__name__,
os.sep.join([self.dirpath, self.file]),
)
def __eq__(self, other):
return self.path == other.path
def __lt__(self, other):
return self.path < other.path
@property
def path(self):
return os.path.join(self.dirpath, self.file)
class BuildFile:
"""
Represent the state of a translatable file during the build process.
"""
def __init__(self, command, domain, translatable):
self.command = command
self.domain = domain
self.translatable = translatable
@cached_property
def is_templatized(self):
if self.domain == 'djangojs':
return self.command.gettext_version < (0, 18, 3)
elif self.domain == 'django':
file_ext = os.path.splitext(self.translatable.file)[1]
return file_ext != '.py'
return False
@cached_property
def path(self):
return self.translatable.path
@cached_property
def work_path(self):
"""
Path to a file which is being fed into GNU gettext pipeline. This may
be either a translatable or its preprocessed version.
"""
if not self.is_templatized:
return self.path
extension = {
'djangojs': 'c',
'django': 'py',
}.get(self.domain)
filename = '%s.%s' % (self.translatable.file, extension)
return os.path.join(self.translatable.dirpath, filename)
def preprocess(self):
"""
Preprocess (if necessary) a translatable file before passing it to
xgettext GNU gettext utility.
"""
if not self.is_templatized:
return
with open(self.path, encoding='utf-8') as fp:
src_data = fp.read()
if self.domain == 'djangojs':
content = prepare_js_for_gettext(src_data)
elif self.domain == 'django':
content = templatize(src_data, origin=self.path[2:])
with open(self.work_path, 'w', encoding='utf-8') as fp:
fp.write(content)
def postprocess_messages(self, msgs):
"""
Postprocess messages generated by xgettext GNU gettext utility.
Transform paths as if these messages were generated from original
translatable files rather than from preprocessed versions.
"""
if not self.is_templatized:
return msgs
# Remove '.py' suffix
if os.name == 'nt':
# Preserve '.\' prefix on Windows to respect gettext behavior
old_path = self.work_path
new_path = self.path
else:
old_path = self.work_path[2:]
new_path = self.path[2:]
return re.sub(
r'^(#: .*)(' + re.escape(old_path) + r')',
lambda match: match[0].replace(old_path, new_path),
msgs,
flags=re.MULTILINE
)
def cleanup(self):
"""
Remove a preprocessed copy of a translatable file (if any).
"""
if self.is_templatized:
# This check is needed for the case of a symlinked file and its
# source being processed inside a single group (locale dir);
# removing either of those two removes both.
if os.path.exists(self.work_path):
os.unlink(self.work_path)
def normalize_eols(raw_contents):
"""
Take a block of raw text that will be passed through str.splitlines() to
get universal newlines treatment.
Return the resulting block of text with normalized `\n` EOL sequences ready
to be written to disk using current platform's native EOLs.
"""
lines_list = raw_contents.splitlines()
# Ensure last line has its EOL
if lines_list and lines_list[-1]:
lines_list.append('')
return '\n'.join(lines_list)
def write_pot_file(potfile, msgs):
"""
Write the `potfile` with the `msgs` contents, making sure its format is
valid.
"""
pot_lines = msgs.splitlines()
if os.path.exists(potfile):
# Strip the header
lines = dropwhile(len, pot_lines)
else:
lines = []
found, header_read = False, False
for line in pot_lines:
if not found and not header_read:
if 'charset=CHARSET' in line:
found = True
line = line.replace('charset=CHARSET', 'charset=UTF-8')
if not line and not found:
header_read = True
lines.append(line)
msgs = '\n'.join(lines)
# Force newlines of POT files to '\n' to work around
# https://savannah.gnu.org/bugs/index.php?52395
with open(potfile, 'a', encoding='utf-8', newline='\n') as fp:
fp.write(msgs)
class Command(BaseCommand):
help = (
"Runs over the entire source tree of the current directory and "
"pulls out all strings marked for translation. It creates (or updates) a message "
"file in the conf/locale (in the django tree) or locale (for projects and "
"applications) directory.\n\nYou must run this command with one of either the "
"--locale, --exclude, or --all options."
)
translatable_file_class = TranslatableFile
build_file_class = BuildFile
requires_system_checks = []
msgmerge_options = ['-q', '--previous']
msguniq_options = ['--to-code=utf-8']
msgattrib_options = ['--no-obsolete']
xgettext_options = ['--from-code=UTF-8', '--add-comments=Translators']
def add_arguments(self, parser):
parser.add_argument(
'--locale', '-l', default=[], action='append',
help='Creates or updates the message files for the given locale(s) (e.g. pt_BR). '
'Can be used multiple times.',
)
parser.add_argument(
'--exclude', '-x', default=[], action='append',
help='Locales to exclude. Default is none. Can be used multiple times.',
)
parser.add_argument(
'--domain', '-d', default='django',
help='The domain of the message files (default: "django").',
)
parser.add_argument(
'--all', '-a', action='store_true',
help='Updates the message files for all existing locales.',
)
parser.add_argument(
'--extension', '-e', dest='extensions', action='append',
help='The file extension(s) to examine (default: "html,txt,py", or "js" '
'if the domain is "djangojs"). Separate multiple extensions with '
'commas, or use -e multiple times.',
)
parser.add_argument(
'--symlinks', '-s', action='store_true',
help='Follows symlinks to directories when examining source code '
'and templates for translation strings.',
)
parser.add_argument(
'--ignore', '-i', action='append', dest='ignore_patterns',
default=[], metavar='PATTERN',
help='Ignore files or directories matching this glob-style pattern. '
'Use multiple times to ignore more.',
)
parser.add_argument(
'--no-default-ignore', action='store_false', dest='use_default_ignore_patterns',
help="Don't ignore the common glob-style patterns 'CVS', '.*', '*~' and '*.pyc'.",
)
parser.add_argument(
'--no-wrap', action='store_true',
help="Don't break long message lines into several lines.",
)
parser.add_argument(
'--no-location', action='store_true',
help="Don't write '#: filename:line' lines.",
)
parser.add_argument(
'--add-location',
choices=('full', 'file', 'never'), const='full', nargs='?',
help=(
"Controls '#: filename:line' lines. If the option is 'full' "
"(the default if not given), the lines include both file name "
"and line number. If it's 'file', the line number is omitted. If "
"it's 'never', the lines are suppressed (same as --no-location). "
"--add-location requires gettext 0.19 or newer."
),
)
parser.add_argument(
'--no-obsolete', action='store_true',
help="Remove obsolete message strings.",
)
parser.add_argument(
'--keep-pot', action='store_true',
help="Keep .pot file after making messages. Useful when debugging.",
)
def handle(self, *args, **options):
locale = options['locale']
exclude = options['exclude']
self.domain = options['domain']
self.verbosity = options['verbosity']
process_all = options['all']
extensions = options['extensions']
self.symlinks = options['symlinks']
ignore_patterns = options['ignore_patterns']
if options['use_default_ignore_patterns']:
ignore_patterns += ['CVS', '.*', '*~', '*.pyc']
self.ignore_patterns = list(set(ignore_patterns))
# Avoid messing with mutable class variables
if options['no_wrap']:
self.msgmerge_options = self.msgmerge_options[:] + ['--no-wrap']
self.msguniq_options = self.msguniq_options[:] + ['--no-wrap']
self.msgattrib_options = self.msgattrib_options[:] + ['--no-wrap']
self.xgettext_options = self.xgettext_options[:] + ['--no-wrap']
if options['no_location']:
self.msgmerge_options = self.msgmerge_options[:] + ['--no-location']
self.msguniq_options = self.msguniq_options[:] + ['--no-location']
self.msgattrib_options = self.msgattrib_options[:] + ['--no-location']
self.xgettext_options = self.xgettext_options[:] + ['--no-location']
if options['add_location']:
if self.gettext_version < (0, 19):
raise CommandError(
"The --add-location option requires gettext 0.19 or later. "
"You have %s." % '.'.join(str(x) for x in self.gettext_version)
)
arg_add_location = "--add-location=%s" % options['add_location']
self.msgmerge_options = self.msgmerge_options[:] + [arg_add_location]
self.msguniq_options = self.msguniq_options[:] + [arg_add_location]
self.msgattrib_options = self.msgattrib_options[:] + [arg_add_location]
self.xgettext_options = self.xgettext_options[:] + [arg_add_location]
self.no_obsolete = options['no_obsolete']
self.keep_pot = options['keep_pot']
if self.domain not in ('django', 'djangojs'):
raise CommandError("currently makemessages only supports domains "
"'django' and 'djangojs'")
if self.domain == 'djangojs':
exts = extensions or ['js']
else:
exts = extensions or ['html', 'txt', 'py']
self.extensions = handle_extensions(exts)
if (not locale and not exclude and not process_all) or self.domain is None:
raise CommandError(
"Type '%s help %s' for usage information."
% (os.path.basename(sys.argv[0]), sys.argv[1])
)
if self.verbosity > 1:
self.stdout.write(
'examining files with the extensions: %s'
% get_text_list(list(self.extensions), 'and')
)
self.invoked_for_django = False
self.locale_paths = []
self.default_locale_path = None
if os.path.isdir(os.path.join('conf', 'locale')):
self.locale_paths = [os.path.abspath(os.path.join('conf', 'locale'))]
self.default_locale_path = self.locale_paths[0]
self.invoked_for_django = True
else:
if self.settings_available:
self.locale_paths.extend(settings.LOCALE_PATHS)
# Allow to run makemessages inside an app dir
if os.path.isdir('locale'):
self.locale_paths.append(os.path.abspath('locale'))
if self.locale_paths:
self.default_locale_path = self.locale_paths[0]
os.makedirs(self.default_locale_path, exist_ok=True)
# Build locale list
looks_like_locale = re.compile(r'[a-z]{2}')
locale_dirs = filter(os.path.isdir, glob.glob('%s/*' % self.default_locale_path))
all_locales = [
lang_code for lang_code in map(os.path.basename, locale_dirs)
if looks_like_locale.match(lang_code)
]
# Account for excluded locales
if process_all:
locales = all_locales
else:
locales = locale or all_locales
locales = set(locales).difference(exclude)
if locales:
check_programs('msguniq', 'msgmerge', 'msgattrib')
check_programs('xgettext')
try:
potfiles = self.build_potfiles()
# Build po files for each selected locale
for locale in locales:
if '-' in locale:
self.stdout.write(
'invalid locale %s, did you mean %s?' % (
locale,
locale.replace('-', '_'),
),
)
continue
if self.verbosity > 0:
self.stdout.write('processing locale %s' % locale)
for potfile in potfiles:
self.write_po_file(potfile, locale)
finally:
if not self.keep_pot:
self.remove_potfiles()
@cached_property
def gettext_version(self):
# Gettext tools will output system-encoded bytestrings instead of UTF-8,
# when looking up the version. It's especially a problem on Windows.
out, err, status = popen_wrapper(
['xgettext', '--version'],
stdout_encoding=DEFAULT_LOCALE_ENCODING,
)
m = re.search(r'(\d+)\.(\d+)\.?(\d+)?', out)
if m:
return tuple(int(d) for d in m.groups() if d is not None)
else:
raise CommandError("Unable to get gettext version. Is it installed?")
@cached_property
def settings_available(self):
try:
settings.LOCALE_PATHS
except ImproperlyConfigured:
if self.verbosity > 1:
self.stderr.write("Running without configured settings.")
return False
return True
def build_potfiles(self):
"""
Build pot files and apply msguniq to them.
"""
file_list = self.find_files(".")
self.remove_potfiles()
self.process_files(file_list)
potfiles = []
for path in self.locale_paths:
potfile = os.path.join(path, '%s.pot' % self.domain)
if not os.path.exists(potfile):
continue
args = ['msguniq'] + self.msguniq_options + [potfile]
msgs, errors, status = popen_wrapper(args)
if errors:
if status != STATUS_OK:
raise CommandError(
"errors happened while running msguniq\n%s" % errors)
elif self.verbosity > 0:
self.stdout.write(errors)
msgs = normalize_eols(msgs)
with open(potfile, 'w', encoding='utf-8') as fp:
fp.write(msgs)
potfiles.append(potfile)
return potfiles
def remove_potfiles(self):
for path in self.locale_paths:
pot_path = os.path.join(path, '%s.pot' % self.domain)
if os.path.exists(pot_path):
os.unlink(pot_path)
def find_files(self, root):
"""
Get all files in the given root. Also check that there is a matching
locale dir for each file.
"""
all_files = []
ignored_roots = []
if self.settings_available:
ignored_roots = [os.path.normpath(p) for p in (settings.MEDIA_ROOT, settings.STATIC_ROOT) if p]
for dirpath, dirnames, filenames in os.walk(root, topdown=True, followlinks=self.symlinks):
for dirname in dirnames[:]:
if (is_ignored_path(os.path.normpath(os.path.join(dirpath, dirname)), self.ignore_patterns) or
os.path.join(os.path.abspath(dirpath), dirname) in ignored_roots):
dirnames.remove(dirname)
if self.verbosity > 1:
self.stdout.write('ignoring directory %s' % dirname)
elif dirname == 'locale':
dirnames.remove(dirname)
self.locale_paths.insert(0, os.path.join(os.path.abspath(dirpath), dirname))
for filename in filenames:
file_path = os.path.normpath(os.path.join(dirpath, filename))
file_ext = os.path.splitext(filename)[1]
if file_ext not in self.extensions or is_ignored_path(file_path, self.ignore_patterns):
if self.verbosity > 1:
self.stdout.write('ignoring file %s in %s' % (filename, dirpath))
else:
locale_dir = None
for path in self.locale_paths:
if os.path.abspath(dirpath).startswith(os.path.dirname(path)):
locale_dir = path
break
locale_dir = locale_dir or self.default_locale_path or NO_LOCALE_DIR
all_files.append(self.translatable_file_class(dirpath, filename, locale_dir))
return sorted(all_files)
def process_files(self, file_list):
"""
Group translatable files by locale directory and run pot file build
process for each group.
"""
file_groups = {}
for translatable in file_list:
file_group = file_groups.setdefault(translatable.locale_dir, [])
file_group.append(translatable)
for locale_dir, files in file_groups.items():
self.process_locale_dir(locale_dir, files)
def process_locale_dir(self, locale_dir, files):
"""
Extract translatable literals from the specified files, creating or
updating the POT file for a given locale directory.
Use the xgettext GNU gettext utility.
"""
build_files = []
for translatable in files:
if self.verbosity > 1:
self.stdout.write('processing file %s in %s' % (
translatable.file, translatable.dirpath
))
if self.domain not in ('djangojs', 'django'):
continue
build_file = self.build_file_class(self, self.domain, translatable)
try:
build_file.preprocess()
except UnicodeDecodeError as e:
self.stdout.write(
'UnicodeDecodeError: skipped file %s in %s (reason: %s)' % (
translatable.file, translatable.dirpath, e,
)
)
continue
except BaseException:
# Cleanup before exit.
for build_file in build_files:
build_file.cleanup()
raise
build_files.append(build_file)
if self.domain == 'djangojs':
is_templatized = build_file.is_templatized
args = [
'xgettext',
'-d', self.domain,
'--language=%s' % ('C' if is_templatized else 'JavaScript',),
'--keyword=gettext_noop',
'--keyword=gettext_lazy',
'--keyword=ngettext_lazy:1,2',
'--keyword=pgettext:1c,2',
'--keyword=npgettext:1c,2,3',
'--output=-',
]
elif self.domain == 'django':
args = [
'xgettext',
'-d', self.domain,
'--language=Python',
'--keyword=gettext_noop',
'--keyword=gettext_lazy',
'--keyword=ngettext_lazy:1,2',
'--keyword=pgettext:1c,2',
'--keyword=npgettext:1c,2,3',
'--keyword=pgettext_lazy:1c,2',
'--keyword=npgettext_lazy:1c,2,3',
'--output=-',
]
else:
return
input_files = [bf.work_path for bf in build_files]
with NamedTemporaryFile(mode='w+') as input_files_list:
input_files_list.write('\n'.join(input_files))
input_files_list.flush()
args.extend(['--files-from', input_files_list.name])
args.extend(self.xgettext_options)
msgs, errors, status = popen_wrapper(args)
if errors:
if status != STATUS_OK:
for build_file in build_files:
build_file.cleanup()
raise CommandError(
'errors happened while running xgettext on %s\n%s' %
('\n'.join(input_files), errors)
)
elif self.verbosity > 0:
# Print warnings
self.stdout.write(errors)
if msgs:
if locale_dir is NO_LOCALE_DIR:
for build_file in build_files:
build_file.cleanup()
file_path = os.path.normpath(build_files[0].path)
raise CommandError(
"Unable to find a locale path to store translations for "
"file %s. Make sure the 'locale' directory exists in an "
"app or LOCALE_PATHS setting is set." % file_path
)
for build_file in build_files:
msgs = build_file.postprocess_messages(msgs)
potfile = os.path.join(locale_dir, '%s.pot' % self.domain)
write_pot_file(potfile, msgs)
for build_file in build_files:
build_file.cleanup()
def write_po_file(self, potfile, locale):
"""
Create or update the PO file for self.domain and `locale`.
Use contents of the existing `potfile`.
Use msgmerge and msgattrib GNU gettext utilities.
"""
basedir = os.path.join(os.path.dirname(potfile), locale, 'LC_MESSAGES')
os.makedirs(basedir, exist_ok=True)
pofile = os.path.join(basedir, '%s.po' % self.domain)
if os.path.exists(pofile):
args = ['msgmerge'] + self.msgmerge_options + [pofile, potfile]
msgs, errors, status = popen_wrapper(args)
if errors:
if status != STATUS_OK:
raise CommandError(
"errors happened while running msgmerge\n%s" % errors)
elif self.verbosity > 0:
self.stdout.write(errors)
else:
with open(potfile, encoding='utf-8') as fp:
msgs = fp.read()
if not self.invoked_for_django:
msgs = self.copy_plural_forms(msgs, locale)
msgs = normalize_eols(msgs)
msgs = msgs.replace(
"#. #-#-#-#-# %s.pot (PACKAGE VERSION) #-#-#-#-#\n" % self.domain, "")
with open(pofile, 'w', encoding='utf-8') as fp:
fp.write(msgs)
if self.no_obsolete:
args = ['msgattrib'] + self.msgattrib_options + ['-o', pofile, pofile]
msgs, errors, status = popen_wrapper(args)
if errors:
if status != STATUS_OK:
raise CommandError(
"errors happened while running msgattrib\n%s" % errors)
elif self.verbosity > 0:
self.stdout.write(errors)
def copy_plural_forms(self, msgs, locale):
"""
Copy plural forms header contents from a Django catalog of locale to
the msgs string, inserting it at the right place. msgs should be the
contents of a newly created .po file.
"""
django_dir = os.path.normpath(os.path.join(os.path.dirname(django.__file__)))
if self.domain == 'djangojs':
domains = ('djangojs', 'django')
else:
domains = ('django',)
for domain in domains:
django_po = os.path.join(django_dir, 'conf', 'locale', locale, 'LC_MESSAGES', '%s.po' % domain)
if os.path.exists(django_po):
with open(django_po, encoding='utf-8') as fp:
m = plural_forms_re.search(fp.read())
if m:
plural_form_line = m['value']
if self.verbosity > 1:
self.stdout.write('copying plural forms: %s' % plural_form_line)
lines = []
found = False
for line in msgs.splitlines():
if not found and (not line or plural_forms_re.search(line)):
line = plural_form_line
found = True
lines.append(line)
msgs = '\n'.join(lines)
break
return msgs

View File

@@ -0,0 +1,325 @@
import os
import sys
import warnings
from itertools import takewhile
from django.apps import apps
from django.conf import settings
from django.core.management.base import (
BaseCommand, CommandError, no_translations,
)
from django.db import DEFAULT_DB_ALIAS, OperationalError, connections, router
from django.db.migrations import Migration
from django.db.migrations.autodetector import MigrationAutodetector
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.questioner import (
InteractiveMigrationQuestioner, MigrationQuestioner,
NonInteractiveMigrationQuestioner,
)
from django.db.migrations.state import ProjectState
from django.db.migrations.utils import get_migration_name_timestamp
from django.db.migrations.writer import MigrationWriter
class Command(BaseCommand):
help = "Creates new migration(s) for apps."
def add_arguments(self, parser):
parser.add_argument(
'args', metavar='app_label', nargs='*',
help='Specify the app label(s) to create migrations for.',
)
parser.add_argument(
'--dry-run', action='store_true',
help="Just show what migrations would be made; don't actually write them.",
)
parser.add_argument(
'--merge', action='store_true',
help="Enable fixing of migration conflicts.",
)
parser.add_argument(
'--empty', action='store_true',
help="Create an empty migration.",
)
parser.add_argument(
'--noinput', '--no-input', action='store_false', dest='interactive',
help='Tells Django to NOT prompt the user for input of any kind.',
)
parser.add_argument(
'-n', '--name',
help="Use this name for migration file(s).",
)
parser.add_argument(
'--no-header', action='store_false', dest='include_header',
help='Do not add header comments to new migration file(s).',
)
parser.add_argument(
'--check', action='store_true', dest='check_changes',
help='Exit with a non-zero status if model changes are missing migrations.',
)
@no_translations
def handle(self, *app_labels, **options):
self.verbosity = options['verbosity']
self.interactive = options['interactive']
self.dry_run = options['dry_run']
self.merge = options['merge']
self.empty = options['empty']
self.migration_name = options['name']
if self.migration_name and not self.migration_name.isidentifier():
raise CommandError('The migration name must be a valid Python identifier.')
self.include_header = options['include_header']
check_changes = options['check_changes']
# Make sure the app they asked for exists
app_labels = set(app_labels)
has_bad_labels = False
for app_label in app_labels:
try:
apps.get_app_config(app_label)
except LookupError as err:
self.stderr.write(str(err))
has_bad_labels = True
if has_bad_labels:
sys.exit(2)
# Load the current graph state. Pass in None for the connection so
# the loader doesn't try to resolve replaced migrations from DB.
loader = MigrationLoader(None, ignore_no_migrations=True)
# Raise an error if any migrations are applied before their dependencies.
consistency_check_labels = {config.label for config in apps.get_app_configs()}
# Non-default databases are only checked if database routers used.
aliases_to_check = connections if settings.DATABASE_ROUTERS else [DEFAULT_DB_ALIAS]
for alias in sorted(aliases_to_check):
connection = connections[alias]
if (connection.settings_dict['ENGINE'] != 'django.db.backends.dummy' and any(
# At least one model must be migrated to the database.
router.allow_migrate(connection.alias, app_label, model_name=model._meta.object_name)
for app_label in consistency_check_labels
for model in apps.get_app_config(app_label).get_models()
)):
try:
loader.check_consistent_history(connection)
except OperationalError as error:
warnings.warn(
"Got an error checking a consistent migration history "
"performed for database connection '%s': %s"
% (alias, error),
RuntimeWarning,
)
# Before anything else, see if there's conflicting apps and drop out
# hard if there are any and they don't want to merge
conflicts = loader.detect_conflicts()
# If app_labels is specified, filter out conflicting migrations for unspecified apps
if app_labels:
conflicts = {
app_label: conflict for app_label, conflict in conflicts.items()
if app_label in app_labels
}
if conflicts and not self.merge:
name_str = "; ".join(
"%s in %s" % (", ".join(names), app)
for app, names in conflicts.items()
)
raise CommandError(
"Conflicting migrations detected; multiple leaf nodes in the "
"migration graph: (%s).\nTo fix them run "
"'python manage.py makemigrations --merge'" % name_str
)
# If they want to merge and there's nothing to merge, then politely exit
if self.merge and not conflicts:
self.stdout.write("No conflicts detected to merge.")
return
# If they want to merge and there is something to merge, then
# divert into the merge code
if self.merge and conflicts:
return self.handle_merge(loader, conflicts)
if self.interactive:
questioner = InteractiveMigrationQuestioner(specified_apps=app_labels, dry_run=self.dry_run)
else:
questioner = NonInteractiveMigrationQuestioner(specified_apps=app_labels, dry_run=self.dry_run)
# Set up autodetector
autodetector = MigrationAutodetector(
loader.project_state(),
ProjectState.from_apps(apps),
questioner,
)
# If they want to make an empty migration, make one for each app
if self.empty:
if not app_labels:
raise CommandError("You must supply at least one app label when using --empty.")
# Make a fake changes() result we can pass to arrange_for_graph
changes = {
app: [Migration("custom", app)]
for app in app_labels
}
changes = autodetector.arrange_for_graph(
changes=changes,
graph=loader.graph,
migration_name=self.migration_name,
)
self.write_migration_files(changes)
return
# Detect changes
changes = autodetector.changes(
graph=loader.graph,
trim_to_apps=app_labels or None,
convert_apps=app_labels or None,
migration_name=self.migration_name,
)
if not changes:
# No changes? Tell them.
if self.verbosity >= 1:
if app_labels:
if len(app_labels) == 1:
self.stdout.write("No changes detected in app '%s'" % app_labels.pop())
else:
self.stdout.write("No changes detected in apps '%s'" % ("', '".join(app_labels)))
else:
self.stdout.write("No changes detected")
else:
self.write_migration_files(changes)
if check_changes:
sys.exit(1)
def write_migration_files(self, changes):
"""
Take a changes dict and write them out as migration files.
"""
directory_created = {}
for app_label, app_migrations in changes.items():
if self.verbosity >= 1:
self.stdout.write(self.style.MIGRATE_HEADING("Migrations for '%s':" % app_label))
for migration in app_migrations:
# Describe the migration
writer = MigrationWriter(migration, self.include_header)
if self.verbosity >= 1:
# Display a relative path if it's below the current working
# directory, or an absolute path otherwise.
try:
migration_string = os.path.relpath(writer.path)
except ValueError:
migration_string = writer.path
if migration_string.startswith('..'):
migration_string = writer.path
self.stdout.write(' %s\n' % self.style.MIGRATE_LABEL(migration_string))
for operation in migration.operations:
self.stdout.write(' - %s' % operation.describe())
if not self.dry_run:
# Write the migrations file to the disk.
migrations_directory = os.path.dirname(writer.path)
if not directory_created.get(app_label):
os.makedirs(migrations_directory, exist_ok=True)
init_path = os.path.join(migrations_directory, "__init__.py")
if not os.path.isfile(init_path):
open(init_path, "w").close()
# We just do this once per app
directory_created[app_label] = True
migration_string = writer.as_string()
with open(writer.path, "w", encoding='utf-8') as fh:
fh.write(migration_string)
elif self.verbosity == 3:
# Alternatively, makemigrations --dry-run --verbosity 3
# will output the migrations to stdout rather than saving
# the file to the disk.
self.stdout.write(self.style.MIGRATE_HEADING(
"Full migrations file '%s':" % writer.filename
))
self.stdout.write(writer.as_string())
def handle_merge(self, loader, conflicts):
"""
Handles merging together conflicted migrations interactively,
if it's safe; otherwise, advises on how to fix it.
"""
if self.interactive:
questioner = InteractiveMigrationQuestioner()
else:
questioner = MigrationQuestioner(defaults={'ask_merge': True})
for app_label, migration_names in conflicts.items():
# Grab out the migrations in question, and work out their
# common ancestor.
merge_migrations = []
for migration_name in migration_names:
migration = loader.get_migration(app_label, migration_name)
migration.ancestry = [
mig for mig in loader.graph.forwards_plan((app_label, migration_name))
if mig[0] == migration.app_label
]
merge_migrations.append(migration)
def all_items_equal(seq):
return all(item == seq[0] for item in seq[1:])
merge_migrations_generations = zip(*(m.ancestry for m in merge_migrations))
common_ancestor_count = sum(1 for common_ancestor_generation
in takewhile(all_items_equal, merge_migrations_generations))
if not common_ancestor_count:
raise ValueError("Could not find common ancestor of %s" % migration_names)
# Now work out the operations along each divergent branch
for migration in merge_migrations:
migration.branch = migration.ancestry[common_ancestor_count:]
migrations_ops = (loader.get_migration(node_app, node_name).operations
for node_app, node_name in migration.branch)
migration.merged_operations = sum(migrations_ops, [])
# In future, this could use some of the Optimizer code
# (can_optimize_through) to automatically see if they're
# mergeable. For now, we always just prompt the user.
if self.verbosity > 0:
self.stdout.write(self.style.MIGRATE_HEADING("Merging %s" % app_label))
for migration in merge_migrations:
self.stdout.write(self.style.MIGRATE_LABEL(" Branch %s" % migration.name))
for operation in migration.merged_operations:
self.stdout.write(' - %s' % operation.describe())
if questioner.ask_merge(app_label):
# If they still want to merge it, then write out an empty
# file depending on the migrations needing merging.
numbers = [
MigrationAutodetector.parse_number(migration.name)
for migration in merge_migrations
]
try:
biggest_number = max(x for x in numbers if x is not None)
except ValueError:
biggest_number = 1
subclass = type("Migration", (Migration,), {
"dependencies": [(app_label, migration.name) for migration in merge_migrations],
})
parts = ['%04i' % (biggest_number + 1)]
if self.migration_name:
parts.append(self.migration_name)
else:
parts.append('merge')
leaf_names = '_'.join(sorted(migration.name for migration in merge_migrations))
if len(leaf_names) > 47:
parts.append(get_migration_name_timestamp())
else:
parts.append(leaf_names)
migration_name = '_'.join(parts)
new_migration = subclass(migration_name, app_label)
writer = MigrationWriter(new_migration, self.include_header)
if not self.dry_run:
# Write the merge migrations file to the disk
with open(writer.path, "w", encoding='utf-8') as fh:
fh.write(writer.as_string())
if self.verbosity > 0:
self.stdout.write("\nCreated new merge migration %s" % writer.path)
elif self.verbosity == 3:
# Alternatively, makemigrations --merge --dry-run --verbosity 3
# will output the merge migrations to stdout rather than saving
# the file to the disk.
self.stdout.write(self.style.MIGRATE_HEADING(
"Full merge migrations file '%s':" % writer.filename
))
self.stdout.write(writer.as_string())

View File

@@ -0,0 +1,386 @@
import sys
import time
from importlib import import_module
from django.apps import apps
from django.core.management.base import (
BaseCommand, CommandError, no_translations,
)
from django.core.management.sql import (
emit_post_migrate_signal, emit_pre_migrate_signal,
)
from django.db import DEFAULT_DB_ALIAS, connections, router
from django.db.migrations.autodetector import MigrationAutodetector
from django.db.migrations.executor import MigrationExecutor
from django.db.migrations.loader import AmbiguityError
from django.db.migrations.state import ModelState, ProjectState
from django.utils.module_loading import module_has_submodule
from django.utils.text import Truncator
class Command(BaseCommand):
help = "Updates database schema. Manages both apps with migrations and those without."
requires_system_checks = []
def add_arguments(self, parser):
parser.add_argument(
'--skip-checks', action='store_true',
help='Skip system checks.',
)
parser.add_argument(
'app_label', nargs='?',
help='App label of an application to synchronize the state.',
)
parser.add_argument(
'migration_name', nargs='?',
help='Database state will be brought to the state after that '
'migration. Use the name "zero" to unapply all migrations.',
)
parser.add_argument(
'--noinput', '--no-input', action='store_false', dest='interactive',
help='Tells Django to NOT prompt the user for input of any kind.',
)
parser.add_argument(
'--database',
default=DEFAULT_DB_ALIAS,
help='Nominates a database to synchronize. Defaults to the "default" database.',
)
parser.add_argument(
'--fake', action='store_true',
help='Mark migrations as run without actually running them.',
)
parser.add_argument(
'--fake-initial', action='store_true',
help='Detect if tables already exist and fake-apply initial migrations if so. Make sure '
'that the current database schema matches your initial migration before using this '
'flag. Django will only check for an existing table name.',
)
parser.add_argument(
'--plan', action='store_true',
help='Shows a list of the migration actions that will be performed.',
)
parser.add_argument(
'--run-syncdb', action='store_true',
help='Creates tables for apps without migrations.',
)
parser.add_argument(
'--check', action='store_true', dest='check_unapplied',
help='Exits with a non-zero status if unapplied migrations exist.',
)
@no_translations
def handle(self, *args, **options):
database = options['database']
if not options['skip_checks']:
self.check(databases=[database])
self.verbosity = options['verbosity']
self.interactive = options['interactive']
# Import the 'management' module within each installed app, to register
# dispatcher events.
for app_config in apps.get_app_configs():
if module_has_submodule(app_config.module, "management"):
import_module('.management', app_config.name)
# Get the database we're operating from
connection = connections[database]
# Hook for backends needing any database preparation
connection.prepare_database()
# Work out which apps have migrations and which do not
executor = MigrationExecutor(connection, self.migration_progress_callback)
# Raise an error if any migrations are applied before their dependencies.
executor.loader.check_consistent_history(connection)
# Before anything else, see if there's conflicting apps and drop out
# hard if there are any
conflicts = executor.loader.detect_conflicts()
if conflicts:
name_str = "; ".join(
"%s in %s" % (", ".join(names), app)
for app, names in conflicts.items()
)
raise CommandError(
"Conflicting migrations detected; multiple leaf nodes in the "
"migration graph: (%s).\nTo fix them run "
"'python manage.py makemigrations --merge'" % name_str
)
# If they supplied command line arguments, work out what they mean.
run_syncdb = options['run_syncdb']
target_app_labels_only = True
if options['app_label']:
# Validate app_label.
app_label = options['app_label']
try:
apps.get_app_config(app_label)
except LookupError as err:
raise CommandError(str(err))
if run_syncdb:
if app_label in executor.loader.migrated_apps:
raise CommandError("Can't use run_syncdb with app '%s' as it has migrations." % app_label)
elif app_label not in executor.loader.migrated_apps:
raise CommandError("App '%s' does not have migrations." % app_label)
if options['app_label'] and options['migration_name']:
migration_name = options['migration_name']
if migration_name == "zero":
targets = [(app_label, None)]
else:
try:
migration = executor.loader.get_migration_by_prefix(app_label, migration_name)
except AmbiguityError:
raise CommandError(
"More than one migration matches '%s' in app '%s'. "
"Please be more specific." %
(migration_name, app_label)
)
except KeyError:
raise CommandError("Cannot find a migration matching '%s' from app '%s'." % (
migration_name, app_label))
target = (app_label, migration.name)
# Partially applied squashed migrations are not included in the
# graph, use the last replacement instead.
if (
target not in executor.loader.graph.nodes and
target in executor.loader.replacements
):
incomplete_migration = executor.loader.replacements[target]
target = incomplete_migration.replaces[-1]
targets = [target]
target_app_labels_only = False
elif options['app_label']:
targets = [key for key in executor.loader.graph.leaf_nodes() if key[0] == app_label]
else:
targets = executor.loader.graph.leaf_nodes()
plan = executor.migration_plan(targets)
exit_dry = plan and options['check_unapplied']
if options['plan']:
self.stdout.write('Planned operations:', self.style.MIGRATE_LABEL)
if not plan:
self.stdout.write(' No planned migration operations.')
for migration, backwards in plan:
self.stdout.write(str(migration), self.style.MIGRATE_HEADING)
for operation in migration.operations:
message, is_error = self.describe_operation(operation, backwards)
style = self.style.WARNING if is_error else None
self.stdout.write(' ' + message, style)
if exit_dry:
sys.exit(1)
return
if exit_dry:
sys.exit(1)
# At this point, ignore run_syncdb if there aren't any apps to sync.
run_syncdb = options['run_syncdb'] and executor.loader.unmigrated_apps
# Print some useful info
if self.verbosity >= 1:
self.stdout.write(self.style.MIGRATE_HEADING("Operations to perform:"))
if run_syncdb:
if options['app_label']:
self.stdout.write(
self.style.MIGRATE_LABEL(" Synchronize unmigrated app: %s" % app_label)
)
else:
self.stdout.write(
self.style.MIGRATE_LABEL(" Synchronize unmigrated apps: ") +
(", ".join(sorted(executor.loader.unmigrated_apps)))
)
if target_app_labels_only:
self.stdout.write(
self.style.MIGRATE_LABEL(" Apply all migrations: ") +
(", ".join(sorted({a for a, n in targets})) or "(none)")
)
else:
if targets[0][1] is None:
self.stdout.write(
self.style.MIGRATE_LABEL(' Unapply all migrations: ') +
str(targets[0][0])
)
else:
self.stdout.write(self.style.MIGRATE_LABEL(
" Target specific migration: ") + "%s, from %s"
% (targets[0][1], targets[0][0])
)
pre_migrate_state = executor._create_project_state(with_applied_migrations=True)
pre_migrate_apps = pre_migrate_state.apps
emit_pre_migrate_signal(
self.verbosity, self.interactive, connection.alias, stdout=self.stdout, apps=pre_migrate_apps, plan=plan,
)
# Run the syncdb phase.
if run_syncdb:
if self.verbosity >= 1:
self.stdout.write(self.style.MIGRATE_HEADING("Synchronizing apps without migrations:"))
if options['app_label']:
self.sync_apps(connection, [app_label])
else:
self.sync_apps(connection, executor.loader.unmigrated_apps)
# Migrate!
if self.verbosity >= 1:
self.stdout.write(self.style.MIGRATE_HEADING("Running migrations:"))
if not plan:
if self.verbosity >= 1:
self.stdout.write(" No migrations to apply.")
# If there's changes that aren't in migrations yet, tell them how to fix it.
autodetector = MigrationAutodetector(
executor.loader.project_state(),
ProjectState.from_apps(apps),
)
changes = autodetector.changes(graph=executor.loader.graph)
if changes:
self.stdout.write(self.style.NOTICE(
" Your models in app(s): %s have changes that are not "
"yet reflected in a migration, and so won't be "
"applied." % ", ".join(repr(app) for app in sorted(changes))
))
self.stdout.write(self.style.NOTICE(
" Run 'manage.py makemigrations' to make new "
"migrations, and then re-run 'manage.py migrate' to "
"apply them."
))
fake = False
fake_initial = False
else:
fake = options['fake']
fake_initial = options['fake_initial']
post_migrate_state = executor.migrate(
targets, plan=plan, state=pre_migrate_state.clone(), fake=fake,
fake_initial=fake_initial,
)
# post_migrate signals have access to all models. Ensure that all models
# are reloaded in case any are delayed.
post_migrate_state.clear_delayed_apps_cache()
post_migrate_apps = post_migrate_state.apps
# Re-render models of real apps to include relationships now that
# we've got a final state. This wouldn't be necessary if real apps
# models were rendered with relationships in the first place.
with post_migrate_apps.bulk_update():
model_keys = []
for model_state in post_migrate_apps.real_models:
model_key = model_state.app_label, model_state.name_lower
model_keys.append(model_key)
post_migrate_apps.unregister_model(*model_key)
post_migrate_apps.render_multiple([
ModelState.from_model(apps.get_model(*model)) for model in model_keys
])
# Send the post_migrate signal, so individual apps can do whatever they need
# to do at this point.
emit_post_migrate_signal(
self.verbosity, self.interactive, connection.alias, stdout=self.stdout, apps=post_migrate_apps, plan=plan,
)
def migration_progress_callback(self, action, migration=None, fake=False):
if self.verbosity >= 1:
compute_time = self.verbosity > 1
if action == "apply_start":
if compute_time:
self.start = time.monotonic()
self.stdout.write(" Applying %s..." % migration, ending="")
self.stdout.flush()
elif action == "apply_success":
elapsed = " (%.3fs)" % (time.monotonic() - self.start) if compute_time else ""
if fake:
self.stdout.write(self.style.SUCCESS(" FAKED" + elapsed))
else:
self.stdout.write(self.style.SUCCESS(" OK" + elapsed))
elif action == "unapply_start":
if compute_time:
self.start = time.monotonic()
self.stdout.write(" Unapplying %s..." % migration, ending="")
self.stdout.flush()
elif action == "unapply_success":
elapsed = " (%.3fs)" % (time.monotonic() - self.start) if compute_time else ""
if fake:
self.stdout.write(self.style.SUCCESS(" FAKED" + elapsed))
else:
self.stdout.write(self.style.SUCCESS(" OK" + elapsed))
elif action == "render_start":
if compute_time:
self.start = time.monotonic()
self.stdout.write(" Rendering model states...", ending="")
self.stdout.flush()
elif action == "render_success":
elapsed = " (%.3fs)" % (time.monotonic() - self.start) if compute_time else ""
self.stdout.write(self.style.SUCCESS(" DONE" + elapsed))
def sync_apps(self, connection, app_labels):
"""Run the old syncdb-style operation on a list of app_labels."""
with connection.cursor() as cursor:
tables = connection.introspection.table_names(cursor)
# Build the manifest of apps and models that are to be synchronized.
all_models = [
(
app_config.label,
router.get_migratable_models(app_config, connection.alias, include_auto_created=False),
)
for app_config in apps.get_app_configs()
if app_config.models_module is not None and app_config.label in app_labels
]
def model_installed(model):
opts = model._meta
converter = connection.introspection.identifier_converter
return not (
(converter(opts.db_table) in tables) or
(opts.auto_created and converter(opts.auto_created._meta.db_table) in tables)
)
manifest = {
app_name: list(filter(model_installed, model_list))
for app_name, model_list in all_models
}
# Create the tables for each model
if self.verbosity >= 1:
self.stdout.write(' Creating tables...')
with connection.schema_editor() as editor:
for app_name, model_list in manifest.items():
for model in model_list:
# Never install unmanaged models, etc.
if not model._meta.can_migrate(connection):
continue
if self.verbosity >= 3:
self.stdout.write(
' Processing %s.%s model' % (app_name, model._meta.object_name)
)
if self.verbosity >= 1:
self.stdout.write(' Creating table %s' % model._meta.db_table)
editor.create_model(model)
# Deferred SQL is executed when exiting the editor's context.
if self.verbosity >= 1:
self.stdout.write(' Running deferred SQL...')
@staticmethod
def describe_operation(operation, backwards):
"""Return a string that describes a migration operation for --plan."""
prefix = ''
is_error = False
if hasattr(operation, 'code'):
code = operation.reverse_code if backwards else operation.code
action = (code.__doc__ or '') if code else None
elif hasattr(operation, 'sql'):
action = operation.reverse_sql if backwards else operation.sql
else:
action = ''
if backwards:
prefix = 'Undo '
if action is not None:
action = str(action).replace('\n', '')
elif backwards:
action = 'IRREVERSIBLE'
is_error = True
if action:
action = ' -> ' + action
truncated = Truncator(action)
return prefix + operation.describe() + truncated.chars(40), is_error

View File

@@ -0,0 +1,164 @@
import errno
import os
import re
import socket
import sys
from datetime import datetime
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django.core.servers.basehttp import (
WSGIServer, get_internal_wsgi_application, run,
)
from django.utils import autoreload
from django.utils.regex_helper import _lazy_re_compile
naiveip_re = _lazy_re_compile(r"""^(?:
(?P<addr>
(?P<ipv4>\d{1,3}(?:\.\d{1,3}){3}) | # IPv4 address
(?P<ipv6>\[[a-fA-F0-9:]+\]) | # IPv6 address
(?P<fqdn>[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*) # FQDN
):)?(?P<port>\d+)$""", re.X)
class Command(BaseCommand):
help = "Starts a lightweight web server for development."
# Validation is called explicitly each time the server is reloaded.
requires_system_checks = []
stealth_options = ('shutdown_message',)
suppressed_base_arguments = {'--verbosity', '--traceback'}
default_addr = '127.0.0.1'
default_addr_ipv6 = '::1'
default_port = '8000'
protocol = 'http'
server_cls = WSGIServer
def add_arguments(self, parser):
parser.add_argument(
'addrport', nargs='?',
help='Optional port number, or ipaddr:port'
)
parser.add_argument(
'--ipv6', '-6', action='store_true', dest='use_ipv6',
help='Tells Django to use an IPv6 address.',
)
parser.add_argument(
'--nothreading', action='store_false', dest='use_threading',
help='Tells Django to NOT use threading.',
)
parser.add_argument(
'--noreload', action='store_false', dest='use_reloader',
help='Tells Django to NOT use the auto-reloader.',
)
parser.add_argument(
'--skip-checks', action='store_true',
help='Skip system checks.',
)
def execute(self, *args, **options):
if options['no_color']:
# We rely on the environment because it's currently the only
# way to reach WSGIRequestHandler. This seems an acceptable
# compromise considering `runserver` runs indefinitely.
os.environ["DJANGO_COLORS"] = "nocolor"
super().execute(*args, **options)
def get_handler(self, *args, **options):
"""Return the default WSGI handler for the runner."""
return get_internal_wsgi_application()
def handle(self, *args, **options):
if not settings.DEBUG and not settings.ALLOWED_HOSTS:
raise CommandError('You must set settings.ALLOWED_HOSTS if DEBUG is False.')
self.use_ipv6 = options['use_ipv6']
if self.use_ipv6 and not socket.has_ipv6:
raise CommandError('Your Python does not support IPv6.')
self._raw_ipv6 = False
if not options['addrport']:
self.addr = ''
self.port = self.default_port
else:
m = re.match(naiveip_re, options['addrport'])
if m is None:
raise CommandError('"%s" is not a valid port number '
'or address:port pair.' % options['addrport'])
self.addr, _ipv4, _ipv6, _fqdn, self.port = m.groups()
if not self.port.isdigit():
raise CommandError("%r is not a valid port number." % self.port)
if self.addr:
if _ipv6:
self.addr = self.addr[1:-1]
self.use_ipv6 = True
self._raw_ipv6 = True
elif self.use_ipv6 and not _fqdn:
raise CommandError('"%s" is not a valid IPv6 address.' % self.addr)
if not self.addr:
self.addr = self.default_addr_ipv6 if self.use_ipv6 else self.default_addr
self._raw_ipv6 = self.use_ipv6
self.run(**options)
def run(self, **options):
"""Run the server, using the autoreloader if needed."""
use_reloader = options['use_reloader']
if use_reloader:
autoreload.run_with_reloader(self.inner_run, **options)
else:
self.inner_run(None, **options)
def inner_run(self, *args, **options):
# If an exception was silenced in ManagementUtility.execute in order
# to be raised in the child process, raise it now.
autoreload.raise_last_exception()
threading = options['use_threading']
# 'shutdown_message' is a stealth option.
shutdown_message = options.get('shutdown_message', '')
quit_command = 'CTRL-BREAK' if sys.platform == 'win32' else 'CONTROL-C'
if not options['skip_checks']:
self.stdout.write('Performing system checks...\n\n')
self.check(display_num_errors=True)
# Need to check migrations here, so can't use the
# requires_migrations_check attribute.
self.check_migrations()
now = datetime.now().strftime('%B %d, %Y - %X')
self.stdout.write(now)
self.stdout.write((
"Django version %(version)s, using settings %(settings)r\n"
"Starting development server at %(protocol)s://%(addr)s:%(port)s/\n"
"Quit the server with %(quit_command)s."
) % {
"version": self.get_version(),
"settings": settings.SETTINGS_MODULE,
"protocol": self.protocol,
"addr": '[%s]' % self.addr if self._raw_ipv6 else self.addr,
"port": self.port,
"quit_command": quit_command,
})
try:
handler = self.get_handler(*args, **options)
run(self.addr, int(self.port), handler,
ipv6=self.use_ipv6, threading=threading, server_cls=self.server_cls)
except OSError as e:
# Use helpful error messages instead of ugly tracebacks.
ERRORS = {
errno.EACCES: "You don't have permission to access that port.",
errno.EADDRINUSE: "That port is already in use.",
errno.EADDRNOTAVAIL: "That IP address can't be assigned to.",
}
try:
error_text = ERRORS[e.errno]
except KeyError:
error_text = e
self.stderr.write("Error: %s" % error_text)
# Need to use an OS exit because sys.exit doesn't work in a thread
os._exit(1)
except KeyboardInterrupt:
if shutdown_message:
self.stdout.write(shutdown_message)
sys.exit(0)

View File

@@ -0,0 +1,40 @@
import socket
from django.core.mail import mail_admins, mail_managers, send_mail
from django.core.management.base import BaseCommand
from django.utils import timezone
class Command(BaseCommand):
help = "Sends a test email to the email addresses specified as arguments."
missing_args_message = "You must specify some email recipients, or pass the --managers or --admin options."
def add_arguments(self, parser):
parser.add_argument(
'email', nargs='*',
help='One or more email addresses to send a test email to.',
)
parser.add_argument(
'--managers', action='store_true',
help='Send a test email to the addresses specified in settings.MANAGERS.',
)
parser.add_argument(
'--admins', action='store_true',
help='Send a test email to the addresses specified in settings.ADMINS.',
)
def handle(self, *args, **kwargs):
subject = 'Test email from %s on %s' % (socket.gethostname(), timezone.now())
send_mail(
subject=subject,
message="If you\'re reading this, it was successful.",
from_email=None,
recipient_list=kwargs['email'],
)
if kwargs['managers']:
mail_managers(subject, "This email was sent to the site managers.")
if kwargs['admins']:
mail_admins(subject, "This email was sent to the site admins.")

View File

@@ -0,0 +1,115 @@
import os
import select
import sys
import traceback
from django.core.management import BaseCommand, CommandError
from django.utils.datastructures import OrderedSet
class Command(BaseCommand):
help = (
"Runs a Python interactive interpreter. Tries to use IPython or "
"bpython, if one of them is available. Any standard input is executed "
"as code."
)
requires_system_checks = []
shells = ['ipython', 'bpython', 'python']
def add_arguments(self, parser):
parser.add_argument(
'--no-startup', action='store_true',
help='When using plain Python, ignore the PYTHONSTARTUP environment variable and ~/.pythonrc.py script.',
)
parser.add_argument(
'-i', '--interface', choices=self.shells,
help='Specify an interactive interpreter interface. Available options: "ipython", "bpython", and "python"',
)
parser.add_argument(
'-c', '--command',
help='Instead of opening an interactive shell, run a command as Django and exit.',
)
def ipython(self, options):
from IPython import start_ipython
start_ipython(argv=[])
def bpython(self, options):
import bpython
bpython.embed()
def python(self, options):
import code
# Set up a dictionary to serve as the environment for the shell.
imported_objects = {}
# We want to honor both $PYTHONSTARTUP and .pythonrc.py, so follow system
# conventions and get $PYTHONSTARTUP first then .pythonrc.py.
if not options['no_startup']:
for pythonrc in OrderedSet([os.environ.get("PYTHONSTARTUP"), os.path.expanduser('~/.pythonrc.py')]):
if not pythonrc:
continue
if not os.path.isfile(pythonrc):
continue
with open(pythonrc) as handle:
pythonrc_code = handle.read()
# Match the behavior of the cpython shell where an error in
# PYTHONSTARTUP prints an exception and continues.
try:
exec(compile(pythonrc_code, pythonrc, 'exec'), imported_objects)
except Exception:
traceback.print_exc()
# By default, this will set up readline to do tab completion and to read and
# write history to the .python_history file, but this can be overridden by
# $PYTHONSTARTUP or ~/.pythonrc.py.
try:
hook = sys.__interactivehook__
except AttributeError:
# Match the behavior of the cpython shell where a missing
# sys.__interactivehook__ is ignored.
pass
else:
try:
hook()
except Exception:
# Match the behavior of the cpython shell where an error in
# sys.__interactivehook__ prints a warning and the exception
# and continues.
print('Failed calling sys.__interactivehook__')
traceback.print_exc()
# Set up tab completion for objects imported by $PYTHONSTARTUP or
# ~/.pythonrc.py.
try:
import readline
import rlcompleter
readline.set_completer(rlcompleter.Completer(imported_objects).complete)
except ImportError:
pass
# Start the interactive interpreter.
code.interact(local=imported_objects)
def handle(self, **options):
# Execute the command and exit.
if options['command']:
exec(options['command'], globals())
return
# Execute stdin if it has anything to read and exit.
# Not supported on Windows due to select.select() limitations.
if sys.platform != 'win32' and not sys.stdin.isatty() and select.select([sys.stdin], [], [], 0)[0]:
exec(sys.stdin.read(), globals())
return
available_shells = [options['interface']] if options['interface'] else self.shells
for shell in available_shells:
try:
return getattr(self, shell)(options)
except ImportError:
pass
raise CommandError("Couldn't import {} interface.".format(shell))

View File

@@ -0,0 +1,157 @@
import sys
from django.apps import apps
from django.core.management.base import BaseCommand
from django.db import DEFAULT_DB_ALIAS, connections
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.recorder import MigrationRecorder
class Command(BaseCommand):
help = "Shows all available migrations for the current project"
def add_arguments(self, parser):
parser.add_argument(
'app_label', nargs='*',
help='App labels of applications to limit the output to.',
)
parser.add_argument(
'--database', default=DEFAULT_DB_ALIAS,
help=(
'Nominates a database to show migrations for. Defaults to the '
'"default" database.'
),
)
formats = parser.add_mutually_exclusive_group()
formats.add_argument(
'--list', '-l', action='store_const', dest='format', const='list',
help=(
'Shows a list of all migrations and which are applied. '
'With a verbosity level of 2 or above, the applied datetimes '
'will be included.'
),
)
formats.add_argument(
'--plan', '-p', action='store_const', dest='format', const='plan',
help=(
'Shows all migrations in the order they will be applied. '
'With a verbosity level of 2 or above all direct migration dependencies '
'and reverse dependencies (run_before) will be included.'
)
)
parser.set_defaults(format='list')
def handle(self, *args, **options):
self.verbosity = options['verbosity']
# Get the database we're operating from
db = options['database']
connection = connections[db]
if options['format'] == "plan":
return self.show_plan(connection, options['app_label'])
else:
return self.show_list(connection, options['app_label'])
def _validate_app_names(self, loader, app_names):
has_bad_names = False
for app_name in app_names:
try:
apps.get_app_config(app_name)
except LookupError as err:
self.stderr.write(str(err))
has_bad_names = True
if has_bad_names:
sys.exit(2)
def show_list(self, connection, app_names=None):
"""
Show a list of all migrations on the system, or only those of
some named apps.
"""
# Load migrations from disk/DB
loader = MigrationLoader(connection, ignore_no_migrations=True)
recorder = MigrationRecorder(connection)
recorded_migrations = recorder.applied_migrations()
graph = loader.graph
# If we were passed a list of apps, validate it
if app_names:
self._validate_app_names(loader, app_names)
# Otherwise, show all apps in alphabetic order
else:
app_names = sorted(loader.migrated_apps)
# For each app, print its migrations in order from oldest (roots) to
# newest (leaves).
for app_name in app_names:
self.stdout.write(app_name, self.style.MIGRATE_LABEL)
shown = set()
for node in graph.leaf_nodes(app_name):
for plan_node in graph.forwards_plan(node):
if plan_node not in shown and plan_node[0] == app_name:
# Give it a nice title if it's a squashed one
title = plan_node[1]
if graph.nodes[plan_node].replaces:
title += " (%s squashed migrations)" % len(graph.nodes[plan_node].replaces)
applied_migration = loader.applied_migrations.get(plan_node)
# Mark it as applied/unapplied
if applied_migration:
if plan_node in recorded_migrations:
output = ' [X] %s' % title
else:
title += " Run 'manage.py migrate' to finish recording."
output = ' [-] %s' % title
if self.verbosity >= 2 and hasattr(applied_migration, 'applied'):
output += ' (applied at %s)' % applied_migration.applied.strftime('%Y-%m-%d %H:%M:%S')
self.stdout.write(output)
else:
self.stdout.write(" [ ] %s" % title)
shown.add(plan_node)
# If we didn't print anything, then a small message
if not shown:
self.stdout.write(" (no migrations)", self.style.ERROR)
def show_plan(self, connection, app_names=None):
"""
Show all known migrations (or only those of the specified app_names)
in the order they will be applied.
"""
# Load migrations from disk/DB
loader = MigrationLoader(connection)
graph = loader.graph
if app_names:
self._validate_app_names(loader, app_names)
targets = [key for key in graph.leaf_nodes() if key[0] in app_names]
else:
targets = graph.leaf_nodes()
plan = []
seen = set()
# Generate the plan
for target in targets:
for migration in graph.forwards_plan(target):
if migration not in seen:
node = graph.node_map[migration]
plan.append(node)
seen.add(migration)
# Output
def print_deps(node):
out = []
for parent in sorted(node.parents):
out.append("%s.%s" % parent.key)
if out:
return " ... (%s)" % ", ".join(out)
return ""
for node in plan:
deps = ""
if self.verbosity >= 2:
deps = print_deps(node)
if node.key in loader.applied_migrations:
self.stdout.write("[X] %s.%s%s" % (node.key[0], node.key[1], deps))
else:
self.stdout.write("[ ] %s.%s%s" % (node.key[0], node.key[1], deps))
if not plan:
self.stdout.write('(no migrations)', self.style.ERROR)

View File

@@ -0,0 +1,25 @@
from django.core.management.base import BaseCommand
from django.core.management.sql import sql_flush
from django.db import DEFAULT_DB_ALIAS, connections
class Command(BaseCommand):
help = (
"Returns a list of the SQL statements required to return all tables in "
"the database to the state they were in just after they were installed."
)
output_transaction = True
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
'--database', default=DEFAULT_DB_ALIAS,
help='Nominates a database to print the SQL for. Defaults to the "default" database.',
)
def handle(self, **options):
sql_statements = sql_flush(self.style, connections[options['database']])
if not sql_statements and options['verbosity'] >= 1:
self.stderr.write('No tables found.')
return '\n'.join(sql_statements)

View File

@@ -0,0 +1,68 @@
from django.apps import apps
from django.core.management.base import BaseCommand, CommandError
from django.db import DEFAULT_DB_ALIAS, connections
from django.db.migrations.loader import AmbiguityError, MigrationLoader
class Command(BaseCommand):
help = "Prints the SQL statements for the named migration."
output_transaction = True
def add_arguments(self, parser):
parser.add_argument('app_label', help='App label of the application containing the migration.')
parser.add_argument('migration_name', help='Migration name to print the SQL for.')
parser.add_argument(
'--database', default=DEFAULT_DB_ALIAS,
help='Nominates a database to create SQL for. Defaults to the "default" database.',
)
parser.add_argument(
'--backwards', action='store_true',
help='Creates SQL to unapply the migration, rather than to apply it',
)
def execute(self, *args, **options):
# sqlmigrate doesn't support coloring its output but we need to force
# no_color=True so that the BEGIN/COMMIT statements added by
# output_transaction don't get colored either.
options['no_color'] = True
return super().execute(*args, **options)
def handle(self, *args, **options):
# Get the database we're operating from
connection = connections[options['database']]
# Load up a loader to get all the migration data, but don't replace
# migrations.
loader = MigrationLoader(connection, replace_migrations=False)
# Resolve command-line arguments into a migration
app_label, migration_name = options['app_label'], options['migration_name']
# Validate app_label
try:
apps.get_app_config(app_label)
except LookupError as err:
raise CommandError(str(err))
if app_label not in loader.migrated_apps:
raise CommandError("App '%s' does not have migrations" % app_label)
try:
migration = loader.get_migration_by_prefix(app_label, migration_name)
except AmbiguityError:
raise CommandError("More than one migration matches '%s' in app '%s'. Please be more specific." % (
migration_name, app_label))
except KeyError:
raise CommandError("Cannot find a migration matching '%s' from app '%s'. Is it in INSTALLED_APPS?" % (
migration_name, app_label))
target = (app_label, migration.name)
# Show begin/end around output for atomic migrations, if the database
# supports transactional DDL.
self.output_transaction = migration.atomic and connection.features.can_rollback_ddl
# Make a plan that represents just the requested migrations and show SQL
# for it
plan = [(loader.graph.nodes[target], options['backwards'])]
sql_statements = loader.collect_sql(plan)
if not sql_statements and options['verbosity'] >= 1:
self.stderr.write('No operations found.')
return '\n'.join(sql_statements)

View File

@@ -0,0 +1,25 @@
from django.core.management.base import AppCommand
from django.db import DEFAULT_DB_ALIAS, connections
class Command(AppCommand):
help = 'Prints the SQL statements for resetting sequences for the given app name(s).'
output_transaction = True
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
'--database', default=DEFAULT_DB_ALIAS,
help='Nominates a database to print the SQL for. Defaults to the "default" database.',
)
def handle_app_config(self, app_config, **options):
if app_config.models_module is None:
return
connection = connections[options['database']]
models = app_config.get_models(include_auto_created=True)
statements = connection.ops.sequence_reset_sql(self.style, models)
if not statements and options['verbosity'] >= 1:
self.stderr.write('No sequences found.')
return '\n'.join(statements)

View File

@@ -0,0 +1,218 @@
from django.apps import apps
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django.db import DEFAULT_DB_ALIAS, connections, migrations
from django.db.migrations.loader import AmbiguityError, MigrationLoader
from django.db.migrations.migration import SwappableTuple
from django.db.migrations.optimizer import MigrationOptimizer
from django.db.migrations.writer import MigrationWriter
from django.utils.version import get_docs_version
class Command(BaseCommand):
help = "Squashes an existing set of migrations (from first until specified) into a single new one."
def add_arguments(self, parser):
parser.add_argument(
'app_label',
help='App label of the application to squash migrations for.',
)
parser.add_argument(
'start_migration_name', nargs='?',
help='Migrations will be squashed starting from and including this migration.',
)
parser.add_argument(
'migration_name',
help='Migrations will be squashed until and including this migration.',
)
parser.add_argument(
'--no-optimize', action='store_true',
help='Do not try to optimize the squashed operations.',
)
parser.add_argument(
'--noinput', '--no-input', action='store_false', dest='interactive',
help='Tells Django to NOT prompt the user for input of any kind.',
)
parser.add_argument(
'--squashed-name',
help='Sets the name of the new squashed migration.',
)
parser.add_argument(
'--no-header', action='store_false', dest='include_header',
help='Do not add a header comment to the new squashed migration.',
)
def handle(self, **options):
self.verbosity = options['verbosity']
self.interactive = options['interactive']
app_label = options['app_label']
start_migration_name = options['start_migration_name']
migration_name = options['migration_name']
no_optimize = options['no_optimize']
squashed_name = options['squashed_name']
include_header = options['include_header']
# Validate app_label.
try:
apps.get_app_config(app_label)
except LookupError as err:
raise CommandError(str(err))
# Load the current graph state, check the app and migration they asked for exists
loader = MigrationLoader(connections[DEFAULT_DB_ALIAS])
if app_label not in loader.migrated_apps:
raise CommandError(
"App '%s' does not have migrations (so squashmigrations on "
"it makes no sense)" % app_label
)
migration = self.find_migration(loader, app_label, migration_name)
# Work out the list of predecessor migrations
migrations_to_squash = [
loader.get_migration(al, mn)
for al, mn in loader.graph.forwards_plan((migration.app_label, migration.name))
if al == migration.app_label
]
if start_migration_name:
start_migration = self.find_migration(loader, app_label, start_migration_name)
start = loader.get_migration(start_migration.app_label, start_migration.name)
try:
start_index = migrations_to_squash.index(start)
migrations_to_squash = migrations_to_squash[start_index:]
except ValueError:
raise CommandError(
"The migration '%s' cannot be found. Maybe it comes after "
"the migration '%s'?\n"
"Have a look at:\n"
" python manage.py showmigrations %s\n"
"to debug this issue." % (start_migration, migration, app_label)
)
# Tell them what we're doing and optionally ask if we should proceed
if self.verbosity > 0 or self.interactive:
self.stdout.write(self.style.MIGRATE_HEADING("Will squash the following migrations:"))
for migration in migrations_to_squash:
self.stdout.write(" - %s" % migration.name)
if self.interactive:
answer = None
while not answer or answer not in "yn":
answer = input("Do you wish to proceed? [yN] ")
if not answer:
answer = "n"
break
else:
answer = answer[0].lower()
if answer != "y":
return
# Load the operations from all those migrations and concat together,
# along with collecting external dependencies and detecting
# double-squashing
operations = []
dependencies = set()
# We need to take all dependencies from the first migration in the list
# as it may be 0002 depending on 0001
first_migration = True
for smigration in migrations_to_squash:
if smigration.replaces:
raise CommandError(
"You cannot squash squashed migrations! Please transition "
"it to a normal migration first: "
"https://docs.djangoproject.com/en/%s/topics/migrations/#squashing-migrations" % get_docs_version()
)
operations.extend(smigration.operations)
for dependency in smigration.dependencies:
if isinstance(dependency, SwappableTuple):
if settings.AUTH_USER_MODEL == dependency.setting:
dependencies.add(("__setting__", "AUTH_USER_MODEL"))
else:
dependencies.add(dependency)
elif dependency[0] != smigration.app_label or first_migration:
dependencies.add(dependency)
first_migration = False
if no_optimize:
if self.verbosity > 0:
self.stdout.write(self.style.MIGRATE_HEADING("(Skipping optimization.)"))
new_operations = operations
else:
if self.verbosity > 0:
self.stdout.write(self.style.MIGRATE_HEADING("Optimizing..."))
optimizer = MigrationOptimizer()
new_operations = optimizer.optimize(operations, migration.app_label)
if self.verbosity > 0:
if len(new_operations) == len(operations):
self.stdout.write(" No optimizations possible.")
else:
self.stdout.write(
" Optimized from %s operations to %s operations." %
(len(operations), len(new_operations))
)
# Work out the value of replaces (any squashed ones we're re-squashing)
# need to feed their replaces into ours
replaces = []
for migration in migrations_to_squash:
if migration.replaces:
replaces.extend(migration.replaces)
else:
replaces.append((migration.app_label, migration.name))
# Make a new migration with those operations
subclass = type("Migration", (migrations.Migration,), {
"dependencies": dependencies,
"operations": new_operations,
"replaces": replaces,
})
if start_migration_name:
if squashed_name:
# Use the name from --squashed-name.
prefix, _ = start_migration.name.split('_', 1)
name = '%s_%s' % (prefix, squashed_name)
else:
# Generate a name.
name = '%s_squashed_%s' % (start_migration.name, migration.name)
new_migration = subclass(name, app_label)
else:
name = '0001_%s' % (squashed_name or 'squashed_%s' % migration.name)
new_migration = subclass(name, app_label)
new_migration.initial = True
# Write out the new migration file
writer = MigrationWriter(new_migration, include_header)
with open(writer.path, "w", encoding='utf-8') as fh:
fh.write(writer.as_string())
if self.verbosity > 0:
self.stdout.write(
self.style.MIGRATE_HEADING('Created new squashed migration %s' % writer.path) + '\n'
' You should commit this migration but leave the old ones in place;\n'
' the new migration will be used for new installs. Once you are sure\n'
' all instances of the codebase have applied the migrations you squashed,\n'
' you can delete them.'
)
if writer.needs_manual_porting:
self.stdout.write(
self.style.MIGRATE_HEADING('Manual porting required') + '\n'
' Your migrations contained functions that must be manually copied over,\n'
' as we could not safely copy their implementation.\n'
' See the comment at the top of the squashed migration for details.'
)
def find_migration(self, loader, app_label, name):
try:
return loader.get_migration_by_prefix(app_label, name)
except AmbiguityError:
raise CommandError(
"More than one migration matches '%s' in app '%s'. Please be "
"more specific." % (name, app_label)
)
except KeyError:
raise CommandError(
"Cannot find a migration matching '%s' from app '%s'." %
(name, app_label)
)

View File

@@ -0,0 +1,14 @@
from django.core.management.templates import TemplateCommand
class Command(TemplateCommand):
help = (
"Creates a Django app directory structure for the given app name in "
"the current directory or optionally in the given directory."
)
missing_args_message = "You must provide an application name."
def handle(self, **options):
app_name = options.pop('name')
target = options.pop('directory')
super().handle('app', app_name, target, **options)

View File

@@ -0,0 +1,21 @@
from django.core.checks.security.base import SECRET_KEY_INSECURE_PREFIX
from django.core.management.templates import TemplateCommand
from ..utils import get_random_secret_key
class Command(TemplateCommand):
help = (
"Creates a Django project directory structure for the given project "
"name in the current directory or optionally in the given directory."
)
missing_args_message = "You must provide a project name."
def handle(self, **options):
project_name = options.pop('name')
target = options.pop('directory')
# Create a random SECRET_KEY to put it in the main settings.
options['secret_key'] = SECRET_KEY_INSECURE_PREFIX + get_random_secret_key()
super().handle('project', project_name, target, **options)

View File

@@ -0,0 +1,62 @@
import sys
from django.conf import settings
from django.core.management.base import BaseCommand
from django.core.management.utils import get_command_line_option
from django.test.runner import get_max_test_processes
from django.test.utils import NullTimeKeeper, TimeKeeper, get_runner
class Command(BaseCommand):
help = 'Discover and run tests in the specified modules or the current directory.'
# DiscoverRunner runs the checks after databases are set up.
requires_system_checks = []
test_runner = None
def run_from_argv(self, argv):
"""
Pre-parse the command line to extract the value of the --testrunner
option. This allows a test runner to define additional command line
arguments.
"""
self.test_runner = get_command_line_option(argv, '--testrunner')
super().run_from_argv(argv)
def add_arguments(self, parser):
parser.add_argument(
'args', metavar='test_label', nargs='*',
help='Module paths to test; can be modulename, modulename.TestCase or modulename.TestCase.test_method'
)
parser.add_argument(
'--noinput', '--no-input', action='store_false', dest='interactive',
help='Tells Django to NOT prompt the user for input of any kind.',
)
parser.add_argument(
'--failfast', action='store_true',
help='Tells Django to stop running the test suite after first failed test.',
)
parser.add_argument(
'--testrunner',
help='Tells Django to use specified test runner class instead of '
'the one specified by the TEST_RUNNER setting.',
)
test_runner_class = get_runner(settings, self.test_runner)
if hasattr(test_runner_class, 'add_arguments'):
test_runner_class.add_arguments(parser)
def handle(self, *test_labels, **options):
TestRunner = get_runner(settings, options['testrunner'])
time_keeper = TimeKeeper() if options.get('timing', False) else NullTimeKeeper()
parallel = options.get('parallel')
if parallel == 'auto':
options['parallel'] = get_max_test_processes()
test_runner = TestRunner(**options)
with time_keeper.timed('Total run'):
failures = test_runner.run_tests(test_labels)
time_keeper.print_results()
if failures:
sys.exit(1)

View File

@@ -0,0 +1,54 @@
from django.core.management import call_command
from django.core.management.base import BaseCommand
from django.db import connection
class Command(BaseCommand):
help = 'Runs a development server with data from the given fixture(s).'
requires_system_checks = []
def add_arguments(self, parser):
parser.add_argument(
'args', metavar='fixture', nargs='*',
help='Path(s) to fixtures to load before running the server.',
)
parser.add_argument(
'--noinput', '--no-input', action='store_false', dest='interactive',
help='Tells Django to NOT prompt the user for input of any kind.',
)
parser.add_argument(
'--addrport', default='',
help='Port number or ipaddr:port to run the server on.',
)
parser.add_argument(
'--ipv6', '-6', action='store_true', dest='use_ipv6',
help='Tells Django to use an IPv6 address.',
)
def handle(self, *fixture_labels, **options):
verbosity = options['verbosity']
interactive = options['interactive']
# Create a test database.
db_name = connection.creation.create_test_db(verbosity=verbosity, autoclobber=not interactive, serialize=False)
# Import the fixture data into the test database.
call_command('loaddata', *fixture_labels, **{'verbosity': verbosity})
# Run the development server. Turn off auto-reloading because it causes
# a strange error -- it causes this handle() method to be called
# multiple times.
shutdown_message = (
'\nServer stopped.\nNote that the test database, %r, has not been '
'deleted. You can explore it on your own.' % db_name
)
use_threading = connection.features.test_db_allows_multiple_connections
call_command(
'runserver',
addrport=options['addrport'],
shutdown_message=shutdown_message,
use_reloader=False,
use_ipv6=options['use_ipv6'],
use_threading=use_threading
)

View File

@@ -0,0 +1,53 @@
import sys
from django.apps import apps
from django.db import models
def sql_flush(style, connection, reset_sequences=True, allow_cascade=False):
"""
Return a list of the SQL statements used to flush the database.
"""
tables = connection.introspection.django_table_names(only_existing=True, include_views=False)
return connection.ops.sql_flush(
style,
tables,
reset_sequences=reset_sequences,
allow_cascade=allow_cascade,
)
def emit_pre_migrate_signal(verbosity, interactive, db, **kwargs):
# Emit the pre_migrate signal for every application.
for app_config in apps.get_app_configs():
if app_config.models_module is None:
continue
if verbosity >= 2:
stdout = kwargs.get('stdout', sys.stdout)
stdout.write('Running pre-migrate handlers for application %s' % app_config.label)
models.signals.pre_migrate.send(
sender=app_config,
app_config=app_config,
verbosity=verbosity,
interactive=interactive,
using=db,
**kwargs
)
def emit_post_migrate_signal(verbosity, interactive, db, **kwargs):
# Emit the post_migrate signal for every application.
for app_config in apps.get_app_configs():
if app_config.models_module is None:
continue
if verbosity >= 2:
stdout = kwargs.get('stdout', sys.stdout)
stdout.write('Running post-migrate handlers for application %s' % app_config.label)
models.signals.post_migrate.send(
sender=app_config,
app_config=app_config,
verbosity=verbosity,
interactive=interactive,
using=db,
**kwargs
)

View File

@@ -0,0 +1,356 @@
import argparse
import cgi
import mimetypes
import os
import posixpath
import shutil
import stat
import tempfile
from importlib import import_module
from urllib.request import urlretrieve
import django
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django.core.management.utils import handle_extensions
from django.template import Context, Engine
from django.utils import archive
from django.utils.version import get_docs_version
class TemplateCommand(BaseCommand):
"""
Copy either a Django application layout template or a Django project
layout template into the specified directory.
:param style: A color style object (see django.core.management.color).
:param app_or_project: The string 'app' or 'project'.
:param name: The name of the application or project.
:param directory: The directory to which the template should be copied.
:param options: The additional variables passed to project or app templates
"""
requires_system_checks = []
# The supported URL schemes
url_schemes = ['http', 'https', 'ftp']
# Rewrite the following suffixes when determining the target filename.
rewrite_template_suffixes = (
# Allow shipping invalid .py files without byte-compilation.
('.py-tpl', '.py'),
)
def add_arguments(self, parser):
parser.add_argument('name', help='Name of the application or project.')
parser.add_argument('directory', nargs='?', help='Optional destination directory')
parser.add_argument('--template', help='The path or URL to load the template from.')
parser.add_argument(
'--extension', '-e', dest='extensions',
action='append', default=['py'],
help='The file extension(s) to render (default: "py"). '
'Separate multiple extensions with commas, or use '
'-e multiple times.'
)
parser.add_argument(
'--name', '-n', dest='files',
action='append', default=[],
help='The file name(s) to render. Separate multiple file names '
'with commas, or use -n multiple times.'
)
parser.add_argument(
'--exclude', '-x',
action='append', default=argparse.SUPPRESS, nargs='?', const='',
help=(
'The directory name(s) to exclude, in addition to .git and '
'__pycache__. Can be used multiple times.'
),
)
def handle(self, app_or_project, name, target=None, **options):
self.app_or_project = app_or_project
self.a_or_an = 'an' if app_or_project == 'app' else 'a'
self.paths_to_remove = []
self.verbosity = options['verbosity']
self.validate_name(name)
# if some directory is given, make sure it's nicely expanded
if target is None:
top_dir = os.path.join(os.getcwd(), name)
try:
os.makedirs(top_dir)
except FileExistsError:
raise CommandError("'%s' already exists" % top_dir)
except OSError as e:
raise CommandError(e)
else:
top_dir = os.path.abspath(os.path.expanduser(target))
if app_or_project == 'app':
self.validate_name(os.path.basename(top_dir), 'directory')
if not os.path.exists(top_dir):
raise CommandError("Destination directory '%s' does not "
"exist, please create it first." % top_dir)
extensions = tuple(handle_extensions(options['extensions']))
extra_files = []
excluded_directories = ['.git', '__pycache__']
for file in options['files']:
extra_files.extend(map(lambda x: x.strip(), file.split(',')))
if exclude := options.get('exclude'):
for directory in exclude:
excluded_directories.append(directory.strip())
if self.verbosity >= 2:
self.stdout.write(
'Rendering %s template files with extensions: %s'
% (app_or_project, ', '.join(extensions))
)
self.stdout.write(
'Rendering %s template files with filenames: %s'
% (app_or_project, ', '.join(extra_files))
)
base_name = '%s_name' % app_or_project
base_subdir = '%s_template' % app_or_project
base_directory = '%s_directory' % app_or_project
camel_case_name = 'camel_case_%s_name' % app_or_project
camel_case_value = ''.join(x for x in name.title() if x != '_')
context = Context({
**options,
base_name: name,
base_directory: top_dir,
camel_case_name: camel_case_value,
'docs_version': get_docs_version(),
'django_version': django.__version__,
}, autoescape=False)
# Setup a stub settings environment for template rendering
if not settings.configured:
settings.configure()
django.setup()
template_dir = self.handle_template(options['template'],
base_subdir)
prefix_length = len(template_dir) + 1
for root, dirs, files in os.walk(template_dir):
path_rest = root[prefix_length:]
relative_dir = path_rest.replace(base_name, name)
if relative_dir:
target_dir = os.path.join(top_dir, relative_dir)
os.makedirs(target_dir, exist_ok=True)
for dirname in dirs[:]:
if 'exclude' not in options:
if dirname.startswith('.') or dirname == '__pycache__':
dirs.remove(dirname)
elif dirname in excluded_directories:
dirs.remove(dirname)
for filename in files:
if filename.endswith(('.pyo', '.pyc', '.py.class')):
# Ignore some files as they cause various breakages.
continue
old_path = os.path.join(root, filename)
new_path = os.path.join(
top_dir, relative_dir, filename.replace(base_name, name)
)
for old_suffix, new_suffix in self.rewrite_template_suffixes:
if new_path.endswith(old_suffix):
new_path = new_path[:-len(old_suffix)] + new_suffix
break # Only rewrite once
if os.path.exists(new_path):
raise CommandError(
"%s already exists. Overlaying %s %s into an existing "
"directory won't replace conflicting files." % (
new_path, self.a_or_an, app_or_project,
)
)
# Only render the Python files, as we don't want to
# accidentally render Django templates files
if new_path.endswith(extensions) or filename in extra_files:
with open(old_path, encoding='utf-8') as template_file:
content = template_file.read()
template = Engine().from_string(content)
content = template.render(context)
with open(new_path, 'w', encoding='utf-8') as new_file:
new_file.write(content)
else:
shutil.copyfile(old_path, new_path)
if self.verbosity >= 2:
self.stdout.write('Creating %s' % new_path)
try:
shutil.copymode(old_path, new_path)
self.make_writeable(new_path)
except OSError:
self.stderr.write(
"Notice: Couldn't set permission bits on %s. You're "
"probably using an uncommon filesystem setup. No "
"problem." % new_path, self.style.NOTICE)
if self.paths_to_remove:
if self.verbosity >= 2:
self.stdout.write('Cleaning up temporary files.')
for path_to_remove in self.paths_to_remove:
if os.path.isfile(path_to_remove):
os.remove(path_to_remove)
else:
shutil.rmtree(path_to_remove)
def handle_template(self, template, subdir):
"""
Determine where the app or project templates are.
Use django.__path__[0] as the default because the Django install
directory isn't known.
"""
if template is None:
return os.path.join(django.__path__[0], 'conf', subdir)
else:
if template.startswith('file://'):
template = template[7:]
expanded_template = os.path.expanduser(template)
expanded_template = os.path.normpath(expanded_template)
if os.path.isdir(expanded_template):
return expanded_template
if self.is_url(template):
# downloads the file and returns the path
absolute_path = self.download(template)
else:
absolute_path = os.path.abspath(expanded_template)
if os.path.exists(absolute_path):
return self.extract(absolute_path)
raise CommandError("couldn't handle %s template %s." %
(self.app_or_project, template))
def validate_name(self, name, name_or_dir='name'):
if name is None:
raise CommandError('you must provide {an} {app} name'.format(
an=self.a_or_an,
app=self.app_or_project,
))
# Check it's a valid directory name.
if not name.isidentifier():
raise CommandError(
"'{name}' is not a valid {app} {type}. Please make sure the "
"{type} is a valid identifier.".format(
name=name,
app=self.app_or_project,
type=name_or_dir,
)
)
# Check it cannot be imported.
try:
import_module(name)
except ImportError:
pass
else:
raise CommandError(
"'{name}' conflicts with the name of an existing Python "
"module and cannot be used as {an} {app} {type}. Please try "
"another {type}.".format(
name=name,
an=self.a_or_an,
app=self.app_or_project,
type=name_or_dir,
)
)
def download(self, url):
"""
Download the given URL and return the file name.
"""
def cleanup_url(url):
tmp = url.rstrip('/')
filename = tmp.split('/')[-1]
if url.endswith('/'):
display_url = tmp + '/'
else:
display_url = url
return filename, display_url
prefix = 'django_%s_template_' % self.app_or_project
tempdir = tempfile.mkdtemp(prefix=prefix, suffix='_download')
self.paths_to_remove.append(tempdir)
filename, display_url = cleanup_url(url)
if self.verbosity >= 2:
self.stdout.write('Downloading %s' % display_url)
try:
the_path, info = urlretrieve(url, os.path.join(tempdir, filename))
except OSError as e:
raise CommandError("couldn't download URL %s to %s: %s" %
(url, filename, e))
used_name = the_path.split('/')[-1]
# Trying to get better name from response headers
content_disposition = info.get('content-disposition')
if content_disposition:
_, params = cgi.parse_header(content_disposition)
guessed_filename = params.get('filename') or used_name
else:
guessed_filename = used_name
# Falling back to content type guessing
ext = self.splitext(guessed_filename)[1]
content_type = info.get('content-type')
if not ext and content_type:
ext = mimetypes.guess_extension(content_type)
if ext:
guessed_filename += ext
# Move the temporary file to a filename that has better
# chances of being recognized by the archive utils
if used_name != guessed_filename:
guessed_path = os.path.join(tempdir, guessed_filename)
shutil.move(the_path, guessed_path)
return guessed_path
# Giving up
return the_path
def splitext(self, the_path):
"""
Like os.path.splitext, but takes off .tar, too
"""
base, ext = posixpath.splitext(the_path)
if base.lower().endswith('.tar'):
ext = base[-4:] + ext
base = base[:-4]
return base, ext
def extract(self, filename):
"""
Extract the given file to a temporary directory and return
the path of the directory with the extracted content.
"""
prefix = 'django_%s_template_' % self.app_or_project
tempdir = tempfile.mkdtemp(prefix=prefix, suffix='_extract')
self.paths_to_remove.append(tempdir)
if self.verbosity >= 2:
self.stdout.write('Extracting %s' % filename)
try:
archive.extract(filename, tempdir)
return tempdir
except (archive.ArchiveException, OSError) as e:
raise CommandError("couldn't extract file %s to %s: %s" %
(filename, tempdir, e))
def is_url(self, template):
"""Return True if the name looks like a URL."""
if ':' not in template:
return False
scheme = template.split(':', 1)[0].lower()
return scheme in self.url_schemes
def make_writeable(self, filename):
"""
Make sure that the file is writeable.
Useful if our source is read-only.
"""
if not os.access(filename, os.W_OK):
st = os.stat(filename)
new_permissions = stat.S_IMODE(st.st_mode) | stat.S_IWUSR
os.chmod(filename, new_permissions)

View File

@@ -0,0 +1,153 @@
import fnmatch
import os
from pathlib import Path
from subprocess import PIPE, run
from django.apps import apps as installed_apps
from django.utils.crypto import get_random_string
from django.utils.encoding import DEFAULT_LOCALE_ENCODING
from .base import CommandError, CommandParser
def popen_wrapper(args, stdout_encoding='utf-8'):
"""
Friendly wrapper around Popen.
Return stdout output, stderr output, and OS status code.
"""
try:
p = run(args, stdout=PIPE, stderr=PIPE, close_fds=os.name != 'nt')
except OSError as err:
raise CommandError('Error executing %s' % args[0]) from err
return (
p.stdout.decode(stdout_encoding),
p.stderr.decode(DEFAULT_LOCALE_ENCODING, errors='replace'),
p.returncode
)
def handle_extensions(extensions):
"""
Organize multiple extensions that are separated with commas or passed by
using --extension/-e multiple times.
For example: running 'django-admin makemessages -e js,txt -e xhtml -a'
would result in an extension list: ['.js', '.txt', '.xhtml']
>>> handle_extensions(['.html', 'html,js,py,py,py,.py', 'py,.py'])
{'.html', '.js', '.py'}
>>> handle_extensions(['.html, txt,.tpl'])
{'.html', '.tpl', '.txt'}
"""
ext_list = []
for ext in extensions:
ext_list.extend(ext.replace(' ', '').split(','))
for i, ext in enumerate(ext_list):
if not ext.startswith('.'):
ext_list[i] = '.%s' % ext_list[i]
return set(ext_list)
def find_command(cmd, path=None, pathext=None):
if path is None:
path = os.environ.get('PATH', '').split(os.pathsep)
if isinstance(path, str):
path = [path]
# check if there are funny path extensions for executables, e.g. Windows
if pathext is None:
pathext = os.environ.get('PATHEXT', '.COM;.EXE;.BAT;.CMD').split(os.pathsep)
# don't use extensions if the command ends with one of them
for ext in pathext:
if cmd.endswith(ext):
pathext = ['']
break
# check if we find the command on PATH
for p in path:
f = os.path.join(p, cmd)
if os.path.isfile(f):
return f
for ext in pathext:
fext = f + ext
if os.path.isfile(fext):
return fext
return None
def get_random_secret_key():
"""
Return a 50 character random string usable as a SECRET_KEY setting value.
"""
chars = 'abcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*(-_=+)'
return get_random_string(50, chars)
def parse_apps_and_model_labels(labels):
"""
Parse a list of "app_label.ModelName" or "app_label" strings into actual
objects and return a two-element tuple:
(set of model classes, set of app_configs).
Raise a CommandError if some specified models or apps don't exist.
"""
apps = set()
models = set()
for label in labels:
if '.' in label:
try:
model = installed_apps.get_model(label)
except LookupError:
raise CommandError('Unknown model: %s' % label)
models.add(model)
else:
try:
app_config = installed_apps.get_app_config(label)
except LookupError as e:
raise CommandError(str(e))
apps.add(app_config)
return models, apps
def get_command_line_option(argv, option):
"""
Return the value of a command line option (which should include leading
dashes, e.g. '--testrunner') from an argument list. Return None if the
option wasn't passed or if the argument list couldn't be parsed.
"""
parser = CommandParser(add_help=False, allow_abbrev=False)
parser.add_argument(option, dest='value')
try:
options, _ = parser.parse_known_args(argv[2:])
except CommandError:
return None
else:
return options.value
def normalize_path_patterns(patterns):
"""Normalize an iterable of glob style patterns based on OS."""
patterns = [os.path.normcase(p) for p in patterns]
dir_suffixes = {'%s*' % path_sep for path_sep in {'/', os.sep}}
norm_patterns = []
for pattern in patterns:
for dir_suffix in dir_suffixes:
if pattern.endswith(dir_suffix):
norm_patterns.append(pattern[:-len(dir_suffix)])
break
else:
norm_patterns.append(pattern)
return norm_patterns
def is_ignored_path(path, ignore_patterns):
"""
Check if the given path should be ignored or not based on matching
one of the glob style `ignore_patterns`.
"""
path = Path(path)
def ignore(pattern):
return fnmatch.fnmatchcase(path.name, pattern) or fnmatch.fnmatchcase(str(path), pattern)
return any(ignore(pattern) for pattern in normalize_path_patterns(ignore_patterns))

View File

@@ -0,0 +1,224 @@
import collections.abc
import inspect
import warnings
from math import ceil
from django.utils.functional import cached_property
from django.utils.inspect import method_has_no_args
from django.utils.translation import gettext_lazy as _
class UnorderedObjectListWarning(RuntimeWarning):
pass
class InvalidPage(Exception):
pass
class PageNotAnInteger(InvalidPage):
pass
class EmptyPage(InvalidPage):
pass
class Paginator:
# Translators: String used to replace omitted page numbers in elided page
# range generated by paginators, e.g. [1, 2, '…', 5, 6, 7, '…', 9, 10].
ELLIPSIS = _('')
def __init__(self, object_list, per_page, orphans=0,
allow_empty_first_page=True):
self.object_list = object_list
self._check_object_list_is_ordered()
self.per_page = int(per_page)
self.orphans = int(orphans)
self.allow_empty_first_page = allow_empty_first_page
def __iter__(self):
for page_number in self.page_range:
yield self.page(page_number)
def validate_number(self, number):
"""Validate the given 1-based page number."""
try:
if isinstance(number, float) and not number.is_integer():
raise ValueError
number = int(number)
except (TypeError, ValueError):
raise PageNotAnInteger(_('That page number is not an integer'))
if number < 1:
raise EmptyPage(_('That page number is less than 1'))
if number > self.num_pages:
if number == 1 and self.allow_empty_first_page:
pass
else:
raise EmptyPage(_('That page contains no results'))
return number
def get_page(self, number):
"""
Return a valid page, even if the page argument isn't a number or isn't
in range.
"""
try:
number = self.validate_number(number)
except PageNotAnInteger:
number = 1
except EmptyPage:
number = self.num_pages
return self.page(number)
def page(self, number):
"""Return a Page object for the given 1-based page number."""
number = self.validate_number(number)
bottom = (number - 1) * self.per_page
top = bottom + self.per_page
if top + self.orphans >= self.count:
top = self.count
return self._get_page(self.object_list[bottom:top], number, self)
def _get_page(self, *args, **kwargs):
"""
Return an instance of a single page.
This hook can be used by subclasses to use an alternative to the
standard :cls:`Page` object.
"""
return Page(*args, **kwargs)
@cached_property
def count(self):
"""Return the total number of objects, across all pages."""
c = getattr(self.object_list, 'count', None)
if callable(c) and not inspect.isbuiltin(c) and method_has_no_args(c):
return c()
return len(self.object_list)
@cached_property
def num_pages(self):
"""Return the total number of pages."""
if self.count == 0 and not self.allow_empty_first_page:
return 0
hits = max(1, self.count - self.orphans)
return ceil(hits / self.per_page)
@property
def page_range(self):
"""
Return a 1-based range of pages for iterating through within
a template for loop.
"""
return range(1, self.num_pages + 1)
def _check_object_list_is_ordered(self):
"""
Warn if self.object_list is unordered (typically a QuerySet).
"""
ordered = getattr(self.object_list, 'ordered', None)
if ordered is not None and not ordered:
obj_list_repr = (
'{} {}'.format(self.object_list.model, self.object_list.__class__.__name__)
if hasattr(self.object_list, 'model')
else '{!r}'.format(self.object_list)
)
warnings.warn(
'Pagination may yield inconsistent results with an unordered '
'object_list: {}.'.format(obj_list_repr),
UnorderedObjectListWarning,
stacklevel=3
)
def get_elided_page_range(self, number=1, *, on_each_side=3, on_ends=2):
"""
Return a 1-based range of pages with some values elided.
If the page range is larger than a given size, the whole range is not
provided and a compact form is returned instead, e.g. for a paginator
with 50 pages, if page 43 were the current page, the output, with the
default arguments, would be:
1, 2, …, 40, 41, 42, 43, 44, 45, 46, …, 49, 50.
"""
number = self.validate_number(number)
if self.num_pages <= (on_each_side + on_ends) * 2:
yield from self.page_range
return
if number > (1 + on_each_side + on_ends) + 1:
yield from range(1, on_ends + 1)
yield self.ELLIPSIS
yield from range(number - on_each_side, number + 1)
else:
yield from range(1, number + 1)
if number < (self.num_pages - on_each_side - on_ends) - 1:
yield from range(number + 1, number + on_each_side + 1)
yield self.ELLIPSIS
yield from range(self.num_pages - on_ends + 1, self.num_pages + 1)
else:
yield from range(number + 1, self.num_pages + 1)
class Page(collections.abc.Sequence):
def __init__(self, object_list, number, paginator):
self.object_list = object_list
self.number = number
self.paginator = paginator
def __repr__(self):
return '<Page %s of %s>' % (self.number, self.paginator.num_pages)
def __len__(self):
return len(self.object_list)
def __getitem__(self, index):
if not isinstance(index, (int, slice)):
raise TypeError(
'Page indices must be integers or slices, not %s.'
% type(index).__name__
)
# The object_list is converted to a list so that if it was a QuerySet
# it won't be a database hit per __getitem__.
if not isinstance(self.object_list, list):
self.object_list = list(self.object_list)
return self.object_list[index]
def has_next(self):
return self.number < self.paginator.num_pages
def has_previous(self):
return self.number > 1
def has_other_pages(self):
return self.has_previous() or self.has_next()
def next_page_number(self):
return self.paginator.validate_number(self.number + 1)
def previous_page_number(self):
return self.paginator.validate_number(self.number - 1)
def start_index(self):
"""
Return the 1-based index of the first object on this page,
relative to total objects in the paginator.
"""
# Special case, return zero if no items.
if self.paginator.count == 0:
return 0
return (self.paginator.per_page * (self.number - 1)) + 1
def end_index(self):
"""
Return the 1-based index of the last object on this page,
relative to total objects found (hits).
"""
# Special case for the last page because there can be orphans.
if self.number == self.paginator.num_pages:
return self.paginator.count
return self.number * self.paginator.per_page

View File

@@ -0,0 +1,245 @@
"""
Interfaces for serializing Django objects.
Usage::
from django.core import serializers
json = serializers.serialize("json", some_queryset)
objects = list(serializers.deserialize("json", json))
To add your own serializers, use the SERIALIZATION_MODULES setting::
SERIALIZATION_MODULES = {
"csv": "path.to.csv.serializer",
"txt": "path.to.txt.serializer",
}
"""
import importlib
from django.apps import apps
from django.conf import settings
from django.core.serializers.base import SerializerDoesNotExist
# Built-in serializers
BUILTIN_SERIALIZERS = {
"xml": "django.core.serializers.xml_serializer",
"python": "django.core.serializers.python",
"json": "django.core.serializers.json",
"yaml": "django.core.serializers.pyyaml",
"jsonl": "django.core.serializers.jsonl",
}
_serializers = {}
class BadSerializer:
"""
Stub serializer to hold exception raised during registration
This allows the serializer registration to cache serializers and if there
is an error raised in the process of creating a serializer it will be
raised and passed along to the caller when the serializer is used.
"""
internal_use_only = False
def __init__(self, exception):
self.exception = exception
def __call__(self, *args, **kwargs):
raise self.exception
def register_serializer(format, serializer_module, serializers=None):
"""Register a new serializer.
``serializer_module`` should be the fully qualified module name
for the serializer.
If ``serializers`` is provided, the registration will be added
to the provided dictionary.
If ``serializers`` is not provided, the registration will be made
directly into the global register of serializers. Adding serializers
directly is not a thread-safe operation.
"""
if serializers is None and not _serializers:
_load_serializers()
try:
module = importlib.import_module(serializer_module)
except ImportError as exc:
bad_serializer = BadSerializer(exc)
module = type('BadSerializerModule', (), {
'Deserializer': bad_serializer,
'Serializer': bad_serializer,
})
if serializers is None:
_serializers[format] = module
else:
serializers[format] = module
def unregister_serializer(format):
"Unregister a given serializer. This is not a thread-safe operation."
if not _serializers:
_load_serializers()
if format not in _serializers:
raise SerializerDoesNotExist(format)
del _serializers[format]
def get_serializer(format):
if not _serializers:
_load_serializers()
if format not in _serializers:
raise SerializerDoesNotExist(format)
return _serializers[format].Serializer
def get_serializer_formats():
if not _serializers:
_load_serializers()
return list(_serializers)
def get_public_serializer_formats():
if not _serializers:
_load_serializers()
return [k for k, v in _serializers.items() if not v.Serializer.internal_use_only]
def get_deserializer(format):
if not _serializers:
_load_serializers()
if format not in _serializers:
raise SerializerDoesNotExist(format)
return _serializers[format].Deserializer
def serialize(format, queryset, **options):
"""
Serialize a queryset (or any iterator that returns database objects) using
a certain serializer.
"""
s = get_serializer(format)()
s.serialize(queryset, **options)
return s.getvalue()
def deserialize(format, stream_or_string, **options):
"""
Deserialize a stream or a string. Return an iterator that yields ``(obj,
m2m_relation_dict)``, where ``obj`` is an instantiated -- but *unsaved* --
object, and ``m2m_relation_dict`` is a dictionary of ``{m2m_field_name :
list_of_related_objects}``.
"""
d = get_deserializer(format)
return d(stream_or_string, **options)
def _load_serializers():
"""
Register built-in and settings-defined serializers. This is done lazily so
that user code has a chance to (e.g.) set up custom settings without
needing to be careful of import order.
"""
global _serializers
serializers = {}
for format in BUILTIN_SERIALIZERS:
register_serializer(format, BUILTIN_SERIALIZERS[format], serializers)
if hasattr(settings, "SERIALIZATION_MODULES"):
for format in settings.SERIALIZATION_MODULES:
register_serializer(format, settings.SERIALIZATION_MODULES[format], serializers)
_serializers = serializers
def sort_dependencies(app_list, allow_cycles=False):
"""Sort a list of (app_config, models) pairs into a single list of models.
The single list of models is sorted so that any model with a natural key
is serialized before a normal model, and any model with a natural key
dependency has it's dependencies serialized first.
If allow_cycles is True, return the best-effort ordering that will respect
most of dependencies but ignore some of them to break the cycles.
"""
# Process the list of models, and get the list of dependencies
model_dependencies = []
models = set()
for app_config, model_list in app_list:
if model_list is None:
model_list = app_config.get_models()
for model in model_list:
models.add(model)
# Add any explicitly defined dependencies
if hasattr(model, 'natural_key'):
deps = getattr(model.natural_key, 'dependencies', [])
if deps:
deps = [apps.get_model(dep) for dep in deps]
else:
deps = []
# Now add a dependency for any FK relation with a model that
# defines a natural key
for field in model._meta.fields:
if field.remote_field:
rel_model = field.remote_field.model
if hasattr(rel_model, 'natural_key') and rel_model != model:
deps.append(rel_model)
# Also add a dependency for any simple M2M relation with a model
# that defines a natural key. M2M relations with explicit through
# models don't count as dependencies.
for field in model._meta.many_to_many:
if field.remote_field.through._meta.auto_created:
rel_model = field.remote_field.model
if hasattr(rel_model, 'natural_key') and rel_model != model:
deps.append(rel_model)
model_dependencies.append((model, deps))
model_dependencies.reverse()
# Now sort the models to ensure that dependencies are met. This
# is done by repeatedly iterating over the input list of models.
# If all the dependencies of a given model are in the final list,
# that model is promoted to the end of the final list. This process
# continues until the input list is empty, or we do a full iteration
# over the input models without promoting a model to the final list.
# If we do a full iteration without a promotion, that means there are
# circular dependencies in the list.
model_list = []
while model_dependencies:
skipped = []
changed = False
while model_dependencies:
model, deps = model_dependencies.pop()
# If all of the models in the dependency list are either already
# on the final model list, or not on the original serialization list,
# then we've found another model with all it's dependencies satisfied.
if all(d not in models or d in model_list for d in deps):
model_list.append(model)
changed = True
else:
skipped.append((model, deps))
if not changed:
if allow_cycles:
# If cycles are allowed, add the last skipped model and ignore
# its dependencies. This could be improved by some graph
# analysis to ignore as few dependencies as possible.
model, _ = skipped.pop()
model_list.append(model)
else:
raise RuntimeError(
"Can't resolve dependencies for %s in serialized app list."
% ', '.join(
model._meta.label
for model, deps in sorted(skipped, key=lambda obj: obj[0].__name__)
),
)
model_dependencies = skipped
return model_list

View File

@@ -0,0 +1,338 @@
"""
Module for abstract serializer/unserializer base classes.
"""
import pickle
from io import StringIO
from django.core.exceptions import ObjectDoesNotExist
from django.db import models
DEFER_FIELD = object()
class PickleSerializer:
"""
Simple wrapper around pickle to be used in signing.dumps()/loads() and
cache backends.
"""
def __init__(self, protocol=None):
self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol
def dumps(self, obj):
return pickle.dumps(obj, self.protocol)
def loads(self, data):
return pickle.loads(data)
class SerializerDoesNotExist(KeyError):
"""The requested serializer was not found."""
pass
class SerializationError(Exception):
"""Something bad happened during serialization."""
pass
class DeserializationError(Exception):
"""Something bad happened during deserialization."""
@classmethod
def WithData(cls, original_exc, model, fk, field_value):
"""
Factory method for creating a deserialization error which has a more
explanatory message.
"""
return cls("%s: (%s:pk=%s) field_value was '%s'" % (original_exc, model, fk, field_value))
class M2MDeserializationError(Exception):
"""Something bad happened during deserialization of a ManyToManyField."""
def __init__(self, original_exc, pk):
self.original_exc = original_exc
self.pk = pk
class ProgressBar:
progress_width = 75
def __init__(self, output, total_count):
self.output = output
self.total_count = total_count
self.prev_done = 0
def update(self, count):
if not self.output:
return
perc = count * 100 // self.total_count
done = perc * self.progress_width // 100
if self.prev_done >= done:
return
self.prev_done = done
cr = '' if self.total_count == 1 else '\r'
self.output.write(cr + '[' + '.' * done + ' ' * (self.progress_width - done) + ']')
if done == self.progress_width:
self.output.write('\n')
self.output.flush()
class Serializer:
"""
Abstract serializer base class.
"""
# Indicates if the implemented serializer is only available for
# internal Django use.
internal_use_only = False
progress_class = ProgressBar
stream_class = StringIO
def serialize(self, queryset, *, stream=None, fields=None, use_natural_foreign_keys=False,
use_natural_primary_keys=False, progress_output=None, object_count=0, **options):
"""
Serialize a queryset.
"""
self.options = options
self.stream = stream if stream is not None else self.stream_class()
self.selected_fields = fields
self.use_natural_foreign_keys = use_natural_foreign_keys
self.use_natural_primary_keys = use_natural_primary_keys
progress_bar = self.progress_class(progress_output, object_count)
self.start_serialization()
self.first = True
for count, obj in enumerate(queryset, start=1):
self.start_object(obj)
# Use the concrete parent class' _meta instead of the object's _meta
# This is to avoid local_fields problems for proxy models. Refs #17717.
concrete_model = obj._meta.concrete_model
# When using natural primary keys, retrieve the pk field of the
# parent for multi-table inheritance child models. That field must
# be serialized, otherwise deserialization isn't possible.
if self.use_natural_primary_keys:
pk = concrete_model._meta.pk
pk_parent = pk if pk.remote_field and pk.remote_field.parent_link else None
else:
pk_parent = None
for field in concrete_model._meta.local_fields:
if field.serialize or field is pk_parent:
if field.remote_field is None:
if self.selected_fields is None or field.attname in self.selected_fields:
self.handle_field(obj, field)
else:
if self.selected_fields is None or field.attname[:-3] in self.selected_fields:
self.handle_fk_field(obj, field)
for field in concrete_model._meta.local_many_to_many:
if field.serialize:
if self.selected_fields is None or field.attname in self.selected_fields:
self.handle_m2m_field(obj, field)
self.end_object(obj)
progress_bar.update(count)
self.first = self.first and False
self.end_serialization()
return self.getvalue()
def start_serialization(self):
"""
Called when serializing of the queryset starts.
"""
raise NotImplementedError('subclasses of Serializer must provide a start_serialization() method')
def end_serialization(self):
"""
Called when serializing of the queryset ends.
"""
pass
def start_object(self, obj):
"""
Called when serializing of an object starts.
"""
raise NotImplementedError('subclasses of Serializer must provide a start_object() method')
def end_object(self, obj):
"""
Called when serializing of an object ends.
"""
pass
def handle_field(self, obj, field):
"""
Called to handle each individual (non-relational) field on an object.
"""
raise NotImplementedError('subclasses of Serializer must provide a handle_field() method')
def handle_fk_field(self, obj, field):
"""
Called to handle a ForeignKey field.
"""
raise NotImplementedError('subclasses of Serializer must provide a handle_fk_field() method')
def handle_m2m_field(self, obj, field):
"""
Called to handle a ManyToManyField.
"""
raise NotImplementedError('subclasses of Serializer must provide a handle_m2m_field() method')
def getvalue(self):
"""
Return the fully serialized queryset (or None if the output stream is
not seekable).
"""
if callable(getattr(self.stream, 'getvalue', None)):
return self.stream.getvalue()
class Deserializer:
"""
Abstract base deserializer class.
"""
def __init__(self, stream_or_string, **options):
"""
Init this serializer given a stream or a string
"""
self.options = options
if isinstance(stream_or_string, str):
self.stream = StringIO(stream_or_string)
else:
self.stream = stream_or_string
def __iter__(self):
return self
def __next__(self):
"""Iteration interface -- return the next item in the stream"""
raise NotImplementedError('subclasses of Deserializer must provide a __next__() method')
class DeserializedObject:
"""
A deserialized model.
Basically a container for holding the pre-saved deserialized data along
with the many-to-many data saved with the object.
Call ``save()`` to save the object (with the many-to-many data) to the
database; call ``save(save_m2m=False)`` to save just the object fields
(and not touch the many-to-many stuff.)
"""
def __init__(self, obj, m2m_data=None, deferred_fields=None):
self.object = obj
self.m2m_data = m2m_data
self.deferred_fields = deferred_fields
def __repr__(self):
return "<%s: %s(pk=%s)>" % (
self.__class__.__name__,
self.object._meta.label,
self.object.pk,
)
def save(self, save_m2m=True, using=None, **kwargs):
# Call save on the Model baseclass directly. This bypasses any
# model-defined save. The save is also forced to be raw.
# raw=True is passed to any pre/post_save signals.
models.Model.save_base(self.object, using=using, raw=True, **kwargs)
if self.m2m_data and save_m2m:
for accessor_name, object_list in self.m2m_data.items():
getattr(self.object, accessor_name).set(object_list)
# prevent a second (possibly accidental) call to save() from saving
# the m2m data twice.
self.m2m_data = None
def save_deferred_fields(self, using=None):
self.m2m_data = {}
for field, field_value in self.deferred_fields.items():
opts = self.object._meta
label = opts.app_label + '.' + opts.model_name
if isinstance(field.remote_field, models.ManyToManyRel):
try:
values = deserialize_m2m_values(field, field_value, using, handle_forward_references=False)
except M2MDeserializationError as e:
raise DeserializationError.WithData(e.original_exc, label, self.object.pk, e.pk)
self.m2m_data[field.name] = values
elif isinstance(field.remote_field, models.ManyToOneRel):
try:
value = deserialize_fk_value(field, field_value, using, handle_forward_references=False)
except Exception as e:
raise DeserializationError.WithData(e, label, self.object.pk, field_value)
setattr(self.object, field.attname, value)
self.save()
def build_instance(Model, data, db):
"""
Build a model instance.
If the model instance doesn't have a primary key and the model supports
natural keys, try to retrieve it from the database.
"""
default_manager = Model._meta.default_manager
pk = data.get(Model._meta.pk.attname)
if (pk is None and hasattr(default_manager, 'get_by_natural_key') and
hasattr(Model, 'natural_key')):
natural_key = Model(**data).natural_key()
try:
data[Model._meta.pk.attname] = Model._meta.pk.to_python(
default_manager.db_manager(db).get_by_natural_key(*natural_key).pk
)
except Model.DoesNotExist:
pass
return Model(**data)
def deserialize_m2m_values(field, field_value, using, handle_forward_references):
model = field.remote_field.model
if hasattr(model._default_manager, 'get_by_natural_key'):
def m2m_convert(value):
if hasattr(value, '__iter__') and not isinstance(value, str):
return model._default_manager.db_manager(using).get_by_natural_key(*value).pk
else:
return model._meta.pk.to_python(value)
else:
def m2m_convert(v):
return model._meta.pk.to_python(v)
try:
pks_iter = iter(field_value)
except TypeError as e:
raise M2MDeserializationError(e, field_value)
try:
values = []
for pk in pks_iter:
values.append(m2m_convert(pk))
return values
except Exception as e:
if isinstance(e, ObjectDoesNotExist) and handle_forward_references:
return DEFER_FIELD
else:
raise M2MDeserializationError(e, pk)
def deserialize_fk_value(field, field_value, using, handle_forward_references):
if field_value is None:
return None
model = field.remote_field.model
default_manager = model._default_manager
field_name = field.remote_field.field_name
if (hasattr(default_manager, 'get_by_natural_key') and
hasattr(field_value, '__iter__') and not isinstance(field_value, str)):
try:
obj = default_manager.db_manager(using).get_by_natural_key(*field_value)
except ObjectDoesNotExist:
if handle_forward_references:
return DEFER_FIELD
else:
raise
value = getattr(obj, field_name)
# If this is a natural foreign key to an object that has a FK/O2O as
# the foreign key, use the FK value.
if model._meta.pk.remote_field:
value = value.pk
return value
return model._meta.get_field(field_name).to_python(field_value)

View File

@@ -0,0 +1,105 @@
"""
Serialize data to/from JSON
"""
import datetime
import decimal
import json
import uuid
from django.core.serializers.base import DeserializationError
from django.core.serializers.python import (
Deserializer as PythonDeserializer, Serializer as PythonSerializer,
)
from django.utils.duration import duration_iso_string
from django.utils.functional import Promise
from django.utils.timezone import is_aware
class Serializer(PythonSerializer):
"""Convert a queryset to JSON."""
internal_use_only = False
def _init_options(self):
self._current = None
self.json_kwargs = self.options.copy()
self.json_kwargs.pop('stream', None)
self.json_kwargs.pop('fields', None)
if self.options.get('indent'):
# Prevent trailing spaces
self.json_kwargs['separators'] = (',', ': ')
self.json_kwargs.setdefault('cls', DjangoJSONEncoder)
self.json_kwargs.setdefault('ensure_ascii', False)
def start_serialization(self):
self._init_options()
self.stream.write("[")
def end_serialization(self):
if self.options.get("indent"):
self.stream.write("\n")
self.stream.write("]")
if self.options.get("indent"):
self.stream.write("\n")
def end_object(self, obj):
# self._current has the field data
indent = self.options.get("indent")
if not self.first:
self.stream.write(",")
if not indent:
self.stream.write(" ")
if indent:
self.stream.write("\n")
json.dump(self.get_dump_object(obj), self.stream, **self.json_kwargs)
self._current = None
def getvalue(self):
# Grandparent super
return super(PythonSerializer, self).getvalue()
def Deserializer(stream_or_string, **options):
"""Deserialize a stream or string of JSON data."""
if not isinstance(stream_or_string, (bytes, str)):
stream_or_string = stream_or_string.read()
if isinstance(stream_or_string, bytes):
stream_or_string = stream_or_string.decode()
try:
objects = json.loads(stream_or_string)
yield from PythonDeserializer(objects, **options)
except (GeneratorExit, DeserializationError):
raise
except Exception as exc:
raise DeserializationError() from exc
class DjangoJSONEncoder(json.JSONEncoder):
"""
JSONEncoder subclass that knows how to encode date/time, decimal types, and
UUIDs.
"""
def default(self, o):
# See "Date Time String Format" in the ECMA-262 specification.
if isinstance(o, datetime.datetime):
r = o.isoformat()
if o.microsecond:
r = r[:23] + r[26:]
if r.endswith('+00:00'):
r = r[:-6] + 'Z'
return r
elif isinstance(o, datetime.date):
return o.isoformat()
elif isinstance(o, datetime.time):
if is_aware(o):
raise ValueError("JSON can't represent timezone-aware times.")
r = o.isoformat()
if o.microsecond:
r = r[:12]
return r
elif isinstance(o, datetime.timedelta):
return duration_iso_string(o)
elif isinstance(o, (decimal.Decimal, uuid.UUID, Promise)):
return str(o)
else:
return super().default(o)

View File

@@ -0,0 +1,57 @@
"""
Serialize data to/from JSON Lines
"""
import json
from django.core.serializers.base import DeserializationError
from django.core.serializers.json import DjangoJSONEncoder
from django.core.serializers.python import (
Deserializer as PythonDeserializer, Serializer as PythonSerializer,
)
class Serializer(PythonSerializer):
"""Convert a queryset to JSON Lines."""
internal_use_only = False
def _init_options(self):
self._current = None
self.json_kwargs = self.options.copy()
self.json_kwargs.pop('stream', None)
self.json_kwargs.pop('fields', None)
self.json_kwargs.pop('indent', None)
self.json_kwargs['separators'] = (',', ': ')
self.json_kwargs.setdefault('cls', DjangoJSONEncoder)
self.json_kwargs.setdefault('ensure_ascii', False)
def start_serialization(self):
self._init_options()
def end_object(self, obj):
# self._current has the field data
json.dump(self.get_dump_object(obj), self.stream, **self.json_kwargs)
self.stream.write("\n")
self._current = None
def getvalue(self):
# Grandparent super
return super(PythonSerializer, self).getvalue()
def Deserializer(stream_or_string, **options):
"""Deserialize a stream or string of JSON data."""
if isinstance(stream_or_string, bytes):
stream_or_string = stream_or_string.decode()
if isinstance(stream_or_string, (bytes, str)):
stream_or_string = stream_or_string.split("\n")
for line in stream_or_string:
if not line.strip():
continue
try:
yield from PythonDeserializer([json.loads(line)], **options)
except (GeneratorExit, DeserializationError):
raise
except Exception as exc:
raise DeserializationError() from exc

View File

@@ -0,0 +1,157 @@
"""
A Python "serializer". Doesn't do much serializing per se -- just converts to
and from basic Python data types (lists, dicts, strings, etc.). Useful as a basis for
other serializers.
"""
from django.apps import apps
from django.core.serializers import base
from django.db import DEFAULT_DB_ALIAS, models
from django.utils.encoding import is_protected_type
class Serializer(base.Serializer):
"""
Serialize a QuerySet to basic Python objects.
"""
internal_use_only = True
def start_serialization(self):
self._current = None
self.objects = []
def end_serialization(self):
pass
def start_object(self, obj):
self._current = {}
def end_object(self, obj):
self.objects.append(self.get_dump_object(obj))
self._current = None
def get_dump_object(self, obj):
data = {'model': str(obj._meta)}
if not self.use_natural_primary_keys or not hasattr(obj, 'natural_key'):
data["pk"] = self._value_from_field(obj, obj._meta.pk)
data['fields'] = self._current
return data
def _value_from_field(self, obj, field):
value = field.value_from_object(obj)
# Protected types (i.e., primitives like None, numbers, dates,
# and Decimals) are passed through as is. All other values are
# converted to string first.
return value if is_protected_type(value) else field.value_to_string(obj)
def handle_field(self, obj, field):
self._current[field.name] = self._value_from_field(obj, field)
def handle_fk_field(self, obj, field):
if self.use_natural_foreign_keys and hasattr(field.remote_field.model, 'natural_key'):
related = getattr(obj, field.name)
if related:
value = related.natural_key()
else:
value = None
else:
value = self._value_from_field(obj, field)
self._current[field.name] = value
def handle_m2m_field(self, obj, field):
if field.remote_field.through._meta.auto_created:
if self.use_natural_foreign_keys and hasattr(field.remote_field.model, 'natural_key'):
def m2m_value(value):
return value.natural_key()
else:
def m2m_value(value):
return self._value_from_field(value, value._meta.pk)
m2m_iter = getattr(obj, '_prefetched_objects_cache', {}).get(
field.name,
getattr(obj, field.name).iterator(),
)
self._current[field.name] = [m2m_value(related) for related in m2m_iter]
def getvalue(self):
return self.objects
def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options):
"""
Deserialize simple Python objects back into Django ORM instances.
It's expected that you pass the Python objects themselves (instead of a
stream or a string) to the constructor
"""
handle_forward_references = options.pop('handle_forward_references', False)
field_names_cache = {} # Model: <list of field_names>
for d in object_list:
# Look up the model and starting build a dict of data for it.
try:
Model = _get_model(d["model"])
except base.DeserializationError:
if ignorenonexistent:
continue
else:
raise
data = {}
if 'pk' in d:
try:
data[Model._meta.pk.attname] = Model._meta.pk.to_python(d.get('pk'))
except Exception as e:
raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), None)
m2m_data = {}
deferred_fields = {}
if Model not in field_names_cache:
field_names_cache[Model] = {f.name for f in Model._meta.get_fields()}
field_names = field_names_cache[Model]
# Handle each field
for (field_name, field_value) in d["fields"].items():
if ignorenonexistent and field_name not in field_names:
# skip fields no longer on model
continue
field = Model._meta.get_field(field_name)
# Handle M2M relations
if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel):
try:
values = base.deserialize_m2m_values(field, field_value, using, handle_forward_references)
except base.M2MDeserializationError as e:
raise base.DeserializationError.WithData(e.original_exc, d['model'], d.get('pk'), e.pk)
if values == base.DEFER_FIELD:
deferred_fields[field] = field_value
else:
m2m_data[field.name] = values
# Handle FK fields
elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel):
try:
value = base.deserialize_fk_value(field, field_value, using, handle_forward_references)
except Exception as e:
raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value)
if value == base.DEFER_FIELD:
deferred_fields[field] = field_value
else:
data[field.attname] = value
# Handle all other fields
else:
try:
data[field.name] = field.to_python(field_value)
except Exception as e:
raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value)
obj = base.build_instance(Model, data, using)
yield base.DeserializedObject(obj, m2m_data, deferred_fields)
def _get_model(model_identifier):
"""Look up a model from an "app_label.model_name" string."""
try:
return apps.get_model(model_identifier)
except (LookupError, TypeError):
raise base.DeserializationError("Invalid model identifier: '%s'" % model_identifier)

View File

@@ -0,0 +1,80 @@
"""
YAML serializer.
Requires PyYaml (https://pyyaml.org/), but that's checked for in __init__.
"""
import collections
import decimal
from io import StringIO
import yaml
from django.core.serializers.base import DeserializationError
from django.core.serializers.python import (
Deserializer as PythonDeserializer, Serializer as PythonSerializer,
)
from django.db import models
# Use the C (faster) implementation if possible
try:
from yaml import CSafeDumper as SafeDumper, CSafeLoader as SafeLoader
except ImportError:
from yaml import SafeDumper, SafeLoader
class DjangoSafeDumper(SafeDumper):
def represent_decimal(self, data):
return self.represent_scalar('tag:yaml.org,2002:str', str(data))
def represent_ordered_dict(self, data):
return self.represent_mapping('tag:yaml.org,2002:map', data.items())
DjangoSafeDumper.add_representer(decimal.Decimal, DjangoSafeDumper.represent_decimal)
DjangoSafeDumper.add_representer(collections.OrderedDict, DjangoSafeDumper.represent_ordered_dict)
# Workaround to represent dictionaries in insertion order.
# See https://github.com/yaml/pyyaml/pull/143.
DjangoSafeDumper.add_representer(dict, DjangoSafeDumper.represent_ordered_dict)
class Serializer(PythonSerializer):
"""Convert a queryset to YAML."""
internal_use_only = False
def handle_field(self, obj, field):
# A nasty special case: base YAML doesn't support serialization of time
# types (as opposed to dates or datetimes, which it does support). Since
# we want to use the "safe" serializer for better interoperability, we
# need to do something with those pesky times. Converting 'em to strings
# isn't perfect, but it's better than a "!!python/time" type which would
# halt deserialization under any other language.
if isinstance(field, models.TimeField) and getattr(obj, field.name) is not None:
self._current[field.name] = str(getattr(obj, field.name))
else:
super().handle_field(obj, field)
def end_serialization(self):
self.options.setdefault('allow_unicode', True)
yaml.dump(self.objects, self.stream, Dumper=DjangoSafeDumper, **self.options)
def getvalue(self):
# Grandparent super
return super(PythonSerializer, self).getvalue()
def Deserializer(stream_or_string, **options):
"""Deserialize a stream or string of YAML data."""
if isinstance(stream_or_string, bytes):
stream_or_string = stream_or_string.decode()
if isinstance(stream_or_string, str):
stream = StringIO(stream_or_string)
else:
stream = stream_or_string
try:
yield from PythonDeserializer(yaml.load(stream, Loader=SafeLoader), **options)
except (GeneratorExit, DeserializationError):
raise
except Exception as exc:
raise DeserializationError() from exc

View File

@@ -0,0 +1,432 @@
"""
XML serializer.
"""
import json
from xml.dom import pulldom
from xml.sax import handler
from xml.sax.expatreader import ExpatParser as _ExpatParser
from django.apps import apps
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from django.core.serializers import base
from django.db import DEFAULT_DB_ALIAS, models
from django.utils.xmlutils import (
SimplerXMLGenerator, UnserializableContentError,
)
class Serializer(base.Serializer):
"""Serialize a QuerySet to XML."""
def indent(self, level):
if self.options.get('indent') is not None:
self.xml.ignorableWhitespace('\n' + ' ' * self.options.get('indent') * level)
def start_serialization(self):
"""
Start serialization -- open the XML document and the root element.
"""
self.xml = SimplerXMLGenerator(self.stream, self.options.get("encoding", settings.DEFAULT_CHARSET))
self.xml.startDocument()
self.xml.startElement("django-objects", {"version": "1.0"})
def end_serialization(self):
"""
End serialization -- end the document.
"""
self.indent(0)
self.xml.endElement("django-objects")
self.xml.endDocument()
def start_object(self, obj):
"""
Called as each object is handled.
"""
if not hasattr(obj, "_meta"):
raise base.SerializationError("Non-model object (%s) encountered during serialization" % type(obj))
self.indent(1)
attrs = {'model': str(obj._meta)}
if not self.use_natural_primary_keys or not hasattr(obj, 'natural_key'):
obj_pk = obj.pk
if obj_pk is not None:
attrs['pk'] = str(obj_pk)
self.xml.startElement("object", attrs)
def end_object(self, obj):
"""
Called after handling all fields for an object.
"""
self.indent(1)
self.xml.endElement("object")
def handle_field(self, obj, field):
"""
Handle each field on an object (except for ForeignKeys and
ManyToManyFields).
"""
self.indent(2)
self.xml.startElement('field', {
'name': field.name,
'type': field.get_internal_type(),
})
# Get a "string version" of the object's data.
if getattr(obj, field.name) is not None:
value = field.value_to_string(obj)
if field.get_internal_type() == 'JSONField':
# Dump value since JSONField.value_to_string() doesn't output
# strings.
value = json.dumps(value, cls=field.encoder)
try:
self.xml.characters(value)
except UnserializableContentError:
raise ValueError("%s.%s (pk:%s) contains unserializable characters" % (
obj.__class__.__name__, field.name, obj.pk))
else:
self.xml.addQuickElement("None")
self.xml.endElement("field")
def handle_fk_field(self, obj, field):
"""
Handle a ForeignKey (they need to be treated slightly
differently from regular fields).
"""
self._start_relational_field(field)
related_att = getattr(obj, field.get_attname())
if related_att is not None:
if self.use_natural_foreign_keys and hasattr(field.remote_field.model, 'natural_key'):
related = getattr(obj, field.name)
# If related object has a natural key, use it
related = related.natural_key()
# Iterable natural keys are rolled out as subelements
for key_value in related:
self.xml.startElement("natural", {})
self.xml.characters(str(key_value))
self.xml.endElement("natural")
else:
self.xml.characters(str(related_att))
else:
self.xml.addQuickElement("None")
self.xml.endElement("field")
def handle_m2m_field(self, obj, field):
"""
Handle a ManyToManyField. Related objects are only serialized as
references to the object's PK (i.e. the related *data* is not dumped,
just the relation).
"""
if field.remote_field.through._meta.auto_created:
self._start_relational_field(field)
if self.use_natural_foreign_keys and hasattr(field.remote_field.model, 'natural_key'):
# If the objects in the m2m have a natural key, use it
def handle_m2m(value):
natural = value.natural_key()
# Iterable natural keys are rolled out as subelements
self.xml.startElement("object", {})
for key_value in natural:
self.xml.startElement("natural", {})
self.xml.characters(str(key_value))
self.xml.endElement("natural")
self.xml.endElement("object")
else:
def handle_m2m(value):
self.xml.addQuickElement("object", attrs={
'pk': str(value.pk)
})
m2m_iter = getattr(obj, '_prefetched_objects_cache', {}).get(
field.name,
getattr(obj, field.name).iterator(),
)
for relobj in m2m_iter:
handle_m2m(relobj)
self.xml.endElement("field")
def _start_relational_field(self, field):
"""Output the <field> element for relational fields."""
self.indent(2)
self.xml.startElement('field', {
'name': field.name,
'rel': field.remote_field.__class__.__name__,
'to': str(field.remote_field.model._meta),
})
class Deserializer(base.Deserializer):
"""Deserialize XML."""
def __init__(self, stream_or_string, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options):
super().__init__(stream_or_string, **options)
self.handle_forward_references = options.pop('handle_forward_references', False)
self.event_stream = pulldom.parse(self.stream, self._make_parser())
self.db = using
self.ignore = ignorenonexistent
def _make_parser(self):
"""Create a hardened XML parser (no custom/external entities)."""
return DefusedExpatParser()
def __next__(self):
for event, node in self.event_stream:
if event == "START_ELEMENT" and node.nodeName == "object":
self.event_stream.expandNode(node)
return self._handle_object(node)
raise StopIteration
def _handle_object(self, node):
"""Convert an <object> node to a DeserializedObject."""
# Look up the model using the model loading mechanism. If this fails,
# bail.
Model = self._get_model_from_node(node, "model")
# Start building a data dictionary from the object.
data = {}
if node.hasAttribute('pk'):
data[Model._meta.pk.attname] = Model._meta.pk.to_python(
node.getAttribute('pk'))
# Also start building a dict of m2m data (this is saved as
# {m2m_accessor_attribute : [list_of_related_objects]})
m2m_data = {}
deferred_fields = {}
field_names = {f.name for f in Model._meta.get_fields()}
# Deserialize each field.
for field_node in node.getElementsByTagName("field"):
# If the field is missing the name attribute, bail (are you
# sensing a pattern here?)
field_name = field_node.getAttribute("name")
if not field_name:
raise base.DeserializationError("<field> node is missing the 'name' attribute")
# Get the field from the Model. This will raise a
# FieldDoesNotExist if, well, the field doesn't exist, which will
# be propagated correctly unless ignorenonexistent=True is used.
if self.ignore and field_name not in field_names:
continue
field = Model._meta.get_field(field_name)
# As is usually the case, relation fields get the special treatment.
if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel):
value = self._handle_m2m_field_node(field_node, field)
if value == base.DEFER_FIELD:
deferred_fields[field] = [
[
getInnerText(nat_node).strip()
for nat_node in obj_node.getElementsByTagName('natural')
]
for obj_node in field_node.getElementsByTagName('object')
]
else:
m2m_data[field.name] = value
elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel):
value = self._handle_fk_field_node(field_node, field)
if value == base.DEFER_FIELD:
deferred_fields[field] = [
getInnerText(k).strip()
for k in field_node.getElementsByTagName('natural')
]
else:
data[field.attname] = value
else:
if field_node.getElementsByTagName('None'):
value = None
else:
value = field.to_python(getInnerText(field_node).strip())
# Load value since JSONField.to_python() outputs strings.
if field.get_internal_type() == 'JSONField':
value = json.loads(value, cls=field.decoder)
data[field.name] = value
obj = base.build_instance(Model, data, self.db)
# Return a DeserializedObject so that the m2m data has a place to live.
return base.DeserializedObject(obj, m2m_data, deferred_fields)
def _handle_fk_field_node(self, node, field):
"""
Handle a <field> node for a ForeignKey
"""
# Check if there is a child node named 'None', returning None if so.
if node.getElementsByTagName('None'):
return None
else:
model = field.remote_field.model
if hasattr(model._default_manager, 'get_by_natural_key'):
keys = node.getElementsByTagName('natural')
if keys:
# If there are 'natural' subelements, it must be a natural key
field_value = [getInnerText(k).strip() for k in keys]
try:
obj = model._default_manager.db_manager(self.db).get_by_natural_key(*field_value)
except ObjectDoesNotExist:
if self.handle_forward_references:
return base.DEFER_FIELD
else:
raise
obj_pk = getattr(obj, field.remote_field.field_name)
# If this is a natural foreign key to an object that
# has a FK/O2O as the foreign key, use the FK value
if field.remote_field.model._meta.pk.remote_field:
obj_pk = obj_pk.pk
else:
# Otherwise, treat like a normal PK
field_value = getInnerText(node).strip()
obj_pk = model._meta.get_field(field.remote_field.field_name).to_python(field_value)
return obj_pk
else:
field_value = getInnerText(node).strip()
return model._meta.get_field(field.remote_field.field_name).to_python(field_value)
def _handle_m2m_field_node(self, node, field):
"""
Handle a <field> node for a ManyToManyField.
"""
model = field.remote_field.model
default_manager = model._default_manager
if hasattr(default_manager, 'get_by_natural_key'):
def m2m_convert(n):
keys = n.getElementsByTagName('natural')
if keys:
# If there are 'natural' subelements, it must be a natural key
field_value = [getInnerText(k).strip() for k in keys]
obj_pk = default_manager.db_manager(self.db).get_by_natural_key(*field_value).pk
else:
# Otherwise, treat like a normal PK value.
obj_pk = model._meta.pk.to_python(n.getAttribute('pk'))
return obj_pk
else:
def m2m_convert(n):
return model._meta.pk.to_python(n.getAttribute('pk'))
values = []
try:
for c in node.getElementsByTagName('object'):
values.append(m2m_convert(c))
except Exception as e:
if isinstance(e, ObjectDoesNotExist) and self.handle_forward_references:
return base.DEFER_FIELD
else:
raise base.M2MDeserializationError(e, c)
else:
return values
def _get_model_from_node(self, node, attr):
"""
Look up a model from a <object model=...> or a <field rel=... to=...>
node.
"""
model_identifier = node.getAttribute(attr)
if not model_identifier:
raise base.DeserializationError(
"<%s> node is missing the required '%s' attribute"
% (node.nodeName, attr))
try:
return apps.get_model(model_identifier)
except (LookupError, TypeError):
raise base.DeserializationError(
"<%s> node has invalid model identifier: '%s'"
% (node.nodeName, model_identifier))
def getInnerText(node):
"""Get all the inner text of a DOM node (recursively)."""
# inspired by https://mail.python.org/pipermail/xml-sig/2005-March/011022.html
inner_text = []
for child in node.childNodes:
if child.nodeType == child.TEXT_NODE or child.nodeType == child.CDATA_SECTION_NODE:
inner_text.append(child.data)
elif child.nodeType == child.ELEMENT_NODE:
inner_text.extend(getInnerText(child))
else:
pass
return "".join(inner_text)
# Below code based on Christian Heimes' defusedxml
class DefusedExpatParser(_ExpatParser):
"""
An expat parser hardened against XML bomb attacks.
Forbid DTDs, external entity references
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setFeature(handler.feature_external_ges, False)
self.setFeature(handler.feature_external_pes, False)
def start_doctype_decl(self, name, sysid, pubid, has_internal_subset):
raise DTDForbidden(name, sysid, pubid)
def entity_decl(self, name, is_parameter_entity, value, base,
sysid, pubid, notation_name):
raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name)
def unparsed_entity_decl(self, name, base, sysid, pubid, notation_name):
# expat 1.2
raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name)
def external_entity_ref_handler(self, context, base, sysid, pubid):
raise ExternalReferenceForbidden(context, base, sysid, pubid)
def reset(self):
_ExpatParser.reset(self)
parser = self._parser
parser.StartDoctypeDeclHandler = self.start_doctype_decl
parser.EntityDeclHandler = self.entity_decl
parser.UnparsedEntityDeclHandler = self.unparsed_entity_decl
parser.ExternalEntityRefHandler = self.external_entity_ref_handler
class DefusedXmlException(ValueError):
"""Base exception."""
def __repr__(self):
return str(self)
class DTDForbidden(DefusedXmlException):
"""Document type definition is forbidden."""
def __init__(self, name, sysid, pubid):
super().__init__()
self.name = name
self.sysid = sysid
self.pubid = pubid
def __str__(self):
tpl = "DTDForbidden(name='{}', system_id={!r}, public_id={!r})"
return tpl.format(self.name, self.sysid, self.pubid)
class EntitiesForbidden(DefusedXmlException):
"""Entity definition is forbidden."""
def __init__(self, name, value, base, sysid, pubid, notation_name):
super().__init__()
self.name = name
self.value = value
self.base = base
self.sysid = sysid
self.pubid = pubid
self.notation_name = notation_name
def __str__(self):
tpl = "EntitiesForbidden(name='{}', system_id={!r}, public_id={!r})"
return tpl.format(self.name, self.sysid, self.pubid)
class ExternalReferenceForbidden(DefusedXmlException):
"""Resolving an external reference is forbidden."""
def __init__(self, context, base, sysid, pubid):
super().__init__()
self.context = context
self.base = base
self.sysid = sysid
self.pubid = pubid
def __str__(self):
tpl = "ExternalReferenceForbidden(system_id='{}', public_id={})"
return tpl.format(self.sysid, self.pubid)

View File

@@ -0,0 +1,238 @@
"""
HTTP server that implements the Python WSGI protocol (PEP 333, rev 1.21).
Based on wsgiref.simple_server which is part of the standard library since 2.5.
This is a simple server for use in testing or debugging Django apps. It hasn't
been reviewed for security issues. DON'T USE IT FOR PRODUCTION USE!
"""
import logging
import socket
import socketserver
import sys
from wsgiref import simple_server
from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import LimitedStream
from django.core.wsgi import get_wsgi_application
from django.db import connections
from django.utils.module_loading import import_string
__all__ = ('WSGIServer', 'WSGIRequestHandler')
logger = logging.getLogger('django.server')
def get_internal_wsgi_application():
"""
Load and return the WSGI application as configured by the user in
``settings.WSGI_APPLICATION``. With the default ``startproject`` layout,
this will be the ``application`` object in ``projectname/wsgi.py``.
This function, and the ``WSGI_APPLICATION`` setting itself, are only useful
for Django's internal server (runserver); external WSGI servers should just
be configured to point to the correct application object directly.
If settings.WSGI_APPLICATION is not set (is ``None``), return
whatever ``django.core.wsgi.get_wsgi_application`` returns.
"""
from django.conf import settings
app_path = getattr(settings, 'WSGI_APPLICATION')
if app_path is None:
return get_wsgi_application()
try:
return import_string(app_path)
except ImportError as err:
raise ImproperlyConfigured(
"WSGI application '%s' could not be loaded; "
"Error importing module." % app_path
) from err
def is_broken_pipe_error():
exc_type, _, _ = sys.exc_info()
return issubclass(exc_type, (
BrokenPipeError,
ConnectionAbortedError,
ConnectionResetError,
))
class WSGIServer(simple_server.WSGIServer):
"""BaseHTTPServer that implements the Python WSGI protocol"""
request_queue_size = 10
def __init__(self, *args, ipv6=False, allow_reuse_address=True, **kwargs):
if ipv6:
self.address_family = socket.AF_INET6
self.allow_reuse_address = allow_reuse_address
super().__init__(*args, **kwargs)
def handle_error(self, request, client_address):
if is_broken_pipe_error():
logger.info("- Broken pipe from %s\n", client_address)
else:
super().handle_error(request, client_address)
class ThreadedWSGIServer(socketserver.ThreadingMixIn, WSGIServer):
"""A threaded version of the WSGIServer"""
daemon_threads = True
def __init__(self, *args, connections_override=None, **kwargs):
super().__init__(*args, **kwargs)
self.connections_override = connections_override
# socketserver.ThreadingMixIn.process_request() passes this method as
# the target to a new Thread object.
def process_request_thread(self, request, client_address):
if self.connections_override:
# Override this thread's database connections with the ones
# provided by the parent thread.
for alias, conn in self.connections_override.items():
connections[alias] = conn
super().process_request_thread(request, client_address)
def _close_connections(self):
# Used for mocking in tests.
connections.close_all()
def close_request(self, request):
self._close_connections()
super().close_request(request)
class ServerHandler(simple_server.ServerHandler):
http_version = '1.1'
def __init__(self, stdin, stdout, stderr, environ, **kwargs):
"""
Use a LimitedStream so that unread request data will be ignored at
the end of the request. WSGIRequest uses a LimitedStream but it
shouldn't discard the data since the upstream servers usually do this.
This fix applies only for testserver/runserver.
"""
try:
content_length = int(environ.get('CONTENT_LENGTH'))
except (ValueError, TypeError):
content_length = 0
super().__init__(LimitedStream(stdin, content_length), stdout, stderr, environ, **kwargs)
def cleanup_headers(self):
super().cleanup_headers()
# HTTP/1.1 requires support for persistent connections. Send 'close' if
# the content length is unknown to prevent clients from reusing the
# connection.
if 'Content-Length' not in self.headers:
self.headers['Connection'] = 'close'
# Persistent connections require threading server.
elif not isinstance(self.request_handler.server, socketserver.ThreadingMixIn):
self.headers['Connection'] = 'close'
# Mark the connection for closing if it's set as such above or if the
# application sent the header.
if self.headers.get('Connection') == 'close':
self.request_handler.close_connection = True
def close(self):
self.get_stdin()._read_limited()
super().close()
class WSGIRequestHandler(simple_server.WSGIRequestHandler):
protocol_version = 'HTTP/1.1'
def address_string(self):
# Short-circuit parent method to not call socket.getfqdn
return self.client_address[0]
def log_message(self, format, *args):
extra = {
'request': self.request,
'server_time': self.log_date_time_string(),
}
if args[1][0] == '4':
# 0x16 = Handshake, 0x03 = SSL 3.0 or TLS 1.x
if args[0].startswith('\x16\x03'):
extra['status_code'] = 500
logger.error(
"You're accessing the development server over HTTPS, but "
"it only supports HTTP.\n", extra=extra,
)
return
if args[1].isdigit() and len(args[1]) == 3:
status_code = int(args[1])
extra['status_code'] = status_code
if status_code >= 500:
level = logger.error
elif status_code >= 400:
level = logger.warning
else:
level = logger.info
else:
level = logger.info
level(format, *args, extra=extra)
def get_environ(self):
# Strip all headers with underscores in the name before constructing
# the WSGI environ. This prevents header-spoofing based on ambiguity
# between underscores and dashes both normalized to underscores in WSGI
# env vars. Nginx and Apache 2.4+ both do this as well.
for k in self.headers:
if '_' in k:
del self.headers[k]
return super().get_environ()
def handle(self):
self.close_connection = True
self.handle_one_request()
while not self.close_connection:
self.handle_one_request()
try:
self.connection.shutdown(socket.SHUT_WR)
except (AttributeError, OSError):
pass
def handle_one_request(self):
"""Copy of WSGIRequestHandler.handle() but with different ServerHandler"""
self.raw_requestline = self.rfile.readline(65537)
if len(self.raw_requestline) > 65536:
self.requestline = ''
self.request_version = ''
self.command = ''
self.send_error(414)
return
if not self.parse_request(): # An error code has been sent, just exit
return
handler = ServerHandler(
self.rfile, self.wfile, self.get_stderr(), self.get_environ()
)
handler.request_handler = self # backpointer for logging & connection closing
handler.run(self.server.get_app())
def run(addr, port, wsgi_handler, ipv6=False, threading=False, server_cls=WSGIServer):
server_address = (addr, port)
if threading:
httpd_cls = type('WSGIServer', (socketserver.ThreadingMixIn, server_cls), {})
else:
httpd_cls = server_cls
httpd = httpd_cls(server_address, WSGIRequestHandler, ipv6=ipv6)
if threading:
# ThreadingMixIn.daemon_threads indicates how threads will behave on an
# abrupt shutdown; like quitting the server by the user or restarting
# by the auto-reloader. True means the server will not wait for thread
# termination before it quits. This will make auto-reloader faster
# and will prevent the need to kill the server manually if a thread
# isn't terminating correctly.
httpd.daemon_threads = True
httpd.set_app(wsgi_handler)
httpd.serve_forever()

View File

@@ -0,0 +1,6 @@
from django.dispatch import Signal
request_started = Signal()
request_finished = Signal()
got_request_exception = Signal()
setting_changed = Signal()

View File

@@ -0,0 +1,237 @@
"""
Functions for creating and restoring url-safe signed JSON objects.
The format used looks like this:
>>> signing.dumps("hello")
'ImhlbGxvIg:1QaUZC:YIye-ze3TTx7gtSv422nZA4sgmk'
There are two components here, separated by a ':'. The first component is a
URLsafe base64 encoded JSON of the object passed to dumps(). The second
component is a base64 encoded hmac/SHA1 hash of "$first_component:$secret"
signing.loads(s) checks the signature and returns the deserialized object.
If the signature fails, a BadSignature exception is raised.
>>> signing.loads("ImhlbGxvIg:1QaUZC:YIye-ze3TTx7gtSv422nZA4sgmk")
'hello'
>>> signing.loads("ImhlbGxvIg:1QaUZC:YIye-ze3TTx7gtSv422nZA4sgmk-modified")
...
BadSignature: Signature failed: ImhlbGxvIg:1QaUZC:YIye-ze3TTx7gtSv422nZA4sgmk-modified
You can optionally compress the JSON prior to base64 encoding it to save
space, using the compress=True argument. This checks if compression actually
helps and only applies compression if the result is a shorter string:
>>> signing.dumps(list(range(1, 20)), compress=True)
'.eJwFwcERACAIwLCF-rCiILN47r-GyZVJsNgkxaFxoDgxcOHGxMKD_T7vhAml:1QaUaL:BA0thEZrp4FQVXIXuOvYJtLJSrQ'
The fact that the string is compressed is signalled by the prefixed '.' at the
start of the base64 JSON.
There are 65 url-safe characters: the 64 used by url-safe base64 and the ':'.
These functions make use of all of them.
"""
import base64
import datetime
import json
import time
import zlib
from django.conf import settings
from django.utils.crypto import constant_time_compare, salted_hmac
from django.utils.encoding import force_bytes
from django.utils.module_loading import import_string
from django.utils.regex_helper import _lazy_re_compile
_SEP_UNSAFE = _lazy_re_compile(r'^[A-z0-9-_=]*$')
BASE62_ALPHABET = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
class BadSignature(Exception):
"""Signature does not match."""
pass
class SignatureExpired(BadSignature):
"""Signature timestamp is older than required max_age."""
pass
def b62_encode(s):
if s == 0:
return '0'
sign = '-' if s < 0 else ''
s = abs(s)
encoded = ''
while s > 0:
s, remainder = divmod(s, 62)
encoded = BASE62_ALPHABET[remainder] + encoded
return sign + encoded
def b62_decode(s):
if s == '0':
return 0
sign = 1
if s[0] == '-':
s = s[1:]
sign = -1
decoded = 0
for digit in s:
decoded = decoded * 62 + BASE62_ALPHABET.index(digit)
return sign * decoded
def b64_encode(s):
return base64.urlsafe_b64encode(s).strip(b'=')
def b64_decode(s):
pad = b'=' * (-len(s) % 4)
return base64.urlsafe_b64decode(s + pad)
def base64_hmac(salt, value, key, algorithm='sha1'):
return b64_encode(salted_hmac(salt, value, key, algorithm=algorithm).digest()).decode()
def get_cookie_signer(salt='django.core.signing.get_cookie_signer'):
Signer = import_string(settings.SIGNING_BACKEND)
key = force_bytes(settings.SECRET_KEY) # SECRET_KEY may be str or bytes.
return Signer(b'django.http.cookies' + key, salt=salt)
class JSONSerializer:
"""
Simple wrapper around json to be used in signing.dumps and
signing.loads.
"""
def dumps(self, obj):
return json.dumps(obj, separators=(',', ':')).encode('latin-1')
def loads(self, data):
return json.loads(data.decode('latin-1'))
def dumps(obj, key=None, salt='django.core.signing', serializer=JSONSerializer, compress=False):
"""
Return URL-safe, hmac signed base64 compressed JSON string. If key is
None, use settings.SECRET_KEY instead. The hmac algorithm is the default
Signer algorithm.
If compress is True (not the default), check if compressing using zlib can
save some space. Prepend a '.' to signify compression. This is included
in the signature, to protect against zip bombs.
Salt can be used to namespace the hash, so that a signed string is
only valid for a given namespace. Leaving this at the default
value or re-using a salt value across different parts of your
application without good cause is a security risk.
The serializer is expected to return a bytestring.
"""
return TimestampSigner(key, salt=salt).sign_object(obj, serializer=serializer, compress=compress)
def loads(s, key=None, salt='django.core.signing', serializer=JSONSerializer, max_age=None):
"""
Reverse of dumps(), raise BadSignature if signature fails.
The serializer is expected to accept a bytestring.
"""
return TimestampSigner(key, salt=salt).unsign_object(s, serializer=serializer, max_age=max_age)
class Signer:
def __init__(self, key=None, sep=':', salt=None, algorithm=None):
self.key = key or settings.SECRET_KEY
self.sep = sep
if _SEP_UNSAFE.match(self.sep):
raise ValueError(
'Unsafe Signer separator: %r (cannot be empty or consist of '
'only A-z0-9-_=)' % sep,
)
self.salt = salt or '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
self.algorithm = algorithm or 'sha256'
def signature(self, value):
return base64_hmac(self.salt + 'signer', value, self.key, algorithm=self.algorithm)
def sign(self, value):
return '%s%s%s' % (value, self.sep, self.signature(value))
def unsign(self, signed_value):
if self.sep not in signed_value:
raise BadSignature('No "%s" found in value' % self.sep)
value, sig = signed_value.rsplit(self.sep, 1)
if constant_time_compare(sig, self.signature(value)):
return value
raise BadSignature('Signature "%s" does not match' % sig)
def sign_object(self, obj, serializer=JSONSerializer, compress=False):
"""
Return URL-safe, hmac signed base64 compressed JSON string.
If compress is True (not the default), check if compressing using zlib
can save some space. Prepend a '.' to signify compression. This is
included in the signature, to protect against zip bombs.
The serializer is expected to return a bytestring.
"""
data = serializer().dumps(obj)
# Flag for if it's been compressed or not.
is_compressed = False
if compress:
# Avoid zlib dependency unless compress is being used.
compressed = zlib.compress(data)
if len(compressed) < (len(data) - 1):
data = compressed
is_compressed = True
base64d = b64_encode(data).decode()
if is_compressed:
base64d = '.' + base64d
return self.sign(base64d)
def unsign_object(self, signed_obj, serializer=JSONSerializer, **kwargs):
# Signer.unsign() returns str but base64 and zlib compression operate
# on bytes.
base64d = self.unsign(signed_obj, **kwargs).encode()
decompress = base64d[:1] == b'.'
if decompress:
# It's compressed; uncompress it first.
base64d = base64d[1:]
data = b64_decode(base64d)
if decompress:
data = zlib.decompress(data)
return serializer().loads(data)
class TimestampSigner(Signer):
def timestamp(self):
return b62_encode(int(time.time()))
def sign(self, value):
value = '%s%s%s' % (value, self.sep, self.timestamp())
return super().sign(value)
def unsign(self, value, max_age=None):
"""
Retrieve original value and check it wasn't signed more
than max_age seconds ago.
"""
result = super().unsign(value)
value, timestamp = result.rsplit(self.sep, 1)
timestamp = b62_decode(timestamp)
if max_age is not None:
if isinstance(max_age, datetime.timedelta):
max_age = max_age.total_seconds()
# Check timestamp is not older than max_age
age = time.time() - timestamp
if age > max_age:
raise SignatureExpired(
'Signature age %s > %s seconds' % (age, max_age))
return value

View File

@@ -0,0 +1,577 @@
import ipaddress
import re
import warnings
from pathlib import Path
from urllib.parse import urlsplit, urlunsplit
from django.core.exceptions import ValidationError
from django.utils.deconstruct import deconstructible
from django.utils.deprecation import RemovedInDjango41Warning
from django.utils.encoding import punycode
from django.utils.ipv6 import is_valid_ipv6_address
from django.utils.regex_helper import _lazy_re_compile
from django.utils.translation import gettext_lazy as _, ngettext_lazy
# These values, if given to validate(), will trigger the self.required check.
EMPTY_VALUES = (None, '', [], (), {})
@deconstructible
class RegexValidator:
regex = ''
message = _('Enter a valid value.')
code = 'invalid'
inverse_match = False
flags = 0
def __init__(self, regex=None, message=None, code=None, inverse_match=None, flags=None):
if regex is not None:
self.regex = regex
if message is not None:
self.message = message
if code is not None:
self.code = code
if inverse_match is not None:
self.inverse_match = inverse_match
if flags is not None:
self.flags = flags
if self.flags and not isinstance(self.regex, str):
raise TypeError("If the flags are set, regex must be a regular expression string.")
self.regex = _lazy_re_compile(self.regex, self.flags)
def __call__(self, value):
"""
Validate that the input contains (or does *not* contain, if
inverse_match is True) a match for the regular expression.
"""
regex_matches = self.regex.search(str(value))
invalid_input = regex_matches if self.inverse_match else not regex_matches
if invalid_input:
raise ValidationError(self.message, code=self.code, params={'value': value})
def __eq__(self, other):
return (
isinstance(other, RegexValidator) and
self.regex.pattern == other.regex.pattern and
self.regex.flags == other.regex.flags and
(self.message == other.message) and
(self.code == other.code) and
(self.inverse_match == other.inverse_match)
)
@deconstructible
class URLValidator(RegexValidator):
ul = '\u00a1-\uffff' # Unicode letters range (must not be a raw string).
# IP patterns
ipv4_re = r'(?:0|25[0-5]|2[0-4]\d|1\d?\d?|[1-9]\d?)(?:\.(?:0|25[0-5]|2[0-4]\d|1\d?\d?|[1-9]\d?)){3}'
ipv6_re = r'\[[0-9a-f:.]+\]' # (simple regex, validated later)
# Host patterns
hostname_re = r'[a-z' + ul + r'0-9](?:[a-z' + ul + r'0-9-]{0,61}[a-z' + ul + r'0-9])?'
# Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1
domain_re = r'(?:\.(?!-)[a-z' + ul + r'0-9-]{1,63}(?<!-))*'
tld_re = (
r'\.' # dot
r'(?!-)' # can't start with a dash
r'(?:[a-z' + ul + '-]{2,63}' # domain label
r'|xn--[a-z0-9]{1,59})' # or punycode label
r'(?<!-)' # can't end with a dash
r'\.?' # may have a trailing dot
)
host_re = '(' + hostname_re + domain_re + tld_re + '|localhost)'
regex = _lazy_re_compile(
r'^(?:[a-z0-9.+-]*)://' # scheme is validated separately
r'(?:[^\s:@/]+(?::[^\s:@/]*)?@)?' # user:pass authentication
r'(?:' + ipv4_re + '|' + ipv6_re + '|' + host_re + ')'
r'(?::\d{1,5})?' # port
r'(?:[/?#][^\s]*)?' # resource path
r'\Z', re.IGNORECASE)
message = _('Enter a valid URL.')
schemes = ['http', 'https', 'ftp', 'ftps']
unsafe_chars = frozenset('\t\r\n')
def __init__(self, schemes=None, **kwargs):
super().__init__(**kwargs)
if schemes is not None:
self.schemes = schemes
def __call__(self, value):
if not isinstance(value, str):
raise ValidationError(self.message, code=self.code, params={'value': value})
if self.unsafe_chars.intersection(value):
raise ValidationError(self.message, code=self.code, params={'value': value})
# Check if the scheme is valid.
scheme = value.split('://')[0].lower()
if scheme not in self.schemes:
raise ValidationError(self.message, code=self.code, params={'value': value})
# Then check full URL
try:
super().__call__(value)
except ValidationError as e:
# Trivial case failed. Try for possible IDN domain
if value:
try:
scheme, netloc, path, query, fragment = urlsplit(value)
except ValueError: # for example, "Invalid IPv6 URL"
raise ValidationError(self.message, code=self.code, params={'value': value})
try:
netloc = punycode(netloc) # IDN -> ACE
except UnicodeError: # invalid domain part
raise e
url = urlunsplit((scheme, netloc, path, query, fragment))
super().__call__(url)
else:
raise
else:
# Now verify IPv6 in the netloc part
host_match = re.search(r'^\[(.+)\](?::\d{1,5})?$', urlsplit(value).netloc)
if host_match:
potential_ip = host_match[1]
try:
validate_ipv6_address(potential_ip)
except ValidationError:
raise ValidationError(self.message, code=self.code, params={'value': value})
# The maximum length of a full host name is 253 characters per RFC 1034
# section 3.1. It's defined to be 255 bytes or less, but this includes
# one byte for the length of the name and one byte for the trailing dot
# that's used to indicate absolute names in DNS.
if len(urlsplit(value).hostname) > 253:
raise ValidationError(self.message, code=self.code, params={'value': value})
integer_validator = RegexValidator(
_lazy_re_compile(r'^-?\d+\Z'),
message=_('Enter a valid integer.'),
code='invalid',
)
def validate_integer(value):
return integer_validator(value)
@deconstructible
class EmailValidator:
message = _('Enter a valid email address.')
code = 'invalid'
user_regex = _lazy_re_compile(
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z" # dot-atom
r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])*"\Z)', # quoted-string
re.IGNORECASE)
domain_regex = _lazy_re_compile(
# max length for domain name labels is 63 characters per RFC 1034
r'((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z',
re.IGNORECASE)
literal_regex = _lazy_re_compile(
# literal form, ipv4 or ipv6 address (SMTP 4.1.3)
r'\[([A-F0-9:.]+)\]\Z',
re.IGNORECASE)
domain_allowlist = ['localhost']
@property
def domain_whitelist(self):
warnings.warn(
'The domain_whitelist attribute is deprecated in favor of '
'domain_allowlist.',
RemovedInDjango41Warning,
stacklevel=2,
)
return self.domain_allowlist
@domain_whitelist.setter
def domain_whitelist(self, allowlist):
warnings.warn(
'The domain_whitelist attribute is deprecated in favor of '
'domain_allowlist.',
RemovedInDjango41Warning,
stacklevel=2,
)
self.domain_allowlist = allowlist
def __init__(self, message=None, code=None, allowlist=None, *, whitelist=None):
if whitelist is not None:
allowlist = whitelist
warnings.warn(
'The whitelist argument is deprecated in favor of allowlist.',
RemovedInDjango41Warning,
stacklevel=2,
)
if message is not None:
self.message = message
if code is not None:
self.code = code
if allowlist is not None:
self.domain_allowlist = allowlist
def __call__(self, value):
if not value or '@' not in value:
raise ValidationError(self.message, code=self.code, params={'value': value})
user_part, domain_part = value.rsplit('@', 1)
if not self.user_regex.match(user_part):
raise ValidationError(self.message, code=self.code, params={'value': value})
if (domain_part not in self.domain_allowlist and
not self.validate_domain_part(domain_part)):
# Try for possible IDN domain-part
try:
domain_part = punycode(domain_part)
except UnicodeError:
pass
else:
if self.validate_domain_part(domain_part):
return
raise ValidationError(self.message, code=self.code, params={'value': value})
def validate_domain_part(self, domain_part):
if self.domain_regex.match(domain_part):
return True
literal_match = self.literal_regex.match(domain_part)
if literal_match:
ip_address = literal_match[1]
try:
validate_ipv46_address(ip_address)
return True
except ValidationError:
pass
return False
def __eq__(self, other):
return (
isinstance(other, EmailValidator) and
(self.domain_allowlist == other.domain_allowlist) and
(self.message == other.message) and
(self.code == other.code)
)
validate_email = EmailValidator()
slug_re = _lazy_re_compile(r'^[-a-zA-Z0-9_]+\Z')
validate_slug = RegexValidator(
slug_re,
# Translators: "letters" means latin letters: a-z and A-Z.
_('Enter a valid “slug” consisting of letters, numbers, underscores or hyphens.'),
'invalid'
)
slug_unicode_re = _lazy_re_compile(r'^[-\w]+\Z')
validate_unicode_slug = RegexValidator(
slug_unicode_re,
_('Enter a valid “slug” consisting of Unicode letters, numbers, underscores, or hyphens.'),
'invalid'
)
def validate_ipv4_address(value):
try:
ipaddress.IPv4Address(value)
except ValueError:
raise ValidationError(_('Enter a valid IPv4 address.'), code='invalid', params={'value': value})
else:
# Leading zeros are forbidden to avoid ambiguity with the octal
# notation. This restriction is included in Python 3.9.5+.
# TODO: Remove when dropping support for PY39.
if any(
octet != '0' and octet[0] == '0'
for octet in value.split('.')
):
raise ValidationError(
_('Enter a valid IPv4 address.'),
code='invalid',
params={'value': value},
)
def validate_ipv6_address(value):
if not is_valid_ipv6_address(value):
raise ValidationError(_('Enter a valid IPv6 address.'), code='invalid', params={'value': value})
def validate_ipv46_address(value):
try:
validate_ipv4_address(value)
except ValidationError:
try:
validate_ipv6_address(value)
except ValidationError:
raise ValidationError(_('Enter a valid IPv4 or IPv6 address.'), code='invalid', params={'value': value})
ip_address_validator_map = {
'both': ([validate_ipv46_address], _('Enter a valid IPv4 or IPv6 address.')),
'ipv4': ([validate_ipv4_address], _('Enter a valid IPv4 address.')),
'ipv6': ([validate_ipv6_address], _('Enter a valid IPv6 address.')),
}
def ip_address_validators(protocol, unpack_ipv4):
"""
Depending on the given parameters, return the appropriate validators for
the GenericIPAddressField.
"""
if protocol != 'both' and unpack_ipv4:
raise ValueError(
"You can only use `unpack_ipv4` if `protocol` is set to 'both'")
try:
return ip_address_validator_map[protocol.lower()]
except KeyError:
raise ValueError("The protocol '%s' is unknown. Supported: %s"
% (protocol, list(ip_address_validator_map)))
def int_list_validator(sep=',', message=None, code='invalid', allow_negative=False):
regexp = _lazy_re_compile(r'^%(neg)s\d+(?:%(sep)s%(neg)s\d+)*\Z' % {
'neg': '(-)?' if allow_negative else '',
'sep': re.escape(sep),
})
return RegexValidator(regexp, message=message, code=code)
validate_comma_separated_integer_list = int_list_validator(
message=_('Enter only digits separated by commas.'),
)
@deconstructible
class BaseValidator:
message = _('Ensure this value is %(limit_value)s (it is %(show_value)s).')
code = 'limit_value'
def __init__(self, limit_value, message=None):
self.limit_value = limit_value
if message:
self.message = message
def __call__(self, value):
cleaned = self.clean(value)
limit_value = self.limit_value() if callable(self.limit_value) else self.limit_value
params = {'limit_value': limit_value, 'show_value': cleaned, 'value': value}
if self.compare(cleaned, limit_value):
raise ValidationError(self.message, code=self.code, params=params)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (
self.limit_value == other.limit_value and
self.message == other.message and
self.code == other.code
)
def compare(self, a, b):
return a is not b
def clean(self, x):
return x
@deconstructible
class MaxValueValidator(BaseValidator):
message = _('Ensure this value is less than or equal to %(limit_value)s.')
code = 'max_value'
def compare(self, a, b):
return a > b
@deconstructible
class MinValueValidator(BaseValidator):
message = _('Ensure this value is greater than or equal to %(limit_value)s.')
code = 'min_value'
def compare(self, a, b):
return a < b
@deconstructible
class MinLengthValidator(BaseValidator):
message = ngettext_lazy(
'Ensure this value has at least %(limit_value)d character (it has %(show_value)d).',
'Ensure this value has at least %(limit_value)d characters (it has %(show_value)d).',
'limit_value')
code = 'min_length'
def compare(self, a, b):
return a < b
def clean(self, x):
return len(x)
@deconstructible
class MaxLengthValidator(BaseValidator):
message = ngettext_lazy(
'Ensure this value has at most %(limit_value)d character (it has %(show_value)d).',
'Ensure this value has at most %(limit_value)d characters (it has %(show_value)d).',
'limit_value')
code = 'max_length'
def compare(self, a, b):
return a > b
def clean(self, x):
return len(x)
@deconstructible
class DecimalValidator:
"""
Validate that the input does not exceed the maximum number of digits
expected, otherwise raise ValidationError.
"""
messages = {
'invalid': _('Enter a number.'),
'max_digits': ngettext_lazy(
'Ensure that there are no more than %(max)s digit in total.',
'Ensure that there are no more than %(max)s digits in total.',
'max'
),
'max_decimal_places': ngettext_lazy(
'Ensure that there are no more than %(max)s decimal place.',
'Ensure that there are no more than %(max)s decimal places.',
'max'
),
'max_whole_digits': ngettext_lazy(
'Ensure that there are no more than %(max)s digit before the decimal point.',
'Ensure that there are no more than %(max)s digits before the decimal point.',
'max'
),
}
def __init__(self, max_digits, decimal_places):
self.max_digits = max_digits
self.decimal_places = decimal_places
def __call__(self, value):
digit_tuple, exponent = value.as_tuple()[1:]
if exponent in {'F', 'n', 'N'}:
raise ValidationError(self.messages['invalid'], code='invalid', params={'value': value})
if exponent >= 0:
# A positive exponent adds that many trailing zeros.
digits = len(digit_tuple) + exponent
decimals = 0
else:
# If the absolute value of the negative exponent is larger than the
# number of digits, then it's the same as the number of digits,
# because it'll consume all of the digits in digit_tuple and then
# add abs(exponent) - len(digit_tuple) leading zeros after the
# decimal point.
if abs(exponent) > len(digit_tuple):
digits = decimals = abs(exponent)
else:
digits = len(digit_tuple)
decimals = abs(exponent)
whole_digits = digits - decimals
if self.max_digits is not None and digits > self.max_digits:
raise ValidationError(
self.messages['max_digits'],
code='max_digits',
params={'max': self.max_digits, 'value': value},
)
if self.decimal_places is not None and decimals > self.decimal_places:
raise ValidationError(
self.messages['max_decimal_places'],
code='max_decimal_places',
params={'max': self.decimal_places, 'value': value},
)
if (self.max_digits is not None and self.decimal_places is not None and
whole_digits > (self.max_digits - self.decimal_places)):
raise ValidationError(
self.messages['max_whole_digits'],
code='max_whole_digits',
params={'max': (self.max_digits - self.decimal_places), 'value': value},
)
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
self.max_digits == other.max_digits and
self.decimal_places == other.decimal_places
)
@deconstructible
class FileExtensionValidator:
message = _(
'File extension “%(extension)s” is not allowed. '
'Allowed extensions are: %(allowed_extensions)s.'
)
code = 'invalid_extension'
def __init__(self, allowed_extensions=None, message=None, code=None):
if allowed_extensions is not None:
allowed_extensions = [allowed_extension.lower() for allowed_extension in allowed_extensions]
self.allowed_extensions = allowed_extensions
if message is not None:
self.message = message
if code is not None:
self.code = code
def __call__(self, value):
extension = Path(value.name).suffix[1:].lower()
if self.allowed_extensions is not None and extension not in self.allowed_extensions:
raise ValidationError(
self.message,
code=self.code,
params={
'extension': extension,
'allowed_extensions': ', '.join(self.allowed_extensions),
'value': value,
}
)
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
self.allowed_extensions == other.allowed_extensions and
self.message == other.message and
self.code == other.code
)
def get_available_image_extensions():
try:
from PIL import Image
except ImportError:
return []
else:
Image.init()
return [ext.lower()[1:] for ext in Image.EXTENSION]
def validate_image_file_extension(value):
return FileExtensionValidator(allowed_extensions=get_available_image_extensions())(value)
@deconstructible
class ProhibitNullCharactersValidator:
"""Validate that the string doesn't contain the null character."""
message = _('Null characters are not allowed.')
code = 'null_characters_not_allowed'
def __init__(self, message=None, code=None):
if message is not None:
self.message = message
if code is not None:
self.code = code
def __call__(self, value):
if '\x00' in str(value):
raise ValidationError(self.message, code=self.code, params={'value': value})
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
self.message == other.message and
self.code == other.code
)

View File

@@ -0,0 +1,13 @@
import django
from django.core.handlers.wsgi import WSGIHandler
def get_wsgi_application():
"""
The public interface to Django's WSGI support. Return a WSGI callable.
Avoids making django.core.handlers.WSGIHandler a public API, in case the
internal WSGI implementation changes or moves in the future.
"""
django.setup(set_prefix=False)
return WSGIHandler()