Ajoutez des fichiers projet.

This commit is contained in:
Ambulance Clerc
2021-12-18 18:43:17 +01:00
parent 3c4d48ed26
commit 46254605fc
4842 changed files with 732322 additions and 0 deletions

View File

@@ -0,0 +1,46 @@
from .comparison import (
Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf,
)
from .datetime import (
Extract, ExtractDay, ExtractHour, ExtractIsoWeekDay, ExtractIsoYear,
ExtractMinute, ExtractMonth, ExtractQuarter, ExtractSecond, ExtractWeek,
ExtractWeekDay, ExtractYear, Now, Trunc, TruncDate, TruncDay, TruncHour,
TruncMinute, TruncMonth, TruncQuarter, TruncSecond, TruncTime, TruncWeek,
TruncYear,
)
from .math import (
Abs, ACos, ASin, ATan, ATan2, Ceil, Cos, Cot, Degrees, Exp, Floor, Ln, Log,
Mod, Pi, Power, Radians, Random, Round, Sign, Sin, Sqrt, Tan,
)
from .text import (
MD5, SHA1, SHA224, SHA256, SHA384, SHA512, Chr, Concat, ConcatPair, Left,
Length, Lower, LPad, LTrim, Ord, Repeat, Replace, Reverse, Right, RPad,
RTrim, StrIndex, Substr, Trim, Upper,
)
from .window import (
CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
PercentRank, Rank, RowNumber,
)
__all__ = [
# comparison and conversion
'Cast', 'Coalesce', 'Collate', 'Greatest', 'JSONObject', 'Least', 'NullIf',
# datetime
'Extract', 'ExtractDay', 'ExtractHour', 'ExtractMinute', 'ExtractMonth',
'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractIsoWeekDay',
'ExtractWeekDay', 'ExtractIsoYear', 'ExtractYear', 'Now', 'Trunc',
'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'TruncMonth',
'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncWeek', 'TruncYear',
# math
'Abs', 'ACos', 'ASin', 'ATan', 'ATan2', 'Ceil', 'Cos', 'Cot', 'Degrees',
'Exp', 'Floor', 'Ln', 'Log', 'Mod', 'Pi', 'Power', 'Radians', 'Random',
'Round', 'Sign', 'Sin', 'Sqrt', 'Tan',
# text
'MD5', 'SHA1', 'SHA224', 'SHA256', 'SHA384', 'SHA512', 'Chr', 'Concat',
'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim', 'Ord', 'Repeat',
'Replace', 'Reverse', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr',
'Trim', 'Upper',
# window
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
]

View File

