Skip to content

Commit

Permalink
BVDNet
Browse files Browse the repository at this point in the history
  • Loading branch information
HypoX64 committed Apr 18, 2021
1 parent 538ccbc commit 972dcc8
Show file tree
Hide file tree
Showing 10 changed files with 964 additions and 388 deletions.
18 changes: 10 additions & 8 deletions cores/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self):
def initialize(self):

#base
self.parser.add_argument('--use_gpu', type=int,default=0, help='if -1, use cpu')
self.parser.add_argument('--use_gpu', type=str,default='0', help='if -1, use cpu')
self.parser.add_argument('--media_path', type=str, default='./imgs/ruoruo.jpg',help='your videos or images path')
self.parser.add_argument('-ss', '--start_time', type=str, default='00:00:00',help='start position of video, default is the beginning of video')
self.parser.add_argument('-t', '--last_time', type=str, default='00:00:00',help='duration of the video, default is the entire video')
Expand Down Expand Up @@ -58,13 +58,15 @@ def getparse(self, test_flag = False):

model_name = os.path.basename(self.opt.model_path)
self.opt.temp_dir = os.path.join(self.opt.temp_dir, 'DeepMosaics_temp')

os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.use_gpu)
import torch
if torch.cuda.is_available() and self.opt.use_gpu > -1:
pass
else:
self.opt.use_gpu = -1


if self.opt.use_gpu != '-1':
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.use_gpu)
import torch
if not torch.cuda.is_available():
self.opt.use_gpu = '-1'
# else:
# self.opt.use_gpu = '-1'

if test_flag:
if not os.path.exists(self.opt.media_path):
Expand Down
2 changes: 1 addition & 1 deletion make_datasets/make_pix2pix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
mask = mask_drawn
if 'irregular' in opt.mod:
mask_irr = impro.imread(irrpaths[random.randint(0,12000-1)],'gray')
mask_irr = data.random_transform_single(mask_irr, (img.shape[0],img.shape[1]))
mask_irr = data.random_transform_single_mask(mask_irr, (img.shape[0],img.shape[1]))
mask = mask_irr
if 'network' in opt.mod:
mask_net = runmodel.get_ROI_position(img,net,opt,keepsize=True)[0]
Expand Down
144 changes: 144 additions & 0 deletions models/BVDNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .pix2pixHD_model import *


class Encoder2d(nn.Module):
def __init__(self, input_nc, ngf=64, n_downsampling=3, norm_layer=nn.BatchNorm2d):
super(Encoder2d, self).__init__()
activation = nn.ReLU(True)

model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.ReflectionPad2d(1),nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0),
norm_layer(ngf * mult * 2), activation]

self.model = nn.Sequential(*model)

def forward(self, input):
return self.model(input)

class Encoder3d(nn.Module):
def __init__(self, input_nc, ngf=64, n_downsampling=3, norm_layer=nn.BatchNorm3d):
super(Encoder3d, self).__init__()
activation = nn.ReLU(True)

model = [nn.Conv3d(input_nc, ngf, kernel_size=3, padding=1), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv3d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]

self.model = nn.Sequential(*model)

def forward(self, input):
return self.model(input)

class BVDNet(nn.Module):
def __init__(self, N, n_downsampling=3, n_blocks=1, input_nc=3, output_nc=3):
super(BVDNet, self).__init__()

ngf = 64
padding_type = 'reflect'
norm_layer = nn.BatchNorm2d
self.N = N

# encoder
self.encoder3d = Encoder3d(input_nc,64,n_downsampling)
self.encoder2d = Encoder2d(input_nc,64,n_downsampling)

### resnet blocks
self.blocks = []
mult = 2**n_downsampling
for i in range(n_blocks):
self.blocks += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=nn.ReLU(True), norm_layer=norm_layer)]
self.blocks = nn.Sequential(*self.blocks)

### decoder
self.decoder = []
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
# self.decoder += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
# norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]
self.decoder += [ nn.Upsample(scale_factor = 2, mode='nearest'),
nn.ReflectionPad2d(1),
nn.Conv2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=1, padding=0),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
self.decoder += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
self.decoder = nn.Sequential(*self.decoder)
self.limiter = nn.Tanh()

def forward(self, stream, last):
this_shortcut = stream[:,:,self.N]
stream = self.encoder3d(stream)
stream = stream.reshape(stream.size(0),stream.size(1),stream.size(3),stream.size(4))
# print(stream.shape)
last = self.encoder2d(last)
x = stream + last
x = self.blocks(x)
x = self.decoder(x)
x = x+this_shortcut
x = self.limiter(x)
#print(x.shape)

# print(stream.shape,last.shape)
return x

class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()

self.vgg = Vgg19()
if gpu_ids != '-1' and len(gpu_ids) == 1:
self.vgg.cuda()
elif gpu_ids != '-1' and len(gpu_ids) > 1:
self.vgg = nn.DataParallel(self.vgg)
self.vgg.cuda()

self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]

def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss

from torchvision import models
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False

def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
10 changes: 10 additions & 0 deletions models/model_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
import torch.nn as nn

def save(net,path,gpu_id):
if isinstance(net, nn.DataParallel):
torch.save(net.module.cpu().state_dict(),path)
else:
torch.save(net.cpu().state_dict(),path)
if gpu_id != '-1':
net.cuda()
2 changes: 1 addition & 1 deletion train/add/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def loadimage(imagepaths,maskpaths,opt,test_flag = False):
for i in range(len(imagepaths)):
img = impro.resize(impro.imread(imagepaths[i]),opt.loadsize)
mask = impro.resize(impro.imread(maskpaths[i],mod = 'gray'),opt.loadsize)
img,mask = data.random_transform_image(img, mask, opt.finesize, test_flag)
img,mask = data.random_transform_pair_image(img, mask, opt.finesize, test_flag)
images[i] = (img.transpose((2, 0, 1))/255.0)
masks[i] = (mask.reshape(1,1,opt.finesize,opt.finesize)/255.0)
images = Totensor(images,opt.use_gpu)
Expand Down
Loading

0 comments on commit 972dcc8

Please sign in to comment.