Skip to content

Commit

Permalink
Update the statistics class to use data attributes and use Statistics…
Browse files Browse the repository at this point in the history
… in merge-mono.py (#853)
  • Loading branch information
gregtatum authored Sep 24, 2024
1 parent a8883aa commit 77479b3
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 110 deletions.
81 changes: 46 additions & 35 deletions pipeline/clean/merge-mono.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import argparse
import glob
import json
import os
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from pathlib import Path
from typing import Generator

from pipeline.common.datasets import WeakStringSet, shuffle_with_max_lines
from pipeline.common.datasets import (
CountingStep,
FilteringStep,
Statistics,
WeakStringSet,
shuffle_with_max_lines,
)
from pipeline.common.downloads import (
format_bytes,
get_human_readable_file_size,
Expand All @@ -23,37 +28,42 @@


@dataclass
class FilteringStatistics:
class FilteringStatistics(Statistics):
"""
Gather statistics about the filtering process.
"""

# The size of the merged parallel corpus.
parallel_corpus_lines: int = 0

# How much of the monolingual data was duplicated in the merged parallel corpus.
duplicates_of_parallel_corpus: int = 0
def __init__(self, dataset_path: Path) -> None:
super().__init__(dataset_path)
self.final_truncated_monolingual_lines = CountingStep(
"After truncation via the config's `experiment.mono-max-sentences-src.total`, "
"how many lines are left."
)

# How much of the monolingual data was duplicated across the monolingual datasets.
duplicates_of_monolingual_corpus: int = 0
self.final_truncated_monolingual_codepoints = CountingStep(
"The amount of codepoints in the final monolingual corpus."
)

# What was the size of the monolingual data that was filtered. This doesn't count the
# truncation of datasets at the datasets gathering time.
original_monolingual_lines: int = 0
self.parallel_corpus_lines = CountingStep(
"The size of the merged parallel corpus before truncation."
)

# After deduplication, how much monolingual data is left.
deduplicated_monolingual_lines: int = 0
self.duplicates_of_parallel_corpus = CountingStep(
"How much of the monolingual data was duplicated in the merged parallel corpus."
)

# After truncation via the config's `experiment.mono-max-sentences-src.total`,
# how many lines are left.
final_truncated_monolingual_lines: int = 0
self.duplicates_of_monolingual_corpus = CountingStep(
"How much of the monolingual data was duplicated across the monolingual datasets."
)

# The amount of codepoints in the final monolingual corpus.
final_truncated_monolingual_codepoints: int = 0
self.deduplicated_size = FilteringStep(
"What was the size of the monolingual data and how much was deduplicated. This "
"doesn't count the truncation of datasets at the datasets gathering time."
)

def save_json(self, path: Path) -> None:
with open(path, "w", encoding="utf-8") as json_file:
json.dump(asdict(self), json_file, indent=2)
self.deduplicated_monolingual_lines = CountingStep(
"After deduplication, how much monolingual data is left."
)


def filter_and_write_monolingual_data(
Expand Down Expand Up @@ -99,11 +109,13 @@ def deduplicate_lines(lines: Generator[str, None, None]) -> Generator[str, None,

yield line

stats.original_monolingual_lines = parallel_discards + mono_discards + retained
stats.duplicates_of_parallel_corpus = parallel_discards
stats.duplicates_of_monolingual_corpus = mono_discards
stats.deduplicated_monolingual_lines = retained
stats.parallel_corpus_lines = len(parallel_hashes)
stats.deduplicated_size.kept = retained
stats.deduplicated_size.filtered = parallel_discards + mono_discards
stats.deduplicated_monolingual_lines.value = retained

stats.duplicates_of_parallel_corpus.value = parallel_discards
stats.duplicates_of_monolingual_corpus.value = mono_discards
stats.parallel_corpus_lines.value = len(parallel_hashes)

# Estimate the byte size. The better the estimate, the better the data distribution will be.
# When filtering mono NLLB data against parallel NLLB data, roughly 70% is kept.
Expand All @@ -128,9 +140,9 @@ def deduplicate_lines(lines: Generator[str, None, None]) -> Generator[str, None,
log_memory(gc_collect=True)
logger.info(f"Write the final file: {output_path}")
with write_lines(output_path) as outfile:
stats.final_truncated_monolingual_lines = len(final_lines)
stats.final_truncated_monolingual_lines.value = len(final_lines)
for i, line in enumerate(final_lines):
stats.final_truncated_monolingual_codepoints += len(line)
stats.final_truncated_monolingual_codepoints.value += len(line)
outfile.write(line)
if i % 1_000_000 == 999_999:
logger.info(f"Wrote line {i+1:,} to {output_path}")
Expand All @@ -149,9 +161,8 @@ def deduplicate_lines(lines: Generator[str, None, None]) -> Generator[str, None,
outfile.write(line)

log_memory(gc_collect=True)
stats_path = output_path.parent / f"{output_path.stem}.stats.json"
logger.info(f"Save the stats: {stats_path}")
stats.save_json(stats_path)
stats_path = stats.save_json()
logger.info(f"Saved the stats: {stats_path}")


def compute_line_hashes(path: Path) -> WeakStringSet:
Expand Down Expand Up @@ -231,7 +242,7 @@ def main() -> None:
logger.info(f"Compute hashes of the parallel data: {path}")
line_hashes = compute_line_hashes(parallel_corpus)

stats = FilteringStatistics()
stats = FilteringStatistics(output_path)

filter_and_write_monolingual_data(
mono_dataset_paths, output_path, line_hashes, max_sentences, args.sample_size, stats
Expand Down
123 changes: 80 additions & 43 deletions pipeline/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import os
import tempfile
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from io import TextIOWrapper
from pathlib import Path
from random import Random
Expand Down Expand Up @@ -284,64 +284,102 @@ def shuffle_in_temp_files(
print(f"Shuffled with {bucket_count} buckets.")


@dataclass
class Statistics:
"""
Base class for handling statistical data and JSON serialization in the pipeline. It
standardizes how the JSON is generated, and how it saves. Implement `as_json` for custom
JSON processing.
Base class for handling statistical data and JSON serialization in the pipeline. All
public data attributes in the implementing class will be saved as JSON. This class
standardizes how the JSON is generated, and where it is saved.
You can derive data at JSON generation time by providing an update_derived_data method.
For instance .save_json() for Statistics("nllb.en.zst") would produce "nllb.en.stats.json".
For instance stats.save_json() for Statistics("nllb.en.zst") would produce "nllb.en.stats.json".
"""

def __init__(self, dataset_path: Union[Path, str]) -> None:
self.dataset_path = Path(dataset_path)

def as_json(self) -> dict:
"""
Convert this data into JSON, and recurse into any other Statistics objects.
"""
data = asdict(self)
for key, value in enumerate(data):
if isinstance(value, Statistics):
data[key] = value.as_json()
return data
def __init__(self, dataset_path: Optional[Union[Path, str]] = None) -> None:
self._dataset_path = Path(dataset_path) if dataset_path else None

def save_json(self) -> Path:
"""
Standardizes how the JSON is saved, based on the dataset.
"""
path = self.dataset_path.parent / f"{self.dataset_path.stem}.stats.json"
if not self._dataset_path:
raise Exception("A dataset_path is required when saving to JSON.")

path = self._dataset_path.parent / f"{self._dataset_path.stem}.stats.json"
obj = self.as_json()
with open(path, "w", encoding="utf-8") as json_file:
json.dump(self.as_json(), json_file, indent=2)
json.dump(obj, json_file, indent=2)
json_file.write("\n")
return path

def _is_subclass(value: any):
"""
Determine if a child object is a subclass or not.
"""
try:
return issubclass(value.__class__, Statistics)
except AttributeError:
return False

def as_json(root: Union[int, str, float, list, "Statistics"]) -> Union[int, str, float, list]:
"""
Recursively walk the data attributes of the statistics.
"""
if Statistics._is_subclass(root):
stats: Statistics = root
stats.update_derived_data()
obj = {}
for key, value in stats.__dict__.items():
if key.startswith("_"):
continue
obj[key] = Statistics.as_json(value)

return obj

if isinstance(root, list):
return [Statistics.as_json(item) for item in root]

if isinstance(root, dict):
root_dict: dict = root
return {key: Statistics.as_json(value) for key, value in root_dict.items()}

if isinstance(root, (float, int, str)):
return root

return str(root)

def update_derived_data(self):
"""
Update any derived data in the sub values. Override this method if anything
needs to be derived.
"""
pass


@dataclass
class FilteringStep(Statistics):
"""
For each step for filtering, store how many were kept or filtered.
"""

filtered: int
kept: int
description: str
# "visited" is implied.

def __init__(self, dataset_path: Path, description: str, filtered=0, kept=0) -> None:
def __init__(
self, description: str, filtered=0, kept=0, dataset_path: Optional[Path] = None
) -> None:
super().__init__(dataset_path)
self.description = description
self.filtered = filtered
self.kept = kept
self.description = description
self.visited = 0

def as_json(self) -> dict:
return {
"description": self.description,
"filtered": self.filtered,
"kept": self.kept,
"visited": self.filtered + self.kept,
}
def update_derived_data(self):
super().update_derived_data()
# Only two of the values need to be kept up to date, the last can be computed.
if not self.visited:
self.visited = self.filtered + self.kept
elif self.filtered and not self.kept:
self.kept = self.visited - self.filtered
return
elif self.kept and not self.filtered:
self.filtered = self.visited - self.kept


@dataclass
Expand All @@ -353,16 +391,15 @@ class CountingStep(Statistics):
value: int
description: str

def __init__(self, dataset_path: Path, description: str, value=0) -> None:
def __init__(
self,
description: str,
value=0,
dataset_path: Optional[Path] = None,
) -> None:
super().__init__(dataset_path)
self.value = value
self.description = description

def as_json(self) -> dict:
return {
"description": self.description,
"value": self.value,
}
self.value = value


class WeakStringSet(Set):
Expand Down
18 changes: 3 additions & 15 deletions pipeline/data/importers/mono/hplt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,41 +47,29 @@ def __post_init__(self):
self.lines = self.text.split("\n")


@dataclass
class FilteringStatistics(Statistics):
"""
Gather statistics about the filtering process.
"""

shards: FilteringStep
visited_lines: FilteringStep
document_count: CountingStep
duplicate_lines: CountingStep
final_lines: CountingStep

def __init__(self, dataset_path: Path) -> None:
super().__init__(dataset_path)
self.shards = FilteringStep(
dataset_path,
"How many shards were sampled from. Each shard contains a subset of the "
"total datasets available.",
)
self.visited_lines = FilteringStep(
dataset_path,
"How many lines were visited and kept from the HPLT documents.",
)
self.document_count = CountingStep(
dataset_path,
"How many documents were visited. This can help represent data diversity.",
)
self.final_lines = CountingStep(
dataset_path,
"How many lines were actually written. Smaller lines will be combined together.",
)
self.duplicate_lines = CountingStep(
dataset_path,
"Of the collected lines, this counts how many were duplicates and discarded.",
)
self.final_lines = CountingStep(
"How many lines were actually written. Smaller lines will be combined together.",
)

def count_shards_visited(self):
self.shards.filtered -= 1
Expand Down
Loading

0 comments on commit 77479b3

Please sign in to comment.