From 114afee9b3d29b24566d62c4032ac986aa1f0928 Mon Sep 17 00:00:00 2001 From: Chenxi Date: Wed, 29 Jun 2022 12:05:38 +0000 Subject: [PATCH] replicate --- README.md | 3 ++ cog.yaml | 17 ++++++++ predict.py | 115 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 135 insertions(+) create mode 100644 cog.yaml create mode 100644 predict.py diff --git a/README.md b/README.md index e8a49f0..18c1957 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,9 @@ This is the official PyTorch implementation of the ICLR 2022 paper "Latent Image - lpips ## 1. Animation demo + +Try Replicate web demo here [![Replicate](https://replicate.com/wyhsirius/lia/badge)](https://replicate.com/wyhsirius/lia) + Download pre-trained checkpoints from [here](https://drive.google.com/drive/folders/1N4QcnqUQwKUZivFV-YeBuPyH4pGJHooc?usp=sharing) and put models under `./checkpoints`. We have provided several demo source images and driving videos in `./data`. To obtain demos, you could run following commands, generated results will be saved under `./res`. ```shell script diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..9551251 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,17 @@ +build: + cuda: "11.0" + gpu: true + python_version: "3.8" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_packages: + - "ipython==7.21.0" + - "torch==1.10.1" + - "torchvision==0.11.2" + - "moviepy==1.0.3" + - "tensorboard==2.9.1" + - "tqdm==4.62.3" + - "av==9.2.0" + +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..314ae12 --- /dev/null +++ b/predict.py @@ -0,0 +1,115 @@ +import tempfile +import torch +from tqdm import tqdm +from PIL import Image +import torchvision +import numpy as np +from cog import BasePredictor, Path, Input + +from networks.generator import Generator + + +class Predictor(BasePredictor): + def setup(self): + + self.size = 256 + channel_multiplier = 1 + latent_dim_style = 512 + latent_dim_motion = 20 + + model_weights = { + "vox": torch.load( + "checkpoints/vox.pt", map_location=lambda storage, loc: storage + )["gen"], + "taichi": torch.load( + "checkpoints/taichi.pt", map_location=lambda storage, loc: storage + )["gen"], + "ted": torch.load( + "checkpoints/ted.pt", map_location=lambda storage, loc: storage + )["gen"], + } + self.gen_models = { + k: Generator( + self.size, latent_dim_style, latent_dim_motion, channel_multiplier + ).cuda() + for k in model_weights.keys() + } + + for k, v in self.gen_models.items(): + v.load_state_dict(model_weights[k]) + v.eval() + + def predict( + self, + img_source: Path = Input( + description="Input source image.", + ), + driving_video: Path = Input( + description="Choose a driving video.", + ), + model: str = Input( + choices=["vox", "taichi", "ted"], + default="vox", + description="Choose a dataset.", + ), + ) -> Path: + gen = self.gen_models[model] + print("==> loading data") + + img_source = img_preprocessing(str(img_source), self.size).cuda() + vid_target, fps = vid_preprocessing(str(driving_video)) + vid_target = vid_target.cuda() + + out_path = Path(tempfile.mkdtemp()) / "output.mp4" + + with torch.no_grad(): + vid_target_recon = [] + + if model == "ted": + h_start = None + else: + h_start = gen.enc.enc_motion(vid_target[:, 0, :, :, :]) + + for i in tqdm(range(vid_target.size(1))): + img_target = vid_target[:, i, :, :, :] + img_recon = gen(img_source, img_target, h_start) + vid_target_recon.append(img_recon.unsqueeze(2)) + + vid_target_recon = torch.cat(vid_target_recon, dim=2) + save_video(vid_target_recon, str(out_path), fps) + + return out_path + + +def load_image(filename, size): + img = Image.open(filename).convert("RGB") + img = img.resize((size, size)) + img = np.asarray(img) + img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256 + + return img / 255.0 + + +def img_preprocessing(img_path, size): + img = load_image(img_path, size) # [0, 1] + img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1] + imgs_norm = (img - 0.5) * 2.0 # [-1, 1] + + return imgs_norm + + +def vid_preprocessing(vid_path): + vid_dict = torchvision.io.read_video(vid_path, pts_unit="sec") + vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) + fps = vid_dict[2]["video_fps"] + vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1] + + return vid_norm, fps + + +def save_video(vid_target_recon, save_path, fps): + vid = vid_target_recon.permute(0, 2, 3, 4, 1) + vid = vid.clamp(-1, 1).cpu() + vid = ((vid - vid.min()) / (vid.max() - vid.min()) * 255).type("torch.ByteTensor") + + torchvision.io.write_video(save_path, vid[0], fps=fps)