Skip to content

Commit

Permalink
Add logic for custom resolvers.
Browse files Browse the repository at this point in the history
Description:
This commit adds logic to use custom user resolvers instead of
using only default manager queryset.

Signed-off-by: Pavel Kirilin <[email protected]>
  • Loading branch information
s3rius committed Jun 28, 2021
1 parent b23c68f commit 0e8f630
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 109 deletions.
10 changes: 5 additions & 5 deletions graphene_django_extras/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def list_resolver(manager, filterset_class, filtering_args, root, info, **kwargs
qs = None

if qs is None:
qs = queryset_factory(manager, info.field_nodes, info.fragments, **kwargs)
qs = queryset_factory(manager, root, info, **kwargs)
qs = filterset_class(
data=filter_kwargs, queryset=qs, request=info.context
).qs
Expand Down Expand Up @@ -228,14 +228,14 @@ def __init__(
def model(self):
return self.type.of_type._meta.node._meta.model

def get_queryset(self, manager, info, **kwargs):
return queryset_factory(manager, info.field_nodes, info.fragments, **kwargs)
def get_queryset(self, manager, root, info, **kwargs):
return queryset_factory(manager, root, info, **kwargs)

def list_resolver(
self, manager, filterset_class, filtering_args, root, info, **kwargs
):
filter_kwargs = {k: v for k, v in kwargs.items() if k in filtering_args}
qs = self.get_queryset(manager, info, **kwargs)
qs = self.get_queryset(manager, root, info, **kwargs)
qs = filterset_class(data=filter_kwargs, queryset=qs, request=info.context).qs

if root and is_valid_django_model(root._meta.model):
Expand Down Expand Up @@ -308,7 +308,7 @@ def list_resolver(
self, manager, filterset_class, filtering_args, root, info, **kwargs
):

qs = queryset_factory(manager, info.field_nodes, info.fragments, **kwargs)
qs = queryset_factory(manager, root, info, **kwargs)

filter_kwargs = {k: v for k, v in kwargs.items() if k in filtering_args}

Expand Down
6 changes: 2 additions & 4 deletions graphene_django_extras/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def BaseType(cls):

class DjangoSerializerType(ObjectType):
"""
DjangoSerializerType definition
DjangoSerializerType definition
"""

ok = Boolean(description="Boolean field that return mutation result request.")
Expand Down Expand Up @@ -611,9 +611,7 @@ def retrieve(cls, manager, root, info, **kwargs):
@classmethod
def list(cls, manager, filterset_class, filtering_args, root, info, **kwargs):

qs = queryset_factory(
cls._meta.queryset or manager, info.field_nodes, info.fragments, **kwargs
)
qs = queryset_factory(cls._meta.queryset or manager, root, info, **kwargs)

filter_kwargs = {k: v for k, v in kwargs.items() if k in filtering_args}

Expand Down
24 changes: 21 additions & 3 deletions graphene_django_extras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def create_obj(django_model, new_obj_key=None, *args, **kwargs):

def clean_dict(d):
"""
Remove all empty fields in a nested dict
Remove all empty fields in a nested dict
"""

if not isinstance(d, (dict, list)):
Expand Down Expand Up @@ -246,6 +246,19 @@ def _get_queryset(klass):
return manager.all()


def _get_custom_resolver(info):
"""
Get custom user defined resolver for query.
This resolver must return QuerySet instance to be successfully resolved.
"""
parent = info.parent_type
custom_resolver_name = f"resolve_{to_snake_case(info.field_name)}"
if hasattr(parent.graphene_type, custom_resolver_name):
return getattr(parent.graphene_type, custom_resolver_name)
return None


def get_Object_or_None(klass, *args, **kwargs):
"""
Uses get() to return an object, or None if the object does not exist.
Expand Down Expand Up @@ -348,7 +361,7 @@ def recursive_params(
return select_related, prefetch_related


def queryset_factory(manager, fields_asts=None, fragments=None, **kwargs):
def queryset_factory(manager, root, info, **kwargs):

select_related = set()
prefetch_related = set()
Expand All @@ -365,15 +378,20 @@ def queryset_factory(manager, fields_asts=None, fragments=None, **kwargs):
select_related = list(select_related)
prefetch_related = list(prefetch_related)

fields_asts = info.field_nodes
if fields_asts:
select_related, prefetch_related = recursive_params(
fields_asts[0].selection_set,
fragments,
info.fragments,
available_related_fields,
select_related,
prefetch_related,
)

custom_resolver = _get_custom_resolver(info)
if custom_resolver is not None:
manager = custom_resolver(root, info, **kwargs)

if select_related and prefetch_related:
return _get_queryset(
manager.select_related(*select_related).prefetch_related(*prefetch_related)
Expand Down
Loading

0 comments on commit 0e8f630

Please sign in to comment.