forked from thuml/Transfer-Learning-Library
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
727 additions
and
99 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
""" | ||
@author: Baixu Chen | ||
@contact: [email protected] | ||
""" | ||
import random | ||
import time | ||
import warnings | ||
import sys | ||
import argparse | ||
import shutil | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.backends.cudnn as cudnn | ||
from torch.optim import SGD | ||
from torch.utils.data import DataLoader | ||
import torchvision.transforms as T | ||
import torch.nn.functional as F | ||
|
||
sys.path.append('../..') | ||
from common.modules.classifier import Classifier | ||
from common.vision.transforms import ResizeImage | ||
from common.utils.metric import accuracy | ||
from common.utils.meter import AverageMeter, ProgressMeter | ||
from common.utils.data import ForeverDataIterator | ||
from common.utils.logger import CompleteLogger | ||
|
||
sys.path.append('.') | ||
import utils | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
def main(args: argparse.Namespace): | ||
logger = CompleteLogger(args.log, args.phase) | ||
print(args) | ||
|
||
if args.seed is not None: | ||
random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
cudnn.deterministic = True | ||
warnings.warn('You have chosen to seed training. ' | ||
'This will turn on the CUDNN deterministic setting, ' | ||
'which can slow down your training considerably! ' | ||
'You may see unexpected behavior when restarting ' | ||
'from checkpoints.') | ||
|
||
cudnn.benchmark = True | ||
|
||
# Data loading code | ||
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||
|
||
train_transform = T.Compose([ | ||
T.RandomResizedCrop(224, scale=(0.2, 1.)), | ||
T.RandomHorizontalFlip(), | ||
T.ToTensor(), | ||
normalize | ||
]) | ||
|
||
val_transform = T.Compose([ | ||
ResizeImage(256), | ||
T.CenterCrop(224), | ||
T.ToTensor(), | ||
normalize | ||
]) | ||
|
||
# get dataset | ||
labeled_train_dataset, _, val_dataset = utils.get_dataset(args.data, args.root, args.sample_rate, train_transform, | ||
val_transform) | ||
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, | ||
num_workers=args.workers, drop_last=True) | ||
labeled_train_iter = ForeverDataIterator(labeled_train_loader) | ||
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) | ||
|
||
# create model | ||
print("=> using pre-trained model '{}'".format(args.arch)) | ||
backbone = utils.get_model(args.arch) | ||
num_classes = labeled_train_dataset.num_classes | ||
pool_layer = nn.Identity() if args.no_pool else None | ||
classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=not args.scratch).to(device) | ||
|
||
# define optimizer | ||
optimizer = SGD(classifier.get_parameters(args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) | ||
|
||
# resume from the best checkpoint | ||
if args.phase == 'test': | ||
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') | ||
classifier.load_state_dict(checkpoint) | ||
acc1 = utils.validate(val_loader, classifier, args, device) | ||
print(acc1) | ||
return | ||
|
||
# start training | ||
best_acc1 = 0.0 | ||
for epoch in range(args.epochs): | ||
# train for one epoch | ||
train(labeled_train_iter, classifier, optimizer, epoch, args) | ||
# evaluate on validation set | ||
with torch.no_grad(): | ||
acc1 = utils.validate(val_loader, classifier, args, device) | ||
|
||
# remember best acc@1 and save checkpoint | ||
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) | ||
if acc1 > best_acc1: | ||
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) | ||
best_acc1 = max(acc1, best_acc1) | ||
|
||
print("best_acc1 = {:3.1f}".format(best_acc1)) | ||
logger.close() | ||
|
||
|
||
def train(labeled_train_iter: ForeverDataIterator, model, optimizer: SGD, epoch: int, args: argparse.Namespace): | ||
batch_time = AverageMeter('Time', ':2.2f') | ||
data_time = AverageMeter('Data', ':2.1f') | ||
losses = AverageMeter('Loss', ':3.2f') | ||
cls_accs = AverageMeter('Acc', ':3.1f') | ||
|
||
progress = ProgressMeter( | ||
args.iters_per_epoch, | ||
[batch_time, data_time, losses, cls_accs], | ||
prefix="Epoch: [{}]".format(epoch)) | ||
|
||
# switch to train mode | ||
model.train() | ||
|
||
end = time.time() | ||
for i in range(args.iters_per_epoch): | ||
labeled_x, labels = next(labeled_train_iter) | ||
labeled_x = labeled_x.to(device) | ||
batch_size = labeled_x.shape[0] | ||
labels = labels.to(device) | ||
|
||
# measure data loading time | ||
data_time.update(time.time() - end) | ||
|
||
# compute output | ||
labeled_y, f = model(labeled_x) | ||
# cross entropy loss | ||
loss = F.cross_entropy(labeled_y, labels) | ||
|
||
# measure accuracy and record loss | ||
losses.update(loss.item(), batch_size) | ||
cls_acc = accuracy(labeled_y, labels)[0] | ||
cls_accs.update(cls_acc.item(), batch_size) | ||
|
||
# compute gradient and do SGD step | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# measure elapsed time | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
if i % args.print_freq == 0: | ||
progress.display(i) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Baseline for Semi Supervised Learning') | ||
# dataset parameters | ||
parser.add_argument('root', metavar='DIR', | ||
help='root path of dataset') | ||
parser.add_argument('-d', '--data', metavar='DATA', | ||
help='dataset: ' + ' | '.join(utils.get_dataset_names())) | ||
parser.add_argument('-sr', '--sample-rate', default=100, type=int, | ||
metavar='N', | ||
help='sample rate of training dataset (default: 100)') | ||
# model parameters | ||
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', | ||
choices=utils.get_model_names(), | ||
help='backbone architecture: ' + | ||
' | '.join(utils.get_model_names()) + | ||
' (default: resnet50)') | ||
parser.add_argument('--no-pool', action='store_true', | ||
help='no pool layer after the feature extractor.') | ||
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') | ||
# training parameters | ||
parser.add_argument('-b', '--batch-size', default=48, type=int, | ||
metavar='N', | ||
help='mini-batch size (default: 48)') | ||
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, | ||
metavar='LR', help='initial learning rate', dest='lr') | ||
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, | ||
metavar='W', help='weight decay (default:5e-4)') | ||
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', | ||
help='number of data loading workers (default: 4)') | ||
parser.add_argument('--epochs', default=5, type=int, metavar='N', | ||
help='number of total epochs to run') | ||
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, | ||
help='Number of iterations per epoch') | ||
parser.add_argument('-p', '--print-freq', default=100, type=int, | ||
metavar='N', help='print frequency (default: 100)') | ||
parser.add_argument('--seed', default=None, type=int, | ||
help='seed for initializing training. ') | ||
parser.add_argument("--log", type=str, default='baseline', | ||
help="Where to save logs, checkpoints and debugging images.") | ||
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], | ||
help="When phase is 'test', only test the model.") | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/usr/bin/env bash | ||
# ResNet50, CUB200 | ||
CUDA_VISIBLE_DEVICES=0 python baseline.py data/cub200 -d CUB200 -sr 15 --seed 0 --log logs/baseline/cub200_15 | ||
CUDA_VISIBLE_DEVICES=0 python baseline.py data/cub200 -d CUB200 -sr 30 --seed 0 --log logs/baseline/cub200_30 | ||
CUDA_VISIBLE_DEVICES=0 python baseline.py data/cub200 -d CUB200 -sr 50 --seed 0 --log logs/baseline/cub200_50 | ||
|
||
# ResNet50, StanfordCars | ||
CUDA_VISIBLE_DEVICES=0 python baseline.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --log logs/baseline/car_15 | ||
CUDA_VISIBLE_DEVICES=0 python baseline.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --log logs/baseline/car_30 | ||
CUDA_VISIBLE_DEVICES=0 python baseline.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --log logs/baseline/car_50 | ||
|
||
# ResNet50, Aircraft | ||
CUDA_VISIBLE_DEVICES=0 python baseline.py data/aircraft -d Aircraft -sr 15 --seed 0 --log logs/baseline/aircraft_15 | ||
CUDA_VISIBLE_DEVICES=0 python baseline.py data/aircraft -d Aircraft -sr 30 --seed 0 --log logs/baseline/aircraft_30 | ||
CUDA_VISIBLE_DEVICES=0 python baseline.py data/aircraft -d Aircraft -sr 50 --seed 0 --log logs/baseline/aircraft_50 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
@author: Baixu Chen | ||
@contact: [email protected] | ||
""" | ||
import numpy as np | ||
from PIL import Image | ||
from torchvision import datasets | ||
|
||
|
||
class CIFAR100(datasets.CIFAR100): | ||
def __init__(self, root, idxes, train=True, | ||
transform=None, target_transform=None, | ||
download=False): | ||
super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) | ||
if idxes is not None: | ||
self.data = self.data[idxes] | ||
self.targets = np.array(self.targets)[idxes] | ||
|
||
def __getitem__(self, index): | ||
img, target = self.data[index], self.targets[index] | ||
img = Image.fromarray(img) | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
if self.target_transform is not None: | ||
target = self.target_transform(target) | ||
|
||
return img, target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.