-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
65 lines (54 loc) · 2.21 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# script to run the app in HuggingFace space
import os, argparse
import numpy as np
from HF.UI import AppUI
from HF.Utils import toPILImage
from HF.NNHelper import modelsFrom, inference_from_config
from Utils.WandBUtils import CWBProject
def infer(models):
def _processImage(modelName, inputImage, **kwargs):
model = models.get(modelName, None)
assert model is not None, f'Invalid model name: {modelName}'
assert isinstance(inputImage, np.ndarray), f'Invalid image type: {type(inputImage)}'
assert 3 == len(inputImage.shape), f'Invalid image shape: {inputImage.shape}'
assert np.uint8 == inputImage.dtype, f'Invalid image dtype: {inputImage.dtype}'
assert 3 == inputImage.shape[-1], f'Invalid image channels: {inputImage.shape[-1]}'
res = model(raw=inputImage, **kwargs)
if ('video' in res):
return dict(video=res['video'])
upscaled = toPILImage(res['upscaled'], isBGR=False)
input = toPILImage(res['input'], isBGR=False)
return dict(upscaled=upscaled, input=input)
return _processImage
def run2inference(run, runName=None):
if runName is None: runName = run.name.replace('[Public] ', '')
runConfig = run.config
runName = "%s (%s, loss: %.5f)" % (runName, run.id, run.bestLoss)
# add corresponding HF info
runConfig['huggingface'] = { "name": runName, "wandb": run.fullId }
return list(inference_from_config(runConfig))
def main(args):
WBProject = CWBProject('green_wizard/FranNet')
folder = os.path.dirname(os.path.abspath(__file__))
# load list of models from the folder "configs"
models = []
bestGroups = WBProject.groups(onlyBest=True)
for runName, run in bestGroups.items():
if runName.startswith('[Public] '):
models.extend( run2inference(run) )
continue
# convert to dict
models = {model.name: model for model in models}
app = AppUI(
processImage=infer(models),
models=models,
)
app.queue() # enable queueing of requests/events
app.launch(inline=False, server_port=args.port, server_name=args.host)
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int, default=7860)
parser.add_argument('--host', type=str, default=None)
args = parser.parse_args()
main(args)