Skip to content

Commit

Permalink
Update pylint 2.10.2 and fix lint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Aug 30, 2021
1 parent ccef20b commit 18da8f5
Show file tree
Hide file tree
Showing 24 changed files with 60 additions and 105 deletions.
2 changes: 1 addition & 1 deletion TTS/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f:
with open(os.path.join(os.path.dirname(__file__), "VERSION"), 'r', encoding='utf-8') as f:
version = f.read().strip()

__version__ = version
2 changes: 1 addition & 1 deletion TTS/bin/compute_attention_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
# ourput metafile
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")

with open(metafile, "w") as f:
with open(metafile, "w", encoding="utf-8") as f:
for p in file_paths:
f.write(f"{p[0]}|{p[1]}\n")
print(f" >> Metafile created: {metafile}")
2 changes: 1 addition & 1 deletion TTS/bin/extract_tts_spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def extract_spectrograms(
wav = ap.inv_melspectrogram(mel)
ap.save_wav(wav, wav_gl_path)

with open(os.path.join(output_path, metada_name), "w") as f:
with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f:
for data in export_metadata:
f.write(f"{data[0]}|{data[1]+'.npy'}\n")

Expand Down
18 changes: 7 additions & 11 deletions TTS/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,31 @@ class BaseModel(nn.Module, ABC):
"""

@abstractmethod
def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict:
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
"""Forward pass for the model mainly used in training.
You can be flexible here and use different number of arguments and argument names since it is mostly used by
`train_step()` in training whitout exposing it to the out of the class.
You can be flexible here and use different number of arguments and argument names since it is intended to be
used by `train_step()` without exposing it out of the model.
Args:
text (torch.Tensor): Input text character sequence ids.
input (torch.Tensor): Input tensor.
aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs.
for the model.
Returns:
Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model.
Dict: Model outputs. Main model output must be named as "model_outputs".
"""
outputs_dict = {"model_outputs": None}
...
return outputs_dict

@abstractmethod
def inference(self, text: torch.Tensor, aux_input={}) -> Dict:
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
"""Forward pass for inference.
After the model is trained this is the only function that connects the model the out world.
This function must only take a `text` input and a dictionary that has all the other model specific inputs.
We don't use `*kwargs` since it is problematic with the TorchScript API.
Args:
text (torch.Tensor): [description]
input (torch.Tensor): [description]
aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.
Returns:
Expand Down
2 changes: 1 addition & 1 deletion TTS/speaker_encoder/losses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn


# adapted from https://github.com/cvqluu/GE2E-Loss
Expand Down
2 changes: 1 addition & 1 deletion TTS/speaker_encoder/models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import torch
import torch.nn as nn
from torch import nn

from TTS.utils.io import load_fsspec

Expand Down
3 changes: 2 additions & 1 deletion TTS/speaker_encoder/utils/prepare_voxceleb.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def download_and_extract(directory, subset, urls):
extract_path = zip_filepath.strip(".zip")

# check zip file md5sum
md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest()
with open(zip_filepath, "rb") as f_zip:
md5 = hashlib.md5(f_zip.read()).hexdigest()
if md5 != MD5SUM[subset]:
raise ValueError("md5sum of %s mismatch" % zip_filepath)

Expand Down
11 changes: 6 additions & 5 deletions TTS/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,13 +631,13 @@ def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_ti
outputs = outputs_per_optimizer

# update avg runtime stats
keep_avg_update = dict()
keep_avg_update = {}
keep_avg_update["avg_loader_time"] = loader_time
keep_avg_update["avg_step_time"] = step_time
self.keep_avg_train.update_values(keep_avg_update)

# update avg loss stats
update_eval_values = dict()
update_eval_values = {}
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
self.keep_avg_train.update_values(update_eval_values)
Expand Down Expand Up @@ -797,7 +797,7 @@ def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
loss_dict = self._detach_loss_dict(loss_dict)

# update avg stats
update_eval_values = dict()
update_eval_values = {}
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
self.keep_avg_eval.update_values(update_eval_values)
Expand Down Expand Up @@ -977,12 +977,13 @@ class Logger(object):
def __init__(self, print_to_terminal=True):
self.print_to_terminal = print_to_terminal
self.terminal = sys.stdout
self.log = open(log_file, "a")
self.log_file = log_file

def write(self, message):
if self.print_to_terminal:
self.terminal.write(message)
self.log.write(message)
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(message)

def flush(self):
# this flush method is needed for python 3 compatibility.
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def load_meta_data(datasets, eval_split=True):

def load_attention_mask_meta_data(metafile_path):
"""Load meta data file created by compute_attention_masks.py"""
with open(metafile_path, "r") as f:
with open(metafile_path, "r", encoding="utf-8") as f:
lines = f.readlines()

meta_data = []
Expand Down
34 changes: 17 additions & 17 deletions TTS/tts/datasets/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def tweb(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "tweb"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav")
Expand All @@ -33,7 +33,7 @@ def mozilla(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = cols[1].strip()
Expand Down Expand Up @@ -77,7 +77,7 @@ def mailabs(root_path, meta_files=None):
continue
speaker_name = speaker_name_match.group("speaker_name")
print(" | > {}".format(csv_file))
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
if meta_files is None:
Expand All @@ -102,7 +102,7 @@ def ljspeech(root_path, meta_file):
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1]
text = cols[2]
items.append([text, wav_file, speaker_name])
return items

Expand All @@ -116,7 +116,7 @@ def ljspeech_test(root_path, meta_file):
for idx, line in enumerate(ttf):
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1]
text = cols[2]
items.append([text, wav_file, f"ljspeech-{idx}"])
return items

Expand Down Expand Up @@ -158,7 +158,7 @@ def css10(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
Expand All @@ -172,7 +172,7 @@ def nancy(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "nancy"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
utt_id = line.split()[1]
text = line[line.find('"') + 1 : line.rfind('"') - 1]
Expand All @@ -185,7 +185,7 @@ def common_voice(root_path, meta_file):
"""Normalize the common voice meta data file to TTS format."""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
if line.startswith("client_id"):
continue
Expand All @@ -208,7 +208,7 @@ def libri_tts(root_path, meta_files=None):

for meta_file in meta_files:
_meta_file = os.path.basename(meta_file).split(".")[0]
with open(meta_file, "r") as ttf:
with open(meta_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("\t")
file_name = cols[0]
Expand Down Expand Up @@ -245,7 +245,7 @@ def brspeech(root_path, meta_file):
"""BRSpeech 3.0 beta"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
if line.startswith("wav_filename"):
continue
Expand All @@ -268,7 +268,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"):
if isinstance(test_speakers, list): # if is list ignore this speakers ids
if speaker_id in test_speakers:
continue
with open(meta_file) as file_text:
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([text, wav_file, "VCTK_" + speaker_id])
Expand All @@ -295,7 +295,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
def mls(root_path, meta_files=None):
"""http://www.openslr.org/94/"""
items = []
with open(os.path.join(root_path, meta_files), "r") as meta:
with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
for line in meta:
file, text = line.split("\t")
text = text[:-1]
Expand Down Expand Up @@ -329,7 +329,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):

# if not exists meta file, crawl recursively for 'wav' files
if meta_file is not None:
with open(str(meta_file), "r") as f:
with open(str(meta_file), "r", encoding="utf-8") as f:
return [x.strip().split("|") for x in f.readlines()]

elif not cache_to.exists():
Expand All @@ -346,12 +346,12 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
text = None # VoxCel does not provide transciptions, and they are not needed for training the SE
meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n")
cnt += 1
with open(str(cache_to), "w") as f:
with open(str(cache_to), "w", encoding="utf-8") as f:
f.write("".join(meta_data))
if cnt < expected_count:
raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}")

with open(str(cache_to), "r") as f:
with open(str(cache_to), "r", encoding="utf-8") as f:
return [x.strip().split("|") for x in f.readlines()]


Expand All @@ -367,7 +367,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "baker"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name)
Expand All @@ -380,7 +380,7 @@ def kokoro(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "kokoro"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/generic/transformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn


class FFTransformer(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/tacotron/gst_layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn


class GST(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/layers/tacotron/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def decode(self, inputs, mask=None):
decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1))

# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx])
for idx, decoder_rnn in enumerate(self.decoder_rnns):
self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx])
# Residual connection
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/align_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Dict, Tuple

