Skip to content

Commit

Permalink
corrected error in metric and improved discriminator model
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Jul 20, 2022
1 parent 139563b commit d4c4b0d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
9 changes: 6 additions & 3 deletions src/downscaling/gan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def wind_speed_rmse(real_output, fake_output):


def angular_cosine_distance(real_output, fake_output):
cos_sim = cosine_similarity(real_output, fake_output)
cos_sim = -cosine_similarity(real_output, fake_output)
# sometimes the keras function returns values just above 1 or below -1
bounded_cos_sim = tf.clip_by_value(cos_sim, -1, 1)
acd = tf.math.acos(bounded_cos_sim) / np.pi
Expand Down Expand Up @@ -160,11 +160,13 @@ def ks_stat_on_patch(patch1, patch2):


@tf.autograph.experimental.do_not_convert
def spatially_convolved_ks_stat(real_output, fake_output):
def spatially_convolved_ks_stat(real_output, fake_output, patch_size=None):
to_concat = []
patch_size = fake_output.shape[2] // 10
patch_size = patch_size or fake_output.shape[2] // 10
i = 0
for time in range(fake_output.shape[1]):
for ch in range(fake_output.shape[-1]):
print(f'Patch {i}/{fake_output.shape[1]*fake_output.shape[-1]}')
patch1 = tf.image.extract_patches(real_output[:, time, ..., ch:ch + 1],
sizes=(1, patch_size, patch_size, 1),
strides=(1, 5, 5, 1),
Expand All @@ -177,6 +179,7 @@ def spatially_convolved_ks_stat(real_output, fake_output):
padding='VALID')
ks_stat_for_time_step = ks_stat_on_patch(patch1, patch2)
to_concat.append(tf.reduce_mean(ks_stat_for_time_step, axis=(1, 2)))
i+=1
mean_ks_stat_img = tf.reduce_mean(to_concat, axis=0)
return mean_ks_stat_img

Expand Down
9 changes: 6 additions & 3 deletions src/downscaling/gan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,18 @@ def make_discriminator(
x = kl.LayerNormalization()(x)

shortcut = x
i = 0
while img_size(x) >= 4:
x = kl.TimeDistributed(kl.ZeroPadding2D())(x)
x = kl.TimeDistributed(
SpectralNormalization(kl.Conv2D(channels(x) * 2, (7, 7), strides=3,
activation=LeakyReLU(0.2))), name=f'conv_{img_size(x)}')(x)
x = kl.LayerNormalization()(x)
shortcut = shortcut_convolution(shortcut, x, channels(x))
# Split connection
x = kl.add([x, shortcut])
i += 1
if i > 1:
shortcut = shortcut_convolution(shortcut, x, channels(x))
# Split connection
x = kl.add([x, shortcut])

while img_size(x) > 2:
x = kl.TimeDistributed(
Expand Down

0 comments on commit d4c4b0d

Please sign in to comment.