Ajoutez des fichiers projet.
This commit is contained in:
21
venv/Lib/site-packages/django/test/__init__.py
Normal file
21
venv/Lib/site-packages/django/test/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Django Unit Test framework."""
|
||||
|
||||
from django.test.client import (
|
||||
AsyncClient, AsyncRequestFactory, Client, RequestFactory,
|
||||
)
|
||||
from django.test.testcases import (
|
||||
LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase,
|
||||
skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature,
|
||||
)
|
||||
from django.test.utils import (
|
||||
ignore_warnings, modify_settings, override_settings,
|
||||
override_system_checks, tag,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory',
|
||||
'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase',
|
||||
'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature',
|
||||
'ignore_warnings', 'modify_settings', 'override_settings',
|
||||
'override_system_checks', 'tag',
|
||||
]
|
937
venv/Lib/site-packages/django/test/client.py
Normal file
937
venv/Lib/site-packages/django/test/client.py
Normal file
@@ -0,0 +1,937 @@
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import sys
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from importlib import import_module
|
||||
from io import BytesIO
|
||||
from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.handlers.asgi import ASGIRequest
|
||||
from django.core.handlers.base import BaseHandler
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
from django.core.signals import (
|
||||
got_request_exception, request_finished, request_started,
|
||||
)
|
||||
from django.db import close_old_connections
|
||||
from django.http import HttpRequest, QueryDict, SimpleCookie
|
||||
from django.test import signals
|
||||
from django.test.utils import ContextList
|
||||
from django.urls import resolve
|
||||
from django.utils.encoding import force_bytes
|
||||
from django.utils.functional import SimpleLazyObject
|
||||
from django.utils.http import urlencode
|
||||
from django.utils.itercompat import is_iterable
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
__all__ = (
|
||||
'AsyncClient', 'AsyncRequestFactory', 'Client', 'RedirectCycleError',
|
||||
'RequestFactory', 'encode_file', 'encode_multipart',
|
||||
)
|
||||
|
||||
|
||||
BOUNDARY = 'BoUnDaRyStRiNg'
|
||||
MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY
|
||||
CONTENT_TYPE_RE = _lazy_re_compile(r'.*; charset=([\w\d-]+);?')
|
||||
# Structured suffix spec: https://tools.ietf.org/html/rfc6838#section-4.2.8
|
||||
JSON_CONTENT_TYPE_RE = _lazy_re_compile(r'^application\/(.+\+)?json')
|
||||
|
||||
|
||||
class RedirectCycleError(Exception):
|
||||
"""The test client has been asked to follow a redirect loop."""
|
||||
def __init__(self, message, last_response):
|
||||
super().__init__(message)
|
||||
self.last_response = last_response
|
||||
self.redirect_chain = last_response.redirect_chain
|
||||
|
||||
|
||||
class FakePayload:
|
||||
"""
|
||||
A wrapper around BytesIO that restricts what can be read since data from
|
||||
the network can't be sought and cannot be read outside of its content
|
||||
length. This makes sure that views can't do anything under the test client
|
||||
that wouldn't work in real life.
|
||||
"""
|
||||
def __init__(self, content=None):
|
||||
self.__content = BytesIO()
|
||||
self.__len = 0
|
||||
self.read_started = False
|
||||
if content is not None:
|
||||
self.write(content)
|
||||
|
||||
def __len__(self):
|
||||
return self.__len
|
||||
|
||||
def read(self, num_bytes=None):
|
||||
if not self.read_started:
|
||||
self.__content.seek(0)
|
||||
self.read_started = True
|
||||
if num_bytes is None:
|
||||
num_bytes = self.__len or 0
|
||||
assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data."
|
||||
content = self.__content.read(num_bytes)
|
||||
self.__len -= num_bytes
|
||||
return content
|
||||
|
||||
def write(self, content):
|
||||
if self.read_started:
|
||||
raise ValueError("Unable to write a payload after it's been read")
|
||||
content = force_bytes(content)
|
||||
self.__content.write(content)
|
||||
self.__len += len(content)
|
||||
|
||||
|
||||
def closing_iterator_wrapper(iterable, close):
|
||||
try:
|
||||
yield from iterable
|
||||
finally:
|
||||
request_finished.disconnect(close_old_connections)
|
||||
close() # will fire request_finished
|
||||
request_finished.connect(close_old_connections)
|
||||
|
||||
|
||||
def conditional_content_removal(request, response):
|
||||
"""
|
||||
Simulate the behavior of most web servers by removing the content of
|
||||
responses for HEAD requests, 1xx, 204, and 304 responses. Ensure
|
||||
compliance with RFC 7230, section 3.3.3.
|
||||
"""
|
||||
if 100 <= response.status_code < 200 or response.status_code in (204, 304):
|
||||
if response.streaming:
|
||||
response.streaming_content = []
|
||||
else:
|
||||
response.content = b''
|
||||
if request.method == 'HEAD':
|
||||
if response.streaming:
|
||||
response.streaming_content = []
|
||||
else:
|
||||
response.content = b''
|
||||
return response
|
||||
|
||||
|
||||
class ClientHandler(BaseHandler):
|
||||
"""
|
||||
An HTTP Handler that can be used for testing purposes. Use the WSGI
|
||||
interface to compose requests, but return the raw HttpResponse object with
|
||||
the originating WSGIRequest attached to its ``wsgi_request`` attribute.
|
||||
"""
|
||||
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
|
||||
self.enforce_csrf_checks = enforce_csrf_checks
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, environ):
|
||||
# Set up middleware if needed. We couldn't do this earlier, because
|
||||
# settings weren't available.
|
||||
if self._middleware_chain is None:
|
||||
self.load_middleware()
|
||||
|
||||
request_started.disconnect(close_old_connections)
|
||||
request_started.send(sender=self.__class__, environ=environ)
|
||||
request_started.connect(close_old_connections)
|
||||
request = WSGIRequest(environ)
|
||||
# sneaky little hack so that we can easily get round
|
||||
# CsrfViewMiddleware. This makes life easier, and is probably
|
||||
# required for backwards compatibility with external tests against
|
||||
# admin views.
|
||||
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
|
||||
|
||||
# Request goes through middleware.
|
||||
response = self.get_response(request)
|
||||
|
||||
# Simulate behaviors of most web servers.
|
||||
conditional_content_removal(request, response)
|
||||
|
||||
# Attach the originating request to the response so that it could be
|
||||
# later retrieved.
|
||||
response.wsgi_request = request
|
||||
|
||||
# Emulate a WSGI server by calling the close method on completion.
|
||||
if response.streaming:
|
||||
response.streaming_content = closing_iterator_wrapper(
|
||||
response.streaming_content, response.close)
|
||||
else:
|
||||
request_finished.disconnect(close_old_connections)
|
||||
response.close() # will fire request_finished
|
||||
request_finished.connect(close_old_connections)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class AsyncClientHandler(BaseHandler):
|
||||
"""An async version of ClientHandler."""
|
||||
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
|
||||
self.enforce_csrf_checks = enforce_csrf_checks
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def __call__(self, scope):
|
||||
# Set up middleware if needed. We couldn't do this earlier, because
|
||||
# settings weren't available.
|
||||
if self._middleware_chain is None:
|
||||
self.load_middleware(is_async=True)
|
||||
# Extract body file from the scope, if provided.
|
||||
if '_body_file' in scope:
|
||||
body_file = scope.pop('_body_file')
|
||||
else:
|
||||
body_file = FakePayload('')
|
||||
|
||||
request_started.disconnect(close_old_connections)
|
||||
await sync_to_async(request_started.send, thread_sensitive=False)(sender=self.__class__, scope=scope)
|
||||
request_started.connect(close_old_connections)
|
||||
request = ASGIRequest(scope, body_file)
|
||||
# Sneaky little hack so that we can easily get round
|
||||
# CsrfViewMiddleware. This makes life easier, and is probably required
|
||||
# for backwards compatibility with external tests against admin views.
|
||||
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
|
||||
# Request goes through middleware.
|
||||
response = await self.get_response_async(request)
|
||||
# Simulate behaviors of most web servers.
|
||||
conditional_content_removal(request, response)
|
||||
# Attach the originating ASGI request to the response so that it could
|
||||
# be later retrieved.
|
||||
response.asgi_request = request
|
||||
# Emulate a server by calling the close method on completion.
|
||||
if response.streaming:
|
||||
response.streaming_content = await sync_to_async(closing_iterator_wrapper, thread_sensitive=False)(
|
||||
response.streaming_content,
|
||||
response.close,
|
||||
)
|
||||
else:
|
||||
request_finished.disconnect(close_old_connections)
|
||||
# Will fire request_finished.
|
||||
await sync_to_async(response.close, thread_sensitive=False)()
|
||||
request_finished.connect(close_old_connections)
|
||||
return response
|
||||
|
||||
|
||||
def store_rendered_templates(store, signal, sender, template, context, **kwargs):
|
||||
"""
|
||||
Store templates and contexts that are rendered.
|
||||
|
||||
The context is copied so that it is an accurate representation at the time
|
||||
of rendering.
|
||||
"""
|
||||
store.setdefault('templates', []).append(template)
|
||||
if 'context' not in store:
|
||||
store['context'] = ContextList()
|
||||
store['context'].append(copy(context))
|
||||
|
||||
|
||||
def encode_multipart(boundary, data):
|
||||
"""
|
||||
Encode multipart POST data from a dictionary of form values.
|
||||
|
||||
The key will be used as the form data name; the value will be transmitted
|
||||
as content. If the value is a file, the contents of the file will be sent
|
||||
as an application/octet-stream; otherwise, str(value) will be sent.
|
||||
"""
|
||||
lines = []
|
||||
|
||||
def to_bytes(s):
|
||||
return force_bytes(s, settings.DEFAULT_CHARSET)
|
||||
|
||||
# Not by any means perfect, but good enough for our purposes.
|
||||
def is_file(thing):
|
||||
return hasattr(thing, "read") and callable(thing.read)
|
||||
|
||||
# Each bit of the multipart form data could be either a form value or a
|
||||
# file, or a *list* of form values and/or files. Remember that HTTP field
|
||||
# names can be duplicated!
|
||||
for (key, value) in data.items():
|
||||
if value is None:
|
||||
raise TypeError(
|
||||
"Cannot encode None for key '%s' as POST data. Did you mean "
|
||||
"to pass an empty string or omit the value?" % key
|
||||
)
|
||||
elif is_file(value):
|
||||
lines.extend(encode_file(boundary, key, value))
|
||||
elif not isinstance(value, str) and is_iterable(value):
|
||||
for item in value:
|
||||
if is_file(item):
|
||||
lines.extend(encode_file(boundary, key, item))
|
||||
else:
|
||||
lines.extend(to_bytes(val) for val in [
|
||||
'--%s' % boundary,
|
||||
'Content-Disposition: form-data; name="%s"' % key,
|
||||
'',
|
||||
item
|
||||
])
|
||||
else:
|
||||
lines.extend(to_bytes(val) for val in [
|
||||
'--%s' % boundary,
|
||||
'Content-Disposition: form-data; name="%s"' % key,
|
||||
'',
|
||||
value
|
||||
])
|
||||
|
||||
lines.extend([
|
||||
to_bytes('--%s--' % boundary),
|
||||
b'',
|
||||
])
|
||||
return b'\r\n'.join(lines)
|
||||
|
||||
|
||||
def encode_file(boundary, key, file):
|
||||
def to_bytes(s):
|
||||
return force_bytes(s, settings.DEFAULT_CHARSET)
|
||||
|
||||
# file.name might not be a string. For example, it's an int for
|
||||
# tempfile.TemporaryFile().
|
||||
file_has_string_name = hasattr(file, 'name') and isinstance(file.name, str)
|
||||
filename = os.path.basename(file.name) if file_has_string_name else ''
|
||||
|
||||
if hasattr(file, 'content_type'):
|
||||
content_type = file.content_type
|
||||
elif filename:
|
||||
content_type = mimetypes.guess_type(filename)[0]
|
||||
else:
|
||||
content_type = None
|
||||
|
||||
if content_type is None:
|
||||
content_type = 'application/octet-stream'
|
||||
filename = filename or key
|
||||
return [
|
||||
to_bytes('--%s' % boundary),
|
||||
to_bytes('Content-Disposition: form-data; name="%s"; filename="%s"'
|
||||
% (key, filename)),
|
||||
to_bytes('Content-Type: %s' % content_type),
|
||||
b'',
|
||||
to_bytes(file.read())
|
||||
]
|
||||
|
||||
|
||||
class RequestFactory:
|
||||
"""
|
||||
Class that lets you create mock Request objects for use in testing.
|
||||
|
||||
Usage:
|
||||
|
||||
rf = RequestFactory()
|
||||
get_request = rf.get('/hello/')
|
||||
post_request = rf.post('/submit/', {'foo': 'bar'})
|
||||
|
||||
Once you have a request object you can pass it to any view function,
|
||||
just as if that view had been hooked up using a URLconf.
|
||||
"""
|
||||
def __init__(self, *, json_encoder=DjangoJSONEncoder, **defaults):
|
||||
self.json_encoder = json_encoder
|
||||
self.defaults = defaults
|
||||
self.cookies = SimpleCookie()
|
||||
self.errors = BytesIO()
|
||||
|
||||
def _base_environ(self, **request):
|
||||
"""
|
||||
The base environment for a request.
|
||||
"""
|
||||
# This is a minimal valid WSGI environ dictionary, plus:
|
||||
# - HTTP_COOKIE: for cookie support,
|
||||
# - REMOTE_ADDR: often useful, see #8551.
|
||||
# See https://www.python.org/dev/peps/pep-3333/#environ-variables
|
||||
return {
|
||||
'HTTP_COOKIE': '; '.join(sorted(
|
||||
'%s=%s' % (morsel.key, morsel.coded_value)
|
||||
for morsel in self.cookies.values()
|
||||
)),
|
||||
'PATH_INFO': '/',
|
||||
'REMOTE_ADDR': '127.0.0.1',
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'SCRIPT_NAME': '',
|
||||
'SERVER_NAME': 'testserver',
|
||||
'SERVER_PORT': '80',
|
||||
'SERVER_PROTOCOL': 'HTTP/1.1',
|
||||
'wsgi.version': (1, 0),
|
||||
'wsgi.url_scheme': 'http',
|
||||
'wsgi.input': FakePayload(b''),
|
||||
'wsgi.errors': self.errors,
|
||||
'wsgi.multiprocess': True,
|
||||
'wsgi.multithread': False,
|
||||
'wsgi.run_once': False,
|
||||
**self.defaults,
|
||||
**request,
|
||||
}
|
||||
|
||||
def request(self, **request):
|
||||
"Construct a generic request object."
|
||||
return WSGIRequest(self._base_environ(**request))
|
||||
|
||||
def _encode_data(self, data, content_type):
|
||||
if content_type is MULTIPART_CONTENT:
|
||||
return encode_multipart(BOUNDARY, data)
|
||||
else:
|
||||
# Encode the content so that the byte representation is correct.
|
||||
match = CONTENT_TYPE_RE.match(content_type)
|
||||
if match:
|
||||
charset = match[1]
|
||||
else:
|
||||
charset = settings.DEFAULT_CHARSET
|
||||
return force_bytes(data, encoding=charset)
|
||||
|
||||
def _encode_json(self, data, content_type):
|
||||
"""
|
||||
Return encoded JSON if data is a dict, list, or tuple and content_type
|
||||
is application/json.
|
||||
"""
|
||||
should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(data, (dict, list, tuple))
|
||||
return json.dumps(data, cls=self.json_encoder) if should_encode else data
|
||||
|
||||
def _get_path(self, parsed):
|
||||
path = parsed.path
|
||||
# If there are parameters, add them
|
||||
if parsed.params:
|
||||
path += ";" + parsed.params
|
||||
path = unquote_to_bytes(path)
|
||||
# Replace the behavior where non-ASCII values in the WSGI environ are
|
||||
# arbitrarily decoded with ISO-8859-1.
|
||||
# Refs comment in `get_bytes_from_wsgi()`.
|
||||
return path.decode('iso-8859-1')
|
||||
|
||||
def get(self, path, data=None, secure=False, **extra):
|
||||
"""Construct a GET request."""
|
||||
data = {} if data is None else data
|
||||
return self.generic('GET', path, secure=secure, **{
|
||||
'QUERY_STRING': urlencode(data, doseq=True),
|
||||
**extra,
|
||||
})
|
||||
|
||||
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
|
||||
secure=False, **extra):
|
||||
"""Construct a POST request."""
|
||||
data = self._encode_json({} if data is None else data, content_type)
|
||||
post_data = self._encode_data(data, content_type)
|
||||
|
||||
return self.generic('POST', path, post_data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def head(self, path, data=None, secure=False, **extra):
|
||||
"""Construct a HEAD request."""
|
||||
data = {} if data is None else data
|
||||
return self.generic('HEAD', path, secure=secure, **{
|
||||
'QUERY_STRING': urlencode(data, doseq=True),
|
||||
**extra,
|
||||
})
|
||||
|
||||
def trace(self, path, secure=False, **extra):
|
||||
"""Construct a TRACE request."""
|
||||
return self.generic('TRACE', path, secure=secure, **extra)
|
||||
|
||||
def options(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"Construct an OPTIONS request."
|
||||
return self.generic('OPTIONS', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def put(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a PUT request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic('PUT', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def patch(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a PATCH request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic('PATCH', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def delete(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a DELETE request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic('DELETE', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def generic(self, method, path, data='',
|
||||
content_type='application/octet-stream', secure=False,
|
||||
**extra):
|
||||
"""Construct an arbitrary HTTP request."""
|
||||
parsed = urlparse(str(path)) # path can be lazy
|
||||
data = force_bytes(data, settings.DEFAULT_CHARSET)
|
||||
r = {
|
||||
'PATH_INFO': self._get_path(parsed),
|
||||
'REQUEST_METHOD': method,
|
||||
'SERVER_PORT': '443' if secure else '80',
|
||||
'wsgi.url_scheme': 'https' if secure else 'http',
|
||||
}
|
||||
if data:
|
||||
r.update({
|
||||
'CONTENT_LENGTH': str(len(data)),
|
||||
'CONTENT_TYPE': content_type,
|
||||
'wsgi.input': FakePayload(data),
|
||||
})
|
||||
r.update(extra)
|
||||
# If QUERY_STRING is absent or empty, we want to extract it from the URL.
|
||||
if not r.get('QUERY_STRING'):
|
||||
# WSGI requires latin-1 encoded strings. See get_path_info().
|
||||
query_string = parsed[4].encode().decode('iso-8859-1')
|
||||
r['QUERY_STRING'] = query_string
|
||||
return self.request(**r)
|
||||
|
||||
|
||||
class AsyncRequestFactory(RequestFactory):
|
||||
"""
|
||||
Class that lets you create mock ASGI-like Request objects for use in
|
||||
testing. Usage:
|
||||
|
||||
rf = AsyncRequestFactory()
|
||||
get_request = await rf.get('/hello/')
|
||||
post_request = await rf.post('/submit/', {'foo': 'bar'})
|
||||
|
||||
Once you have a request object you can pass it to any view function,
|
||||
including synchronous ones. The reason we have a separate class here is:
|
||||
a) this makes ASGIRequest subclasses, and
|
||||
b) AsyncTestClient can subclass it.
|
||||
"""
|
||||
def _base_scope(self, **request):
|
||||
"""The base scope for a request."""
|
||||
# This is a minimal valid ASGI scope, plus:
|
||||
# - headers['cookie'] for cookie support,
|
||||
# - 'client' often useful, see #8551.
|
||||
scope = {
|
||||
'asgi': {'version': '3.0'},
|
||||
'type': 'http',
|
||||
'http_version': '1.1',
|
||||
'client': ['127.0.0.1', 0],
|
||||
'server': ('testserver', '80'),
|
||||
'scheme': 'http',
|
||||
'method': 'GET',
|
||||
'headers': [],
|
||||
**self.defaults,
|
||||
**request,
|
||||
}
|
||||
scope['headers'].append((
|
||||
b'cookie',
|
||||
b'; '.join(sorted(
|
||||
('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii')
|
||||
for morsel in self.cookies.values()
|
||||
)),
|
||||
))
|
||||
return scope
|
||||
|
||||
def request(self, **request):
|
||||
"""Construct a generic request object."""
|
||||
# This is synchronous, which means all methods on this class are.
|
||||
# AsyncClient, however, has an async request function, which makes all
|
||||
# its methods async.
|
||||
if '_body_file' in request:
|
||||
body_file = request.pop('_body_file')
|
||||
else:
|
||||
body_file = FakePayload('')
|
||||
return ASGIRequest(self._base_scope(**request), body_file)
|
||||
|
||||
def generic(
|
||||
self, method, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra,
|
||||
):
|
||||
"""Construct an arbitrary HTTP request."""
|
||||
parsed = urlparse(str(path)) # path can be lazy.
|
||||
data = force_bytes(data, settings.DEFAULT_CHARSET)
|
||||
s = {
|
||||
'method': method,
|
||||
'path': self._get_path(parsed),
|
||||
'server': ('127.0.0.1', '443' if secure else '80'),
|
||||
'scheme': 'https' if secure else 'http',
|
||||
'headers': [(b'host', b'testserver')],
|
||||
}
|
||||
if data:
|
||||
s['headers'].extend([
|
||||
(b'content-length', str(len(data)).encode('ascii')),
|
||||
(b'content-type', content_type.encode('ascii')),
|
||||
])
|
||||
s['_body_file'] = FakePayload(data)
|
||||
follow = extra.pop('follow', None)
|
||||
if follow is not None:
|
||||
s['follow'] = follow
|
||||
if query_string := extra.pop('QUERY_STRING', None):
|
||||
s['query_string'] = query_string
|
||||
s['headers'] += [
|
||||
(key.lower().encode('ascii'), value.encode('latin1'))
|
||||
for key, value in extra.items()
|
||||
]
|
||||
# If QUERY_STRING is absent or empty, we want to extract it from the
|
||||
# URL.
|
||||
if not s.get('query_string'):
|
||||
s['query_string'] = parsed[4]
|
||||
return self.request(**s)
|
||||
|
||||
|
||||
class ClientMixin:
|
||||
"""
|
||||
Mixin with common methods between Client and AsyncClient.
|
||||
"""
|
||||
def store_exc_info(self, **kwargs):
|
||||
"""Store exceptions when they are generated by a view."""
|
||||
self.exc_info = sys.exc_info()
|
||||
|
||||
def check_exception(self, response):
|
||||
"""
|
||||
Look for a signaled exception, clear the current context exception
|
||||
data, re-raise the signaled exception, and clear the signaled exception
|
||||
from the local cache.
|
||||
"""
|
||||
response.exc_info = self.exc_info
|
||||
if self.exc_info:
|
||||
_, exc_value, _ = self.exc_info
|
||||
self.exc_info = None
|
||||
if self.raise_request_exception:
|
||||
raise exc_value
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""Return the current session variables."""
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
if cookie:
|
||||
return engine.SessionStore(cookie.value)
|
||||
session = engine.SessionStore()
|
||||
session.save()
|
||||
self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
return session
|
||||
|
||||
def login(self, **credentials):
|
||||
"""
|
||||
Set the Factory to appear as if it has successfully logged into a site.
|
||||
|
||||
Return True if login is possible or False if the provided credentials
|
||||
are incorrect.
|
||||
"""
|
||||
from django.contrib.auth import authenticate
|
||||
user = authenticate(**credentials)
|
||||
if user:
|
||||
self._login(user)
|
||||
return True
|
||||
return False
|
||||
|
||||
def force_login(self, user, backend=None):
|
||||
def get_backend():
|
||||
from django.contrib.auth import load_backend
|
||||
for backend_path in settings.AUTHENTICATION_BACKENDS:
|
||||
backend = load_backend(backend_path)
|
||||
if hasattr(backend, 'get_user'):
|
||||
return backend_path
|
||||
|
||||
if backend is None:
|
||||
backend = get_backend()
|
||||
user.backend = backend
|
||||
self._login(user, backend)
|
||||
|
||||
def _login(self, user, backend=None):
|
||||
from django.contrib.auth import login
|
||||
|
||||
# Create a fake request to store login details.
|
||||
request = HttpRequest()
|
||||
if self.session:
|
||||
request.session = self.session
|
||||
else:
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
request.session = engine.SessionStore()
|
||||
login(request, user, backend)
|
||||
# Save the session values.
|
||||
request.session.save()
|
||||
# Set the cookie to represent the session.
|
||||
session_cookie = settings.SESSION_COOKIE_NAME
|
||||
self.cookies[session_cookie] = request.session.session_key
|
||||
cookie_data = {
|
||||
'max-age': None,
|
||||
'path': '/',
|
||||
'domain': settings.SESSION_COOKIE_DOMAIN,
|
||||
'secure': settings.SESSION_COOKIE_SECURE or None,
|
||||
'expires': None,
|
||||
}
|
||||
self.cookies[session_cookie].update(cookie_data)
|
||||
|
||||
def logout(self):
|
||||
"""Log out the user by removing the cookies and session object."""
|
||||
from django.contrib.auth import get_user, logout
|
||||
request = HttpRequest()
|
||||
if self.session:
|
||||
request.session = self.session
|
||||
request.user = get_user(request)
|
||||
else:
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
request.session = engine.SessionStore()
|
||||
logout(request)
|
||||
self.cookies = SimpleCookie()
|
||||
|
||||
def _parse_json(self, response, **extra):
|
||||
if not hasattr(response, '_json'):
|
||||
if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
|
||||
raise ValueError(
|
||||
'Content-Type header is "%s", not "application/json"'
|
||||
% response.get('Content-Type')
|
||||
)
|
||||
response._json = json.loads(response.content.decode(response.charset), **extra)
|
||||
return response._json
|
||||
|
||||
|
||||
class Client(ClientMixin, RequestFactory):
|
||||
"""
|
||||
A class that can act as a client for testing purposes.
|
||||
|
||||
It allows the user to compose GET and POST requests, and
|
||||
obtain the response that the server gave to those requests.
|
||||
The server Response objects are annotated with the details
|
||||
of the contexts and templates that were rendered during the
|
||||
process of serving the request.
|
||||
|
||||
Client objects are stateful - they will retain cookie (and
|
||||
thus session) details for the lifetime of the Client instance.
|
||||
|
||||
This is not intended as a replacement for Twill/Selenium or
|
||||
the like - it is here to allow testing against the
|
||||
contexts and templates produced by a view, rather than the
|
||||
HTML rendered to the end-user.
|
||||
"""
|
||||
def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
|
||||
super().__init__(**defaults)
|
||||
self.handler = ClientHandler(enforce_csrf_checks)
|
||||
self.raise_request_exception = raise_request_exception
|
||||
self.exc_info = None
|
||||
self.extra = None
|
||||
|
||||
def request(self, **request):
|
||||
"""
|
||||
The master request method. Compose the environment dictionary and pass
|
||||
to the handler, return the result of the handler. Assume defaults for
|
||||
the query environment, which can be overridden using the arguments to
|
||||
the request.
|
||||
"""
|
||||
environ = self._base_environ(**request)
|
||||
|
||||
# Curry a data dictionary into an instance of the template renderer
|
||||
# callback function.
|
||||
data = {}
|
||||
on_template_render = partial(store_rendered_templates, data)
|
||||
signal_uid = "template-render-%s" % id(request)
|
||||
signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
|
||||
# Capture exceptions created by the handler.
|
||||
exception_uid = "request-exception-%s" % id(request)
|
||||
got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
|
||||
try:
|
||||
response = self.handler(environ)
|
||||
finally:
|
||||
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
|
||||
got_request_exception.disconnect(dispatch_uid=exception_uid)
|
||||
# Check for signaled exceptions.
|
||||
self.check_exception(response)
|
||||
# Save the client and request that stimulated the response.
|
||||
response.client = self
|
||||
response.request = request
|
||||
# Add any rendered template detail to the response.
|
||||
response.templates = data.get('templates', [])
|
||||
response.context = data.get('context')
|
||||
response.json = partial(self._parse_json, response)
|
||||
# Attach the ResolverMatch instance to the response.
|
||||
urlconf = getattr(response.wsgi_request, 'urlconf', None)
|
||||
response.resolver_match = SimpleLazyObject(
|
||||
lambda: resolve(request['PATH_INFO'], urlconf=urlconf),
|
||||
)
|
||||
# Flatten a single context. Not really necessary anymore thanks to the
|
||||
# __getattr__ flattening in ContextList, but has some edge case
|
||||
# backwards compatibility implications.
|
||||
if response.context and len(response.context) == 1:
|
||||
response.context = response.context[0]
|
||||
# Update persistent cookie data.
|
||||
if response.cookies:
|
||||
self.cookies.update(response.cookies)
|
||||
return response
|
||||
|
||||
def get(self, path, data=None, follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using GET."""
|
||||
self.extra = extra
|
||||
response = super().get(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
|
||||
follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using POST."""
|
||||
self.extra = extra
|
||||
response = super().post(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def head(self, path, data=None, follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using HEAD."""
|
||||
self.extra = extra
|
||||
response = super().head(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def options(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using OPTIONS."""
|
||||
self.extra = extra
|
||||
response = super().options(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def put(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a resource to the server using PUT."""
|
||||
self.extra = extra
|
||||
response = super().put(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def patch(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a resource to the server using PATCH."""
|
||||
self.extra = extra
|
||||
response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def delete(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a DELETE request to the server."""
|
||||
self.extra = extra
|
||||
response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def trace(self, path, data='', follow=False, secure=False, **extra):
|
||||
"""Send a TRACE request to the server."""
|
||||
self.extra = extra
|
||||
response = super().trace(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def _handle_redirects(self, response, data='', content_type='', **extra):
|
||||
"""
|
||||
Follow any redirects by requesting responses from the server using GET.
|
||||
"""
|
||||
response.redirect_chain = []
|
||||
redirect_status_codes = (
|
||||
HTTPStatus.MOVED_PERMANENTLY,
|
||||
HTTPStatus.FOUND,
|
||||
HTTPStatus.SEE_OTHER,
|
||||
HTTPStatus.TEMPORARY_REDIRECT,
|
||||
HTTPStatus.PERMANENT_REDIRECT,
|
||||
)
|
||||
while response.status_code in redirect_status_codes:
|
||||
response_url = response.url
|
||||
redirect_chain = response.redirect_chain
|
||||
redirect_chain.append((response_url, response.status_code))
|
||||
|
||||
url = urlsplit(response_url)
|
||||
if url.scheme:
|
||||
extra['wsgi.url_scheme'] = url.scheme
|
||||
if url.hostname:
|
||||
extra['SERVER_NAME'] = url.hostname
|
||||
if url.port:
|
||||
extra['SERVER_PORT'] = str(url.port)
|
||||
|
||||
path = url.path
|
||||
# RFC 2616: bare domains without path are treated as the root.
|
||||
if not path and url.netloc:
|
||||
path = '/'
|
||||
# Prepend the request path to handle relative path redirects
|
||||
if not path.startswith('/'):
|
||||
path = urljoin(response.request['PATH_INFO'], path)
|
||||
|
||||
if response.status_code in (HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT):
|
||||
# Preserve request method and query string (if needed)
|
||||
# post-redirect for 307/308 responses.
|
||||
request_method = response.request['REQUEST_METHOD'].lower()
|
||||
if request_method not in ('get', 'head'):
|
||||
extra['QUERY_STRING'] = url.query
|
||||
request_method = getattr(self, request_method)
|
||||
else:
|
||||
request_method = self.get
|
||||
data = QueryDict(url.query)
|
||||
content_type = None
|
||||
|
||||
response = request_method(path, data=data, content_type=content_type, follow=False, **extra)
|
||||
response.redirect_chain = redirect_chain
|
||||
|
||||
if redirect_chain[-1] in redirect_chain[:-1]:
|
||||
# Check that we're not redirecting to somewhere we've already
|
||||
# been to, to prevent loops.
|
||||
raise RedirectCycleError("Redirect loop detected.", last_response=response)
|
||||
if len(redirect_chain) > 20:
|
||||
# Such a lengthy chain likely also means a loop, but one with
|
||||
# a growing path, changing view, or changing query argument;
|
||||
# 20 is the value of "network.http.redirection-limit" from Firefox.
|
||||
raise RedirectCycleError("Too many redirects.", last_response=response)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class AsyncClient(ClientMixin, AsyncRequestFactory):
|
||||
"""
|
||||
An async version of Client that creates ASGIRequests and calls through an
|
||||
async request path.
|
||||
|
||||
Does not currently support "follow" on its methods.
|
||||
"""
|
||||
def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
|
||||
super().__init__(**defaults)
|
||||
self.handler = AsyncClientHandler(enforce_csrf_checks)
|
||||
self.raise_request_exception = raise_request_exception
|
||||
self.exc_info = None
|
||||
self.extra = None
|
||||
|
||||
async def request(self, **request):
|
||||
"""
|
||||
The master request method. Compose the scope dictionary and pass to the
|
||||
handler, return the result of the handler. Assume defaults for the
|
||||
query environment, which can be overridden using the arguments to the
|
||||
request.
|
||||
"""
|
||||
if 'follow' in request:
|
||||
raise NotImplementedError(
|
||||
'AsyncClient request methods do not accept the follow '
|
||||
'parameter.'
|
||||
)
|
||||
scope = self._base_scope(**request)
|
||||
# Curry a data dictionary into an instance of the template renderer
|
||||
# callback function.
|
||||
data = {}
|
||||
on_template_render = partial(store_rendered_templates, data)
|
||||
signal_uid = 'template-render-%s' % id(request)
|
||||
signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
|
||||
# Capture exceptions created by the handler.
|
||||
exception_uid = 'request-exception-%s' % id(request)
|
||||
got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
|
||||
try:
|
||||
response = await self.handler(scope)
|
||||
finally:
|
||||
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
|
||||
got_request_exception.disconnect(dispatch_uid=exception_uid)
|
||||
# Check for signaled exceptions.
|
||||
self.check_exception(response)
|
||||
# Save the client and request that stimulated the response.
|
||||
response.client = self
|
||||
response.request = request
|
||||
# Add any rendered template detail to the response.
|
||||
response.templates = data.get('templates', [])
|
||||
response.context = data.get('context')
|
||||
response.json = partial(self._parse_json, response)
|
||||
# Attach the ResolverMatch instance to the response.
|
||||
urlconf = getattr(response.asgi_request, 'urlconf', None)
|
||||
response.resolver_match = SimpleLazyObject(
|
||||
lambda: resolve(request['path'], urlconf=urlconf),
|
||||
)
|
||||
# Flatten a single context. Not really necessary anymore thanks to the
|
||||
# __getattr__ flattening in ContextList, but has some edge case
|
||||
# backwards compatibility implications.
|
||||
if response.context and len(response.context) == 1:
|
||||
response.context = response.context[0]
|
||||
# Update persistent cookie data.
|
||||
if response.cookies:
|
||||
self.cookies.update(response.cookies)
|
||||
return response
|
252
venv/Lib/site-packages/django/test/html.py
Normal file
252
venv/Lib/site-packages/django/test/html.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Compare two HTML documents."""
|
||||
|
||||
from html.parser import HTMLParser
|
||||
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
# ASCII whitespace is U+0009 TAB, U+000A LF, U+000C FF, U+000D CR, or U+0020
|
||||
# SPACE.
|
||||
# https://infra.spec.whatwg.org/#ascii-whitespace
|
||||
ASCII_WHITESPACE = _lazy_re_compile(r'[\t\n\f\r ]+')
|
||||
|
||||
# https://html.spec.whatwg.org/#attributes-3
|
||||
BOOLEAN_ATTRIBUTES = {
|
||||
'allowfullscreen', 'async', 'autofocus', 'autoplay', 'checked', 'controls',
|
||||
'default', 'defer ', 'disabled', 'formnovalidate', 'hidden', 'ismap',
|
||||
'itemscope', 'loop', 'multiple', 'muted', 'nomodule', 'novalidate', 'open',
|
||||
'playsinline', 'readonly', 'required', 'reversed', 'selected',
|
||||
# Attributes for deprecated tags.
|
||||
'truespeed',
|
||||
}
|
||||
|
||||
|
||||
def normalize_whitespace(string):
|
||||
return ASCII_WHITESPACE.sub(' ', string)
|
||||
|
||||
|
||||
def normalize_attributes(attributes):
|
||||
normalized = []
|
||||
for name, value in attributes:
|
||||
if name == 'class' and value:
|
||||
# Special case handling of 'class' attribute, so that comparisons
|
||||
# of DOM instances are not sensitive to ordering of classes.
|
||||
value = ' '.join(sorted(
|
||||
value for value in ASCII_WHITESPACE.split(value) if value
|
||||
))
|
||||
# Boolean attributes without a value is same as attribute with value
|
||||
# that equals the attributes name. For example:
|
||||
# <input checked> == <input checked="checked">
|
||||
if name in BOOLEAN_ATTRIBUTES:
|
||||
if not value or value == name:
|
||||
value = None
|
||||
elif value is None:
|
||||
value = ''
|
||||
normalized.append((name, value))
|
||||
return normalized
|
||||
|
||||
|
||||
class Element:
|
||||
def __init__(self, name, attributes):
|
||||
self.name = name
|
||||
self.attributes = sorted(attributes)
|
||||
self.children = []
|
||||
|
||||
def append(self, element):
|
||||
if isinstance(element, str):
|
||||
element = normalize_whitespace(element)
|
||||
if self.children and isinstance(self.children[-1], str):
|
||||
self.children[-1] += element
|
||||
self.children[-1] = normalize_whitespace(self.children[-1])
|
||||
return
|
||||
elif self.children:
|
||||
# removing last children if it is only whitespace
|
||||
# this can result in incorrect dom representations since
|
||||
# whitespace between inline tags like <span> is significant
|
||||
if isinstance(self.children[-1], str) and self.children[-1].isspace():
|
||||
self.children.pop()
|
||||
if element:
|
||||
self.children.append(element)
|
||||
|
||||
def finalize(self):
|
||||
def rstrip_last_element(children):
|
||||
if children and isinstance(children[-1], str):
|
||||
children[-1] = children[-1].rstrip()
|
||||
if not children[-1]:
|
||||
children.pop()
|
||||
children = rstrip_last_element(children)
|
||||
return children
|
||||
|
||||
rstrip_last_element(self.children)
|
||||
for i, child in enumerate(self.children):
|
||||
if isinstance(child, str):
|
||||
self.children[i] = child.strip()
|
||||
elif hasattr(child, 'finalize'):
|
||||
child.finalize()
|
||||
|
||||
def __eq__(self, element):
|
||||
if not hasattr(element, 'name') or self.name != element.name:
|
||||
return False
|
||||
if self.attributes != element.attributes:
|
||||
return False
|
||||
return self.children == element.children
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.name, *self.attributes))
|
||||
|
||||
def _count(self, element, count=True):
|
||||
if not isinstance(element, str) and self == element:
|
||||
return 1
|
||||
if isinstance(element, RootElement) and self.children == element.children:
|
||||
return 1
|
||||
i = 0
|
||||
elem_child_idx = 0
|
||||
for child in self.children:
|
||||
# child is text content and element is also text content, then
|
||||
# make a simple "text" in "text"
|
||||
if isinstance(child, str):
|
||||
if isinstance(element, str):
|
||||
if count:
|
||||
i += child.count(element)
|
||||
elif element in child:
|
||||
return 1
|
||||
else:
|
||||
# Look for element wholly within this child.
|
||||
i += child._count(element, count=count)
|
||||
if not count and i:
|
||||
return i
|
||||
# Also look for a sequence of element's children among self's
|
||||
# children. self.children == element.children is tested above,
|
||||
# but will fail if self has additional children. Ex: '<a/><b/>'
|
||||
# is contained in '<a/><b/><c/>'.
|
||||
if isinstance(element, RootElement) and element.children:
|
||||
elem_child = element.children[elem_child_idx]
|
||||
# Start or continue match, advance index.
|
||||
if elem_child == child:
|
||||
elem_child_idx += 1
|
||||
# Match found, reset index.
|
||||
if elem_child_idx == len(element.children):
|
||||
i += 1
|
||||
elem_child_idx = 0
|
||||
# No match, reset index.
|
||||
else:
|
||||
elem_child_idx = 0
|
||||
return i
|
||||
|
||||
def __contains__(self, element):
|
||||
return self._count(element, count=False) > 0
|
||||
|
||||
def count(self, element):
|
||||
return self._count(element, count=True)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.children[key]
|
||||
|
||||
def __str__(self):
|
||||
output = '<%s' % self.name
|
||||
for key, value in self.attributes:
|
||||
if value is not None:
|
||||
output += ' %s="%s"' % (key, value)
|
||||
else:
|
||||
output += ' %s' % key
|
||||
if self.children:
|
||||
output += '>\n'
|
||||
output += ''.join(str(c) for c in self.children)
|
||||
output += '\n</%s>' % self.name
|
||||
else:
|
||||
output += '>'
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class RootElement(Element):
|
||||
def __init__(self):
|
||||
super().__init__(None, ())
|
||||
|
||||
def __str__(self):
|
||||
return ''.join(str(c) for c in self.children)
|
||||
|
||||
|
||||
class HTMLParseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Parser(HTMLParser):
|
||||
# https://html.spec.whatwg.org/#void-elements
|
||||
SELF_CLOSING_TAGS = {
|
||||
'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'link', 'meta',
|
||||
'param', 'source', 'track', 'wbr',
|
||||
# Deprecated tags
|
||||
'frame', 'spacer',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.root = RootElement()
|
||||
self.open_tags = []
|
||||
self.element_positions = {}
|
||||
|
||||
def error(self, msg):
|
||||
raise HTMLParseError(msg, self.getpos())
|
||||
|
||||
def format_position(self, position=None, element=None):
|
||||
if not position and element:
|
||||
position = self.element_positions[element]
|
||||
if position is None:
|
||||
position = self.getpos()
|
||||
if hasattr(position, 'lineno'):
|
||||
position = position.lineno, position.offset
|
||||
return 'Line %d, Column %d' % position
|
||||
|
||||
@property
|
||||
def current(self):
|
||||
if self.open_tags:
|
||||
return self.open_tags[-1]
|
||||
else:
|
||||
return self.root
|
||||
|
||||
def handle_startendtag(self, tag, attrs):
|
||||
self.handle_starttag(tag, attrs)
|
||||
if tag not in self.SELF_CLOSING_TAGS:
|
||||
self.handle_endtag(tag)
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
attrs = normalize_attributes(attrs)
|
||||
element = Element(tag, attrs)
|
||||
self.current.append(element)
|
||||
if tag not in self.SELF_CLOSING_TAGS:
|
||||
self.open_tags.append(element)
|
||||
self.element_positions[element] = self.getpos()
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
if not self.open_tags:
|
||||
self.error("Unexpected end tag `%s` (%s)" % (
|
||||
tag, self.format_position()))
|
||||
element = self.open_tags.pop()
|
||||
while element.name != tag:
|
||||
if not self.open_tags:
|
||||
self.error("Unexpected end tag `%s` (%s)" % (
|
||||
tag, self.format_position()))
|
||||
element = self.open_tags.pop()
|
||||
|
||||
def handle_data(self, data):
|
||||
self.current.append(data)
|
||||
|
||||
|
||||
def parse_html(html):
|
||||
"""
|
||||
Take a string that contains HTML and turn it into a Python object structure
|
||||
that can be easily compared against other HTML on semantic equivalence.
|
||||
Syntactical differences like which quotation is used on arguments will be
|
||||
ignored.
|
||||
"""
|
||||
parser = Parser()
|
||||
parser.feed(html)
|
||||
parser.close()
|
||||
document = parser.root
|
||||
document.finalize()
|
||||
# Removing ROOT element if it's not necessary
|
||||
if len(document.children) == 1 and not isinstance(document.children[0], str):
|
||||
document = document.children[0]
|
||||
return document
|
1099
venv/Lib/site-packages/django/test/runner.py
Normal file
1099
venv/Lib/site-packages/django/test/runner.py
Normal file
File diff suppressed because it is too large
Load Diff
132
venv/Lib/site-packages/django/test/selenium.py
Normal file
132
venv/Lib/site-packages/django/test/selenium.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import sys
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.test import LiveServerTestCase, tag
|
||||
from django.utils.functional import classproperty
|
||||
from django.utils.module_loading import import_string
|
||||
from django.utils.text import capfirst
|
||||
|
||||
|
||||
class SeleniumTestCaseBase(type(LiveServerTestCase)):
|
||||
# List of browsers to dynamically create test classes for.
|
||||
browsers = []
|
||||
# A selenium hub URL to test against.
|
||||
selenium_hub = None
|
||||
# The external host Selenium Hub can reach.
|
||||
external_host = None
|
||||
# Sentinel value to differentiate browser-specific instances.
|
||||
browser = None
|
||||
# Run browsers in headless mode.
|
||||
headless = False
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
"""
|
||||
Dynamically create new classes and add them to the test module when
|
||||
multiple browsers specs are provided (e.g. --selenium=firefox,chrome).
|
||||
"""
|
||||
test_class = super().__new__(cls, name, bases, attrs)
|
||||
# If the test class is either browser-specific or a test base, return it.
|
||||
if test_class.browser or not any(name.startswith('test') and callable(value) for name, value in attrs.items()):
|
||||
return test_class
|
||||
elif test_class.browsers:
|
||||
# Reuse the created test class to make it browser-specific.
|
||||
# We can't rename it to include the browser name or create a
|
||||
# subclass like we do with the remaining browsers as it would
|
||||
# either duplicate tests or prevent pickling of its instances.
|
||||
first_browser = test_class.browsers[0]
|
||||
test_class.browser = first_browser
|
||||
# Listen on an external interface if using a selenium hub.
|
||||
host = test_class.host if not test_class.selenium_hub else '0.0.0.0'
|
||||
test_class.host = host
|
||||
test_class.external_host = cls.external_host
|
||||
# Create subclasses for each of the remaining browsers and expose
|
||||
# them through the test's module namespace.
|
||||
module = sys.modules[test_class.__module__]
|
||||
for browser in test_class.browsers[1:]:
|
||||
browser_test_class = cls.__new__(
|
||||
cls,
|
||||
"%s%s" % (capfirst(browser), name),
|
||||
(test_class,),
|
||||
{
|
||||
'browser': browser,
|
||||
'host': host,
|
||||
'external_host': cls.external_host,
|
||||
'__module__': test_class.__module__,
|
||||
}
|
||||
)
|
||||
setattr(module, browser_test_class.__name__, browser_test_class)
|
||||
return test_class
|
||||
# If no browsers were specified, skip this class (it'll still be discovered).
|
||||
return unittest.skip('No browsers specified.')(test_class)
|
||||
|
||||
@classmethod
|
||||
def import_webdriver(cls, browser):
|
||||
return import_string("selenium.webdriver.%s.webdriver.WebDriver" % browser)
|
||||
|
||||
@classmethod
|
||||
def import_options(cls, browser):
|
||||
return import_string('selenium.webdriver.%s.options.Options' % browser)
|
||||
|
||||
@classmethod
|
||||
def get_capability(cls, browser):
|
||||
from selenium.webdriver.common.desired_capabilities import (
|
||||
DesiredCapabilities,
|
||||
)
|
||||
return getattr(DesiredCapabilities, browser.upper())
|
||||
|
||||
def create_options(self):
|
||||
options = self.import_options(self.browser)()
|
||||
if self.headless:
|
||||
try:
|
||||
options.headless = True
|
||||
except AttributeError:
|
||||
pass # Only Chrome and Firefox support the headless mode.
|
||||
return options
|
||||
|
||||
def create_webdriver(self):
|
||||
if self.selenium_hub:
|
||||
from selenium import webdriver
|
||||
return webdriver.Remote(
|
||||
command_executor=self.selenium_hub,
|
||||
desired_capabilities=self.get_capability(self.browser),
|
||||
)
|
||||
return self.import_webdriver(self.browser)(options=self.create_options())
|
||||
|
||||
|
||||
@tag('selenium')
|
||||
class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase):
|
||||
implicit_wait = 10
|
||||
external_host = None
|
||||
|
||||
@classproperty
|
||||
def live_server_url(cls):
|
||||
return 'http://%s:%s' % (cls.external_host or cls.host, cls.server_thread.port)
|
||||
|
||||
@classproperty
|
||||
def allowed_host(cls):
|
||||
return cls.external_host or cls.host
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.selenium = cls.create_webdriver()
|
||||
cls.selenium.implicitly_wait(cls.implicit_wait)
|
||||
super().setUpClass()
|
||||
|
||||
@classmethod
|
||||
def _tearDownClassInternal(cls):
|
||||
# quit() the WebDriver before attempting to terminate and join the
|
||||
# single-threaded LiveServerThread to avoid a dead lock if the browser
|
||||
# kept a connection alive.
|
||||
if hasattr(cls, 'selenium'):
|
||||
cls.selenium.quit()
|
||||
super()._tearDownClassInternal()
|
||||
|
||||
@contextmanager
|
||||
def disable_implicit_wait(self):
|
||||
"""Disable the default implicit wait."""
|
||||
self.selenium.implicitly_wait(0)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.selenium.implicitly_wait(self.implicit_wait)
|
209
venv/Lib/site-packages/django/test/signals.py
Normal file
209
venv/Lib/site-packages/django/test/signals.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from asgiref.local import Local
|
||||
|
||||
from django.apps import apps
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.signals import setting_changed
|
||||
from django.db import connections, router
|
||||
from django.db.utils import ConnectionRouter
|
||||
from django.dispatch import Signal, receiver
|
||||
from django.utils import timezone
|
||||
from django.utils.formats import FORMAT_SETTINGS, reset_format_cache
|
||||
from django.utils.functional import empty
|
||||
|
||||
template_rendered = Signal()
|
||||
|
||||
# Most setting_changed receivers are supposed to be added below,
|
||||
# except for cases where the receiver is related to a contrib app.
|
||||
|
||||
# Settings that may not work well when using 'override_settings' (#19031)
|
||||
COMPLEX_OVERRIDE_SETTINGS = {'DATABASES'}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_cache_handlers(**kwargs):
|
||||
if kwargs['setting'] == 'CACHES':
|
||||
from django.core.cache import caches, close_caches
|
||||
close_caches()
|
||||
caches._settings = caches.settings = caches.configure_settings(None)
|
||||
caches._connections = Local()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def update_installed_apps(**kwargs):
|
||||
if kwargs['setting'] == 'INSTALLED_APPS':
|
||||
# Rebuild any AppDirectoriesFinder instance.
|
||||
from django.contrib.staticfiles.finders import get_finder
|
||||
get_finder.cache_clear()
|
||||
# Rebuild management commands cache
|
||||
from django.core.management import get_commands
|
||||
get_commands.cache_clear()
|
||||
# Rebuild get_app_template_dirs cache.
|
||||
from django.template.utils import get_app_template_dirs
|
||||
get_app_template_dirs.cache_clear()
|
||||
# Rebuild translations cache.
|
||||
from django.utils.translation import trans_real
|
||||
trans_real._translations = {}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def update_connections_time_zone(**kwargs):
|
||||
if kwargs['setting'] == 'TIME_ZONE':
|
||||
# Reset process time zone
|
||||
if hasattr(time, 'tzset'):
|
||||
if kwargs['value']:
|
||||
os.environ['TZ'] = kwargs['value']
|
||||
else:
|
||||
os.environ.pop('TZ', None)
|
||||
time.tzset()
|
||||
|
||||
# Reset local time zone cache
|
||||
timezone.get_default_timezone.cache_clear()
|
||||
|
||||
# Reset the database connections' time zone
|
||||
if kwargs['setting'] in {'TIME_ZONE', 'USE_TZ'}:
|
||||
for conn in connections.all():
|
||||
try:
|
||||
del conn.timezone
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
del conn.timezone_name
|
||||
except AttributeError:
|
||||
pass
|
||||
conn.ensure_timezone()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_routers_cache(**kwargs):
|
||||
if kwargs['setting'] == 'DATABASE_ROUTERS':
|
||||
router.routers = ConnectionRouter().routers
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def reset_template_engines(**kwargs):
|
||||
if kwargs['setting'] in {
|
||||
'TEMPLATES',
|
||||
'DEBUG',
|
||||
'INSTALLED_APPS',
|
||||
}:
|
||||
from django.template import engines
|
||||
try:
|
||||
del engines.templates
|
||||
except AttributeError:
|
||||
pass
|
||||
engines._templates = None
|
||||
engines._engines = {}
|
||||
from django.template.engine import Engine
|
||||
Engine.get_default.cache_clear()
|
||||
from django.forms.renderers import get_default_renderer
|
||||
get_default_renderer.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_serializers_cache(**kwargs):
|
||||
if kwargs['setting'] == 'SERIALIZATION_MODULES':
|
||||
from django.core import serializers
|
||||
serializers._serializers = {}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def language_changed(**kwargs):
|
||||
if kwargs['setting'] in {'LANGUAGES', 'LANGUAGE_CODE', 'LOCALE_PATHS'}:
|
||||
from django.utils.translation import trans_real
|
||||
trans_real._default = None
|
||||
trans_real._active = Local()
|
||||
if kwargs['setting'] in {'LANGUAGES', 'LOCALE_PATHS'}:
|
||||
from django.utils.translation import trans_real
|
||||
trans_real._translations = {}
|
||||
trans_real.check_for_language.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def localize_settings_changed(**kwargs):
|
||||
if kwargs['setting'] in FORMAT_SETTINGS or kwargs['setting'] == 'USE_THOUSAND_SEPARATOR':
|
||||
reset_format_cache()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def file_storage_changed(**kwargs):
|
||||
if kwargs['setting'] == 'DEFAULT_FILE_STORAGE':
|
||||
from django.core.files.storage import default_storage
|
||||
default_storage._wrapped = empty
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def complex_setting_changed(**kwargs):
|
||||
if kwargs['enter'] and kwargs['setting'] in COMPLEX_OVERRIDE_SETTINGS:
|
||||
# Considering the current implementation of the signals framework,
|
||||
# this stacklevel shows the line containing the override_settings call.
|
||||
warnings.warn("Overriding setting %s can lead to unexpected behavior."
|
||||
% kwargs['setting'], stacklevel=6)
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def root_urlconf_changed(**kwargs):
|
||||
if kwargs['setting'] == 'ROOT_URLCONF':
|
||||
from django.urls import clear_url_caches, set_urlconf
|
||||
clear_url_caches()
|
||||
set_urlconf(None)
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def static_storage_changed(**kwargs):
|
||||
if kwargs['setting'] in {
|
||||
'STATICFILES_STORAGE',
|
||||
'STATIC_ROOT',
|
||||
'STATIC_URL',
|
||||
}:
|
||||
from django.contrib.staticfiles.storage import staticfiles_storage
|
||||
staticfiles_storage._wrapped = empty
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def static_finders_changed(**kwargs):
|
||||
if kwargs['setting'] in {
|
||||
'STATICFILES_DIRS',
|
||||
'STATIC_ROOT',
|
||||
}:
|
||||
from django.contrib.staticfiles.finders import get_finder
|
||||
get_finder.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def auth_password_validators_changed(**kwargs):
|
||||
if kwargs['setting'] == 'AUTH_PASSWORD_VALIDATORS':
|
||||
from django.contrib.auth.password_validation import (
|
||||
get_default_password_validators,
|
||||
)
|
||||
get_default_password_validators.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def user_model_swapped(**kwargs):
|
||||
if kwargs['setting'] == 'AUTH_USER_MODEL':
|
||||
apps.clear_cache()
|
||||
try:
|
||||
from django.contrib.auth import get_user_model
|
||||
UserModel = get_user_model()
|
||||
except ImproperlyConfigured:
|
||||
# Some tests set an invalid AUTH_USER_MODEL.
|
||||
pass
|
||||
else:
|
||||
from django.contrib.auth import backends
|
||||
backends.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth import forms
|
||||
forms.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth.handlers import modwsgi
|
||||
modwsgi.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth.management.commands import changepassword
|
||||
changepassword.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth import views
|
||||
views.UserModel = UserModel
|
1637
venv/Lib/site-packages/django/test/testcases.py
Normal file
1637
venv/Lib/site-packages/django/test/testcases.py
Normal file
File diff suppressed because it is too large
Load Diff
953
venv/Lib/site-packages/django/test/utils.py
Normal file
953
venv/Lib/site-packages/django/test/utils.py
Normal file
@@ -0,0 +1,953 @@
|
||||
import asyncio
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from io import StringIO
|
||||
from itertools import chain
|
||||
from types import SimpleNamespace
|
||||
from unittest import TestCase, skipIf, skipUnless
|
||||
from xml.dom.minidom import Node, parseString
|
||||
|
||||
from django.apps import apps
|
||||
from django.apps.registry import Apps
|
||||
from django.conf import UserSettingsHolder, settings
|
||||
from django.core import mail
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.signals import request_started
|
||||
from django.db import DEFAULT_DB_ALIAS, connections, reset_queries
|
||||
from django.db.models.options import Options
|
||||
from django.template import Template
|
||||
from django.test.signals import setting_changed, template_rendered
|
||||
from django.urls import get_script_prefix, set_script_prefix
|
||||
from django.utils.deprecation import RemovedInDjango50Warning
|
||||
from django.utils.translation import deactivate
|
||||
|
||||
try:
|
||||
import jinja2
|
||||
except ImportError:
|
||||
jinja2 = None
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Approximate', 'ContextList', 'isolate_lru_cache', 'get_runner',
|
||||
'CaptureQueriesContext',
|
||||
'ignore_warnings', 'isolate_apps', 'modify_settings', 'override_settings',
|
||||
'override_system_checks', 'tag',
|
||||
'requires_tz_support',
|
||||
'setup_databases', 'setup_test_environment', 'teardown_test_environment',
|
||||
)
|
||||
|
||||
TZ_SUPPORT = hasattr(time, 'tzset')
|
||||
|
||||
|
||||
class Approximate:
|
||||
def __init__(self, val, places=7):
|
||||
self.val = val
|
||||
self.places = places
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.val)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.val == other or round(abs(self.val - other), self.places) == 0
|
||||
|
||||
|
||||
class ContextList(list):
|
||||
"""
|
||||
A wrapper that provides direct key access to context items contained
|
||||
in a list of context objects.
|
||||
"""
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, str):
|
||||
for subcontext in self:
|
||||
if key in subcontext:
|
||||
return subcontext[key]
|
||||
raise KeyError(key)
|
||||
else:
|
||||
return super().__getitem__(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self.__getitem__(key)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
try:
|
||||
self[key]
|
||||
except KeyError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def keys(self):
|
||||
"""
|
||||
Flattened keys of subcontexts.
|
||||
"""
|
||||
return set(chain.from_iterable(d for subcontext in self for d in subcontext))
|
||||
|
||||
|
||||
def instrumented_test_render(self, context):
|
||||
"""
|
||||
An instrumented Template render method, providing a signal that can be
|
||||
intercepted by the test Client.
|
||||
"""
|
||||
template_rendered.send(sender=self, template=self, context=context)
|
||||
return self.nodelist.render(context)
|
||||
|
||||
|
||||
class _TestState:
|
||||
pass
|
||||
|
||||
|
||||
def setup_test_environment(debug=None):
|
||||
"""
|
||||
Perform global pre-test setup, such as installing the instrumented template
|
||||
renderer and setting the email backend to the locmem email backend.
|
||||
"""
|
||||
if hasattr(_TestState, 'saved_data'):
|
||||
# Executing this function twice would overwrite the saved values.
|
||||
raise RuntimeError(
|
||||
"setup_test_environment() was already called and can't be called "
|
||||
"again without first calling teardown_test_environment()."
|
||||
)
|
||||
|
||||
if debug is None:
|
||||
debug = settings.DEBUG
|
||||
|
||||
saved_data = SimpleNamespace()
|
||||
_TestState.saved_data = saved_data
|
||||
|
||||
saved_data.allowed_hosts = settings.ALLOWED_HOSTS
|
||||
# Add the default host of the test client.
|
||||
settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, 'testserver']
|
||||
|
||||
saved_data.debug = settings.DEBUG
|
||||
settings.DEBUG = debug
|
||||
|
||||
saved_data.email_backend = settings.EMAIL_BACKEND
|
||||
settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend'
|
||||
|
||||
saved_data.template_render = Template._render
|
||||
Template._render = instrumented_test_render
|
||||
|
||||
mail.outbox = []
|
||||
|
||||
deactivate()
|
||||
|
||||
|
||||
def teardown_test_environment():
|
||||
"""
|
||||
Perform any global post-test teardown, such as restoring the original
|
||||
template renderer and restoring the email sending functions.
|
||||
"""
|
||||
saved_data = _TestState.saved_data
|
||||
|
||||
settings.ALLOWED_HOSTS = saved_data.allowed_hosts
|
||||
settings.DEBUG = saved_data.debug
|
||||
settings.EMAIL_BACKEND = saved_data.email_backend
|
||||
Template._render = saved_data.template_render
|
||||
|
||||
del _TestState.saved_data
|
||||
del mail.outbox
|
||||
|
||||
|
||||
def setup_databases(
|
||||
verbosity,
|
||||
interactive,
|
||||
*,
|
||||
time_keeper=None,
|
||||
keepdb=False,
|
||||
debug_sql=False,
|
||||
parallel=0,
|
||||
aliases=None,
|
||||
serialized_aliases=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create the test databases."""
|
||||
if time_keeper is None:
|
||||
time_keeper = NullTimeKeeper()
|
||||
|
||||
test_databases, mirrored_aliases = get_unique_databases_and_mirrors(aliases)
|
||||
|
||||
old_names = []
|
||||
|
||||
for db_name, aliases in test_databases.values():
|
||||
first_alias = None
|
||||
for alias in aliases:
|
||||
connection = connections[alias]
|
||||
old_names.append((connection, db_name, first_alias is None))
|
||||
|
||||
# Actually create the database for the first connection
|
||||
if first_alias is None:
|
||||
first_alias = alias
|
||||
with time_keeper.timed(" Creating '%s'" % alias):
|
||||
# RemovedInDjango50Warning: when the deprecation ends,
|
||||
# replace with:
|
||||
# serialize_alias = serialized_aliases is None or alias in serialized_aliases
|
||||
try:
|
||||
serialize_alias = connection.settings_dict['TEST']['SERIALIZE']
|
||||
except KeyError:
|
||||
serialize_alias = (
|
||||
serialized_aliases is None or
|
||||
alias in serialized_aliases
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
'The SERIALIZE test database setting is '
|
||||
'deprecated as it can be inferred from the '
|
||||
'TestCase/TransactionTestCase.databases that '
|
||||
'enable the serialized_rollback feature.',
|
||||
category=RemovedInDjango50Warning,
|
||||
)
|
||||
connection.creation.create_test_db(
|
||||
verbosity=verbosity,
|
||||
autoclobber=not interactive,
|
||||
keepdb=keepdb,
|
||||
serialize=serialize_alias,
|
||||
)
|
||||
if parallel > 1:
|
||||
for index in range(parallel):
|
||||
with time_keeper.timed(" Cloning '%s'" % alias):
|
||||
connection.creation.clone_test_db(
|
||||
suffix=str(index + 1),
|
||||
verbosity=verbosity,
|
||||
keepdb=keepdb,
|
||||
)
|
||||
# Configure all other connections as mirrors of the first one
|
||||
else:
|
||||
connections[alias].creation.set_as_test_mirror(connections[first_alias].settings_dict)
|
||||
|
||||
# Configure the test mirrors.
|
||||
for alias, mirror_alias in mirrored_aliases.items():
|
||||
connections[alias].creation.set_as_test_mirror(
|
||||
connections[mirror_alias].settings_dict)
|
||||
|
||||
if debug_sql:
|
||||
for alias in connections:
|
||||
connections[alias].force_debug_cursor = True
|
||||
|
||||
return old_names
|
||||
|
||||
|
||||
def iter_test_cases(tests):
|
||||
"""
|
||||
Return an iterator over a test suite's unittest.TestCase objects.
|
||||
|
||||
The tests argument can also be an iterable of TestCase objects.
|
||||
"""
|
||||
for test in tests:
|
||||
if isinstance(test, str):
|
||||
# Prevent an unfriendly RecursionError that can happen with
|
||||
# strings.
|
||||
raise TypeError(
|
||||
f'Test {test!r} must be a test case or test suite not string '
|
||||
f'(was found in {tests!r}).'
|
||||
)
|
||||
if isinstance(test, TestCase):
|
||||
yield test
|
||||
else:
|
||||
# Otherwise, assume it is a test suite.
|
||||
yield from iter_test_cases(test)
|
||||
|
||||
|
||||
def dependency_ordered(test_databases, dependencies):
|
||||
"""
|
||||
Reorder test_databases into an order that honors the dependencies
|
||||
described in TEST[DEPENDENCIES].
|
||||
"""
|
||||
ordered_test_databases = []
|
||||
resolved_databases = set()
|
||||
|
||||
# Maps db signature to dependencies of all its aliases
|
||||
dependencies_map = {}
|
||||
|
||||
# Check that no database depends on its own alias
|
||||
for sig, (_, aliases) in test_databases:
|
||||
all_deps = set()
|
||||
for alias in aliases:
|
||||
all_deps.update(dependencies.get(alias, []))
|
||||
if not all_deps.isdisjoint(aliases):
|
||||
raise ImproperlyConfigured(
|
||||
"Circular dependency: databases %r depend on each other, "
|
||||
"but are aliases." % aliases
|
||||
)
|
||||
dependencies_map[sig] = all_deps
|
||||
|
||||
while test_databases:
|
||||
changed = False
|
||||
deferred = []
|
||||
|
||||
# Try to find a DB that has all its dependencies met
|
||||
for signature, (db_name, aliases) in test_databases:
|
||||
if dependencies_map[signature].issubset(resolved_databases):
|
||||
resolved_databases.update(aliases)
|
||||
ordered_test_databases.append((signature, (db_name, aliases)))
|
||||
changed = True
|
||||
else:
|
||||
deferred.append((signature, (db_name, aliases)))
|
||||
|
||||
if not changed:
|
||||
raise ImproperlyConfigured("Circular dependency in TEST[DEPENDENCIES]")
|
||||
test_databases = deferred
|
||||
return ordered_test_databases
|
||||
|
||||
|
||||
def get_unique_databases_and_mirrors(aliases=None):
|
||||
"""
|
||||
Figure out which databases actually need to be created.
|
||||
|
||||
Deduplicate entries in DATABASES that correspond the same database or are
|
||||
configured as test mirrors.
|
||||
|
||||
Return two values:
|
||||
- test_databases: ordered mapping of signatures to (name, list of aliases)
|
||||
where all aliases share the same underlying database.
|
||||
- mirrored_aliases: mapping of mirror aliases to original aliases.
|
||||
"""
|
||||
if aliases is None:
|
||||
aliases = connections
|
||||
mirrored_aliases = {}
|
||||
test_databases = {}
|
||||
dependencies = {}
|
||||
default_sig = connections[DEFAULT_DB_ALIAS].creation.test_db_signature()
|
||||
|
||||
for alias in connections:
|
||||
connection = connections[alias]
|
||||
test_settings = connection.settings_dict['TEST']
|
||||
|
||||
if test_settings['MIRROR']:
|
||||
# If the database is marked as a test mirror, save the alias.
|
||||
mirrored_aliases[alias] = test_settings['MIRROR']
|
||||
elif alias in aliases:
|
||||
# Store a tuple with DB parameters that uniquely identify it.
|
||||
# If we have two aliases with the same values for that tuple,
|
||||
# we only need to create the test database once.
|
||||
item = test_databases.setdefault(
|
||||
connection.creation.test_db_signature(),
|
||||
(connection.settings_dict['NAME'], []),
|
||||
)
|
||||
# The default database must be the first because data migrations
|
||||
# use the default alias by default.
|
||||
if alias == DEFAULT_DB_ALIAS:
|
||||
item[1].insert(0, alias)
|
||||
else:
|
||||
item[1].append(alias)
|
||||
|
||||
if 'DEPENDENCIES' in test_settings:
|
||||
dependencies[alias] = test_settings['DEPENDENCIES']
|
||||
else:
|
||||
if alias != DEFAULT_DB_ALIAS and connection.creation.test_db_signature() != default_sig:
|
||||
dependencies[alias] = test_settings.get('DEPENDENCIES', [DEFAULT_DB_ALIAS])
|
||||
|
||||
test_databases = dict(dependency_ordered(test_databases.items(), dependencies))
|
||||
return test_databases, mirrored_aliases
|
||||
|
||||
|
||||
def teardown_databases(old_config, verbosity, parallel=0, keepdb=False):
|
||||
"""Destroy all the non-mirror databases."""
|
||||
for connection, old_name, destroy in old_config:
|
||||
if destroy:
|
||||
if parallel > 1:
|
||||
for index in range(parallel):
|
||||
connection.creation.destroy_test_db(
|
||||
suffix=str(index + 1),
|
||||
verbosity=verbosity,
|
||||
keepdb=keepdb,
|
||||
)
|
||||
connection.creation.destroy_test_db(old_name, verbosity, keepdb)
|
||||
|
||||
|
||||
def get_runner(settings, test_runner_class=None):
|
||||
test_runner_class = test_runner_class or settings.TEST_RUNNER
|
||||
test_path = test_runner_class.split('.')
|
||||
# Allow for relative paths
|
||||
if len(test_path) > 1:
|
||||
test_module_name = '.'.join(test_path[:-1])
|
||||
else:
|
||||
test_module_name = '.'
|
||||
test_module = __import__(test_module_name, {}, {}, test_path[-1])
|
||||
return getattr(test_module, test_path[-1])
|
||||
|
||||
|
||||
class TestContextDecorator:
|
||||
"""
|
||||
A base class that can either be used as a context manager during tests
|
||||
or as a test function or unittest.TestCase subclass decorator to perform
|
||||
temporary alterations.
|
||||
|
||||
`attr_name`: attribute assigned the return value of enable() if used as
|
||||
a class decorator.
|
||||
|
||||
`kwarg_name`: keyword argument passing the return value of enable() if
|
||||
used as a function decorator.
|
||||
"""
|
||||
def __init__(self, attr_name=None, kwarg_name=None):
|
||||
self.attr_name = attr_name
|
||||
self.kwarg_name = kwarg_name
|
||||
|
||||
def enable(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def disable(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
return self.enable()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.disable()
|
||||
|
||||
def decorate_class(self, cls):
|
||||
if issubclass(cls, TestCase):
|
||||
decorated_setUp = cls.setUp
|
||||
|
||||
def setUp(inner_self):
|
||||
context = self.enable()
|
||||
inner_self.addCleanup(self.disable)
|
||||
if self.attr_name:
|
||||
setattr(inner_self, self.attr_name, context)
|
||||
decorated_setUp(inner_self)
|
||||
|
||||
cls.setUp = setUp
|
||||
return cls
|
||||
raise TypeError('Can only decorate subclasses of unittest.TestCase')
|
||||
|
||||
def decorate_callable(self, func):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
# If the inner function is an async function, we must execute async
|
||||
# as well so that the `with` statement executes at the right time.
|
||||
@wraps(func)
|
||||
async def inner(*args, **kwargs):
|
||||
with self as context:
|
||||
if self.kwarg_name:
|
||||
kwargs[self.kwarg_name] = context
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
with self as context:
|
||||
if self.kwarg_name:
|
||||
kwargs[self.kwarg_name] = context
|
||||
return func(*args, **kwargs)
|
||||
return inner
|
||||
|
||||
def __call__(self, decorated):
|
||||
if isinstance(decorated, type):
|
||||
return self.decorate_class(decorated)
|
||||
elif callable(decorated):
|
||||
return self.decorate_callable(decorated)
|
||||
raise TypeError('Cannot decorate object of type %s' % type(decorated))
|
||||
|
||||
|
||||
class override_settings(TestContextDecorator):
|
||||
"""
|
||||
Act as either a decorator or a context manager. If it's a decorator, take a
|
||||
function and return a wrapped function. If it's a contextmanager, use it
|
||||
with the ``with`` statement. In either event, entering/exiting are called
|
||||
before and after, respectively, the function/block is executed.
|
||||
"""
|
||||
enable_exception = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.options = kwargs
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
# Keep this code at the beginning to leave the settings unchanged
|
||||
# in case it raises an exception because INSTALLED_APPS is invalid.
|
||||
if 'INSTALLED_APPS' in self.options:
|
||||
try:
|
||||
apps.set_installed_apps(self.options['INSTALLED_APPS'])
|
||||
except Exception:
|
||||
apps.unset_installed_apps()
|
||||
raise
|
||||
override = UserSettingsHolder(settings._wrapped)
|
||||
for key, new_value in self.options.items():
|
||||
setattr(override, key, new_value)
|
||||
self.wrapped = settings._wrapped
|
||||
settings._wrapped = override
|
||||
for key, new_value in self.options.items():
|
||||
try:
|
||||
setting_changed.send(
|
||||
sender=settings._wrapped.__class__,
|
||||
setting=key, value=new_value, enter=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.enable_exception = exc
|
||||
self.disable()
|
||||
|
||||
def disable(self):
|
||||
if 'INSTALLED_APPS' in self.options:
|
||||
apps.unset_installed_apps()
|
||||
settings._wrapped = self.wrapped
|
||||
del self.wrapped
|
||||
responses = []
|
||||
for key in self.options:
|
||||
new_value = getattr(settings, key, None)
|
||||
responses_for_setting = setting_changed.send_robust(
|
||||
sender=settings._wrapped.__class__,
|
||||
setting=key, value=new_value, enter=False,
|
||||
)
|
||||
responses.extend(responses_for_setting)
|
||||
if self.enable_exception is not None:
|
||||
exc = self.enable_exception
|
||||
self.enable_exception = None
|
||||
raise exc
|
||||
for _, response in responses:
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
|
||||
def save_options(self, test_func):
|
||||
if test_func._overridden_settings is None:
|
||||
test_func._overridden_settings = self.options
|
||||
else:
|
||||
# Duplicate dict to prevent subclasses from altering their parent.
|
||||
test_func._overridden_settings = {
|
||||
**test_func._overridden_settings,
|
||||
**self.options,
|
||||
}
|
||||
|
||||
def decorate_class(self, cls):
|
||||
from django.test import SimpleTestCase
|
||||
if not issubclass(cls, SimpleTestCase):
|
||||
raise ValueError(
|
||||
"Only subclasses of Django SimpleTestCase can be decorated "
|
||||
"with override_settings")
|
||||
self.save_options(cls)
|
||||
return cls
|
||||
|
||||
|
||||
class modify_settings(override_settings):
|
||||
"""
|
||||
Like override_settings, but makes it possible to append, prepend, or remove
|
||||
items instead of redefining the entire list.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if args:
|
||||
# Hack used when instantiating from SimpleTestCase.setUpClass.
|
||||
assert not kwargs
|
||||
self.operations = args[0]
|
||||
else:
|
||||
assert not args
|
||||
self.operations = list(kwargs.items())
|
||||
super(override_settings, self).__init__()
|
||||
|
||||
def save_options(self, test_func):
|
||||
if test_func._modified_settings is None:
|
||||
test_func._modified_settings = self.operations
|
||||
else:
|
||||
# Duplicate list to prevent subclasses from altering their parent.
|
||||
test_func._modified_settings = list(
|
||||
test_func._modified_settings) + self.operations
|
||||
|
||||
def enable(self):
|
||||
self.options = {}
|
||||
for name, operations in self.operations:
|
||||
try:
|
||||
# When called from SimpleTestCase.setUpClass, values may be
|
||||
# overridden several times; cumulate changes.
|
||||
value = self.options[name]
|
||||
except KeyError:
|
||||
value = list(getattr(settings, name, []))
|
||||
for action, items in operations.items():
|
||||
# items my be a single value or an iterable.
|
||||
if isinstance(items, str):
|
||||
items = [items]
|
||||
if action == 'append':
|
||||
value = value + [item for item in items if item not in value]
|
||||
elif action == 'prepend':
|
||||
value = [item for item in items if item not in value] + value
|
||||
elif action == 'remove':
|
||||
value = [item for item in value if item not in items]
|
||||
else:
|
||||
raise ValueError("Unsupported action: %s" % action)
|
||||
self.options[name] = value
|
||||
super().enable()
|
||||
|
||||
|
||||
class override_system_checks(TestContextDecorator):
|
||||
"""
|
||||
Act as a decorator. Override list of registered system checks.
|
||||
Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app,
|
||||
you also need to exclude its system checks.
|
||||
"""
|
||||
def __init__(self, new_checks, deployment_checks=None):
|
||||
from django.core.checks.registry import registry
|
||||
self.registry = registry
|
||||
self.new_checks = new_checks
|
||||
self.deployment_checks = deployment_checks
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
self.old_checks = self.registry.registered_checks
|
||||
self.registry.registered_checks = set()
|
||||
for check in self.new_checks:
|
||||
self.registry.register(check, *getattr(check, 'tags', ()))
|
||||
self.old_deployment_checks = self.registry.deployment_checks
|
||||
if self.deployment_checks is not None:
|
||||
self.registry.deployment_checks = set()
|
||||
for check in self.deployment_checks:
|
||||
self.registry.register(check, *getattr(check, 'tags', ()), deploy=True)
|
||||
|
||||
def disable(self):
|
||||
self.registry.registered_checks = self.old_checks
|
||||
self.registry.deployment_checks = self.old_deployment_checks
|
||||
|
||||
|
||||
def compare_xml(want, got):
|
||||
"""
|
||||
Try to do a 'xml-comparison' of want and got. Plain string comparison
|
||||
doesn't always work because, for example, attribute ordering should not be
|
||||
important. Ignore comment nodes, processing instructions, document type
|
||||
node, and leading and trailing whitespaces.
|
||||
|
||||
Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py
|
||||
"""
|
||||
_norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
|
||||
|
||||
def norm_whitespace(v):
|
||||
return _norm_whitespace_re.sub(' ', v)
|
||||
|
||||
def child_text(element):
|
||||
return ''.join(c.data for c in element.childNodes
|
||||
if c.nodeType == Node.TEXT_NODE)
|
||||
|
||||
def children(element):
|
||||
return [c for c in element.childNodes
|
||||
if c.nodeType == Node.ELEMENT_NODE]
|
||||
|
||||
def norm_child_text(element):
|
||||
return norm_whitespace(child_text(element))
|
||||
|
||||
def attrs_dict(element):
|
||||
return dict(element.attributes.items())
|
||||
|
||||
def check_element(want_element, got_element):
|
||||
if want_element.tagName != got_element.tagName:
|
||||
return False
|
||||
if norm_child_text(want_element) != norm_child_text(got_element):
|
||||
return False
|
||||
if attrs_dict(want_element) != attrs_dict(got_element):
|
||||
return False
|
||||
want_children = children(want_element)
|
||||
got_children = children(got_element)
|
||||
if len(want_children) != len(got_children):
|
||||
return False
|
||||
return all(check_element(want, got) for want, got in zip(want_children, got_children))
|
||||
|
||||
def first_node(document):
|
||||
for node in document.childNodes:
|
||||
if node.nodeType not in (
|
||||
Node.COMMENT_NODE,
|
||||
Node.DOCUMENT_TYPE_NODE,
|
||||
Node.PROCESSING_INSTRUCTION_NODE,
|
||||
):
|
||||
return node
|
||||
|
||||
want = want.strip().replace('\\n', '\n')
|
||||
got = got.strip().replace('\\n', '\n')
|
||||
|
||||
# If the string is not a complete xml document, we may need to add a
|
||||
# root element. This allow us to compare fragments, like "<foo/><bar/>"
|
||||
if not want.startswith('<?xml'):
|
||||
wrapper = '<root>%s</root>'
|
||||
want = wrapper % want
|
||||
got = wrapper % got
|
||||
|
||||
# Parse the want and got strings, and compare the parsings.
|
||||
want_root = first_node(parseString(want))
|
||||
got_root = first_node(parseString(got))
|
||||
|
||||
return check_element(want_root, got_root)
|
||||
|
||||
|
||||
class CaptureQueriesContext:
|
||||
"""
|
||||
Context manager that captures queries executed by the specified connection.
|
||||
"""
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.captured_queries)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.captured_queries[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.captured_queries)
|
||||
|
||||
@property
|
||||
def captured_queries(self):
|
||||
return self.connection.queries[self.initial_queries:self.final_queries]
|
||||
|
||||
def __enter__(self):
|
||||
self.force_debug_cursor = self.connection.force_debug_cursor
|
||||
self.connection.force_debug_cursor = True
|
||||
# Run any initialization queries if needed so that they won't be
|
||||
# included as part of the count.
|
||||
self.connection.ensure_connection()
|
||||
self.initial_queries = len(self.connection.queries_log)
|
||||
self.final_queries = None
|
||||
request_started.disconnect(reset_queries)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.connection.force_debug_cursor = self.force_debug_cursor
|
||||
request_started.connect(reset_queries)
|
||||
if exc_type is not None:
|
||||
return
|
||||
self.final_queries = len(self.connection.queries_log)
|
||||
|
||||
|
||||
class ignore_warnings(TestContextDecorator):
|
||||
def __init__(self, **kwargs):
|
||||
self.ignore_kwargs = kwargs
|
||||
if 'message' in self.ignore_kwargs or 'module' in self.ignore_kwargs:
|
||||
self.filter_func = warnings.filterwarnings
|
||||
else:
|
||||
self.filter_func = warnings.simplefilter
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
self.catch_warnings = warnings.catch_warnings()
|
||||
self.catch_warnings.__enter__()
|
||||
self.filter_func('ignore', **self.ignore_kwargs)
|
||||
|
||||
def disable(self):
|
||||
self.catch_warnings.__exit__(*sys.exc_info())
|
||||
|
||||
|
||||
# On OSes that don't provide tzset (Windows), we can't set the timezone
|
||||
# in which the program runs. As a consequence, we must skip tests that
|
||||
# don't enforce a specific timezone (with timezone.override or equivalent),
|
||||
# or attempt to interpret naive datetimes in the default timezone.
|
||||
|
||||
requires_tz_support = skipUnless(
|
||||
TZ_SUPPORT,
|
||||
"This test relies on the ability to run a program in an arbitrary "
|
||||
"time zone, but your operating system isn't able to do that."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def extend_sys_path(*paths):
|
||||
"""Context manager to temporarily add paths to sys.path."""
|
||||
_orig_sys_path = sys.path[:]
|
||||
sys.path.extend(paths)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.path = _orig_sys_path
|
||||
|
||||
|
||||
@contextmanager
|
||||
def isolate_lru_cache(lru_cache_object):
|
||||
"""Clear the cache of an LRU cache object on entering and exiting."""
|
||||
lru_cache_object.cache_clear()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lru_cache_object.cache_clear()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def captured_output(stream_name):
|
||||
"""Return a context manager used by captured_stdout/stdin/stderr
|
||||
that temporarily replaces the sys stream *stream_name* with a StringIO.
|
||||
|
||||
Note: This function and the following ``captured_std*`` are copied
|
||||
from CPython's ``test.support`` module."""
|
||||
orig_stdout = getattr(sys, stream_name)
|
||||
setattr(sys, stream_name, StringIO())
|
||||
try:
|
||||
yield getattr(sys, stream_name)
|
||||
finally:
|
||||
setattr(sys, stream_name, orig_stdout)
|
||||
|
||||
|
||||
def captured_stdout():
|
||||
"""Capture the output of sys.stdout:
|
||||
|
||||
with captured_stdout() as stdout:
|
||||
print("hello")
|
||||
self.assertEqual(stdout.getvalue(), "hello\n")
|
||||
"""
|
||||
return captured_output("stdout")
|
||||
|
||||
|
||||
def captured_stderr():
|
||||
"""Capture the output of sys.stderr:
|
||||
|
||||
with captured_stderr() as stderr:
|
||||
print("hello", file=sys.stderr)
|
||||
self.assertEqual(stderr.getvalue(), "hello\n")
|
||||
"""
|
||||
return captured_output("stderr")
|
||||
|
||||
|
||||
def captured_stdin():
|
||||
"""Capture the input to sys.stdin:
|
||||
|
||||
with captured_stdin() as stdin:
|
||||
stdin.write('hello\n')
|
||||
stdin.seek(0)
|
||||
# call test code that consumes from sys.stdin
|
||||
captured = input()
|
||||
self.assertEqual(captured, "hello")
|
||||
"""
|
||||
return captured_output("stdin")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def freeze_time(t):
|
||||
"""
|
||||
Context manager to temporarily freeze time.time(). This temporarily
|
||||
modifies the time function of the time module. Modules which import the
|
||||
time function directly (e.g. `from time import time`) won't be affected
|
||||
This isn't meant as a public API, but helps reduce some repetitive code in
|
||||
Django's test suite.
|
||||
"""
|
||||
_real_time = time.time
|
||||
time.time = lambda: t
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
time.time = _real_time
|
||||
|
||||
|
||||
def require_jinja2(test_func):
|
||||
"""
|
||||
Decorator to enable a Jinja2 template engine in addition to the regular
|
||||
Django template engine for a test or skip it if Jinja2 isn't available.
|
||||
"""
|
||||
test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func)
|
||||
return override_settings(TEMPLATES=[{
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'APP_DIRS': True,
|
||||
}, {
|
||||
'BACKEND': 'django.template.backends.jinja2.Jinja2',
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {'keep_trailing_newline': True},
|
||||
}])(test_func)
|
||||
|
||||
|
||||
class override_script_prefix(TestContextDecorator):
|
||||
"""Decorator or context manager to temporary override the script prefix."""
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
self.old_prefix = get_script_prefix()
|
||||
set_script_prefix(self.prefix)
|
||||
|
||||
def disable(self):
|
||||
set_script_prefix(self.old_prefix)
|
||||
|
||||
|
||||
class LoggingCaptureMixin:
|
||||
"""
|
||||
Capture the output from the 'django' logger and store it on the class's
|
||||
logger_output attribute.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.logger = logging.getLogger('django')
|
||||
self.old_stream = self.logger.handlers[0].stream
|
||||
self.logger_output = StringIO()
|
||||
self.logger.handlers[0].stream = self.logger_output
|
||||
|
||||
def tearDown(self):
|
||||
self.logger.handlers[0].stream = self.old_stream
|
||||
|
||||
|
||||
class isolate_apps(TestContextDecorator):
|
||||
"""
|
||||
Act as either a decorator or a context manager to register models defined
|
||||
in its wrapped context to an isolated registry.
|
||||
|
||||
The list of installed apps the isolated registry should contain must be
|
||||
passed as arguments.
|
||||
|
||||
Two optional keyword arguments can be specified:
|
||||
|
||||
`attr_name`: attribute assigned the isolated registry if used as a class
|
||||
decorator.
|
||||
|
||||
`kwarg_name`: keyword argument passing the isolated registry if used as a
|
||||
function decorator.
|
||||
"""
|
||||
def __init__(self, *installed_apps, **kwargs):
|
||||
self.installed_apps = installed_apps
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def enable(self):
|
||||
self.old_apps = Options.default_apps
|
||||
apps = Apps(self.installed_apps)
|
||||
setattr(Options, 'default_apps', apps)
|
||||
return apps
|
||||
|
||||
def disable(self):
|
||||
setattr(Options, 'default_apps', self.old_apps)
|
||||
|
||||
|
||||
class TimeKeeper:
|
||||
def __init__(self):
|
||||
self.records = collections.defaultdict(list)
|
||||
|
||||
@contextmanager
|
||||
def timed(self, name):
|
||||
self.records[name]
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
end_time = time.perf_counter() - start_time
|
||||
self.records[name].append(end_time)
|
||||
|
||||
def print_results(self):
|
||||
for name, end_times in self.records.items():
|
||||
for record_time in end_times:
|
||||
record = '%s took %.3fs' % (name, record_time)
|
||||
sys.stderr.write(record + os.linesep)
|
||||
|
||||
|
||||
class NullTimeKeeper:
|
||||
@contextmanager
|
||||
def timed(self, name):
|
||||
yield
|
||||
|
||||
def print_results(self):
|
||||
pass
|
||||
|
||||
|
||||
def tag(*tags):
|
||||
"""Decorator to add tags to a test class or method."""
|
||||
def decorator(obj):
|
||||
if hasattr(obj, 'tags'):
|
||||
obj.tags = obj.tags.union(tags)
|
||||
else:
|
||||
setattr(obj, 'tags', set(tags))
|
||||
return obj
|
||||
return decorator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def register_lookup(field, *lookups, lookup_name=None):
|
||||
"""
|
||||
Context manager to temporarily register lookups on a model field using
|
||||
lookup_name (or the lookup's lookup_name if not provided).
|
||||
"""
|
||||
try:
|
||||
for lookup in lookups:
|
||||
field.register_lookup(lookup, lookup_name)
|
||||
yield
|
||||
finally:
|
||||
for lookup in lookups:
|
||||
field._unregister_lookup(lookup, lookup_name)
|
Reference in New Issue
Block a user