Skip to content

Commit

Permalink
Merge pull request #16 from developer0hye/patch-1
Browse files Browse the repository at this point in the history
Cleanup
  • Loading branch information
weiaicunzai authored Aug 2, 2020
2 parents 440781d + 0d41819 commit 06105d0
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def get_lr(self):
parser.add_argument('-base_lr', type=float, default=1e-7, help='min learning rate')
parser.add_argument('-max_lr', type=float, default=10, help='max learning rate')
parser.add_argument('-num_iter', type=int, default=100, help='num of iteration')
parser.add_argument('-gpu', type=bool, default=True, help='use gpu or not')
parser.add_argument('-gpus', nargs='+', type=int, default=0, help='gpu device')
args = parser.parse_args()

Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def eval_training(epoch):
parser.add_argument('-lr', type=float, default=0.1, help='initial learning rate')
args = parser.parse_args()

net = get_network(args, use_gpu=args.gpu)

net = get_network(args)
#data preprocessing:
cifar100_training_loader = get_training_dataloader(
settings.CIFAR100_TRAIN_MEAN,
Expand Down
6 changes: 3 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.utils.data import DataLoader


def get_network(args, use_gpu=True):
def get_network(args):
""" return given network
"""

Expand Down Expand Up @@ -140,7 +140,7 @@ def get_network(args, use_gpu=True):
print('the network name you have entered is not supported yet')
sys.exit()

if use_gpu:
if args.gpu: #use_gpu
net = net.cuda()

return net
Expand Down Expand Up @@ -229,4 +229,4 @@ def get_lr(self):
"""we will use the first m batches, and set the learning
rate to base_lr * m / total_iters
"""
return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]

0 comments on commit 06105d0

Please sign in to comment.