forked from deepfakes/faceswap
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplugin_loader.py
105 lines (91 loc) · 4.4 KB
/
plugin_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
""" Plugin loader for extract, training and model tasks """
import logging
import os
from importlib import import_module
from lib.utils import get_backend
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class PluginLoader():
""" Plugin loader for extract, training and model tasks """
@staticmethod
def get_detector(name, disable_logging=False):
""" Return requested detector plugin """
return PluginLoader._import("extract.detect", name, disable_logging)
@staticmethod
def get_aligner(name, disable_logging=False):
""" Return requested detector plugin """
return PluginLoader._import("extract.align", name, disable_logging)
@staticmethod
def get_model(name, disable_logging=False):
""" Return requested model plugin """
return PluginLoader._import("train.model", name, disable_logging)
@staticmethod
def get_trainer(name, disable_logging=False):
""" Return requested trainer plugin """
return PluginLoader._import("train.trainer", name, disable_logging)
@staticmethod
def get_converter(category, name, disable_logging=False):
""" Return the converter sub plugin """
return PluginLoader._import("convert.{}".format(category), name, disable_logging)
@staticmethod
def _import(attr, name, disable_logging):
""" Import the plugin's module """
name = name.replace("-", "_")
ttl = attr.split(".")[-1].title()
if not disable_logging:
logger.info("Loading %s from %s plugin...", ttl, name.title())
attr = "model" if attr == "Trainer" else attr.lower()
mod = ".".join(("plugins", attr, name))
module = import_module(mod)
return getattr(module, ttl)
@staticmethod
def get_available_extractors(extractor_type):
""" Return a list of available aligners/detectors """
extractpath = os.path.join(os.path.dirname(__file__),
"extract",
extractor_type)
extractors = sorted(item.name.replace(".py", "").replace("_", "-")
for item in os.scandir(extractpath)
if not item.name.startswith("_")
and not item.name.endswith("defaults.py")
and item.name.endswith(".py")
and item.name != "manual.py")
# TODO Remove this hacky fix when we move them to the same models
multi_versions = [extractor.replace("-amd", "")
for extractor in extractors if extractor.endswith("-amd")]
if get_backend() == "amd":
for extractor in multi_versions:
extractors.remove(extractor)
else:
for extractor in multi_versions:
extractors.remove("{}-amd".format(extractor))
return extractors
@staticmethod
def get_available_models():
""" Return a list of available models """
modelpath = os.path.join(os.path.dirname(__file__), "train", "model")
models = sorted(item.name.replace(".py", "").replace("_", "-")
for item in os.scandir(modelpath)
if not item.name.startswith("_")
and not item.name.endswith("defaults.py")
and item.name.endswith(".py"))
return models
@staticmethod
def get_default_model():
""" Return the default model """
models = PluginLoader.get_available_models()
return 'original' if 'original' in models else models[0]
@staticmethod
def get_available_convert_plugins(convert_category, add_none=True):
""" Return a list of available models """
convertpath = os.path.join(os.path.dirname(__file__),
"convert",
convert_category)
converters = sorted(item.name.replace(".py", "").replace("_", "-")
for item in os.scandir(convertpath)
if not item.name.startswith("_")
and not item.name.endswith("defaults.py")
and item.name.endswith(".py"))
if add_none:
converters.insert(0, "none")
return converters