Skip to content

Commit

Permalink
updated main.py to use recommend optimizer params
Browse files Browse the repository at this point in the history
  • Loading branch information
bodokaiser committed Apr 18, 2017
1 parent eb225c2 commit aa063d6
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
21 changes: 15 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,24 @@
def train(args, model):
model.train(True)

if args.steps_plot > 0:
board = Dashboard(args.port)

loader = DataLoader(VOC12(args.datadir, input_transform, target_transform),
num_workers=args.num_workers, batch_size=args.batch_size)

optimizer = SGD(model.parameters(), lr=.0001, momentum=.9, weight_decay=.04)
num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
criterion = CrossEntropyLoss2d()

optimizer = Adam(model.parameters())
if args.model.startswith('FCN'):
optimizer = SGD(model.parameters(), 1e-4, .9, 2e-5)
if args.model.startswith('PSP'):
optimizer = SGD(model.parameters(), 1e-2, .9, 1e-4)
if args.model.startswith('Seg'):
optimizer = SGD(model.parameters(), 1e-3, .9)

if args.cuda:
criterion = criterion.cuda()

if args.steps_plot > 0:
board = Dashboard(args.port)

for epoch in range(1, args.num_epochs+1):
epoch_loss = []

Expand Down
18 changes: 15 additions & 3 deletions piwise/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,20 @@
import torch.nn.functional as F

from torchvision import models
from torchvision.transforms import Normalize

def vgg_normalize(images):
images[:, 0] -= .485
images[:, 0] /= .229
images[:, 1] -= .456
images[:, 1] /= .224
images[:, 2] -= .406
images[:, 2] /= .225

class FCN(nn.Module):

MEAN = [.485, .456, .406]
STD = [.229, .224, .225]

def __init__(self, num_classes):
super().__init__()

Expand All @@ -27,11 +37,13 @@ def __init__(self, num_classes):
nn.Dropout(),
)
self.score_fconn = nn.Conv2d(4096, num_classes, 1)
self.normalize = Normalize([.485, .456, .406], [.229, .224, .225])

def normalize(self, x):
for i in range(3):
x[:, i] = (x[:, i] - self.MEAN[i]) / self.STD[i]

def forward(self, x):
x = self.normalize(x)
self.normalize(x)
x = self.feat1(x)
x = self.feat2(x)
x = self.feat3(x)
Expand Down
2 changes: 1 addition & 1 deletion piwise/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def __call__(self, gray_image):
color_image[1][mask] = self.cmap[label][1]
color_image[2][mask] = self.cmap[label][2]

return color_image
return color_image

0 comments on commit aa063d6

Please sign in to comment.