@@ -0,0 +1,193 @@
"""Database functions that do comparisons or type conversions."""
from django.db import NotSupportedError
from django.db.models.expressions import Func, Value
from django.db.models.fields.json import JSONField
from django.utils.regex_helper import _lazy_re_compile
class Cast(Func):
"""Coerce an expression to a new field type."""
function = 'CAST'
template = '%(function)s(%(expressions)s AS %(db_type)s)'
def __init__(self, expression, output_field):
super().__init__(expression, output_field=output_field)
def as_sql(self, compiler, connection, **extra_context):
extra_context['db_type'] = self.output_field.cast_db_type(connection)
return super().as_sql(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
db_type = self.output_field.db_type(connection)
if db_type in {'datetime', 'time'}:
# Use strftime as datetime/time don't keep fractional seconds.
template = 'strftime(%%s, %(expressions)s)'
sql, params = super().as_sql(compiler, connection, template=template, **extra_context)
format_string = '%H:%M:%f' if db_type == 'time' else '%Y-%m-%d %H:%M:%f'
params.insert(0, format_string)
return sql, params
elif db_type == 'date':
template = 'date(%(expressions)s)'
return super().as_sql(compiler, connection, template=template, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection, **extra_context):
template = None
output_type = self.output_field.get_internal_type()
# MySQL doesn't support explicit cast to float.
if output_type == 'FloatField':
template = '(%(expressions)s + 0.0)'
# MariaDB doesn't support explicit cast to JSON.
elif output_type == 'JSONField' and connection.mysql_is_mariadb:
template = "JSON_EXTRACT(%(expressions)s, '$')"
return self.as_sql(compiler, connection, template=template, **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
# CAST would be valid too, but the :: shortcut syntax is more readable.
# 'expressions' is wrapped in parentheses in case it's a complex
# expression.
return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'JSONField':
# Oracle doesn't support explicit cast to JSON.
template = "JSON_QUERY(%(expressions)s, '$')"
return super().as_sql(compiler, connection, template=template, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
class Coalesce(Func):
"""Return, from left to right, the first non-null expression."""
function = 'COALESCE'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Coalesce must take at least two expressions')
super().__init__(*expressions, **extra)
@property
def empty_result_set_value(self):
for expression in self.get_source_expressions():
result = expression.empty_result_set_value
if result is NotImplemented or result is not None:
return result
return None
def as_oracle(self, compiler, connection, **extra_context):
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
# so convert all fields to NCLOB when that type is expected.
if self.output_field.get_internal_type() == 'TextField':
clone = self.copy()
clone.set_source_expressions([
Func(expression, function='TO_NCLOB') for expression in self.get_source_expressions()
])
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
class Collate(Func):
function = 'COLLATE'
template = '%(expressions)s %(function)s %(collation)s'
# Inspired from https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
collation_re = _lazy_re_compile(r'^[\w\-]+$')
def __init__(self, expression, collation):
if not (collation and self.collation_re.match(collation)):
raise ValueError('Invalid collation name: %r.' % collation)
self.collation = collation
super().__init__(expression)
def as_sql(self, compiler, connection, **extra_context):
extra_context.setdefault('collation', connection.ops.quote_name(self.collation))
return super().as_sql(compiler, connection, **extra_context)
class Greatest(Func):
"""
Return the maximum expression.
If any expression is null the return value is database-specific:
On PostgreSQL, the maximum not-null expression is returned.
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
"""
function = 'GREATEST'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Greatest must take at least two expressions')
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MAX function on SQLite."""
return super().as_sqlite(compiler, connection, function='MAX', **extra_context)
class JSONObject(Func):
function = 'JSON_OBJECT'
output_field = JSONField()
def __init__(self, **fields):
expressions = []
for key, value in fields.items():
expressions.extend((Value(key), value))
super().__init__(*expressions)
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.has_json_object_function:
raise NotSupportedError(
'JSONObject() is not supported on this database backend.'
)
return super().as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
return self.as_sql(
compiler,
connection,
function='JSONB_BUILD_OBJECT',
**extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
class ArgJoiner:
def join(self, args):
args = [' VALUE '.join(arg) for arg in zip(args[::2], args[1::2])]
return ', '.join(args)
return self.as_sql(
compiler,
connection,
arg_joiner=ArgJoiner(),
template='%(function)s(%(expressions)s RETURNING CLOB)',
**extra_context,
)
class Least(Func):
"""
Return the minimum expression.
If any expression is null the return value is database-specific:
On PostgreSQL, return the minimum not-null expression.
On MySQL, Oracle, and SQLite, if any expression is null, return null.
"""
function = 'LEAST'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Least must take at least two expressions')
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MIN function on SQLite."""
return super().as_sqlite(compiler, connection, function='MIN', **extra_context)
class NullIf(Func):
function = 'NULLIF'
arity = 2
def as_oracle(self, compiler, connection, **extra_context):
expression1 = self.get_source_expressions()[0]
if isinstance(expression1, Value) and expression1.value is None:
raise ValueError('Oracle does not allow Value(None) for expression1.')
return super().as_sql(compiler, connection, **extra_context)

View File

@@ -0,0 +1,339 @@
from datetime import datetime
from django.conf import settings
from django.db.models.expressions import Func
from django.db.models.fields import (
DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
)
from django.db.models.lookups import (
Transform, YearExact, YearGt, YearGte, YearLt, YearLte,
)
from django.utils import timezone
class TimezoneMixin:
tzinfo = None
def get_tzname(self):
# Timezone conversions must happen to the input datetime *before*
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
# database as 2016-01-01 01:00:00 +00:00. Any results should be
# based on the input datetime not the stored datetime.
tzname = None
if settings.USE_TZ:
if self.tzinfo is None:
tzname = timezone.get_current_timezone_name()
else:
tzname = timezone._get_timezone_name(self.tzinfo)
return tzname
class Extract(TimezoneMixin, Transform):
lookup_name = None
output_field = IntegerField()
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
if self.lookup_name is None:
self.lookup_name = lookup_name
if self.lookup_name is None:
raise ValueError('lookup_name must be provided')
self.tzinfo = tzinfo
super().__init__(expression, **extra)
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField):
tzname = self.get_tzname()
sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
elif self.tzinfo is not None:
raise ValueError('tzinfo can only be used with DateTimeField.')
elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, TimeField):
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, DurationField):
if not connection.features.has_native_duration_field:
raise ValueError('Extract requires native DurationField database support.')
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
else:
# resolve_expression has already validated the output_field so this
# assert should never be hit.
assert False, "Tried to Extract from an invalid type."
return sql, params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
field = getattr(copy.lhs, 'output_field', None)
if field is None:
return copy
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
raise ValueError(
'Extract input expression must be DateField, DateTimeField, '
'TimeField, or DurationField.'
)
# Passing dates to functions expecting datetimes is most likely a mistake.
if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
raise ValueError(
"Cannot extract time component '%s' from DateField '%s'." % (copy.lookup_name, field.name)
)
if (
isinstance(field, DurationField) and
copy.lookup_name in ('year', 'iso_year', 'month', 'week', 'week_day', 'iso_week_day', 'quarter')
):
raise ValueError(
"Cannot extract component '%s' from DurationField '%s'."
% (copy.lookup_name, field.name)
)
return copy
class ExtractYear(Extract):
lookup_name = 'year'
class ExtractIsoYear(Extract):
"""Return the ISO-8601 week-numbering year."""
lookup_name = 'iso_year'
class ExtractMonth(Extract):
lookup_name = 'month'
class ExtractDay(Extract):
lookup_name = 'day'
class ExtractWeek(Extract):
"""
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
week.
"""
lookup_name = 'week'
class ExtractWeekDay(Extract):
"""
Return Sunday=1 through Saturday=7.
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
"""
lookup_name = 'week_day'
class ExtractIsoWeekDay(Extract):
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
lookup_name = 'iso_week_day'
class ExtractQuarter(Extract):
lookup_name = 'quarter'
class ExtractHour(Extract):
lookup_name = 'hour'
class ExtractMinute(Extract):
lookup_name = 'minute'
class ExtractSecond(Extract):
lookup_name = 'second'
DateField.register_lookup(ExtractYear)
DateField.register_lookup(ExtractMonth)
DateField.register_lookup(ExtractDay)
DateField.register_lookup(ExtractWeekDay)
DateField.register_lookup(ExtractIsoWeekDay)
DateField.register_lookup(ExtractWeek)
DateField.register_lookup(ExtractIsoYear)
DateField.register_lookup(ExtractQuarter)
TimeField.register_lookup(ExtractHour)
TimeField.register_lookup(ExtractMinute)
TimeField.register_lookup(ExtractSecond)
DateTimeField.register_lookup(ExtractHour)
DateTimeField.register_lookup(ExtractMinute)
DateTimeField.register_lookup(ExtractSecond)
ExtractYear.register_lookup(YearExact)
ExtractYear.register_lookup(YearGt)
ExtractYear.register_lookup(YearGte)
ExtractYear.register_lookup(YearLt)
ExtractYear.register_lookup(YearLte)
ExtractIsoYear.register_lookup(YearExact)
ExtractIsoYear.register_lookup(YearGt)
ExtractIsoYear.register_lookup(YearGte)
ExtractIsoYear.register_lookup(YearLt)
ExtractIsoYear.register_lookup(YearLte)
class Now(Func):
template = 'CURRENT_TIMESTAMP'
output_field = DateTimeField()
def as_postgresql(self, compiler, connection, **extra_context):
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
# other databases.
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
class TruncBase(TimezoneMixin, Transform):
kind = None
tzinfo = None
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
def __init__(self, expression, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra):
self.tzinfo = tzinfo
self.is_dst = is_dst
super().__init__(expression, output_field=output_field, **extra)
def as_sql(self, compiler, connection):
inner_sql, inner_params = compiler.compile(self.lhs)
tzname = None
if isinstance(self.lhs.output_field, DateTimeField):
tzname = self.get_tzname()
elif self.tzinfo is not None:
raise ValueError('tzinfo can only be used with DateTimeField.')
if isinstance(self.output_field, DateTimeField):
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, DateField):
sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, TimeField):
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
else:
raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
return sql, inner_params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
field = copy.lhs.output_field
# DateTimeField is a subclass of DateField so this works for both.
if not isinstance(field, (DateField, TimeField)):
raise TypeError(
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
)
# If self.output_field was None, then accessing the field will trigger
# the resolver to assign it to self.lhs.output_field.
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
# Passing dates or times to functions expecting datetimes is most
# likely a mistake.
class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
output_field = class_output_field or copy.output_field
has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
if type(field) == DateField and (
isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
raise ValueError("Cannot truncate DateField '%s' to %s." % (
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
))
elif isinstance(field, TimeField) and (
isinstance(output_field, DateTimeField) or
copy.kind in ('year', 'quarter', 'month', 'week', 'day', 'date')):
raise ValueError("Cannot truncate TimeField '%s' to %s." % (
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
))
return copy
def convert_value(self, value, expression, connection):
if isinstance(self.output_field, DateTimeField):
if not settings.USE_TZ:
pass
elif value is not None:
value = value.replace(tzinfo=None)
value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
elif not connection.features.has_zoneinfo_database:
raise ValueError(
'Database returned an invalid datetime value. Are time '
'zone definitions for your database installed?'
)
elif isinstance(value, datetime):
if value is None:
pass
elif isinstance(self.output_field, DateField):
value = value.date()
elif isinstance(self.output_field, TimeField):
value = value.time()
return value
class Trunc(TruncBase):
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
def __init__(self, expression, kind, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra):
self.kind = kind
super().__init__(
expression, output_field=output_field, tzinfo=tzinfo,
is_dst=is_dst, **extra
)
class TruncYear(TruncBase):
kind = 'year'
class TruncQuarter(TruncBase):
kind = 'quarter'
class TruncMonth(TruncBase):
kind = 'month'
class TruncWeek(TruncBase):
"""Truncate to midnight on the Monday of the week."""
kind = 'week'
class TruncDay(TruncBase):
kind = 'day'
class TruncDate(TruncBase):
kind = 'date'
lookup_name = 'date'
output_field = DateField()
def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date.
lhs, lhs_params = compiler.compile(self.lhs)
tzname = self.get_tzname()
sql = connection.ops.datetime_cast_date_sql(lhs, tzname)
return sql, lhs_params
class TruncTime(TruncBase):
kind = 'time'
lookup_name = 'time'
output_field = TimeField()
def as_sql(self, compiler, connection):
# Cast to time rather than truncate to time.
lhs, lhs_params = compiler.compile(self.lhs)
tzname = self.get_tzname()
sql = connection.ops.datetime_cast_time_sql(lhs, tzname)
return sql, lhs_params
class TruncHour(TruncBase):
kind = 'hour'
class TruncMinute(TruncBase):
kind = 'minute'
class TruncSecond(TruncBase):
kind = 'second'
DateTimeField.register_lookup(TruncDate)
DateTimeField.register_lookup(TruncTime)

View File

@@ -0,0 +1,197 @@
import math
from django.db.models.expressions import Func, Value
from django.db.models.fields import FloatField, IntegerField
from django.db.models.functions import Cast
from django.db.models.functions.mixins import (
FixDecimalInputMixin, NumericOutputFieldMixin,
)
from django.db.models.lookups import Transform
class Abs(Transform):
function = 'ABS'
lookup_name = 'abs'
class ACos(NumericOutputFieldMixin, Transform):
function = 'ACOS'
lookup_name = 'acos'
class ASin(NumericOutputFieldMixin, Transform):
function = 'ASIN'
lookup_name = 'asin'
class ATan(NumericOutputFieldMixin, Transform):
function = 'ATAN'
lookup_name = 'atan'
class ATan2(NumericOutputFieldMixin, Func):
function = 'ATAN2'
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version >= (5, 0, 0):
return self.as_sql(compiler, connection)
# This function is usually ATan2(y, x), returning the inverse tangent
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
# Cast integers to float to avoid inconsistent/buggy behavior if the
# arguments are mixed between integer and float or decimal.
# https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
clone = self.copy()
clone.set_source_expressions([
Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField)
else expression for expression in self.get_source_expressions()[::-1]
])
return clone.as_sql(compiler, connection, **extra_context)
class Ceil(Transform):
function = 'CEILING'
lookup_name = 'ceil'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='CEIL', **extra_context)
class Cos(NumericOutputFieldMixin, Transform):
function = 'COS'
lookup_name = 'cos'
class Cot(NumericOutputFieldMixin, Transform):
function = 'COT'
lookup_name = 'cot'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
class Degrees(NumericOutputFieldMixin, Transform):
function = 'DEGREES'
lookup_name = 'degrees'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection,
template='((%%(expressions)s) * 180 / %s)' % math.pi,
**extra_context
)
class Exp(NumericOutputFieldMixin, Transform):
function = 'EXP'
lookup_name = 'exp'
class Floor(Transform):
function = 'FLOOR'
lookup_name = 'floor'
class Ln(NumericOutputFieldMixin, Transform):
function = 'LN'
lookup_name = 'ln'
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = 'LOG'
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(connection.ops, 'spatialite', False):
return self.as_sql(compiler, connection)
# This function is usually Log(b, x) returning the logarithm of x to
# the base b, but on SpatiaLite it's Log(x, b).
clone = self.copy()
clone.set_source_expressions(self.get_source_expressions()[::-1])
return clone.as_sql(compiler, connection, **extra_context)
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = 'MOD'
arity = 2
class Pi(NumericOutputFieldMixin, Func):
function = 'PI'
arity = 0
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
class Power(NumericOutputFieldMixin, Func):
function = 'POWER'
arity = 2
class Radians(NumericOutputFieldMixin, Transform):
function = 'RADIANS'
lookup_name = 'radians'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection,
template='((%%(expressions)s) * %s / 180)' % math.pi,
**extra_context
)
class Random(NumericOutputFieldMixin, Func):
function = 'RANDOM'
arity = 0
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='RAND', **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='DBMS_RANDOM.VALUE', **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='RAND', **extra_context)
def get_group_by_cols(self, alias=None):
return []
class Round(FixDecimalInputMixin, Transform):
function = 'ROUND'
lookup_name = 'round'
arity = None # Override Transform's arity=1 to enable passing precision.
def __init__(self, expression, precision=0, **extra):
super().__init__(expression, precision, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
precision = self.get_source_expressions()[1]
if isinstance(precision, Value) and precision.value < 0:
raise ValueError('SQLite does not support negative precision.')
return super().as_sqlite(compiler, connection, **extra_context)
def _resolve_output_field(self):
source = self.get_source_expressions()[0]
return source.output_field
class Sign(Transform):
function = 'SIGN'
lookup_name = 'sign'
class Sin(NumericOutputFieldMixin, Transform):
function = 'SIN'
lookup_name = 'sin'
class Sqrt(NumericOutputFieldMixin, Transform):
function = 'SQRT'
lookup_name = 'sqrt'
class Tan(NumericOutputFieldMixin, Transform):
function = 'TAN'
lookup_name = 'tan'

View File

@@ -0,0 +1,52 @@
import sys
from django.db.models.fields import DecimalField, FloatField, IntegerField
from django.db.models.functions import Cast
class FixDecimalInputMixin:
def as_postgresql(self, compiler, connection, **extra_context):
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
# following function signatures:
# - LOG(double, double)
# - MOD(double, double)
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
clone = self.copy()
clone.set_source_expressions([
Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
else expression for expression in self.get_source_expressions()
])
return clone.as_sql(compiler, connection, **extra_context)
class FixDurationInputMixin:
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s AS SIGNED)' % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
options = self._get_repr_options()
from django.db.backends.oracle.functions import (
IntervalToSeconds, SecondsToInterval,
)
return compiler.compile(
SecondsToInterval(self.__class__(IntervalToSeconds(expression), **options))
)
return super().as_sql(compiler, connection, **extra_context)
class NumericOutputFieldMixin:
def _resolve_output_field(self):
source_fields = self.get_source_fields()
if any(isinstance(s, DecimalField) for s in source_fields):
return DecimalField()
if any(isinstance(s, IntegerField) for s in source_fields):
return FloatField()
return super()._resolve_output_field() if source_fields else FloatField()

