Skip to content

Commit

Permalink
Fix inference on cpu device (babysor#241)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzxiang authored Nov 29, 2021
1 parent a4daf42 commit 4728863
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion synthesizer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def load(self):
stop_threshold=hparams.tts_stop_threshold,
speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)

self._model.load(self.model_fpath)
self._model.load(self.model_fpath, self.device)
self._model.eval()

if self.verbose:
Expand Down
8 changes: 5 additions & 3 deletions synthesizer/models/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,9 @@ def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_
# put after encoder
if hparams.use_gst and self.gst is not None:
if style_idx >= 0 and style_idx < 10:
query = torch.zeros(1, 1, self.gst.stl.attention.num_units).cuda()
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
if device.type == 'cuda':
query = query.cuda()
gst_embed = torch.tanh(self.gst.stl.embed)
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
style_embed = self.gst.stl.attention(query, key)
Expand Down Expand Up @@ -539,9 +541,9 @@ def log(self, path, msg):
with open(path, "a") as f:
print(msg, file=f)

def load(self, path, optimizer=None):
def load(self, path, device, optimizer=None):
# Use device of model params as location for loaded state
checkpoint = torch.load(str(path))
checkpoint = torch.load(str(path), map_location=device)
self.load_state_dict(checkpoint["model_state"], strict=False)

if "optimizer_state" in checkpoint and optimizer is not None:
Expand Down
2 changes: 1 addition & 1 deletion synthesizer/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def run_synthesis(in_dir, out_dir, model_dir, hparams):
model_dir = Path(model_dir)
model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
print("\nLoading weights at %s" % model_fpath)
model.load(model_fpath)
model.load(model_fpath, device)
print("Tacotron weights loaded from step %d" % model.step)

# Synthesize using same reduction factor as the model is currently trained
Expand Down
2 changes: 1 addition & 1 deletion synthesizer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,

else:
print("\nLoading weights at %s" % weights_fpath)
model.load(weights_fpath, optimizer)
model.load(weights_fpath, device, optimizer)
print("Tacotron weights loaded from step %d" % model.step)

# Initialize the dataset
Expand Down

0 comments on commit 4728863

Please sign in to comment.