Skip to content

Commit

Permalink
separate source, target, validation noise model
Browse files Browse the repository at this point in the history
  • Loading branch information
yusuke-a-uchida committed Jul 22, 2018
1 parent c07651c commit 3185f7e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
4 changes: 2 additions & 2 deletions generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def __getitem__(self, idx):


class ValGenerator(Sequence):
def __init__(self, image_dir, source_noise_model):
def __init__(self, image_dir, val_noise_model):
image_paths = list(Path(image_dir).glob("*.*"))
self.image_num = len(image_paths)
self.data = []

for image_path in image_paths:
y = cv2.imread(str(image_path))
x = source_noise_model(y)
x = val_noise_model(y)
self.data.append([np.expand_dims(x, axis=0), np.expand_dims(y, axis=0)])

def __len__(self):
Expand Down
7 changes: 5 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def get_args():
help="checkpoint dir")
parser.add_argument("--source_noise_model", type=str, default="gaussian,0,50",
help="noise model for source images")
parser.add_argument("--target_noise_model", type=str, default="gaussian,25,25",
parser.add_argument("--target_noise_model", type=str, default="gaussian,0,50",
help="noise model for target images")
parser.add_argument("--val_noise_model", type=str, default="gaussian,25,25",
help="noise model for target images")
args = parser.parse_args()

Expand All @@ -66,9 +68,10 @@ def main():
model.compile(optimizer=opt, loss="mse", metrics=[PSNR])
source_noise_model = get_noise_model(args.source_noise_model)
target_noise_model = get_noise_model(args.target_noise_model)
val_noise_model = get_noise_model(args.val_noise_model)
generator = NoisyImageGenerator(image_dir, source_noise_model, target_noise_model, batch_size=batch_size,
image_size=image_size)
val_generator = ValGenerator(test_dir, source_noise_model)
val_generator = ValGenerator(test_dir, val_noise_model)
output_path.mkdir(parents=True, exist_ok=True)
callbacks = [
LearningRateScheduler(schedule=Schedule(nb_epochs, lr)),
Expand Down

0 comments on commit 3185f7e

Please sign in to comment.