Skip to content

Commit

Permalink
Finish functional API ops draft (towhee-io#835)
Browse files Browse the repository at this point in the history
* Finish functional API ops draft

Signed-off-by: Kaiyuan Hu <[email protected]>

* Add test

Signed-off-by: Kaiyuan Hu <[email protected]>
  • Loading branch information
Chiiizzzy authored Mar 16, 2022
1 parent a046754 commit aa1a4e3
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 16 deletions.
9 changes: 8 additions & 1 deletion tests/unittests/pipelines/test_pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import unittest

from towhee import ops
from towhee.functional.entity import Entity


class TestPipelineOps(unittest.TestCase):
"""
tests for template build
"""

def test_ops(self):
# pylint: disable=protected-access
op1 = ops.my.op1(arg1=1, arg2=2)
Expand All @@ -34,9 +34,16 @@ def test_repo_op(self):
res = test_op(1)
self.assertEqual(res, 2)

e = Entity({'in': 1})
test_op = ops.towhee.test_operator['in', 'out'](x=1)
res = test_op(e)
self.assertIsInstance(res, Entity)
self.assertEqual(res.out, 2)

# def test_image_embedding_pipeline(self):
# pipe = image_embedding_pipeline(models = "xxx", ensemble = ops.my.ensemble_v1(agg='xxx', ....))
# pipe = image_embedding_pipeline(operators = [ops.my.embedding(model='xxx'), ops.my.embedding(model='xxx')])


if __name__ == '__main__':
unittest.main()
33 changes: 20 additions & 13 deletions towhee/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from towhee.engine.operator_loader import OperatorLoader
from towhee.hparam import param_scope, auto_param

__all__ = ['DEFAULT_PIPELINES', 'pipeline', 'register', 'resolve', 'param_scope', 'auto_param', 'Build', 'Inject',
'dataset']
__all__ = ['DEFAULT_PIPELINES', 'pipeline', 'register', 'resolve', 'param_scope', 'auto_param', 'Build', 'Inject', 'dataset']

DEFAULT_PIPELINES = {
'image-embedding': 'towhee/image-embedding-resnet50',
Expand Down Expand Up @@ -171,9 +170,7 @@ def dataset(name: str, *args, **kwargs) -> 'TorchDataSet':
from torchvision import datasets # pylint: disable=import-outside-toplevel
from towhee.data.dataset.dataset import TorchDataSet # pylint: disable=import-outside-toplevel
dataset_construct_map = {
'mnist': datasets.MNIST,
'cifar10': datasets.cifar.CIFAR10,
'fake': datasets.FakeData
'mnist': datasets.MNIST, 'cifar10': datasets.cifar.CIFAR10, 'fake': datasets.FakeData
# 'imdb': IMDB # ,()
}
torch_dataset = dataset_construct_map[name](*args, **kwargs)
Expand Down Expand Up @@ -250,8 +247,9 @@ class _OperatorLazyWrapper:
"""
operator wrapper for lazy initialization.
"""
def __init__(self, name: str, tag: str='main', **kws) -> None:
def __init__(self, name: str, index: Tuple[str], tag: str = 'main', **kws) -> None:
self._name = name.replace('.', '/').replace('_', '-')
self._index = index
self._tag = tag
self._kws = kws
self._op = None
Expand All @@ -261,7 +259,15 @@ def __call__(self, *arg, **kws):
with self._lock:
if self._op is None:
self._op = op(self._name, self._tag, **self._kws)
return self._op(*arg, **kws)

if bool(self._index):
res = self._op(getattr(arg[0], self._index[0]), **kws)
setattr(arg[0], self._index[1], res)
arg[0].register(self._index[1])
return arg[0]
else:
res = self._op(*arg, **kws)
return res

def train(self, *arg, **kws):
with self._lock:
Expand All @@ -278,12 +284,11 @@ def init_args(self):
return self._kws

@staticmethod
def callback(name, index, *arg, **kws):
_ = index
def callback(name: str, index: Tuple[str], *arg, **kws):
if len(arg) == 0:
return _OperatorLazyWrapper(name, **kws)
return _OperatorLazyWrapper(name, index, **kws)
else:
return _OperatorLazyWrapper(name, arg[0], **kws)
return _OperatorLazyWrapper(name, index, arg[0], **kws)


ops = param_scope().callholder(_OperatorLazyWrapper.callback)
Expand All @@ -295,24 +300,26 @@ def callback(name, index, *arg, **kws):
An instance of `my_namespace`/`my_operator_name` is created.
"""


def _pipeline_callback(name, index, *arg, **kws):
name = name.replace('.', '/').replace('_', '-')
_ = index
return Build(**kws).pipeline(name, *arg)


pipes = param_scope().callholder(_pipeline_callback)
"""
Entry point for creating pipeline instances, for example:
>>> pipe_instance = pipes.my_namespace.my_pipeline_name(template_variable_1=xxx, template_variable_2=xxx)
An instance of `my_namespace`/`my_pipeline_name` is created, and template variables in the pipeline,
An instance of `my_namespace`/`my_pipeline_name` is created, and template variables in the pipeline,z
`template_variable_1` and `template_variable_2` are replaced with given values.
"""


def plot(img1: Union[str, list], img2: list = None):
from towhee.utils.plot_utils import plot_img # pylint: disable=C
from towhee.utils.plot_utils import plot_img # pylint: disable=C
if not img2:
plot_img(img1)
else:
Expand Down
3 changes: 3 additions & 0 deletions towhee/functional/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ def __repr__(self):
content = str(self.info)
content += f' at {getattr(self, "id", id(self))}'
return f'<{self.__class__.__name__} {content.strip()}>'

def register(self, index: str):
self._data.append(index)
8 changes: 6 additions & 2 deletions towhee/pipelines/alias_resolvers.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
#!/usr/bin/env python3
from towhee import ops


class AliasResolverBase:
"""
Base class for alias resolvers
"""
def resolve(self, name: str):
pass


class LocalAliasResolver(AliasResolverBase):
"""
Resolve aliases with locally with builtin rules
"""

aliases = {
'efficientnet-b3': ops.filip_halt.timm_image_embedding(model_name='efficientnet_b3'),
'regnety-004': ops.filip_halt.timm_image_embedding(model_name = 'regnety-004')
'regnety-004': ops.filip_halt.timm_image_embedding(model_name='regnety-004')
}

def resolve(self, name: str):
return LocalAliasResolver.aliases[name]


class RemoteAliasResolver(AliasResolverBase):
"""
Resolve aliases from towhee hub
"""

def resolve(self, name: str):
pass


def get_resolver(name: str) -> AliasResolverBase:
if name == 'local':
return LocalAliasResolver()
Expand Down

0 comments on commit aa1a4e3

Please sign in to comment.