Skip to content

Commit

Permalink
Suffix W&B runs with task group ID for offline Taskcluster publicatio…
Browse files Browse the repository at this point in the history
…n from GCP (#799)

* Use task group ID as suffix for offline Taskcluster publication from GCP

* Fix group_logs publication

* Add a mode to support GCP experiments from Taskcluster in a generic way

* Fix metrics path for GCP experiments that ran on Taskcluster

* Ignore old snakemake metrics that cannot be parsed

* Update tests

* Do not parse metrics name for new GCP experiments (taskcluster)

* Add tests for metrics filename parser

* Add a parser for GCP metrics filename support

* Support Taskcluster metrics structure in WandB.publish_group_logs

* Patch model name in group_logs

* Patch model suffix in group_logs

* Add details to value error exceptions

* Do continue on unsupported filename (Snakemake)

* Preserve legacy metrics dir for snakemake experiments

* Rework the GCP file structure browsing

* Fixes

* Include quantized metrics

* Update tests
  • Loading branch information
vrigal authored Sep 13, 2024
1 parent 98f8f1c commit 241f168
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 95 deletions.
39 changes: 19 additions & 20 deletions tests/test_tracking_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def test_taskcluster(wandb_mock, getargs_mock, disable_wandb, caplog, samples_di

@patch(
"translations_parser.cli.experiments.get_args",
return_value=argparse.Namespace(directory=Path(__file__).parent / "data" / "experiments_1_10"),
return_value=argparse.Namespace(
directory=Path(__file__).parent / "data" / "experiments_1_10",
mode="snakemake",
),
)
@patch("translations_parser.publishers.wandb")
def test_experiments_marian_1_10(
Expand All @@ -117,11 +120,12 @@ def test_experiments_marian_1_10(
assert set([(level, message) for _module, level, message in caplog.record_tuples]) == set(
[
(logging.INFO, "Reading 3 train.log data"),
# student
(
logging.INFO,
f"Parsing folder {samples_dir}/experiments_1_10/models/en-nl/prod/student",
f"Parsing folder {samples_dir}/experiments_1_10/models/en-nl/prod",
),
# student
(logging.INFO, "Handling training task student"),
(logging.INFO, "Reading logs stream."),
(logging.INFO, "Detected Marian version 1.10"),
(logging.INFO, "Reading Marian command line arguments."),
Expand All @@ -133,19 +137,13 @@ def test_experiments_marian_1_10(
(logging.INFO, "Found 550 training entries"),
(logging.INFO, "Found 108 validation entries"),
# teacher-finetuned0
(
logging.INFO,
f"Parsing folder {samples_dir}/experiments_1_10/models/en-nl/prod/teacher-finetuned0",
),
(logging.INFO, "Handling training task teacher-finetune-0"),
(logging.INFO, "Reading logs stream."),
(logging.INFO, "Successfully parsed 1944 lines"),
(logging.INFO, "Found 567 training entries"),
(logging.INFO, "Found 189 validation entries"),
# teacher-finetuned1
(
logging.INFO,
f"Parsing folder {samples_dir}/experiments_1_10/models/en-nl/prod/teacher-finetuned1",
),
(logging.INFO, "Handling training task teacher-finetune-1"),
(logging.INFO, "Reading logs stream."),
(logging.INFO, "Successfully parsed 1963 lines"),
(logging.INFO, "Found 573 training entries"),
Expand Down Expand Up @@ -225,7 +223,10 @@ def test_experiments_marian_1_10(

@patch(
"translations_parser.cli.experiments.get_args",
return_value=argparse.Namespace(directory=Path(__file__).parent / "data" / "experiments_1_12"),
return_value=argparse.Namespace(
directory=Path(__file__).parent / "data" / "experiments_1_12",
mode="snakemake",
),
)
@patch("translations_parser.publishers.wandb")
def test_experiments_marian_1_12(
Expand All @@ -243,24 +244,22 @@ def test_experiments_marian_1_12(
assert set([(level, message) for _module, level, message in caplog.record_tuples]) == set(
[
(logging.INFO, "Reading 2 train.log data"),
(logging.INFO, "Detected Marian version 1.12"),
(logging.INFO, "Reading Marian command line arguments."),
(
logging.INFO,
"Extra configuration files can only be retrieved in Taskcluster context, skipping.",
f"Parsing folder {samples_dir}/experiments_1_12/models/fi-en/opusprod",
),
(logging.INFO, "Detected Marian version 1.12"),
(logging.INFO, "Reading Marian command line arguments."),
(
logging.INFO,
f"Parsing folder {samples_dir}/experiments_1_12/models/fi-en/opusprod/student",
"Extra configuration files can only be retrieved in Taskcluster context, skipping.",
),
(logging.INFO, "Handling training task student"),
(logging.INFO, "Reading logs stream."),
(logging.INFO, "Successfully parsed 1533 lines"),
(logging.INFO, "Found 405 training entries"),
(logging.INFO, "Found 79 validation entries"),
(
logging.INFO,
f"Parsing folder {samples_dir}/experiments_1_12/models/fi-en/opusprod/student-finetuned",
),
(logging.INFO, "Handling training task student-finetune"),
(logging.INFO, "Reading logs stream."),
(logging.INFO, "Successfully parsed 1174 lines"),
(logging.INFO, "Found 330 training entries"),
Expand Down
82 changes: 81 additions & 1 deletion tests/test_tracking_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pytest
from fixtures import get_full_taskgraph

from tracking.translations_parser.utils import ParsedTaskLabel, build_task_name, parse_task_label
from tracking.translations_parser.utils import (
ParsedTaskLabel,
build_task_name,
parse_task_label,
parse_gcp_metric,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -129,3 +134,78 @@ def test_parse_labels_on_full_taskgraph():
def test_build_task_name(task_tags, values):
task = {"tags": task_tags}
assert build_task_name(task) == values


@pytest.mark.parametrize(
"filename, parsed_values",
[
(
"flores_aug-mix_devtest",
("flores", "aug-mix", "devtest"),
),
(
"flores_aug-title_devtest",
("flores", "aug-title", "devtest"),
),
(
"flores_aug-title-strict_devtest",
("flores", "aug-title-strict", "devtest"),
),
(
"flores_aug-typos_devtest",
("flores", "aug-typos", "devtest"),
),
(
"flores_aug-upper_devtest",
("flores", "aug-upper", "devtest"),
),
(
"flores_aug-upper-strict_devtest",
("flores", "aug-upper-strict", "devtest"),
),
(
"flores_devtest",
("flores", None, "devtest"),
),
(
"mtdata_aug-mix_Neulab-tedtalks_test-1-eng-lit",
("mtdata", "aug-mix", "Neulab-tedtalks_test-1-eng-lit"),
),
(
"mtdata_Neulab-tedtalks_test-1-eng-lit",
("mtdata", None, "Neulab-tedtalks_test-1-eng-lit"),
),
(
"sacrebleu_aug-mix_wmt19",
("sacrebleu", "aug-mix", "wmt19"),
),
(
"sacrebleu_aug-title-strict_wmt19",
("sacrebleu", "aug-title-strict", "wmt19"),
),
(
"sacrebleu_aug-title_wmt19",
("sacrebleu", "aug-title", "wmt19"),
),
(
"sacrebleu_aug-typos_wmt19",
("sacrebleu", "aug-typos", "wmt19"),
),
(
"sacrebleu_aug-upper-strict_wmt19",
("sacrebleu", "aug-upper-strict", "wmt19"),
),
(
"sacrebleu_wmt19",
("sacrebleu", None, "wmt19"),
),
],
)
def test_gcp_metric(filename, parsed_values):
assert tuple(parse_gcp_metric(filename)) == parsed_values


@pytest.mark.parametrize("filename", ["devtest", "tc_Tatoeba-Challenge-v2021-08-07", "test"])
def test_wrong_gcp_metric(filename):
with pytest.raises(ValueError):
parse_gcp_metric(filename)
137 changes: 85 additions & 52 deletions tracking/translations_parser/cli/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,34 @@
import argparse
import logging
import os
from enum import Enum
from itertools import groupby
from pathlib import Path

from translations_parser.data import Metric
from translations_parser.parser import TrainingParser
from translations_parser.publishers import WandB
from translations_parser.utils import parse_task_label
from translations_parser.utils import parse_task_label, parse_gcp_metric

logger = logging.getLogger(__name__)


class ExperimentMode(Enum):
SNAKEMAKE = "snakemake"
TASKCLUSTER = "taskcluster"


def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Publish multiple experiments to Weight & Biases")
parser.add_argument(
"--mode",
"-m",
help="Mode to publish experiments.",
type=ExperimentMode,
choices=ExperimentMode,
metavar=[e.value for e in ExperimentMode],
required=True,
)
parser.add_argument(
"--directory",
"-d",
Expand All @@ -40,6 +55,7 @@ def parse_experiment(
suffix: str,
logs_file: Path,
metrics_dir: Path | None = None,
mode=ExperimentMode,
) -> None:
"""
Parse logs from a Taskcluster dump and publish data to W&B.
Expand All @@ -48,8 +64,19 @@ def parse_experiment(
metrics = []
if metrics_dir:
for metrics_file in metrics_dir.glob("*.metrics"):
importer, dataset = metrics_file.stem.split("_", 1)
metrics.append(Metric.from_file(metrics_file, importer=importer, dataset=dataset))
try:
metric_attrs = parse_gcp_metric(metrics_file.stem)
except ValueError:
logger.error(f"Error parsing metric from GCP: {metrics_file.stem}. Skipping.")
else:
metrics.append(
Metric.from_file(
metrics_file,
importer=metric_attrs.importer,
dataset=metric_attrs.dataset,
augmentation=metric_attrs.augmentation,
)
)

with logs_file.open("r") as f:
lines = (line.strip() for line in f.readlines())
Expand All @@ -71,44 +98,56 @@ def parse_experiment(
def main() -> None:
args = get_args()
directory = args.directory
mode = args.mode

# Ignore files with a different name than "train.log"
file_groups = {
path: list(files)
for path, files in groupby(
sorted(directory.glob("**/train.log")), lambda path: path.parent
)
}
logger.info(f"Reading {len(file_groups)} train.log data")
prefix = os.path.commonprefix([path.parts for path in file_groups])
train_files = sorted(directory.glob("**/train.log"))

logger.info(f"Reading {len(train_files)} train.log data")
prefix = os.path.commonprefix([path.parts for path in train_files])

# Move on top of the main models (Snakemake) or logs (Taskcluster) folder
if "models" in prefix:
prefix = prefix[: prefix.index("models") + 1]
prefix = prefix[: prefix.index("models")]
if "logs" in prefix:
prefix = prefix[: prefix.index("logs")]

# First parent folder correspond to the run name, second one is the group
groups = groupby(train_files, lambda path: path.parent.parent)

last_index = None
existing_runs = []
for index, (path, files) in enumerate(file_groups.items(), start=1):
for path, files in groups:
logger.info(f"Parsing folder {path.resolve()}")
parents = path.parts[len(prefix) :]
if len(parents) < 3:
logger.warning(f"Skipping folder {path.resolve()}: Unexpected folder structure")
continue
project, group, *name = parents
base_name = name[0]
name = "_".join(name)
# Directly use group name as a suffix from GCP experiments, since we don't have access to the task group ID
suffix = f"_{group}"
try:
name = parse_task_label(f"train-{name}").model
except ValueError:
logger.error(f"Invalid tag extracted from file @{path}: '{name}'")
continue
*_, project, group = path.parts
if mode == ExperimentMode.TASKCLUSTER:
if len(group) < 22:
logger.error(
f"Skip folder {group} as it cannot contain a task group ID (too few caracters)."
)
continue
suffix = f"_{group[-22:-17]}"
else:
# Use the full experiment name as a suffix for old Snakemake experiments
suffix = f"_{group}"

# Publish a run for each file inside that group
published_runs = []
for file in files:
try:
tag = f"train-{file.parent.name}"
name = parse_task_label(tag).model
except ValueError:
logger.error(f"Invalid tag extracted from file @{path}: {tag}")
continue
logger.info(f"Handling training task {name}")

# Also publish metric files when available
metrics_path = Path("/".join([*prefix, project, group, "evaluation", base_name]))
metrics_path = Path(
"/".join([*prefix, "models", project, group, "evaluation", file.parent.name])
)
metrics_dir = metrics_path if metrics_path.is_dir() else None
if metrics_dir is None:
logger.warning("Evaluation metrics files not found, skipping.")
logger.warning(f"Evaluation metrics files not found for {name}.")

try:
parse_experiment(
project=project,
Expand All @@ -117,28 +156,22 @@ def main() -> None:
suffix=suffix,
logs_file=file,
metrics_dir=metrics_dir,
mode=mode,
)
existing_runs.append(name)
except Exception as e:
logger.error(f"An exception occured parsing {file}: {e}")
logger.error(f"An exception occured parsing training file {file}: {e}")
else:
published_runs.append(name)

# Try to publish related log files to the group on a last run named "group_logs"
if index == len(file_groups) or last_index and last_index != (project, group):
last_project, last_group = (
last_index
if last_index
# May occur when handling a single run
else (project, group)
)
logger.info(
f"Publishing '{last_project}/{last_group}' evaluation metrics and files (fake run 'group_logs')"
)
WandB.publish_group_logs(
logs_parent_folder=prefix,
project=last_project,
group=last_group,
suffix=suffix,
existing_runs=existing_runs,
)
existing_runs = []
last_index = (project, group)
logger.info(
f"Publishing '{project}/{group}' evaluation metrics and files (fake run 'group_logs')"
)
WandB.publish_group_logs(
logs_parent_folder=[*prefix, "logs"],
project=project,
group=group,
suffix=suffix,
existing_runs=published_runs,
snakemake=(mode == ExperimentMode.SNAKEMAKE.value),
)
Loading

0 comments on commit 241f168

Please sign in to comment.