-
Notifications
You must be signed in to change notification settings - Fork 11
/
hubconf.py
39 lines (30 loc) · 1.22 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
Torch Hub script for accessing te hand segmentation model outside the repo.
"""
##################################################
# Imports
##################################################
dependencies = ['torch', 'pytorch_lightning']
import torch
from model import HandSegModel
import gdown
import os
def hand_segmentor(pretrained=True, *args, **kwargs):
"""
Hand segmentor based on a DeepLabV3 model with a ResNet50 encoder.
DeeplabV3: https://arxiv.org/abs/1706.05587
ResNet50: https://arxiv.org/abs/1512.03385
"""
model = HandSegModel(*args, **kwargs)
if pretrained:
#os.system('chmod +x ./scripts/download_model_checkpoint.sh')
#os.system('./scripts/download_model_checkpoint.sh')
_download_file_from_google_drive('1w7dztGAsPHD_fl_Kv_a8qHL4eW92rlQg', './checkpoint/checkpoint.ckpt')
model = model.load_from_checkpoint('./checkpoint/checkpoint.ckpt', map_location=torch.device('cpu'), *args, **kwargs)
return model
def _download_file_from_google_drive(id, destination):
url = f'https://drive.google.com/uc?id={id}'
path = os.path.dirname(destination)
if not os.path.exists(path):
os.makedirs(path)
gdown.download(url, destination, quiet=False)