import torch
import torch.nn as nn
from coqpit import Coqpit
from torch import nn

from TTS.tts.layers.align_tts.mdn import MDNBlock
from TTS.tts.layers.feed_forward.decoder import Decoder
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def forward(
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
loss_duration = loss_duration/ torch.sum(x_mask)
loss_duration = loss_duration / torch.sum(x_mask)
else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor(
Expand Down Expand Up @@ -579,7 +579,7 @@ def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> T
scores_disc_fake=outputs["scores_disc_fake"],
feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"],
loss_duration=outputs["loss_duration"]
loss_duration=outputs["loss_duration"],
)

elif optimizer_idx == 1:
Expand Down
43 changes: 0 additions & 43 deletions TTS/utils/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,46 +18,3 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):

# Initialize distributed communication
dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name)


def apply_gradient_allreduce(module):

# sync model parameters
for p in module.state_dict().values():
if not torch.is_tensor(p):
continue
dist.broadcast(p, 0)

def allreduce_params():
if module.needs_reduction:
module.needs_reduction = False
# bucketing params based on value types
buckets = {}
for param in module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.reduce_op.SUM)
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

for param in list(module.parameters()):

def allreduce_hook(*_):
Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access

if param.requires_grad:
param.register_hook(allreduce_hook)

def set_needs_reduction(self, *_):
self.needs_reduction = True

module.register_forward_hook(set_needs_reduction)
return module
Loading

0 comments on commit 18da8f5

Please sign in to comment.