Skip to content

Commit

Permalink
Remove unused test dl
Browse files Browse the repository at this point in the history
  • Loading branch information
enhuiz committed Jan 18, 2023
1 parent 2e9f503 commit d80ef1d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 34 deletions.
43 changes: 14 additions & 29 deletions vall_e/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,20 +244,14 @@ def _load_train_val_paths():
return train_paths, val_paths


def _load_test_paths():
test_paths = []
for data_dir in cfg.test_data_dirs:
test_paths.extend(data_dir.rglob("*.phn.txt"))
test_paths = sorted(test_paths)
return test_paths


@cfg.diskcache()
def create_datasets():
train_paths, val_paths = _load_train_val_paths()
test_paths = _load_test_paths()

train_dataset = VALLEDatset(train_paths, training=True)
train_dataset = VALLEDatset(
train_paths,
training=True,
)

val_dataset = VALLEDatset(
val_paths,
Expand All @@ -269,41 +263,32 @@ def create_datasets():
val_dataset.interleaved_reorder_(_get_spkr_name)
val_dataset.head_(cfg.max_num_val)

test_dataset = VALLEDatset(
test_paths,
train_dataset.phone_symmap,
train_dataset.spkr_symmap,
extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
)

return train_dataset, val_dataset, test_dataset
return train_dataset, val_dataset


def create_train_val_dataloader():
train_dataset, val_dataset, test_dataset = create_datasets()
train_dataset, val_dataset = create_datasets()

train_dl = _create_dl(train_dataset, training=True)
val_dl = _create_dl(val_dataset, training=False)
test_dl = _create_dl(test_dataset, training=False)

_logger.info(str(train_dataset.phone_symmap))
_logger.info(str(train_dataset.spkr_symmap))

_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#samples (test): {len(test_dataset)}.")

train_for_val_dataset = copy.deepcopy(train_dataset)
train_for_val_dataset.interleaved_reorder_(_get_spkr_name)
train_for_val_dataset.head_(cfg.max_num_val)
train_for_val_dataset.training_(False)
train_for_val_dl = _create_dl(train_for_val_dataset, training=False)
assert isinstance(train_for_val_dl.dataset, VALLEDatset)
subtrain_dataset = copy.deepcopy(train_dataset)
subtrain_dataset.interleaved_reorder_(_get_spkr_name)
subtrain_dataset.head_(cfg.max_num_val)
subtrain_dataset.training_(False)
subtrain_dl = _create_dl(subtrain_dataset, training=False)
assert isinstance(subtrain_dl.dataset, VALLEDatset)

return train_dl, train_for_val_dl, val_dl, test_dl
return train_dl, subtrain_dl, val_dl


if __name__ == "__main__":
train_dl, train_for_val_dl, val_dl, test_dl = create_train_val_dataloader()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
sample = train_dl.dataset[0]
print(sample)
8 changes: 3 additions & 5 deletions vall_e/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def load_engines():
def main():
setup_logging(cfg.log_dir)

train_dl, train_for_val_dl, val_dl, test_dl = create_train_val_dataloader()
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()

def train_feeder(engines, batch, name):
model = engines["model"]
Expand Down Expand Up @@ -68,8 +68,7 @@ def run_eval(engines, name, dl):
log_dir = cfg.log_dir / str(engines.global_step) / name
stats = defaultdict(list)
for batch in tqdm(dl):
batch: dict
batch = to_device(batch, cfg.device)
batch: dict = to_device(batch, cfg.device)

if cfg.model.startswith("ar"):
resp_list = model(
Expand Down Expand Up @@ -114,9 +113,8 @@ def run_eval(engines, name, dl):
_logger.info(f"{json.dumps(stats)}.")

def eval_fn(engines):
run_eval(engines, "train_for_val", train_for_val_dl)
run_eval(engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl)
run_eval(engines, "test", test_dl)

trainer.train(
engines_loader=load_engines,
Expand Down

0 comments on commit d80ef1d

Please sign in to comment.