Skip to content

Commit

Permalink
normalize the train.py code
Browse files Browse the repository at this point in the history
  • Loading branch information
yurizzzzz committed Dec 15, 2021
1 parent 39a774a commit b32103f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 16 deletions.
8 changes: 5 additions & 3 deletions code/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sys import argv
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
Expand Down Expand Up @@ -29,10 +30,11 @@ def get_boundingbox(face, width, height, scale=1.3, minsize=None):


class LoadData(Dataset):
def __init__(self, img, label_dict, mode='train'):
def __init__(self, arg, img, label_dict, mode='train'):
self.img = img
self.label_dict = label_dict
self.mode = mode
self.args = arg

def __getitem__(self, item):
input_img = self.img[item]
Expand All @@ -42,7 +44,7 @@ def __getitem__(self, item):

if self.mode == 'train':
face_detect = dlib.get_frontal_face_detector()
img = cv2.imread('/home/fzw/face/image/train/' + input_img)
img = cv2.imread(self.args.train_dir + '/' + input_img)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
faces = face_detect(gray, 1)
if len(faces) != 0:
Expand All @@ -62,7 +64,7 @@ def __getitem__(self, item):

if self.mode == 'val':
face_detect = dlib.get_frontal_face_detector()
img = cv2.imread('/home/fzw/face/image/val/' + input_img)
img = cv2.imread(self.args.val_dir + '/' + input_img)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
faces = face_detect(gray, 1)
if len(faces) != 0:
Expand Down
8 changes: 6 additions & 2 deletions code/test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import os
import cv2
from torch.serialization import load
import dlib
import torch
import numpy as np
Expand All @@ -24,6 +25,9 @@ def input_args():

parser.add_argument("--test_dir", type=str, default='/home/fzw/face/image/test/',
help="The testdata path")

parser.add_argument("--pre_model", type=str, default='home/fzw/face-forgery-detection-val/checkpoint/checkpoint_9.tar',
help="the path of pretraining model")

return parser.parse_args()

Expand All @@ -37,8 +41,8 @@ def input_args():
model = model_core.Two_Stream_Net()
model = model.cuda()

# model_state_dict = torch.load('/home/fzw/face-forgery-detection-val/checkpoint/checkpoint_9.tar', map_location='cuda:2')['state_dict']
# model.load_state_dict(model_state_dict)
model_state_dict = torch.load(args.pre_model, map_location='cuda:0')['state_dict']
model.load_state_dict(model_state_dict)

test_list = [file for file in os.listdir(args.test_dir) if file.endswith('.jpg')]
test_list = tqdm(test_list)
Expand Down
53 changes: 42 additions & 11 deletions code/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,65 @@
import torch.optim as optim
from torchvision import utils as vutils

from tqdm import tqdm
import csv
import dataset
import argparse
import model_core
from loss import am_softmax

torch.cuda.set_device(2)
def input_args():
parser = argparse.ArgumentParser()

parser.add_argument("--cuda_id", type=int, default=0,
help="The GPU ID")

parser.add_argument("--train_label", type=str, default='/home/fzw/face/train.labels.csv',
help="The traindata label path")

parser.add_argument("--train_dir", type=str, default='/home/fzw/face/image/train/',
help="The traindata path ")

parser.add_argument("--val_dir", type=str, default='/home/fzw/face/image/val/',
help="The valdata path ")

parser.add_argument("--load_model", type=bool, default=False,
help="Whether load pretraining model")

parser.add_argument("--pre_model", type=str, default='home/fzw/face-forgery-detection-val/checkpoint/checkpoint_9.tar',
help="the path of pretraining model")

parser.add_argument("--save_model", type=str, default='/home/fzw/face-forgery-detection-val/checkpoint/',
help="the path of saving model")

return parser.parse_args()


if __name__ == '__main__':
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
args = input_args()

torch.cuda.set_device(args.cuda_id)
device = torch.device("cuda:%d" % (args.cuda_id) if torch.cuda.is_available() else "cpu")

csvFile = open("/home/fzw/face/train.labels.csv", "r")
csvFile = open(args.train_label, "r")
reader = csv.reader(csvFile)
label_dict = dict()
for item in reader:
# key: filename
key = item[-1][:-2]
# value: the label (0 or 1) of file
value = item[-1][-1]
if value != 'l':
value = int(value)
label_dict.update({key: value})

train_list = [file for file in os.listdir('/home/fzw/face/image/train/') if file.endswith('.jpg')]
val_list = [file for file in os.listdir('/home/fzw/face/image/val/') if file.endswith('.jpg')]
TrainData = torch.utils.data.DataLoader(dataset.LoadData(train_list, label_dict, mode='train'),
train_list = [file for file in os.listdir(args.train_dir) if file.endswith('.jpg')]
val_list = [file for file in os.listdir(args.val_dir) if file.endswith('.jpg')]
TrainData = torch.utils.data.DataLoader(dataset.LoadData(args, train_list, label_dict, mode='train'),
batch_size=16,
shuffle=True,
num_workers=16,
drop_last=False)
ValData = torch.utils.data.DataLoader(dataset.LoadData(val_list, label_dict, mode='val'),
ValData = torch.utils.data.DataLoader(dataset.LoadData(args, val_list, label_dict, mode='val'),
batch_size=16,
shuffle=True,
num_workers=16,
Expand All @@ -48,8 +78,9 @@

model = model_core.Two_Stream_Net()
model = model.cuda()
model_state_dict = torch.load('/home/fzw/face-forgery-detection-val/checkpoint/checkpoint_9.tar', map_location='cuda:2')['state_dict']
model.load_state_dict(model_state_dict)
if args.load_model:
model_state_dict = torch.load(args.pre_model, map_location='cuda:0')['state_dict']
model.load_state_dict(model_state_dict)
optimizer = optim.Adam(model.parameters(), lr=0.0002, betas=(0.9, 0.999))

epoch = 0
Expand Down Expand Up @@ -110,7 +141,7 @@
val_bar.set_description(desc)
val_bar.update()

savename = '/home/fzw/face-forgery-detection-val/checkpoint/checkpoint' + '_' + str(epoch) + '.tar'
savename = args.save_model + '/checkpoint' + '_' + str(epoch) + '.tar'
torch.save({'epoch': epoch, 'state_dict': model.state_dict()}, savename)
epoch = epoch + 1

Expand Down

0 comments on commit b32103f

Please sign in to comment.