Skip to content

Commit

Permalink
Exploding smoothing update
Browse files Browse the repository at this point in the history
  • Loading branch information
osoblanco committed Dec 24, 2021
1 parent d3d8062 commit cf3b498
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
2 changes: 1 addition & 1 deletion config/LJSpeech/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ optimizer:
eps: 0.000000001
weight_decay: 0.0
grad_clip_thresh: 1.0
grad_acc_step: 10
grad_acc_step: 1
warm_up_step: 4000
anneal_steps: [300000, 400000, 500000]
anneal_rate: 0.3
Expand Down
4 changes: 2 additions & 2 deletions model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def forward(
x = x + pitch_embedding
if self.energy_feature_level == "phoneme_level":
energy_prediction, energy_embedding = self.get_energy_embedding(
x, energy_target, src_mask, p_control
x, energy_target, src_mask, e_control
)
x = x + energy_embedding

Expand All @@ -143,7 +143,7 @@ def forward(
x = x + pitch_embedding
if self.energy_feature_level == "frame_level":
energy_prediction, energy_embedding = self.get_energy_embedding(
x, energy_target, mel_mask, p_control
x, energy_target, mel_mask, e_control
)
x = x + energy_embedding

Expand Down
18 changes: 11 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import PIL

import matplotlib.pyplot as plt
# import plotly
# import plotly.plotly as py
# import plotly.tools as tls
from chart_studio import plotly
import chart_studio.plotly as py
import plotly.tools as tls
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -109,14 +110,17 @@ def main(args, configs):
# Backward
total_loss = total_loss / grad_acc_step


total_loss.backward()

if step % grad_acc_step == 0:
# Clipping gradients to avoid gradient explosion
nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)

# Update weights
optimizer.step_and_update_lr()
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)

if math.isnan(grad_norm):
print("grad_norm is nan. Not Updating.")
else:
optimizer.step_and_update_lr()
optimizer.zero_grad()

if step % log_step == 0:
Expand Down

0 comments on commit cf3b498

Please sign in to comment.