Skip to content

Commit

Permalink
Fixed #24629 -- Unified Transform and Expression APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
jarshwah committed Sep 21, 2015
1 parent 8dc3ba5 commit 534aaf5
Show file tree
Hide file tree
Showing 15 changed files with 522 additions and 377 deletions.
4 changes: 2 additions & 2 deletions django/contrib/postgres/fields/hstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def __call__(self, *args, **kwargs):


@HStoreField.register_lookup
class KeysTransform(lookups.FunctionTransform):
class KeysTransform(Transform):
lookup_name = 'keys'
function = 'akeys'
output_field = ArrayField(TextField())


@HStoreField.register_lookup
class ValuesTransform(lookups.FunctionTransform):
class ValuesTransform(Transform):
lookup_name = 'values'
function = 'avals'
output_field = ArrayField(TextField())
6 changes: 3 additions & 3 deletions django/contrib/postgres/fields/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class AdjacentToLookup(lookups.PostgresSimpleLookup):


@RangeField.register_lookup
class RangeStartsWith(lookups.FunctionTransform):
class RangeStartsWith(models.Transform):
lookup_name = 'startswith'
function = 'lower'

Expand All @@ -183,7 +183,7 @@ def output_field(self):


@RangeField.register_lookup
class RangeEndsWith(lookups.FunctionTransform):
class RangeEndsWith(models.Transform):
lookup_name = 'endswith'
function = 'upper'

Expand All @@ -193,7 +193,7 @@ def output_field(self):


@RangeField.register_lookup
class IsEmpty(lookups.FunctionTransform):
class IsEmpty(models.Transform):
lookup_name = 'isempty'
function = 'isempty'
output_field = models.BooleanField()
8 changes: 1 addition & 7 deletions django/contrib/postgres/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ def as_sql(self, qn, connection):
return '%s %s %s' % (lhs, self.operator, rhs), params


class FunctionTransform(Transform):
def as_sql(self, qn, connection):
lhs, params = qn.compile(self.lhs)
return "%s(%s)" % (self.function, lhs), params


class DataContains(PostgresSimpleLookup):
lookup_name = 'contains'
operator = '@>'
Expand Down Expand Up @@ -45,7 +39,7 @@ class HasAnyKeys(PostgresSimpleLookup):
operator = '?|'


class Unaccent(FunctionTransform):
class Unaccent(Transform):
bilateral = True
lookup_name = 'unaccent'
function = 'UNACCENT'
165 changes: 1 addition & 164 deletions django/db/models/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
# purposes.
from django.core.exceptions import FieldDoesNotExist # NOQA
from django.db import connection, connections, router
from django.db.models.lookups import (
Lookup, RegisterLookupMixin, Transform, default_lookups,
)
from django.db.models.query_utils import QueryWrapper
from django.db.models.query_utils import QueryWrapper, RegisterLookupMixin
from django.utils import six, timezone
from django.utils.datastructures import DictWrapper
from django.utils.dateparse import (
Expand Down Expand Up @@ -120,7 +117,6 @@ class Field(RegisterLookupMixin):
'unique_for_date': _("%(field_label)s must be unique for "
"%(date_field_label)s %(lookup_type)s."),
}
class_lookups = default_lookups.copy()
system_check_deprecated_details = None
system_check_removed_details = None

Expand Down Expand Up @@ -1492,22 +1488,6 @@ def formfield(self, **kwargs):
return super(DateTimeField, self).formfield(**defaults)


@DateTimeField.register_lookup
class DateTimeDateTransform(Transform):
lookup_name = 'date'

@cached_property
def output_field(self):
return DateField()

def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
lhs_params.extend(tz_params)
return sql, lhs_params


class DecimalField(Field):
empty_strings_allowed = False
default_error_messages = {
Expand Down Expand Up @@ -2450,146 +2430,3 @@ def formfield(self, **kwargs):
}
defaults.update(kwargs)
return super(UUIDField, self).formfield(**defaults)


class DateTransform(Transform):
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField):
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
params.extend(tz_params)
elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, TimeField):
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
else:
raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
return sql, params

@cached_property
def output_field(self):
return IntegerField()


class YearTransform(DateTransform):
lookup_name = 'year'


