Skip to content

Commit

Permalink
Fixed #31640 -- Made Trunc() truncate datetimes to Date/TimeField in …
Browse files Browse the repository at this point in the history
…a specific timezone.
  • Loading branch information
David-Wobrock authored and felixxm committed Oct 14, 2020
1 parent 8d01823 commit ee00532
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 32 deletions.
18 changes: 12 additions & 6 deletions django/db/backends/base/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,14 @@ def date_extract_sql(self, lookup_type, field_name):
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_extract_sql() method')

def date_trunc_sql(self, lookup_type, field_name):
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
"""
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
truncates the given date field field_name to a date object with only
the given specificity.
truncates the given date or datetime field field_name to a date object
with only the given specificity.
If `tzname` is provided, the given value is truncated in a specific
timezone.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_trunc_sql() method.')

Expand Down Expand Up @@ -138,11 +141,14 @@ def datetime_trunc_sql(self, lookup_type, field_name, tzname):
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method')

def time_trunc_sql(self, lookup_type, field_name):
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
"""
Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
that truncates the given time field field_name to a time object with
only the given specificity.
that truncates the given time or datetime field field_name to a time
object with only the given specificity.
If `tzname` is provided, the given value is truncated in a specific
timezone.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a time_trunc_sql() method')

Expand Down
8 changes: 5 additions & 3 deletions django/db/backends/mysql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def date_extract_sql(self, lookup_type, field_name):
# EXTRACT returns 1-53 based on ISO-8601 for the week number.
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)

def date_trunc_sql(self, lookup_type, field_name):
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = {
'year': '%%Y-01-01',
'month': '%%Y-%%m-01',
Expand All @@ -82,7 +83,7 @@ def _prepare_tzname_delta(self, tzname):
return tzname

def _convert_field_to_tz(self, field_name, tzname):
if settings.USE_TZ and self.connection.timezone_name != tzname:
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
field_name = "CONVERT_TZ(%s, '%s', '%s')" % (
field_name,
self.connection.timezone_name,
Expand Down Expand Up @@ -128,7 +129,8 @@ def datetime_trunc_sql(self, lookup_type, field_name, tzname):
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
return sql

def time_trunc_sql(self, lookup_type, field_name):
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = {
'hour': '%%H:00:00',
'minute': '%%H:%%i:00',
Expand Down
8 changes: 5 additions & 3 deletions django/db/backends/oracle/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def date_extract_sql(self, lookup_type, field_name):
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/EXTRACT-datetime.html
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)

def date_trunc_sql(self, lookup_type, field_name):
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
if lookup_type in ('year', 'month'):
return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
Expand All @@ -114,7 +115,7 @@ def _prepare_tzname_delta(self, tzname):
return tzname

def _convert_field_to_tz(self, field_name, tzname):
if not settings.USE_TZ:
if not (settings.USE_TZ and tzname):
return field_name
if not self._tzname_re.match(tzname):
raise ValueError("Invalid time zone name: %s" % tzname)
Expand Down Expand Up @@ -161,10 +162,11 @@ def datetime_trunc_sql(self, lookup_type, field_name, tzname):
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
return sql

def time_trunc_sql(self, lookup_type, field_name):
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
# The implementation is similar to `datetime_trunc_sql` as both
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where
# the date part of the later is ignored.
field_name = self._convert_field_to_tz(field_name, tzname)
if lookup_type == 'hour':
sql = "TRUNC(%s, 'HH24')" % field_name
elif lookup_type == 'minute':
Expand Down
8 changes: 5 additions & 3 deletions django/db/backends/postgresql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def date_extract_sql(self, lookup_type, field_name):
else:
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)

def date_trunc_sql(self, lookup_type, field_name):
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)

Expand All @@ -50,7 +51,7 @@ def _prepare_tzname_delta(self, tzname):
return tzname

def _convert_field_to_tz(self, field_name, tzname):
if settings.USE_TZ:
if tzname and settings.USE_TZ:
field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname))
return field_name

Expand All @@ -71,7 +72,8 @@ def datetime_trunc_sql(self, lookup_type, field_name, tzname):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)

def time_trunc_sql(self, lookup_type, field_name):
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name)

def deferrable_sql(self):
Expand Down
22 changes: 13 additions & 9 deletions django/db/backends/sqlite3/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,13 @@ def get_new_connection(self, conn_params):
else:
create_deterministic_function = conn.create_function
create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract)
create_deterministic_function('django_date_trunc', 2, _sqlite_date_trunc)
create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc)
create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date)
create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time)
create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract)
create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc)
create_deterministic_function('django_time_extract', 2, _sqlite_time_extract)
create_deterministic_function('django_time_trunc', 2, _sqlite_time_trunc)
create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc)
create_deterministic_function('django_time_diff', 2, _sqlite_time_diff)
create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff)
create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta)
Expand Down Expand Up @@ -445,8 +445,8 @@ def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
return dt


def _sqlite_date_trunc(lookup_type, dt):
dt = _sqlite_datetime_parse(dt)
def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == 'year':
Expand All @@ -463,13 +463,17 @@ def _sqlite_date_trunc(lookup_type, dt):
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)


def _sqlite_time_trunc(lookup_type, dt):
def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
if dt is None:
return None
try:
dt = backend_utils.typecast_time(dt)
except (ValueError, TypeError):
return None
dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt_parsed is None:
try:
dt = backend_utils.typecast_time(dt)
except (ValueError, TypeError):
return None
else:
dt = dt_parsed
if lookup_type == 'hour':
return "%02i:00:00" % dt.hour
elif lookup_type == 'minute':
Expand Down
18 changes: 13 additions & 5 deletions django/db/backends/sqlite3/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,22 @@ def format_for_duration_arithmetic(self, sql):
"""Do nothing since formatting is handled in the custom function."""
return sql

