Skip to content

Commit

Permalink
Bug fix in freeze encoder (coqui-ai#1391)
Browse files Browse the repository at this point in the history
* Fix the bug in freeze encoder

* Remove emb_l definition for non-multilingual training

* Fix unit tests
  • Loading branch information
Edresson authored Mar 24, 2022
1 parent 464dc65 commit 37896e1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
1 change: 0 additions & 1 deletion TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,6 @@ def init_multilingual(self, config: Coqpit):
torch.nn.init.xavier_uniform_(self.emb_l.weight)
else:
self.embedded_language_dim = 0
self.emb_l = None

def get_aux_input(self, aux_input: Dict):
sid, g, lid = self._set_cond_input(aux_input)
Expand Down
8 changes: 4 additions & 4 deletions tests/tts_tests/test_vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,25 @@ def test_init_multilingual(self):
model = Vits(args)
self.assertEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, 0)
self.assertEqual(model.emb_l, None)
assertHasNotAttr(self, model, "emb_l")

args = VitsArgs(language_ids_file=LANG_FILE)
model = Vits(args)
self.assertNotEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, 0)
self.assertEqual(model.emb_l, None)
assertHasNotAttr(self, model, "emb_l")

args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True)
model = Vits(args)
self.assertNotEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
self.assertNotEqual(model.emb_l, None)
assertHasAttr(self, model, "emb_l")

args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102)
model = Vits(args)
self.assertNotEqual(model.language_manager, None)
self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
self.assertNotEqual(model.emb_l, None)
assertHasAttr(self, model, "emb_l")

def test_get_aux_input(self):
aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}
Expand Down

0 comments on commit 37896e1

Please sign in to comment.