View File

@@ -0,0 +1,323 @@
from django.db import NotSupportedError
from django.db.models.expressions import Func, Value
from django.db.models.fields import CharField, IntegerField
from django.db.models.functions import Coalesce
from django.db.models.lookups import Transform
class MySQLSHA2Mixin:
def as_mysql(self, compiler, connection, **extra_content):
return super().as_sql(
compiler,
connection,
template='SHA2(%%(expressions)s, %s)' % self.function[3:],
**extra_content,
)
class OracleHashMixin:
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template=(
"LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW("
"%(expressions)s, 'AL32UTF8'), '%(function)s')))"
),
**extra_context,
)
class PostgreSQLSHAMixin:
def as_postgresql(self, compiler, connection, **extra_content):
return super().as_sql(
compiler,
connection,
template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
function=self.function.lower(),
**extra_content,
)
class Chr(Transform):
function = 'CHR'
lookup_name = 'chr'
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, function='CHAR',
template='%(function)s(%(expressions)s USING utf16)',
**extra_context
)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection,
template='%(function)s(%(expressions)s USING NCHAR_CS)',
**extra_context
)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='CHAR', **extra_context)
class ConcatPair(Func):
"""
Concatenate two arguments together. This is used by `Concat` because not
all backend databases support more than two arguments.
"""
function = 'CONCAT'
def as_sqlite(self, compiler, connection, **extra_context):
coalesced = self.coalesce()
return super(ConcatPair, coalesced).as_sql(
compiler, connection, template='%(expressions)s', arg_joiner=' || ',
**extra_context
)
def as_mysql(self, compiler, connection, **extra_context):
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
return super().as_sql(
compiler, connection, function='CONCAT_WS',
template="%(function)s('', %(expressions)s)",
**extra_context
)
def coalesce(self):
# null on either side results in null for expression, wrap with coalesce
c = self.copy()
c.set_source_expressions([
Coalesce(expression, Value('')) for expression in c.get_source_expressions()
])
return c
class Concat(Func):
"""
Concatenate text fields together. Backends that result in an entire
null expression when any arguments are null will wrap each argument in
coalesce functions to ensure a non-null result.
"""
function = None
template = "%(expressions)s"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Concat must take at least two expressions')
paired = self._paired(expressions)
super().__init__(paired, **extra)
def _paired(self, expressions):
# wrap pairs of expressions in successive concat functions
# exp = [a, b, c, d]
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
if len(expressions) == 2:
return ConcatPair(*expressions)
return ConcatPair(expressions[0], self._paired(expressions[1:]))
class Left(Func):
function = 'LEFT'
arity = 2
output_field = CharField()
def __init__(self, expression, length, **extra):
"""
expression: the name of a field, or an expression returning a string
length: the number of characters to return from the start of the string
"""
if not hasattr(length, 'resolve_expression'):
if length < 1:
raise ValueError("'length' must be greater than 0.")
super().__init__(expression, length, **extra)
def get_substr(self):
return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
def as_oracle(self, compiler, connection, **extra_context):
return self.get_substr().as_oracle(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return self.get_substr().as_sqlite(compiler, connection, **extra_context)
class Length(Transform):
"""Return the number of characters in the expression."""
function = 'LENGTH'
lookup_name = 'length'
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context)
class Lower(Transform):
function = 'LOWER'
lookup_name = 'lower'
class LPad(Func):
function = 'LPAD'
output_field = CharField()
def __init__(self, expression, length, fill_text=Value(' '), **extra):
if not hasattr(length, 'resolve_expression') and length is not None and length < 0:
raise ValueError("'length' must be greater or equal to 0.")
super().__init__(expression, length, fill_text, **extra)
class LTrim(Transform):
function = 'LTRIM'
lookup_name = 'ltrim'
class MD5(OracleHashMixin, Transform):
function = 'MD5'
lookup_name = 'md5'
class Ord(Transform):
function = 'ASCII'
lookup_name = 'ord'
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='ORD', **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='UNICODE', **extra_context)
class Repeat(Func):
function = 'REPEAT'
output_field = CharField()
def __init__(self, expression, number, **extra):
if not hasattr(number, 'resolve_expression') and number is not None and number < 0:
raise ValueError("'number' must be greater or equal to 0.")
super().__init__(expression, number, **extra)
def as_oracle(self, compiler, connection, **extra_context):
expression, number = self.source_expressions
length = None if number is None else Length(expression) * number
rpad = RPad(expression, length, expression)
return rpad.as_sql(compiler, connection, **extra_context)
class Replace(Func):
function = 'REPLACE'
def __init__(self, expression, text, replacement=Value(''), **extra):
super().__init__(expression, text, replacement, **extra)
class Reverse(Transform):
function = 'REVERSE'
lookup_name = 'reverse'
def as_oracle(self, compiler, connection, **extra_context):
# REVERSE in Oracle is undocumented and doesn't support multi-byte
# strings. Use a special subquery instead.
return super().as_sql(
compiler, connection,
template=(
'(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM '
'(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s '
'FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) '
'GROUP BY %(expressions)s)'
),
**extra_context
)
class Right(Left):
function = 'RIGHT'
def get_substr(self):
return Substr(self.source_expressions[0], self.source_expressions[1] * Value(-1))
class RPad(LPad):
function = 'RPAD'
class RTrim(Transform):
function = 'RTRIM'
lookup_name = 'rtrim'
class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = 'SHA1'
lookup_name = 'sha1'
class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
function = 'SHA224'
lookup_name = 'sha224'
def as_oracle(self, compiler, connection, **extra_context):
raise NotSupportedError('SHA224 is not supported on Oracle.')
class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = 'SHA256'
lookup_name = 'sha256'
class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = 'SHA384'
lookup_name = 'sha384'
class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = 'SHA512'
lookup_name = 'sha512'
class StrIndex(Func):
"""
Return a positive integer corresponding to the 1-indexed position of the
first occurrence of a substring inside another string, or 0 if the
substring is not found.
"""
function = 'INSTR'
arity = 2
output_field = IntegerField()
def as_postgresql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='STRPOS', **extra_context)
class Substr(Func):
function = 'SUBSTRING'
output_field = CharField()
def __init__(self, expression, pos, length=None, **extra):
"""
expression: the name of a field, or an expression returning a string
pos: an integer > 0, or an expression returning an integer
length: an optional number of characters to return
"""
if not hasattr(pos, 'resolve_expression'):
if pos < 1:
raise ValueError("'pos' must be greater than 0")
expressions = [expression, pos]
if length is not None:
expressions.append(length)
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
class Trim(Transform):
function = 'TRIM'
lookup_name = 'trim'
class Upper(Transform):
function = 'UPPER'
lookup_name = 'upper'

