Skip to content

Commit

Permalink
TaskSerializer extension support (HumanSignal#1678)
Browse files Browse the repository at this point in the history
  • Loading branch information
triklozoid authored Nov 3, 2021
1 parent f54e67f commit 7b897cb
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 32 deletions.
2 changes: 2 additions & 0 deletions label_studio/core/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@
GET_OBJECT_WITH_CHECK_AND_LOG = 'core.utils.get_object.get_object_with_check_and_log'
SAVE_USER = 'users.functions.save_user'
USER_SERIALIZER = 'users.serializers.BaseUserSerializer'
TASK_SERIALIZER = 'tasks.serializers.BaseTaskSerializer'
EXPORT_DATA_SERIALIZER = 'data_export.serializers.BaseExportDataSerializer'
DATA_MANAGER_GET_ALL_COLUMNS = 'data_manager.functions.get_all_columns'
DATA_MANAGER_ANNOTATIONS_MAP = {}
DATA_MANAGER_ACTIONS = {}
Expand Down
7 changes: 6 additions & 1 deletion label_studio/data_export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ class ExportAPI(generics.RetrieveAPIView):
def get_queryset(self):
return Project.objects.filter(organization=self.request.user.active_organization)

def get_task_queryset(self, queryset):
return queryset

def get(self, request, *args, **kwargs):
project = self.get_object()
export_type = (
Expand Down Expand Up @@ -170,7 +173,9 @@ def get(self, request, *args, **kwargs):
logger.debug('Serialize tasks for export')
tasks = []
for _task_ids in batch(task_ids, 1000):
tasks += ExportDataSerializer(query.filter(id__in=_task_ids), many=True, expand=['drafts']).data
tasks += ExportDataSerializer(
self.get_task_queryset(query.filter(id__in=_task_ids)), many=True, expand=['drafts']
).data
logger.debug('Prepare export files')

export_stream, content_type, filename = DataExport.generate_export_file(
Expand Down
50 changes: 24 additions & 26 deletions label_studio/data_export/mixins.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
from datetime import datetime
from functools import reduce
import hashlib
import io
import json
import logging
import pathlib
from datetime import datetime
from functools import reduce
import shutil

from core.utils.io import get_all_files_from_dir, get_temp_dir, read_bytes_stream
from data_manager.models import View
from django.core.files import File
from django.db.models import Prefetch
from django.db import transaction
from django.db.models import Prefetch
from django.db.models.query_utils import Q
from django.utils import dateformat, timezone
import django_rq
from label_studio_converter import Converter
from projects.models import Project
from tasks.models import Task, Annotation
from core.utils.common import batch

from core.redis import redis_connected
import django_rq
from core.utils.common import batch
from core.utils.io import get_all_files_from_dir, get_temp_dir, read_bytes_stream
from data_manager.models import View
from projects.models import Project
from tasks.models import Annotation, Task


ONLY = 'only'
EXCLUDE = 'exclude'
Expand Down Expand Up @@ -128,10 +129,16 @@ def _get_export_serializer_option(self, serialization_options):

return options

def get_serializer_class(self):
from .serializers import ExportDataSerializer

return ExportDataSerializer
def get_task_queryset(self, ids, annotation_filter_options):
annotations_qs = self._get_filtered_annotations_queryset(
annotation_filter_options=annotation_filter_options
)
return Task.objects.filter(id__in=ids).prefetch_related(
Prefetch(
"annotations",
queryset=annotations_qs,
)
)

def get_export_data(self, task_filter_options=None, annotation_filter_options=None, serialization_options=None):
"""
Expand All @@ -156,6 +163,8 @@ def get_export_data(self, task_filter_options=None, annotation_filter_options=No
})
})
"""
from .serializers import ExportDataSerializer

with transaction.atomic():
# TODO: make counters from queryset
# counters = Project.objects.with_counts().filter(id=self.project.id)[0].get_counters()
Expand All @@ -173,25 +182,14 @@ def get_export_data(self, task_filter_options=None, annotation_filter_options=No
base_export_serializer_option = self._get_export_serializer_option(serialization_options)
i = 0
BATCH_SIZE = 1000
serializer_class = self.get_serializer_class()
annotations_qs = self._get_filtered_annotations_queryset(
annotation_filter_options=annotation_filter_options
)
for ids in batch(task_ids, BATCH_SIZE):
i += 1
tasks = list(
Task.objects.filter(id__in=ids).prefetch_related(
Prefetch(
"annotations",
queryset=annotations_qs,
)
)
)
tasks = list(self.get_task_queryset(ids, annotation_filter_options))
logger.debug(f'Batch: {i*BATCH_SIZE}')
if isinstance(task_filter_options, dict) and task_filter_options.get('only_with_annotations'):
tasks = [task for task in tasks if task.annotations.all()]

serializer = serializer_class(tasks, many=True, **base_export_serializer_option)
serializer = ExportDataSerializer(tasks, many=True, **base_export_serializer_option)
result += serializer.data

counters['task_number'] = len(result)
Expand Down
11 changes: 8 additions & 3 deletions label_studio/data_export/serializers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
from django.conf import settings
from rest_flex_fields import FlexFieldsModelSerializer
from rest_framework import serializers

from core.label_config import replace_task_data_undefined_with_config_field
from rest_flex_fields import FlexFieldsModelSerializer
from core.utils.common import load_func
from tasks.models import Annotation, Task
from tasks.serializers import PredictionSerializer, AnnotationDraftSerializer
from tasks.serializers import AnnotationDraftSerializer, PredictionSerializer
from users.models import User
from users.serializers import UserSimpleSerializer

Expand All @@ -29,7 +31,7 @@ class Meta:
}


class ExportDataSerializer(FlexFieldsModelSerializer):
class BaseExportDataSerializer(FlexFieldsModelSerializer):
annotations = AnnotationSerializer(many=True, read_only=True)
file_upload = serializers.ReadOnlyField(source='file_upload_name')
drafts = serializers.PrimaryKeyRelatedField(many=True, read_only=True)
Expand Down Expand Up @@ -108,3 +110,6 @@ class Meta(ExportSerializer.Meta):
task_filter_options = TaskFilterOptionsSerializer(required=False, default=None)
annotation_filter_options = AnnotationFilterOptionsSerializer(required=False, default=None)
serialization_options = SerializationOptionsSerializer(required=False, default=None)


ExportDataSerializer = load_func(settings.EXPORT_DATA_SERIALIZER)
7 changes: 5 additions & 2 deletions label_studio/tasks/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from projects.models import Project
from tasks.models import Task, Annotation, AnnotationDraft, Prediction
from tasks.validation import TaskValidator
from core.utils.common import get_object_with_check_and_log, retry_database_locked
from core.utils.common import get_object_with_check_and_log, retry_database_locked, load_func
from core.label_config import replace_task_data_undefined_with_config_field
from users.serializers import UserSerializer
from core.utils.common import load_func
Expand Down Expand Up @@ -124,7 +124,7 @@ class Meta:
fields = '__all__'


class TaskSerializer(ModelSerializer):
class BaseTaskSerializer(ModelSerializer):
""" Task Serializer with project scheme configs validation
"""
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -382,8 +382,11 @@ def post_process_annotations(db_annotations):
class Meta:
model = Task
fields = "__all__"


TaskSerializer = load_func(settings.TASK_SERIALIZER)


class TaskWithAnnotationsSerializer(TaskSerializer):
"""
"""
Expand Down

0 comments on commit 7b897cb

Please sign in to comment.