Skip to content

Commit edf4cf3

Browse files
committed
replace initializer by __init__ for models and datasets
1 parent e04e31e commit edf4cf3

16 files changed

+52
-60
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ This PyTorch implementation produces results comparable to or better than our or
1212

1313
**Note**: The current software works well with PyTorch 0.4+. Check out the older [branch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/tree/pytorch0.3.1) that supports PyTorch 0.1-0.3.
1414

15-
You may find useful information in [training/test tips](docs/tips.md) and [frequently asked questions](docs/qa.md). To implement your own model and dataset, check out our [templates](#implement-your-own-model-and-dataset).
15+
You may find useful information in [training/test tips](docs/tips.md) and [frequently asked questions](docs/qa.md). To implement your own model and dataset, check out our [templates](#implement-your-own-model-and-dataset).
1616

1717
**CycleGAN: [Project](https://junyanz.github.io/CycleGAN/) | [Paper](https://arxiv.org/pdf/1703.10593.pdf) | [Torch](https://github.com/junyanz/CycleGAN)**
1818
<img src="https://junyanz.github.io/CycleGAN/images/teaser_high_res.jpg" width="800"/>
@@ -29,12 +29,12 @@ You may find useful information in [training/test tips](docs/tips.md) and [frequ
2929

3030
If you use this code for your research, please cite:
3131

32-
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
32+
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.
3333
[Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*, [Taesung Park](https://taesung.me/)\*, [Phillip Isola](https://people.eecs.berkeley.edu/~isola/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros)
3434
In ICCV 2017. (* equal contributions) [[Bibtex]](https://junyanz.github.io/CycleGAN/CycleGAN.txt)
3535

3636

37-
Image-to-Image Translation with Conditional Adversarial Networks
37+
Image-to-Image Translation with Conditional Adversarial Networks.
3838
[Phillip Isola](https://people.eecs.berkeley.edu/~isola), [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz), [Tinghui Zhou](https://people.eecs.berkeley.edu/~tinghuiz), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros)
3939
In CVPR 2017. [[Bibtex]](http://people.csail.mit.edu/junyanz/projects/pix2pix/pix2pix.bib)
4040

data/__init__.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,13 @@ def get_option_setter(dataset_name):
3535

3636
def create_dataset(opt):
3737
dataset = find_dataset_using_name(opt.dataset_mode)
38-
instance = dataset()
39-
instance.initialize(opt)
38+
instance = dataset(opt)
4039
print("dataset [%s] was created" % (instance.name()))
4140
return instance
4241

4342

4443
def CreateDataLoader(opt):
45-
data_loader = CustomDatasetDataLoader()
46-
data_loader.initialize(opt)
44+
data_loader = CustomDatasetDataLoader(opt)
4745
return data_loader
4846

4947

@@ -53,8 +51,8 @@ class CustomDatasetDataLoader(BaseDataLoader):
5351
def name(self):
5452
return 'CustomDatasetDataLoader'
5553

56-
def initialize(self, opt):
57-
BaseDataLoader.initialize(self, opt)
54+
def __init__(self, opt):
55+
BaseDataLoader.__init__(self, opt)
5856
self.dataset = create_dataset(opt)
5957
self.dataloader = torch.utils.data.DataLoader(
6058
self.dataset,

data/aligned_dataset.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@ class AlignedDataset(BaseDataset):
1212
def modify_commandline_options(parser, is_train):
1313
return parser
1414

15-
def initialize(self, opt):
16-
self.opt = opt
17-
self.root = opt.dataroot
15+
def __init__(self, opt):
16+
BaseDataset.__init__(self, opt)
1817
self.dir_AB = os.path.join(opt.dataroot, opt.phase)
1918
self.AB_paths = sorted(make_dataset(self.dir_AB))
2019
assert(opt.resize_or_crop == 'resize_and_crop') # only support this mode

data/base_data_loader.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
class BaseDataLoader():
2-
def __init__(self):
3-
pass
4-
5-
def initialize(self, opt):
2+
def __init__(self, opt):
63
self.opt = opt
74
pass
85

data/base_dataset.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55

66
class BaseDataset(data.Dataset):
7-
def __init__(self):
8-
super(BaseDataset, self).__init__()
7+
def __init__(self, opt):
8+
self.opt = opt
9+
self.root = opt.dataroot
910

1011
def name(self):
1112
return 'BaseDataset'
@@ -14,9 +15,6 @@ def name(self):
1415
def modify_commandline_options(parser, is_train):
1516
return parser
1617

17-
def initialize(self, opt):
18-
pass
19-
2018
def __len__(self):
2119
return 0
2220

data/colorization_dataset.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ def modify_commandline_options(parser, is_train):
1313
parser.set_defaults(input_nc=1, output_nc=2, direction='AtoB')
1414
return parser
1515

16-
def initialize(self, opt):
17-
self.opt = opt
18-
self.root = opt.dataroot
16+
def __init__(self, opt):
17+
BaseDataset.__init__(self, opt)
1918
self.dir_A = os.path.join(opt.dataroot)
2019
self.A_paths = sorted(make_dataset(self.dir_A))
2120
assert(opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == 'AtoB')

data/single_dataset.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ class SingleDataset(BaseDataset):
88
def modify_commandline_options(parser, is_train):
99
return parser
1010

11-
def initialize(self, opt):
12-
self.opt = opt
13-
self.root = opt.dataroot
11+
def __init__(self, opt):
12+
BaseDataset.__init__(self, opt)
1413
self.A_paths = sorted(make_dataset(self.root))
1514
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
1615
self.transform = get_transform(opt, input_nc == 1)

data/template_dataset.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
"""Dataset class template
22
33
This module provides a templete for users to implement custom datasets.
4+
You need to implement the following functions:
5+
<modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
6+
<__init__>: Initialize this dataset class.
7+
<__getitem__>: Return a data point and its metadata information.
8+
<__len__>: Return the number of images.
9+
<name>: Return the name of this dataset.
410
"""
511
from data.base_dataset import BaseDataset, get_transform
612
# from data.image_folder import make_dataset
@@ -30,19 +36,18 @@ def modify_commandline_options(parser, is_train):
3036
parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
3137
return parser
3238

33-
def initialize(self, opt):
34-
"""Initialize this dataset class
39+
def __init__(self, opt):
40+
"""Initialize this dataset class.
3541
3642
Parameters:
3743
opt -- training/test options
3844
A few things can be done here.
39-
- save the options.
45+
- save the options (have been done in BaseDataset)
4046
- get image paths and meta information of the dataset.
4147
- define the image transformation.
4248
"""
4349
# save the option and dataset root
44-
self.opt = opt
45-
self.root = opt.dataroot
50+
BaseDataset.__init__(self, opt)
4651
# get the image paths of your dataset;
4752
self.image_paths = [] # You can call <sorted(make_dataset(self.root))> to get all the image paths under the directory self.root
4853
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
@@ -68,9 +73,9 @@ def __getitem__(self, index):
6873
return {'data_A': data_A, 'data_B': data_B, 'path': path}
6974

7075
def __len__(self):
71-
"""Return the number of images"""
76+
"""Return the total number of images."""
7277
return len(self.image_paths)
7378

7479
def name(self):
75-
"""Return the name of this dataset"""
80+
"""Return the name of this dataset."""
7681
return 'TemplateDataset'

data/unaligned_dataset.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ class UnalignedDataset(BaseDataset):
1010
def modify_commandline_options(parser, is_train):
1111
return parser
1212

13-
def initialize(self, opt):
14-
self.opt = opt
15-
self.root = opt.dataroot
13+
def __init__(self, opt):
14+
BaseDataset.__init__(self, opt)
1615
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
1716
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
1817

models/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def get_option_setter(model_name):
3333

3434
def create_model(opt):
3535
model = find_model_using_name(opt.model)
36-
instance = model()
37-
instance.initialize(opt)
36+
instance = model(opt)
3837
print("model [%s] was created" % (instance.name()))
3938
return instance

models/base_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def modify_commandline_options(parser, is_train):
1515
def name(self):
1616
return 'BaseModel'
1717

18-
def initialize(self, opt):
18+
def __init__(self, opt):
1919
self.opt = opt
2020
self.gpu_ids = opt.gpu_ids
2121
self.isTrain = opt.isTrain

models/cycle_gan_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def modify_commandline_options(parser, is_train=True):
2020

2121
return parser
2222

23-
def initialize(self, opt):
24-
BaseModel.initialize(self, opt)
23+
def __init__(self, opt):
24+
BaseModel.__init__(self, opt)
2525
# specify the training losses you want to print out. The program will call base_model.get_current_losses
2626
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
2727
# specify the images you want to save/display. The program will call base_model.get_current_visuals

models/pix2pix_colorization_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ class Pix2PixColorizationModel(Pix2PixModel):
88
def name(self):
99
return 'Pix2PixColorizationModel'
1010

11-
def initialize(self, opt):
11+
def __init__(self, opt):
1212
# reuse the pix2pix model
13-
Pix2PixModel.initialize(self, opt)
13+
Pix2PixModel.__init__(self, opt)
1414
# specify the images to be visualized.
1515
self.visual_names = ['real_A', 'real_B_rgb', 'fake_B_rgb']
1616

models/pix2pix_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ def modify_commandline_options(parser, is_train=True):
1919

2020
return parser
2121

22-
def initialize(self, opt):
23-
BaseModel.initialize(self, opt)
22+
def __init__(self, opt):
23+
BaseModel.__init__(self, opt)
2424
# specify the training losses you want to print out. The program will call base_model.get_current_losses
2525
self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
2626
# specify the images you want to save/display. The program will call base_model.get_current_visuals

models/template_model.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
min_<netG> ||netG(data_A) - data_B||_1
77
You need to implement the following functions:
88
<modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
9-
<initialize>: Initialize this model class
10-
<set_input>: Unpack input data and perform data pre-processing
11-
<forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>
12-
<backward>: Calculate gradients for network G
13-
<optimize_parameters>: Update network weights; it will be called in every training iteration
9+
<__init__>: Initialize this model class.
10+
<set_input>: Unpack input data and perform data pre-processing.
11+
<forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
12+
<backward>: Calculate gradients for network weights.
13+
<optimize_parameters>: Update network weights; it will be called in every training iteration.
1414
"""
1515
import torch
1616
from .base_model import BaseModel
@@ -39,8 +39,8 @@ def modify_commandline_options(parser, is_train=True):
3939

4040
return parser
4141

42-
def initialize(self, opt):
43-
"""Initialize this model class
42+
def __init__(self, opt):
43+
"""Initialize this model class.
4444
4545
Parameters:
4646
opt -- training/test options
@@ -49,7 +49,7 @@ def initialize(self, opt):
4949
- (required) call the initialization function of BaseModel
5050
- define loss function, visualization images, model names, and optimizers
5151
"""
52-
BaseModel.initialize(self, opt) # call the initialization method of BaseModel
52+
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
5353
# specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
5454
self.loss_names = ['loss_G']
5555
# specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
@@ -83,20 +83,19 @@ def set_input(self, input):
8383
self.path = input['path'] # get image path
8484

8585
def forward(self):
86-
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>"""
86+
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
8787
self.output = self.netG(self.data_A) # generate output image given the input data_A
8888

8989
def backward(self):
90-
"""calculate gradients for network G"""
90+
"""calculate gradients for network weights."""
9191
# caculate the intermediate results if necessary; here self.output has been computed during function <forward>
9292
# calculate loss given the input and intermediate results
9393
self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
9494
self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
9595

9696
def optimize_parameters(self):
97-
"""Update network weights; it will be called in every training iteration"""
97+
"""Update network weights; it will be called in every training iteration."""
9898
self.forward() # first call forward to calculate intermediate results
99-
# update network G
10099
self.optimizer.zero_grad() # clear network G's existing gradients
101100
self.backward() # calculate gradients for network G
102101
self.optimizer.step() # update gradients for network G

models/test_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def modify_commandline_options(parser, is_train=True):
1616

1717
return parser
1818

19-
def initialize(self, opt):
19+
def __init__(self, opt):
2020
assert(not opt.isTrain)
21-
BaseModel.initialize(self, opt)
21+
BaseModel.__init__(self, opt)
2222
# specify the training losses you want to print out. The program will call base_model.get_current_losses
2323
self.loss_names = []
2424
# specify the images you want to save/display. The program will call base_model.get_current_visuals

0 commit comments

Comments
 (0)