Skip to content

Commit

Permalink
fix type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Jun 28, 2021
1 parent 87c61d2 commit 3021151
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions TTS/tts/utils/speakers.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def get_x_vectors_by_speaker(self, speaker_idx: str) -> List[List]:
"""
return [x["embedding"] for x in self.x_vectors.values() if x["name"] == speaker_idx]

def get_mean_x_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.Array:
def get_mean_x_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
"""Get mean x_vector of a speaker ID.
Args:
Expand All @@ -250,7 +250,7 @@ def get_mean_x_vector(self, speaker_idx: str, num_samples: int = None, randomize
randomize (bool, optional): Pick random `num_samples`of x_vectors. Defaults to False.
Returns:
np.Array: Mean x_vector.
np.ndarray: Mean x_vector.
"""
x_vectors = self.get_x_vectors_by_speaker(speaker_idx)
if num_samples is None:
Expand Down Expand Up @@ -315,11 +315,11 @@ def _compute(wav_file: str):
x_vector = _compute(wav_file)
return x_vector[0].tolist()

def compute_x_vector(self, feats: Union[torch.Tensor, np.Array]) -> List:
def compute_x_vector(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
"""Compute x_vector from features.
Args:
feats (Union[torch.Tensor, np.Array]): Input features.
feats (Union[torch.Tensor, np.ndarray]): Input features.
Returns:
List: computed x_vector.
Expand Down

0 comments on commit 3021151

Please sign in to comment.