From 83f6815dda2cc6114a5f575d06e2795d3afc5628 Mon Sep 17 00:00:00 2001 From: Shohan Dutta Roy Date: Mon, 19 Feb 2024 17:15:08 +0530 Subject: [PATCH 1/2] feat: Add distinct method to FakeQuerySet --- modelcluster/queryset.py | 16 ++++++++++++++++ tests/tests/test_cluster.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/modelcluster/queryset.py b/modelcluster/queryset.py index bd021ac..ffec014 100644 --- a/modelcluster/queryset.py +++ b/modelcluster/queryset.py @@ -3,6 +3,7 @@ import datetime import re +from django.db import NotSupportedError, connection from django.db.models import Model, prefetch_related_objects from modelcluster.utils import extract_field_value, get_model_field, sort_by_fields @@ -516,6 +517,21 @@ def order_by(self, *fields): clone = self.get_clone(results=self.results[:]) sort_by_fields(clone.results, fields) return clone + + def distinct(self, *fields): + if fields and connection.vendor != 'postgresql': + raise NotSupportedError("DISTINCT ON fields is not supported by this database backend") + + unique_results = [] + if not fields: + fields = [field.name for field in self.model._meta.fields if not field.primary_key] + seen_keys = set() + for result in self.results: + key = '$$$'.join([str(extract_field_value(result, field)) for field in fields]) + if key not in seen_keys: + seen_keys.add(key) + unique_results.append(result) + return self.get_clone(results=unique_results) # a standard QuerySet will store the results in _result_cache on running the query; # this is effectively the same as self.results on a FakeQuerySet, and so we'll make diff --git a/tests/tests/test_cluster.py b/tests/tests/test_cluster.py index 7abd3ad..eb27e1c 100644 --- a/tests/tests/test_cluster.py +++ b/tests/tests/test_cluster.py @@ -2,9 +2,10 @@ import datetime import itertools +from unittest.mock import patch from django.test import TestCase -from django.db import IntegrityError +from django.db import IntegrityError, NotSupportedError, connection from django.db.models import Prefetch from modelcluster.models import get_all_child_relations @@ -796,6 +797,36 @@ def test_meta_ordering(self): albums = [album.name for album in beatles.albums.all()] self.assertEqual(['With The Beatles', 'Please Please Me', 'Abbey Road'], albums) + def test_distinct_with_no_fields(self): + beatles = Band(name='The Beatles', albums=[ + Album(name='Please Please Me', sort_order=1), + Album(name='With The Beatles', sort_order=2), + Album(name='Abbey Road', sort_order=2), + ]) + + albums = [album.name for album in beatles.albums.order_by('sort_order').distinct()] + self.assertEqual(['Please Please Me', 'With The Beatles', 'Abbey Road'], albums) + + def test_distinct_with_fields(self): + beatles = Band(name='The Beatles', albums=[ + Album(name='Please Please Me', sort_order=1), + Album(name='With The Beatles', sort_order=2), + Album(name='Abbey Road', sort_order=2), + ]) + + for vendor in ['sqlite', 'mysql', 'oracle']: + with patch.object(connection, 'vendor', vendor): + with self.assertRaises(NotSupportedError): + beatles.albums.order_by('sort_order').distinct('sort_order') + + # patch db.connection.vendor to pass the vendor check + with patch.object(connection, 'vendor', 'postgresql'): + albums = [album.name for album in beatles.albums.order_by('sort_order').distinct('sort_order')] + self.assertEqual(['Please Please Me', 'With The Beatles'], albums) + + albums = [album.name for album in beatles.albums.order_by('sort_order').distinct('name')] + self.assertEqual(['Please Please Me', 'With The Beatles', 'Abbey Road'], albums) + def test_parental_key_checks_clusterable_model(self): from django.core import checks from django.db import models From a58451a82085cc2cad37f3108d135ef8cddbae9a Mon Sep 17 00:00:00 2001 From: Shohan Dutta Roy Date: Fri, 23 Feb 2024 14:13:53 +0530 Subject: [PATCH 2/2] feat: Remove check for postgres database for distinct with fields list --- modelcluster/queryset.py | 6 +----- tests/tests/test_cluster.py | 19 +++++-------------- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/modelcluster/queryset.py b/modelcluster/queryset.py index ffec014..00ffade 100644 --- a/modelcluster/queryset.py +++ b/modelcluster/queryset.py @@ -3,7 +3,6 @@ import datetime import re -from django.db import NotSupportedError, connection from django.db.models import Model, prefetch_related_objects from modelcluster.utils import extract_field_value, get_model_field, sort_by_fields @@ -519,15 +518,12 @@ def order_by(self, *fields): return clone def distinct(self, *fields): - if fields and connection.vendor != 'postgresql': - raise NotSupportedError("DISTINCT ON fields is not supported by this database backend") - unique_results = [] if not fields: fields = [field.name for field in self.model._meta.fields if not field.primary_key] seen_keys = set() for result in self.results: - key = '$$$'.join([str(extract_field_value(result, field)) for field in fields]) + key = tuple(str(extract_field_value(result, field)) for field in fields) if key not in seen_keys: seen_keys.add(key) unique_results.append(result) diff --git a/tests/tests/test_cluster.py b/tests/tests/test_cluster.py index eb27e1c..be55a25 100644 --- a/tests/tests/test_cluster.py +++ b/tests/tests/test_cluster.py @@ -2,10 +2,9 @@ import datetime import itertools -from unittest.mock import patch from django.test import TestCase -from django.db import IntegrityError, NotSupportedError, connection +from django.db import IntegrityError from django.db.models import Prefetch from modelcluster.models import get_all_child_relations @@ -813,19 +812,11 @@ def test_distinct_with_fields(self): Album(name='With The Beatles', sort_order=2), Album(name='Abbey Road', sort_order=2), ]) - - for vendor in ['sqlite', 'mysql', 'oracle']: - with patch.object(connection, 'vendor', vendor): - with self.assertRaises(NotSupportedError): - beatles.albums.order_by('sort_order').distinct('sort_order') - - # patch db.connection.vendor to pass the vendor check - with patch.object(connection, 'vendor', 'postgresql'): - albums = [album.name for album in beatles.albums.order_by('sort_order').distinct('sort_order')] - self.assertEqual(['Please Please Me', 'With The Beatles'], albums) + albums = [album.name for album in beatles.albums.order_by('sort_order').distinct('sort_order')] + self.assertEqual(['Please Please Me', 'With The Beatles'], albums) - albums = [album.name for album in beatles.albums.order_by('sort_order').distinct('name')] - self.assertEqual(['Please Please Me', 'With The Beatles', 'Abbey Road'], albums) + albums = [album.name for album in beatles.albums.order_by('sort_order').distinct('name')] + self.assertEqual(['Please Please Me', 'With The Beatles', 'Abbey Road'], albums) def test_parental_key_checks_clusterable_model(self): from django.core import checks