def date_trunc_sql(self, lookup_type, field_name):
return "django_date_trunc('%s', %s)" % (lookup_type.lower(), field_name)
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
return "django_date_trunc('%s', %s, %s, %s)" % (
lookup_type.lower(),
field_name,
*self._convert_tznames_to_sql(tzname),
)

def time_trunc_sql(self, lookup_type, field_name):
return "django_time_trunc('%s', %s)" % (lookup_type.lower(), field_name)
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
return "django_time_trunc('%s', %s, %s, %s)" % (
lookup_type.lower(),
field_name,
*self._convert_tznames_to_sql(tzname),
)

def _convert_tznames_to_sql(self, tzname):
if settings.USE_TZ:
if tzname and settings.USE_TZ:
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
return 'NULL', 'NULL'

Expand Down
10 changes: 7 additions & 3 deletions django/db/models/functions/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,17 @@ def __init__(self, expression, output_field=None, tzinfo=None, is_dst=None, **ex

def as_sql(self, compiler, connection):
inner_sql, inner_params = compiler.compile(self.lhs)
if isinstance(self.output_field, DateTimeField):
tzname = None
if isinstance(self.lhs.output_field, DateTimeField):
tzname = self.get_tzname()
elif self.tzinfo is not None:
raise ValueError('tzinfo can only be used with DateTimeField.')
if isinstance(self.output_field, DateTimeField):
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, DateField):
sql = connection.ops.date_trunc_sql(self.kind, inner_sql)
sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, TimeField):
sql = connection.ops.time_trunc_sql(self.kind, inner_sql)
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
else:
raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
return sql, inner_params
Expand Down
4 changes: 4 additions & 0 deletions docs/releases/3.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,10 @@ backends.
* ``DatabaseOperations.random_function_sql()`` is removed in favor of the new
:class:`~django.db.models.functions.Random` database function.

* ``DatabaseOperations.date_trunc_sql()`` and
``DatabaseOperations.time_trunc_sql()`` now take the optional ``tzname``
argument in order to truncate in a specific timezone.

:mod:`django.contrib.admin`
---------------------------

Expand Down
81 changes: 81 additions & 0 deletions tests/db_functions/datetime/test_extract_trunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,18 @@ def test_time_kind(kind):
lambda m: (m.start_datetime, m.truncated)
)

def test_datetime_to_time_kind(kind):
self.assertQuerysetEqual(
DTModel.objects.annotate(
truncated=Trunc('start_datetime', kind, output_field=TimeField()),
).order_by('start_datetime'),
[
(start_datetime, truncate_to(start_datetime.time(), kind)),
(end_datetime, truncate_to(end_datetime.time(), kind)),
],
lambda m: (m.start_datetime, m.truncated),
)

test_date_kind('year')
test_date_kind('quarter')
test_date_kind('month')
Expand All @@ -688,6 +700,9 @@ def test_time_kind(kind):
test_datetime_kind('hour')
test_datetime_kind('minute')
test_datetime_kind('second')
test_datetime_to_time_kind('hour')
test_datetime_to_time_kind('minute')
test_datetime_to_time_kind('second')

qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField()))
self.assertEqual(qs.count(), 2)
Expand Down Expand Up @@ -1205,6 +1220,60 @@ def test_datetime_kind(kind):
lambda m: (m.start_datetime, m.truncated)
)

def test_datetime_to_date_kind(kind):
self.assertQuerysetEqual(
DTModel.objects.annotate(
truncated=Trunc(
'start_datetime',
kind,
output_field=DateField(),
tzinfo=melb,
),
).order_by('start_datetime'),
[
(
start_datetime,
truncate_to(start_datetime.astimezone(melb).date(), kind),
),
(
end_datetime,
truncate_to(end_datetime.astimezone(melb).date(), kind),
),
],
lambda m: (m.start_datetime, m.truncated),
)

def test_datetime_to_time_kind(kind):
self.assertQuerysetEqual(
DTModel.objects.annotate(
truncated=Trunc(
'start_datetime',
kind,
output_field=TimeField(),
tzinfo=melb,
)
).order_by('start_datetime'),
[
(
start_datetime,
truncate_to(start_datetime.astimezone(melb).time(), kind),
),
(
end_datetime,
truncate_to(end_datetime.astimezone(melb).time(), kind),
),
],
lambda m: (m.start_datetime, m.truncated),
)

test_datetime_to_date_kind('year')
test_datetime_to_date_kind('quarter')
test_datetime_to_date_kind('month')
test_datetime_to_date_kind('week')
test_datetime_to_date_kind('day')
test_datetime_to_time_kind('hour')
test_datetime_to_time_kind('minute')
test_datetime_to_time_kind('second')
test_datetime_kind('year')
test_datetime_kind('quarter')
test_datetime_kind('month')
Expand All @@ -1216,3 +1285,15 @@ def test_datetime_kind(kind):

qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField()))
self.assertEqual(qs.count(), 2)

def test_trunc_invalid_field_with_timezone(self):
melb = pytz.timezone('Australia/Melbourne')
msg = 'tzinfo can only be used with DateTimeField.'
with self.assertRaisesMessage(ValueError, msg):
DTModel.objects.annotate(
day_melb=Trunc('start_date', 'day', tzinfo=melb),
).get()
with self.assertRaisesMessage(ValueError, msg):
DTModel.objects.annotate(
hour_melb=Trunc('start_time', 'hour', tzinfo=melb),
).get()

0 comments on commit ee00532

Please sign in to comment.