Skip to content

Commit

Permalink
added model argument to main.py, fixed some errors
Browse files Browse the repository at this point in the history
  • Loading branch information
bodokaiser committed Apr 15, 2017
1 parent 30b6d4e commit d4960af
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 30 deletions.
30 changes: 25 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchvision.transforms import Compose, CenterCrop, ToTensor

from piwise.dataset import VOC12
from piwise.network import FCN8, FCN16, FCN32, SegNet, PSPNet, UNetSeg
from piwise.network import FCN8, FCN16, FCN32, UNet, PSPNet, SegNet1, SegNet2
from piwise.criterion import CrossEntropyLoss2d
from piwise.transform import Relabel, ToLabel, Colorize
from piwise.visualize import Dashboard
Expand Down Expand Up @@ -50,7 +50,27 @@ def evaluate(args, model, loader):
return outputs

def main(args):
model = FCN8(NUM_CHANNELS, NUM_CLASSES)
Net = None

if args.model == 'fcn8':
Net = FCN8
if args.model == 'fcn16':
Net = FCN16
if args.model == 'fcn32':
Net = FCN32
if args.model == 'fcn32':
Net = FCN32
if args.model == 'unet':
Net = UNet
if args.model == 'pspnet':
Net = PSPNet
if args.model == 'segnet1':
Net = SegNet1
if args.model == 'segnet2':
Net = SegNet2
assert Net is not None, f'model {args.model} not available'

model = Net(NUM_CHANNELS, NUM_CLASSES)

loader = DataLoader(VOC12(args.dataroot,
input_transform=Compose([
Expand All @@ -70,14 +90,14 @@ def main(args):

if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--port', type=int, default=80)
parser.add_argument('--cuda', action='store_true')
parser.add_argument('--model', choices=['simple'])
parser.add_argument('--model', required=True)
parser.add_argument('--visualize', choices=['dashboard'])
parser.add_argument('--visualize-loss-steps', type=int, default=50)
parser.add_argument('--visualize-image-steps', type=int, default=50)
parser.add_argument('--num-epochs', type=int, default=32)
parser.add_argument('--num-workers', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--dataroot', nargs='?', default='data')
main(parser.parse_args())
parser.add_argument('--port', type=int, default=80)
main(parser.parse_args())
50 changes: 25 additions & 25 deletions piwise/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def forward(self, x):
return self.down(x)


class UNetSeg(nn.Module):
class UNet(nn.Module):

def __init__(self, num_channels, num_classes):
super().__init__()
Expand Down Expand Up @@ -189,10 +189,10 @@ def forward(self, x):
up1 = self.up1(torch.cat([
up2, F.upsample_bilinear(down1, up2.size()[2:])], 1))

return self.final(up1)
return F.upsample_bilinear(self.final(up1), x.size()[2:])


class BasicSegNetUp(nn.Module):
class SegNet1Up(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
Expand All @@ -207,7 +207,7 @@ def forward(self, x):
return self.up(x)


class BasicSegNetDown(nn.Module):
class SegNet1Down(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
Expand All @@ -223,19 +223,19 @@ def forward(self, x):
return self.down(x)


class BasicSegNet(nn.Module):
class SegNet1(nn.Module):

def __init__(self, num_channels, num_classes):
super().__init__()

self.down1 = BasicSegNetDown(num_channels, 64)
self.down2 = BasicSegNetDown(64, 64)
self.down3 = BasicSegNetDown(64, 64)
self.down4 = BasicSegNetDown(64, 64)
self.up4 = BasicSegNetUp(64, 64)
self.up3 = BasicSegNetUp(128, 64)
self.up2 = BasicSegNetUp(128, 64)
self.up1 = BasicSegNetUp(128, 64)
self.down1 = SegNet1Down(num_channels, 64)
self.down2 = SegNet1Down(64, 64)
self.down3 = SegNet1Down(64, 64)
self.down4 = SegNet1Down(64, 64)
self.up4 = SegNet1Up(64, 64)
self.up3 = SegNet1Up(128, 64)
self.up2 = SegNet1Up(128, 64)
self.up1 = SegNet1Up(128, 64)
self.final = nn.Conv2d(64, num_classes, 1)

def forward(self, x):
Expand All @@ -251,7 +251,7 @@ def forward(self, x):
return self.final(up1)


class SegNetUp(nn.Module):
class SegNet2Up(nn.Module):

def __init__(self, in_channels, out_channels, layers):
super().__init__()
Expand All @@ -278,7 +278,7 @@ def forward(self, x):
return self.up(x)


class SegNetDown(nn.Module):
class SegNet2Down(nn.Module):

def __init__(self, in_channels, out_channels, layers):
super().__init__()
Expand All @@ -302,20 +302,20 @@ def forward(self, x):
return self.down(x)


class SegNet(nn.Module):
class SegNet2(nn.Module):

def __init__(self, num_channels, num_classes):
super().__init__()

self.down1 = SegNetDown(num_channels, 64, layers=1)
self.down2 = SegNetDown(64, 128, layers=1)
self.down3 = SegNetDown(128, 256, layers=2)
self.down4 = SegNetDown(256, 512, layers=2)
self.down5 = SegNetDown(512, 512, layers=2)
self.up5 = SegNetUp(512, 512, layers=1)
self.up4 = SegNetUp(1024, 256, layers=1)
self.up3 = SegNetUp(512, 128, layers=1)
self.up2 = SegNetUp(256, 64, layers=0)
self.down1 = SegNet2Down(num_channels, 64, layers=1)
self.down2 = SegNet2Down(64, 128, layers=1)
self.down3 = SegNet2Down(128, 256, layers=2)
self.down4 = SegNet2Down(256, 512, layers=2)
self.down5 = SegNet2Down(512, 512, layers=2)
self.up5 = SegNet2Up(512, 512, layers=1)
self.up4 = SegNet2Up(1024, 256, layers=1)
self.up3 = SegNet2Up(512, 128, layers=1)
self.up2 = SegNet2Up(256, 64, layers=0)
self.up1 = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(128, 64, 3, padding=1),
Expand Down

0 comments on commit d4960af

Please sign in to comment.