Skip to content

Commit

Permalink
make the offload function customizable for DatasetFromList
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#4626

Previously we use `serialize: bool` to control if we want to offload the `DatasetFromList` storage to numpy. This diff generalize the "serialize" to "offload", and make the "offload function" customizable so that we can switch between implementations.

The setting of `offload function` is done by context manager in order to avoid passing this argument all the way down.

Reviewed By: sstsai-adl

Differential Revision: D40818736

fbshipit-source-id: ed1b47eea86546def6c06f78bc12d6edf267df28
  • Loading branch information
Yanghan Wang authored and facebook-github-bot committed Nov 3, 2022
1 parent c54429b commit 2b98c27
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 30 deletions.
107 changes: 78 additions & 29 deletions detectron2/data/common.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import contextlib
import copy
import itertools
import logging
import numpy as np
import pickle
import random
from typing import Callable, Union
import torch.utils.data as data
from torch.utils.data.sampler import Sampler

from detectron2.utils.serialize import PicklableWrapper

__all__ = ["MapDataset", "DatasetFromList", "AspectRatioGroupedDataset", "ToIterableDataset"]

logger = logging.getLogger(__name__)


def _shard_iterator_dataloader_worker(iterable):
# Shard the iterable if we're currently inside pytorch dataloader worker.
Expand Down Expand Up @@ -106,56 +110,101 @@ def __getitem__(self, idx):
)


class NumpySerializedList(object):
"""
A list-like object whose items are serialized and stored in a Numpy Array. When
forking a process that has NumpySerializedList, subprocesses can read the same list
without triggering copy-on-access, therefore they will share RAM for the list. This
avoids the issue in https://github.com/pytorch/pytorch/issues/13246
"""

def __init__(self, lst: list):
self._lst = lst

def _serialize(data):
buffer = pickle.dumps(data, protocol=-1)
return np.frombuffer(buffer, dtype=np.uint8)

logger.info(
"Serializing {} elements to byte tensors and concatenating them all ...".format(
len(self._lst)
)
)
self._lst = [_serialize(x) for x in self._lst]
self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
self._addr = np.cumsum(self._addr)
self._lst = np.concatenate(self._lst)
logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))

def __len__(self):
return len(self._addr)

def __getitem__(self, idx):
start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
end_addr = self._addr[idx].item()
bytes = memoryview(self._lst[start_addr:end_addr])

# @lint-ignore PYTHONPICKLEISBAD
return pickle.loads(bytes)


_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = NumpySerializedList


@contextlib.contextmanager
def set_default_dataset_from_list_serialize_method(new):
"""
Context manager for using custom serialize function when creating DatasetFromList
"""

global _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
orig = _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = new
yield
_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = orig


class DatasetFromList(data.Dataset):
"""
Wrap a list to a torch Dataset. It produces elements of the list as data.
"""

def __init__(self, lst: list, copy: bool = True, serialize: bool = True):
def __init__(
self,
lst: list,
copy: bool = True,
serialize: Union[bool, Callable] = True,
):
"""
Args:
lst (list): a list which contains elements to produce.
copy (bool): whether to deepcopy the element when producing it,
so that the result can be modified in place without affecting the
source in the list.
serialize (bool): whether to hold memory using serialized objects, when
enabled, data loader workers can use shared RAM from master
process instead of making a copy.
serialize (bool or callable): whether to serialize the stroage to other
backend. If `True`, the default serialize method will be used, if given
a callable, the callable will be used as serialize method.
"""
self._lst = lst
self._copy = copy
self._serialize = serialize

def _serialize(data):
buffer = pickle.dumps(data, protocol=-1)
return np.frombuffer(buffer, dtype=np.uint8)
if not isinstance(serialize, (bool, Callable)):
raise TypeError(f"Unsupported type for argument `serailzie`: {serialize}")
self._serialize = serialize is not False

if self._serialize:
logger = logging.getLogger(__name__)
logger.info(
"Serializing {} elements to byte tensors and concatenating them all ...".format(
len(self._lst)
)
serialize_method = (
serialize
if isinstance(serialize, Callable)
else _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
)
self._lst = [_serialize(x) for x in self._lst]
self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
self._addr = np.cumsum(self._addr)
self._lst = np.concatenate(self._lst)
logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))
logger.info(f"Serializing the dataset using: {serialize_method}")
self._lst = serialize_method(self._lst)

def __len__(self):
if self._serialize:
return len(self._addr)
else:
return len(self._lst)
return len(self._lst)

def __getitem__(self, idx):
if self._serialize:
start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
end_addr = self._addr[idx].item()
bytes = memoryview(self._lst[start_addr:end_addr])
return pickle.loads(bytes)
elif self._copy:
if self._copy and not self._serialize:
return copy.deepcopy(self._lst[idx])
else:
return self._lst[idx]
Expand Down
18 changes: 17 additions & 1 deletion tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
build_detection_test_loader,
build_detection_train_loader,
)
from detectron2.data.common import AspectRatioGroupedDataset
from detectron2.data.common import (
AspectRatioGroupedDataset,
set_default_dataset_from_list_serialize_method,
)
from detectron2.data.samplers import InferenceSampler, TrainingSampler


Expand All @@ -41,6 +44,19 @@ def test_using_lazy_path(self):
self.assertTrue(isinstance(path, LazyPath))
self.assertEqual(os.fspath(path), _a_slow_func(i))

def test_alternative_serialize_method(self):
dataset = [1, 2, 3]
dataset = DatasetFromList(dataset, serialize=torch.tensor)
self.assertEqual(dataset[2], torch.tensor(3))

def test_change_default_serialize_method(self):
dataset = [1, 2, 3]
with set_default_dataset_from_list_serialize_method(torch.tensor):
dataset_1 = DatasetFromList(dataset, serialize=True)
self.assertEqual(dataset_1[2], torch.tensor(3))
dataset_2 = DatasetFromList(dataset, serialize=True)
self.assertEqual(dataset_2[2], 3)


class TestMapDataset(unittest.TestCase):
@staticmethod
Expand Down

0 comments on commit 2b98c27

Please sign in to comment.