Skip to content

Commit

Permalink
Consistent dataset names (#494)
Browse files Browse the repository at this point in the history
* Consistent evaluation tags parsing

* Add test

* Support backwards training task label

* Support evaluation task with suffix

* Support suffixes with form -1/2

---------

Co-authored-by: Evgeny Pavlov <[email protected]>
  • Loading branch information
vrigal and eu9ene authored Mar 28, 2024
1 parent 015a74d commit 418a1e4
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 48 deletions.
52 changes: 52 additions & 0 deletions tests/test_tracking_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest

from tracking.translations_parser.utils import parse_tag


@pytest.mark.parametrize(
"example, parsed_values",
[
(
"evaluate-teacher-flores-flores_aug-title_devtest-lt-en-1_2",
("teacher", "flores", "devtest", "aug-title"),
),
(
"evaluate-quantized-mtdata_aug-mix_Neulab-tedtalks_eng-lit-lt-en-1_2",
("quantized", "mtdata", "Neulab-tedtalks_eng-lit", "aug-mix"),
),
(
"evaluate-finetuned-student-sacrebleu-wmt19-lt-en",
("finetuned-student", "sacrebleu", "wmt19", None),
),
(
"evaluate-student-2-sacrebleu-wmt19-lt-en",
("student-2", "sacrebleu", "wmt19", None),
),
(
"train-student-en-hu",
("student", None, None, None),
),
(
"eval_teacher-ensemble_mtdata_Neulab-tedtalks_test-1-eng-nld",
("teacher-ensemble", "mtdata", "Neulab-tedtalks_test-1-eng-nld", None),
),
(
"eval_student-finetuned_flores_devtest",
("student-finetuned", "flores", "devtest", None),
),
(
"eval_teacher-base0_flores_devtest",
("teacher-base0", "flores", "devtest", None),
),
(
"train-backwards-en-ca",
("backwards", None, None, None),
),
(
"evaluate-teacher-flores-flores_dev-en-ca-1/2",
("teacher", "flores", "dev", None),
),
],
)
def test_parse_tag(example, parsed_values):
assert parse_tag(example) == parsed_values
3 changes: 2 additions & 1 deletion tracking/translations_parser/cli/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def parse_experiment(
metrics = []
if metrics_dir:
for metrics_file in metrics_dir.glob("*.metrics"):
metrics.append(Metric.from_file(metrics_file, dataset=metrics_file.stem))
importer, dataset = metrics_file.stem.split("_", 1)
metrics.append(Metric.from_file(metrics_file, importer=importer, dataset=dataset))

with logs_file.open("r") as f:
lines = (line.strip() for line in f.readlines())
Expand Down
12 changes: 8 additions & 4 deletions tracking/translations_parser/cli/taskcluster_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from translations_parser.data import Metric
from translations_parser.parser import TrainingParser, logger
from translations_parser.publishers import WandB
from translations_parser.utils import extract_dataset_from_tag
from translations_parser.utils import parse_tag

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_metrics_from_task(task: dict) -> list[Metric]:
with file.open("wb") as log_file:
log_file.write(log.tobytes())
log_file.flush()
metrics.append(Metric.from_file(Path(log_file.name), sep="-"))
metrics.append(Metric.from_file(Path(log_file.name)))

return metrics

Expand Down Expand Up @@ -229,7 +229,7 @@ def publish_task_group(group_id: str) -> None:
eval_label = eval_task["task"]["tags"].get("label", "")

try:
model_name, _, _ = extract_dataset_from_tag(eval_label, sep="-")
model_name, _, _, _ = parse_tag(eval_label)
except ValueError:
continue

Expand Down Expand Up @@ -283,6 +283,10 @@ def publish_task_group(group_id: str) -> None:

for metrics_task in metrics_tasks.values():
filename = metrics_task["task"]["tags"]["label"]
if re_match := MULTIPLE_TRAIN_SUFFIX.search(filename):
(suffix,) = re_match.groups()
filename = MULTIPLE_TRAIN_SUFFIX.sub(suffix, filename)

with (eval_folder / f"{filename}.log").open("wb") as log_file:
downloadArtifactToFile(
log_file,
Expand All @@ -299,7 +303,7 @@ def publish_task_group(group_id: str) -> None:
yaml.dump(config, config_file)

parents = str(logs_folder.resolve()).strip().split("/")
WandB.publish_group_logs(parents, project_name, group_name, existing_runs=[], tag_sep="-")
WandB.publish_group_logs(parents, project_name, group_name, existing_runs=[])


def list_dependent_group_ids(task_id: str, known: set[str]):
Expand Down
15 changes: 10 additions & 5 deletions tracking/translations_parser/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path
from typing import List, Sequence

from translations_parser.utils import extract_dataset_from_tag
from translations_parser.utils import parse_tag

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -53,6 +53,7 @@ class Metric:
"""Data extracted from a `.metrics` file"""

# Evaluation identifiers
importer: str
dataset: str
augmentation: str | None
# Scores
Expand All @@ -63,7 +64,7 @@ class Metric:
def from_file(
cls,
metrics_file: Path,
sep="_",
importer: str | None = None,
dataset: str | None = None,
augmentation: str | None = None,
):
Expand All @@ -85,17 +86,20 @@ def from_file(
except Exception as e:
raise ValueError(f"Metrics file could not be parsed: {e}")
bleu_detok, chrf = values
if dataset is None:
_, dataset, augmentation = extract_dataset_from_tag(metrics_file.stem, sep=sep)
if importer is None:
_, importer, dataset, augmentation = parse_tag(metrics_file.stem)
return cls(
importer=importer,
dataset=dataset,
augmentation=augmentation,
chrf=chrf,
bleu_detok=bleu_detok,
)

@classmethod
def from_tc_context(cls, dataset: str, lines: Sequence[str], augmentation: str | None = None):
def from_tc_context(
cls, importer: str, dataset: str, lines: Sequence[str], augmentation: str | None = None
):
"""
Try reading a metric from Taskcluster logs, looking for two
successive floats after a line maching METRIC_LOG_RE.
Expand All @@ -113,6 +117,7 @@ def from_tc_context(cls, dataset: str, lines: Sequence[str], augmentation: str |
continue
bleu_detok, chrf = values
return cls(
importer=importer,
dataset=dataset,
augmentation=augmentation,
chrf=chrf,
Expand Down
23 changes: 16 additions & 7 deletions tracking/translations_parser/publishers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import yaml

from translations_parser.data import Metric, TrainingEpoch, TrainingLog, ValidationEpoch
from translations_parser.utils import extract_dataset_from_tag
from translations_parser.utils import parse_tag

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -145,17 +145,22 @@ def handle_metrics(self, metrics: Sequence[Metric]) -> None:
if self.wandb is None:
return
for metric in metrics:
title = metric.importer
if metric.augmentation:
title = f"{title}_{metric.augmentation}"
if metric.dataset:
title = f"{title}_{metric.dataset}"
# Publish a bar chart (a table with values will also be available from W&B)
self.wandb.log(
{
metric.dataset: wandb.plot.bar(
title: wandb.plot.bar(
wandb.Table(
columns=["Metric", "Value"],
data=[[key, getattr(metric, key)] for key in ("bleu_detok", "chrf")],
),
"Metric",
"Value",
title=metric.dataset.capitalize(),
title=title,
)
}
)
Expand Down Expand Up @@ -183,7 +188,6 @@ def publish_group_logs(
project: str,
group: str,
existing_runs: list[str] | None = None,
tag_sep: str = "_",
) -> None:
"""
Publish files within `logs_dir` to W&B artifacts for a specific group.
Expand Down Expand Up @@ -215,14 +219,19 @@ def publish_group_logs(
metrics = defaultdict(list)
# Add "quantized" metrics
for file in quantized_metrics:
metrics["quantized"].append(Metric.from_file(file, dataset=file.stem))
importer, dataset = file.stem.split("_", 1)
metrics["quantized"].append(Metric.from_file(file, importer=importer, dataset=dataset))
# Add experiment (runs) metrics
for file in evaluation_metrics:
model_name, dataset, aug = extract_dataset_from_tag(file.stem, tag_sep)
model_name, importer, dataset, aug = parse_tag(file.stem)
with file.open("r") as f:
lines = f.readlines()
try:
metrics[model_name].append(Metric.from_tc_context(dataset, lines))
metrics[model_name].append(
Metric.from_tc_context(
importer=importer, dataset=dataset, lines=lines, augmentation=aug
)
)
except ValueError as e:
logger.error(f"Could not parse metrics from {file.resolve()}: {e}")

Expand Down
76 changes: 45 additions & 31 deletions tracking/translations_parser/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,51 @@
TAG_PROJECT_SUFFIX_REGEX = re.compile(r"((-\w{2}){2}|(-\w{3}){2})$")


def extract_dataset_from_tag(tag, sep="_") -> tuple[str, str, str | None]:
"""
Experiment tag usually has a structure like "<prefix>_<model_name>_<dataset>_<?augmentation>_<project>"
This function removes the prefix and suffix, and try to split model, dataset and optional augmentation.
"""
prefix, *name = tag.split(sep, 1)
if len(name) != 1:
raise ValueError(f"Tag could not be parsed: '{tag}'.")
model_name = name[0]
# Eventually remove suffix
name = TAG_PROJECT_SUFFIX_REGEX.sub("", model_name)
dataset = ""
aug = None
for keyword in DATASET_KEYWORDS:
if keyword in model_name:
index = model_name.index(keyword)
model_name, dataset = model_name[:index].rstrip(sep), model_name[index:]
break
else:
continue
if dataset:
# Look for augmentation information in the second part of the tag (dataset)
if "aug" in model_name:
index = model_name.index("aug")
dataset, aug = dataset[:index].rstrip(sep), dataset[index:]
else:
logger.warning(
f"No dataset could be extracted from {tag}."
" Please ensure utils.DATASET_KEYWORDS is up to date."
)
return model_name, dataset, aug
TRAIN_LABEL_REGEX = re.compile(
r"^"
r"train-"
r"(?P<model>"
r"(finetuned-student|student-finetuned|teacher-ensemble|teacher|teacher-base|teacher-finetuned"
r"|student|quantized|backwards|backward)"
r"(-?\d+)?"
r")"
r"[_-]"
r"(?P<lang>[a-z]{2}-[a-z]{2})"
r"-?"
r"-?(?P<suffix>[\d_\/]+)?$"
r"$"
)
EVAL_REGEX = re.compile(
r"^"
r"(evaluate|eval)[-_]"
r"(?P<model>"
r"(finetuned-student|student-finetuned|teacher-ensemble|teacher|teacher-base|teacher-finetuned"
r"|student|quantized|backwards|backward)"
r"(-?\d+)?"
r")"
r"[_-]"
r"(?P<importer>flores|mtdata|sacrebleu)"
r"(?P<extra_importer>-flores|-mtdata|-sacrebleu)?"
r"[_-]"
r"(?P<aug>aug-[^_]+)?"
r"_?(?P<dataset>[-\w_]*?(-[a-z]{3}-[a-z]{3})?)?"
r"-?(?P<lang>[a-z]{2}-[a-z]{2})?"
r"-?(?P<suffix>[\d_\/]+)?$"
r"$"
)


def parse_tag(tag, sep="_"):
# First try to parse a simple training label
match = TRAIN_LABEL_REGEX.match(tag)
if match is not None:
return match.groupdict()["model"], None, None, None
# Else try to parse an evaluation label with importer, dataset and auugmentation
match = EVAL_REGEX.match(tag)
if not match:
raise ValueError(tag)
groups = match.groupdict()
return groups["model"], groups["importer"], groups["dataset"], groups["aug"]


def taskcluster_log_filter(headers: Sequence[Sequence[str]]) -> bool:
Expand Down

0 comments on commit 418a1e4

Please sign in to comment.