View File

@@ -0,0 +1,108 @@
from django.db.models.expressions import Func
from django.db.models.fields import FloatField, IntegerField
__all__ = [
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
]
class CumeDist(Func):
function = 'CUME_DIST'
output_field = FloatField()
window_compatible = True
class DenseRank(Func):
function = 'DENSE_RANK'
output_field = IntegerField()
window_compatible = True
class FirstValue(Func):
arity = 1
function = 'FIRST_VALUE'
window_compatible = True
class LagLeadFunction(Func):
window_compatible = True
def __init__(self, expression, offset=1, default=None, **extra):
if expression is None:
raise ValueError(
'%s requires a non-null source expression.' %
self.__class__.__name__
)
if offset is None or offset <= 0:
raise ValueError(
'%s requires a positive integer for the offset.' %
self.__class__.__name__
)
args = (expression, offset)
if default is not None:
args += (default,)
super().__init__(*args, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Lag(LagLeadFunction):
function = 'LAG'
class LastValue(Func):
arity = 1
function = 'LAST_VALUE'
window_compatible = True
class Lead(LagLeadFunction):
function = 'LEAD'
class NthValue(Func):
function = 'NTH_VALUE'
window_compatible = True
def __init__(self, expression, nth=1, **extra):
if expression is None:
raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__)
if nth is None or nth <= 0:
raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__)
super().__init__(expression, nth, **extra)
def _resolve_output_field(self):
sources = self.get_source_expressions()
return sources[0].output_field
class Ntile(Func):
function = 'NTILE'
output_field = IntegerField()
window_compatible = True
def __init__(self, num_buckets=1, **extra):
if num_buckets <= 0:
raise ValueError('num_buckets must be greater than 0.')
super().__init__(num_buckets, **extra)
class PercentRank(Func):
function = 'PERCENT_RANK'
output_field = FloatField()
window_compatible = True
class Rank(Func):
function = 'RANK'
output_field = IntegerField()
window_compatible = True
class RowNumber(Func):
function = 'ROW_NUMBER'
output_field = IntegerField()
window_compatible = True