class YearLookup(Lookup):
def year_lookup_bounds(self, connection, year):
output_field = self.lhs.lhs.output_field
if isinstance(output_field, DateTimeField):
bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
else:
bounds = connection.ops.year_lookup_bounds_for_date_field(year)
return bounds


@YearTransform.register_lookup
class YearExact(YearLookup):
lookup_name = 'exact'

def as_sql(self, compiler, connection):
# We will need to skip the extract part and instead go
# directly with the originating field, that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
bounds = self.year_lookup_bounds(connection, rhs_params[0])
params.extend(bounds)
return '%s BETWEEN %%s AND %%s' % lhs_sql, params


class YearComparisonLookup(YearLookup):
def as_sql(self, compiler, connection):
# We will need to skip the extract part and instead go
# directly with the originating field, that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
rhs_sql = self.get_rhs_op(connection, rhs_sql)
start, finish = self.year_lookup_bounds(connection, rhs_params[0])
params.append(self.get_bound(start, finish))
return '%s %s' % (lhs_sql, rhs_sql), params

def get_rhs_op(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs

def get_bound(self):
raise NotImplementedError(
'subclasses of YearComparisonLookup must provide a get_bound() method'
)


@YearTransform.register_lookup
class YearGt(YearComparisonLookup):
lookup_name = 'gt'

def get_bound(self, start, finish):
return finish


@YearTransform.register_lookup
class YearGte(YearComparisonLookup):
lookup_name = 'gte'

def get_bound(self, start, finish):
return start


@YearTransform.register_lookup
class YearLt(YearComparisonLookup):
lookup_name = 'lt'

def get_bound(self, start, finish):
return start


@YearTransform.register_lookup
class YearLte(YearComparisonLookup):
lookup_name = 'lte'

def get_bound(self, start, finish):
return finish


class MonthTransform(DateTransform):
lookup_name = 'month'


class DayTransform(DateTransform):
lookup_name = 'day'


class WeekDayTransform(DateTransform):
lookup_name = 'week_day'


class HourTransform(DateTransform):
lookup_name = 'hour'


class MinuteTransform(DateTransform):
lookup_name = 'minute'


class SecondTransform(DateTransform):
lookup_name = 'second'


DateField.register_lookup(YearTransform)
DateField.register_lookup(MonthTransform)
DateField.register_lookup(DayTransform)
DateField.register_lookup(WeekDayTransform)

TimeField.register_lookup(HourTransform)
TimeField.register_lookup(MinuteTransform)
TimeField.register_lookup(SecondTransform)

DateTimeField.register_lookup(YearTransform)
DateTimeField.register_lookup(MonthTransform)
DateTimeField.register_lookup(DayTransform)
DateTimeField.register_lookup(WeekDayTransform)
DateTimeField.register_lookup(HourTransform)
DateTimeField.register_lookup(MinuteTransform)
DateTimeField.register_lookup(SecondTransform)
14 changes: 9 additions & 5 deletions django/db/models/functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
Classes that represent database functions.
"""
from django.db.models import DateTimeField, IntegerField
from django.db.models.expressions import Func, Value
from django.db.models import (
DateTimeField, Func, IntegerField, Transform, Value,
)


class Coalesce(Func):
Expand Down Expand Up @@ -123,9 +124,10 @@ def as_sqlite(self, compiler, connection):
return super(Least, self).as_sql(compiler, connection, function='MIN')


class Length(Func):
class Length(Transform):
"""Returns the number of characters in the expression"""
function = 'LENGTH'
lookup_name = 'length'

def __init__(self, expression, **extra):
output_field = extra.pop('output_field', IntegerField())
Expand All @@ -136,8 +138,9 @@ def as_mysql(self, compiler, connection):
return super(Length, self).as_sql(compiler, connection)


class Lower(Func):
class Lower(Transform):
function = 'LOWER'
lookup_name = 'lower'

def __init__(self, expression, **extra):
super(Lower, self).__init__(expression, **extra)
Expand Down Expand Up @@ -188,8 +191,9 @@ def as_oracle(self, compiler, connection):
return super(Substr, self).as_sql(compiler, connection)


class Upper(Func):
class Upper(Transform):
function = 'UPPER'
lookup_name = 'upper'

def __init__(self, expression, **extra):
super(Upper, self).__init__(expression, **extra)
Loading

0 comments on commit 534aaf5

Please sign in to comment.