forked from wyhsirius/g3an-project
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
yaowang
committed
Sep 10, 2020
1 parent
8e44c4c
commit 7743aff
Showing
164 changed files
with
862 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import argparse | ||
|
||
|
||
def parse_args(): | ||
|
||
parser = argparse.ArgumentParser('g3an training config') | ||
|
||
# train | ||
parser.add_argument('--max_epoch', type=int, default=5001, help='number of epochs of training') | ||
parser.add_argument('--batch_size', type=int, default=96, help='size of the batch') | ||
parser.add_argument('--g_lr', type=float, default=2e-4, help='learning rate of generator') | ||
parser.add_argument('--d_lr', type=float, default=2e-4, help='learning rate of discriminator') | ||
parser.add_argument('--d_za', type=int, default=128, help='appearance dim') | ||
parser.add_argument('--d_zm', type=int, default=10, help='motion dim') | ||
parser.add_argument('--num_workers', type=int, default=8, help='number of workers') | ||
parser.add_argument('--ch_g', type=int, default=64, help='base channels of generator') | ||
parser.add_argument('--ch_d', type=int, default=64, help='base channels of discriminator') | ||
parser.add_argument('--g_mode', type=str, default='1p2d', choices=['1p2d', '2p1d', '3d'], help='generator operation mode') | ||
parser.add_argument('--img_size', type=int, default=64, help='generate image size') | ||
parser.add_argument('--dataset', type=str, default='uva', choices=['uva'], help='dataset choice') | ||
parser.add_argument('--val_freq', type=int, default=50, help='validation frequence') | ||
parser.add_argument('--print_freq', type=int, default=100, help='log frequence') | ||
parser.add_argument('--save_freq', type=int, default=100, help='model save frequence') | ||
parser.add_argument('--exp_name', type=str, default='g3an') | ||
parser.add_argument('--save_path', type=str, default='./exps', help='model and log save path') | ||
parser.add_argument('--data_path', type=str, default='', help='dataset path') | ||
parser.add_argument('--use_attention', action='store_true', default=False, help='whether to use attention') | ||
parser.add_argument('--random_seed', type=int, default='12345') | ||
|
||
# test | ||
parser.add_argument('--n', type=int, default=64, help='number of random generation') | ||
parser.add_argument('--n_za_test', type=int, default=8, help='number of appearance') | ||
parser.add_argument('--n_zm_test', type=int, default=8, help='number of motion') | ||
parser.add_argument('--demo_name', type=str, default='v1', help='name of demo') | ||
parser.add_argument('--model_path', type=str, default='./pretrain', help='pre-trained model path') | ||
parser.add_argument('--demo_path', type=str, default='./demos', help='demos save path') | ||
|
||
args = parser.parse_args() | ||
|
||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pandas as pd | ||
import random | ||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
import os | ||
import glob | ||
|
||
class UVA(Dataset): | ||
def __init__(self, data_path, transform=None): | ||
|
||
self.data_path = data_path | ||
self.step = [2, 3] | ||
|
||
self.vids = os.listdir(self.data_path) | ||
self.transform = transform | ||
|
||
def __getitem__(self, idx): | ||
|
||
video_path = os.path.join(self.data_path, self.vids[idx]) | ||
frames = sorted(glob.glob(video_path + '/*.jpg')) | ||
nframes = len(frames) | ||
step = random.sample(self.step, 1)[0] | ||
|
||
start_idx = random.randint(0, nframes-16 * step) | ||
vid = [Image.open(frames[start_idx + i * step]).convert('RGB') for i in range(16)] | ||
|
||
if self.transform is not None: | ||
vid = self.transform(vid) | ||
|
||
return vid | ||
|
||
def __len__(self): | ||
|
||
return len(self.vids) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
data_path = '/data/stars/user/yaowang/data/UVA/crop_faces/data/' | ||
|
||
dataset = UVA(data_path) | ||
for i in range(len(dataset)): | ||
vid = dataset.__getitem__(i) | ||
print(len(vid)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import absolute_import | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.utils.tensorboard import SummaryWriter | ||
from model.networks import Generator | ||
import cfg | ||
import skvideo.io | ||
import numpy as np | ||
import os | ||
|
||
|
||
def save_videos(path, vids, n_za, frames): | ||
|
||
for i in range(n_za): # appearance loop | ||
v = vids[i].permute(0,2,3,1).cpu().numpy() | ||
v *= 255 | ||
v = v.astype(np.uint8) | ||
skvideo.io.vwrite(os.path.join(path, "%d_%d.mp4"%(i, frames)), v, outputdict={"-vcodec":"libx264"}) | ||
|
||
return | ||
|
||
|
||
def main(): | ||
|
||
args = cfg.parse_args() | ||
|
||
# write into tensorboard | ||
log_path = os.path.join(args.demo_path, args.demo_name + '/log') | ||
vid_path = os.path.join(args.demo_path, args.demo_name + '/vids') | ||
if not os.path.exists(log_path) and not os.path.exists(vid_path): | ||
os.makedirs(log_path) | ||
os.makedirs(vid_path) | ||
writer = SummaryWriter(log_path) | ||
|
||
device = torch.device("cuda:0") | ||
|
||
G = Generator().to(device) | ||
G = nn.DataParallel(G) | ||
G.load_state_dict(torch.load(args.model_path)) | ||
|
||
with torch.no_grad(): | ||
G.eval() | ||
|
||
za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device) # appearance | ||
|
||
# generating frames from [16, 20, 24, 28, 32, 36, 40, 44, 48] | ||
for i in range(9): | ||
zm = torch.randn(args.n_zm_test, args.d_zm, (i+1), 1, 1).to(device) # 16+i*4 | ||
vid_fake = G(za, zm) | ||
vid_fake = vid_fake.transpose(2,1) | ||
vid_fake = ((vid_fake - vid_fake.min()) / (vid_fake.max() - vid_fake.min())).data | ||
writer.add_video(tag='generated_videos_%dframes'%(16+i*4), global_step=1, vid_tensor=vid_fake) | ||
writer.flush() | ||
|
||
print('saving videos') | ||
save_videos(vid_path, vid_fake, args.n_za_test, (16+i*4)) | ||
|
||
return | ||
|
||
if __name__ == '__main__': | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from __future__ import absolute_import | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.utils.tensorboard import SummaryWriter | ||
from model.networks import Generator | ||
import cfg | ||
import skvideo.io | ||
import numpy as np | ||
import os | ||
|
||
|
||
def save_videos(path, vids, n_za, n_zm): | ||
|
||
for i in range(n_za): # appearance loop | ||
for j in range(n_zm): # motion loop | ||
v = vids[n_za*i + j].permute(0,2,3,1).cpu().numpy() | ||
v *= 255 | ||
v = v.astype(np.uint8) | ||
skvideo.io.vwrite(os.path.join(path, "%d_%d.mp4"%(i, j)), v, outputdict={"-vcodec":"libx264"}) | ||
|
||
return | ||
|
||
|
||
def main(): | ||
|
||
args = cfg.parse_args() | ||
|
||
# write into tensorboard | ||
log_path = os.path.join(args.demo_path, args.demo_name + '/log') | ||
vid_path = os.path.join(args.demo_path, args.demo_name + '/vids') | ||
|
||
if not os.path.exists(log_path) and not os.path.exists(vid_path): | ||
os.makedirs(log_path) | ||
os.makedirs(vid_path) | ||
writer = SummaryWriter(log_path) | ||
|
||
device = torch.device("cuda:0") | ||
|
||
G = Generator().to(device) | ||
G = nn.DataParallel(G) | ||
G.load_state_dict(torch.load(args.model_path)) | ||
|
||
|
||
with torch.no_grad(): | ||
G.eval() | ||
|
||
za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device) | ||
zm = torch.randn(args.n_zm_test, args.d_zm, 1, 1, 1).to(device) | ||
|
||
n_za = za.size(0) | ||
n_zm = zm.size(0) | ||
za = za.unsqueeze(1).repeat(1, n_zm, 1, 1, 1, 1).contiguous().view(n_za*n_zm, -1, 1, 1, 1) | ||
zm = zm.repeat(n_za, 1, 1, 1, 1) | ||
|
||
vid_fake = G(za, zm) | ||
|
||
vid_fake = vid_fake.transpose(2,1) # bs x 16 x 3 x 64 x 64 | ||
vid_fake = ((vid_fake - vid_fake.min()) / (vid_fake.max() - vid_fake.min())).data | ||
|
||
writer.add_video(tag='generated_videos', global_step=1, vid_tensor=vid_fake) | ||
writer.flush() | ||
|
||
# save into videos | ||
print('==> saving videos...') | ||
save_videos(vid_path, vid_fake, n_za, n_zm) | ||
|
||
|
||
return | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from __future__ import absolute_import | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.utils.tensorboard import SummaryWriter | ||
from model.networks import Generator | ||
import cfg | ||
import skvideo.io | ||
import numpy as np | ||
import os | ||
|
||
|
||
def save_videos(path, vids, n): | ||
|
||
for i in range(n): | ||
v = vids[i].permute(0,2,3,1).cpu().numpy() | ||
v *= 255 | ||
v = v.astype(np.uint8) | ||
skvideo.io.vwrite(os.path.join(path, "%d.mp4"%(i)), v, outputdict={"-vcodec":"libx264"}) | ||
|
||
return | ||
|
||
|
||
def main(): | ||
|
||
args = cfg.parse_args() | ||
|
||
# write into tensorboard | ||
log_path = os.path.join(args.demo_path, args.demo_name + '/log') | ||
vid_path = os.path.join(args.demo_path, args.demo_name + '/vids') | ||
|
||
if not os.path.exists(log_path) and not os.path.exists(vid_path): | ||
os.makedirs(log_path) | ||
os.makedirs(vid_path) | ||
writer = SummaryWriter(log_path) | ||
|
||
device = torch.device("cuda:0") | ||
|
||
G = Generator().to(device) | ||
G = nn.DataParallel(G) | ||
G.load_state_dict(torch.load(args.model_path)) | ||
|
||
with torch.no_grad(): | ||
G.eval() | ||
|
||
za = torch.randn(args.n, args.d_za, 1, 1, 1).to(device) | ||
zm = torch.randn(args.n, args.d_zm, 1, 1, 1).to(device) | ||
|
||
vid_fake = G(za, zm) | ||
|
||
vid_fake = vid_fake.transpose(2,1) # bs x 16 x 3 x 64 x 64 | ||
vid_fake = ((vid_fake - vid_fake.min()) / (vid_fake.max() - vid_fake.min())).data | ||
|
||
writer.add_video(tag='generated_videos', global_step=1, vid_tensor=vid_fake) | ||
writer.flush() | ||
|
||
# save into videos | ||
print('==> saving videos...') | ||
save_videos(vid_path, vid_fake, args.n) | ||
|
||
return | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
main() |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.