Skip to content

Commit

Permalink
修改train.py和部分脚本。
Browse files Browse the repository at this point in the history
  • Loading branch information
WanglifuCV committed Dec 5, 2020
1 parent ae13c5e commit f37c1fc
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 20 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
**/__pycache__
**/__pycache__
checkpoint/
runs/
5 changes: 5 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
python train.py -net resnet152 \
-gpu \
-batch_size 128 \
-lr 0.1 \
--gpu_list 0
1 change: 1 addition & 0 deletions tensorboard.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tensorboard --logdir runs --port 6008
26 changes: 10 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import os
import sys
import argparse
import time
from datetime import datetime

Expand All @@ -18,12 +17,13 @@
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
from conf import settings
from utils import get_network, get_training_dataloader, get_test_dataloader, WarmUpLR
from utils import Args as args
from torchsummary import summary

def train(epoch):

Expand Down Expand Up @@ -56,7 +56,7 @@ def train(epoch):
loss.item(),
optimizer.param_groups[0]['lr'],
epoch=epoch,
trained_samples=batch_index * args.b + len(images),
trained_samples=batch_index * args.batch_size + len(images),
total_samples=len(cifar100_training_loader.dataset)
))

Expand Down Expand Up @@ -112,31 +112,25 @@ def eval_training(epoch):
return correct.float() / len(cifar100_test_loader.dataset)

if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument('-net', type=str, required=True, help='net type')
parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not')
parser.add_argument('-b', type=int, default=128, help='batch size for dataloader')
parser.add_argument('-warm', type=int, default=1, help='warm up training phase')
parser.add_argument('-lr', type=float, default=0.1, help='initial learning rate')
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_list
net = get_network(args)

summary(net, (3, 32, 32))

#data preprocessing:
cifar100_training_loader = get_training_dataloader(
settings.CIFAR100_TRAIN_MEAN,
settings.CIFAR100_TRAIN_STD,
num_workers=4,
batch_size=args.b,
batch_size=args.batch_size,
shuffle=True
)

cifar100_test_loader = get_test_dataloader(
settings.CIFAR100_TRAIN_MEAN,
settings.CIFAR100_TRAIN_STD,
num_workers=4,
batch_size=args.b,
batch_size=args.batch_size,
shuffle=True
)

Expand Down
20 changes: 17 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import sys

import argparse
import numpy

import torch
Expand Down Expand Up @@ -167,7 +167,7 @@ def get_training_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=Tru
transforms.Normalize(mean, std)
])
#cifar100_training = CIFAR100Train(path, transform=transform_train)
cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=False, transform=transform_train)
cifar100_training_loader = DataLoader(
cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

Expand All @@ -190,7 +190,7 @@ def get_test_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True):
transforms.Normalize(mean, std)
])
#cifar100_test = CIFAR100Test(path, transform=transform_test)
cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=False, transform=transform_test)
cifar100_test_loader = DataLoader(
cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

Expand Down Expand Up @@ -230,3 +230,17 @@ def get_lr(self):
rate to base_lr * m / total_iters
"""
return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]


def parse_argments():
parser = argparse.ArgumentParser()
parser.add_argument('-net', type=str, required=True, help='net type')
parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not')
parser.add_argument('-batch_size', type=int, default=128, help='batch size for dataloader')
parser.add_argument('-warm', type=int, default=1, help='warm up training phase')
parser.add_argument('-lr', type=float, default=0.1, help='initial learning rate')
parser.add_argument('--gpu_list', type=str, default='0', help='GPU list')
args = parser.parse_args()
return args

Args = parse_argments()

0 comments on commit f37c1fc

Please sign in to comment.