forked from songweige/TATS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
download.py
76 lines (56 loc) · 2.23 KB
/
download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import requests
from tqdm import tqdm
import os
import torch
from .tats_vqgan import VQGAN
from .tats_transformer import Net2NetTransformer
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination):
CHUNK_SIZE = 8192
pbar = tqdm(total=0, unit='iB', unit_scale=True)
with open(destination, 'wb') as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
pbar.close()
def download(id, fname, root=os.path.expanduser('./ckpts')):
os.makedirs(root, exist_ok=True)
destination = os.path.join(root, fname)
if os.path.exists(destination):
return destination
URL = 'https://drive.google.com/uc?export=download'
session = requests.Session()
response = session.get(URL, params={'id': id}, stream=True)
token = get_confirm_token(response)
if token:
params = {'id': id, 'confirm': token}
response = session.get(URL, params=params, stream=True)
save_response_content(response, destination)
return destination
def load_vqgan(vqgan_ckpt, device=torch.device('cpu')):
vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).to(device)
vqgan.eval()
return vqgan
def load_transformer(gpt_ckpt, vqgan_ckpt, stft_vqgan_ckpt='', device=torch.device('cpu')):
from pytorch_lightning.utilities.cloud_io import load as pl_load
checkpoint = pl_load(gpt_ckpt)
checkpoint['hyper_parameters']['args'].vqvae = vqgan_ckpt
if stft_vqgan_ckpt:
checkpoint['hyper_parameters']['args'].stft_vqvae = stft_vqgan_ckpt
gpt = Net2NetTransformer._load_model_state(checkpoint)
gpt.eval()
return gpt
_I3D_PRETRAINED_ID = '1mQK8KD8G6UWRa5t87SRMm5PVXtlpneJT'
def load_i3d_pretrained(device=torch.device('cpu')):
from .fvd.pytorch_i3d import InceptionI3d
i3d = InceptionI3d(400, in_channels=3).to(device)
filepath = download(_I3D_PRETRAINED_ID, 'i3d_pretrained_400.pt')
i3d.load_state_dict(torch.load(filepath, map_location=device))
i3d.eval()
return i3d