From aa1a4e306f911bb935965a037a5eb88eb6aa84d9 Mon Sep 17 00:00:00 2001 From: KY HU Date: Wed, 16 Mar 2022 16:55:19 +0800 Subject: [PATCH] Finish functional API ops draft (#835) * Finish functional API ops draft Signed-off-by: Kaiyuan Hu * Add test Signed-off-by: Kaiyuan Hu --- .../unittests/pipelines/test_pipeline_ops.py | 9 ++++- towhee/__init__.py | 33 +++++++++++-------- towhee/functional/entity.py | 3 ++ towhee/pipelines/alias_resolvers.py | 8 +++-- 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/tests/unittests/pipelines/test_pipeline_ops.py b/tests/unittests/pipelines/test_pipeline_ops.py index 37c14913c3..67e9fc52b2 100644 --- a/tests/unittests/pipelines/test_pipeline_ops.py +++ b/tests/unittests/pipelines/test_pipeline_ops.py @@ -15,13 +15,13 @@ import unittest from towhee import ops +from towhee.functional.entity import Entity class TestPipelineOps(unittest.TestCase): """ tests for template build """ - def test_ops(self): # pylint: disable=protected-access op1 = ops.my.op1(arg1=1, arg2=2) @@ -34,9 +34,16 @@ def test_repo_op(self): res = test_op(1) self.assertEqual(res, 2) + e = Entity({'in': 1}) + test_op = ops.towhee.test_operator['in', 'out'](x=1) + res = test_op(e) + self.assertIsInstance(res, Entity) + self.assertEqual(res.out, 2) + # def test_image_embedding_pipeline(self): # pipe = image_embedding_pipeline(models = "xxx", ensemble = ops.my.ensemble_v1(agg='xxx', ....)) # pipe = image_embedding_pipeline(operators = [ops.my.embedding(model='xxx'), ops.my.embedding(model='xxx')]) + if __name__ == '__main__': unittest.main() diff --git a/towhee/__init__.py b/towhee/__init__.py index 0a4f4f548b..3c98e763a0 100644 --- a/towhee/__init__.py +++ b/towhee/__init__.py @@ -23,8 +23,7 @@ from towhee.engine.operator_loader import OperatorLoader from towhee.hparam import param_scope, auto_param -__all__ = ['DEFAULT_PIPELINES', 'pipeline', 'register', 'resolve', 'param_scope', 'auto_param', 'Build', 'Inject', - 'dataset'] +__all__ = ['DEFAULT_PIPELINES', 'pipeline', 'register', 'resolve', 'param_scope', 'auto_param', 'Build', 'Inject', 'dataset'] DEFAULT_PIPELINES = { 'image-embedding': 'towhee/image-embedding-resnet50', @@ -171,9 +170,7 @@ def dataset(name: str, *args, **kwargs) -> 'TorchDataSet': from torchvision import datasets # pylint: disable=import-outside-toplevel from towhee.data.dataset.dataset import TorchDataSet # pylint: disable=import-outside-toplevel dataset_construct_map = { - 'mnist': datasets.MNIST, - 'cifar10': datasets.cifar.CIFAR10, - 'fake': datasets.FakeData + 'mnist': datasets.MNIST, 'cifar10': datasets.cifar.CIFAR10, 'fake': datasets.FakeData # 'imdb': IMDB # ,() } torch_dataset = dataset_construct_map[name](*args, **kwargs) @@ -250,8 +247,9 @@ class _OperatorLazyWrapper: """ operator wrapper for lazy initialization. """ - def __init__(self, name: str, tag: str='main', **kws) -> None: + def __init__(self, name: str, index: Tuple[str], tag: str = 'main', **kws) -> None: self._name = name.replace('.', '/').replace('_', '-') + self._index = index self._tag = tag self._kws = kws self._op = None @@ -261,7 +259,15 @@ def __call__(self, *arg, **kws): with self._lock: if self._op is None: self._op = op(self._name, self._tag, **self._kws) - return self._op(*arg, **kws) + + if bool(self._index): + res = self._op(getattr(arg[0], self._index[0]), **kws) + setattr(arg[0], self._index[1], res) + arg[0].register(self._index[1]) + return arg[0] + else: + res = self._op(*arg, **kws) + return res def train(self, *arg, **kws): with self._lock: @@ -278,12 +284,11 @@ def init_args(self): return self._kws @staticmethod - def callback(name, index, *arg, **kws): - _ = index + def callback(name: str, index: Tuple[str], *arg, **kws): if len(arg) == 0: - return _OperatorLazyWrapper(name, **kws) + return _OperatorLazyWrapper(name, index, **kws) else: - return _OperatorLazyWrapper(name, arg[0], **kws) + return _OperatorLazyWrapper(name, index, arg[0], **kws) ops = param_scope().callholder(_OperatorLazyWrapper.callback) @@ -295,24 +300,26 @@ def callback(name, index, *arg, **kws): An instance of `my_namespace`/`my_operator_name` is created. """ + def _pipeline_callback(name, index, *arg, **kws): name = name.replace('.', '/').replace('_', '-') _ = index return Build(**kws).pipeline(name, *arg) + pipes = param_scope().callholder(_pipeline_callback) """ Entry point for creating pipeline instances, for example: >>> pipe_instance = pipes.my_namespace.my_pipeline_name(template_variable_1=xxx, template_variable_2=xxx) -An instance of `my_namespace`/`my_pipeline_name` is created, and template variables in the pipeline, +An instance of `my_namespace`/`my_pipeline_name` is created, and template variables in the pipeline,z `template_variable_1` and `template_variable_2` are replaced with given values. """ def plot(img1: Union[str, list], img2: list = None): - from towhee.utils.plot_utils import plot_img # pylint: disable=C + from towhee.utils.plot_utils import plot_img # pylint: disable=C if not img2: plot_img(img1) else: diff --git a/towhee/functional/entity.py b/towhee/functional/entity.py index 23a35c90ca..cfdc214b9f 100644 --- a/towhee/functional/entity.py +++ b/towhee/functional/entity.py @@ -63,3 +63,6 @@ def __repr__(self): content = str(self.info) content += f' at {getattr(self, "id", id(self))}' return f'<{self.__class__.__name__} {content.strip()}>' + + def register(self, index: str): + self._data.append(index) diff --git a/towhee/pipelines/alias_resolvers.py b/towhee/pipelines/alias_resolvers.py index 81d6787bd5..b5c3d8ac60 100644 --- a/towhee/pipelines/alias_resolvers.py +++ b/towhee/pipelines/alias_resolvers.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 from towhee import ops + + class AliasResolverBase: """ Base class for alias resolvers @@ -7,6 +9,7 @@ class AliasResolverBase: def resolve(self, name: str): pass + class LocalAliasResolver(AliasResolverBase): """ Resolve aliases with locally with builtin rules @@ -14,20 +17,21 @@ class LocalAliasResolver(AliasResolverBase): aliases = { 'efficientnet-b3': ops.filip_halt.timm_image_embedding(model_name='efficientnet_b3'), - 'regnety-004': ops.filip_halt.timm_image_embedding(model_name = 'regnety-004') + 'regnety-004': ops.filip_halt.timm_image_embedding(model_name='regnety-004') } def resolve(self, name: str): return LocalAliasResolver.aliases[name] + class RemoteAliasResolver(AliasResolverBase): """ Resolve aliases from towhee hub """ - def resolve(self, name: str): pass + def get_resolver(name: str) -> AliasResolverBase: if name == 'local': return LocalAliasResolver()