Skip to content

Commit

Permalink
Splitter compression (#836)
Browse files Browse the repository at this point in the history
* Split run_task commands on ampersand

* Remove the compression commands from splitter.py

* Add split corpus test
  • Loading branch information
gregtatum authored Sep 9, 2024
1 parent 3e921d2 commit 98f8f1c
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 116 deletions.
9 changes: 9 additions & 0 deletions pipeline/common/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,15 @@ def write_lines(path: Path | str):
stack.close()


def count_lines(path: Path | str) -> int:
"""
Similar to wc -l, this counts the lines in a file. However, this command does so regardless
of the compression strategy used on the file.
"""
with read_lines(path) as lines:
return sum(1 for _ in lines)


def get_file_size(location: Union[Path, str]) -> int:
"""Get the size of a file, whether it is remote or local."""
if str(location).startswith("http://") or str(location).startswith("https://"):
Expand Down
94 changes: 50 additions & 44 deletions pipeline/translate/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,58 +12,64 @@

import argparse
import os
import subprocess
from contextlib import ExitStack
from typing import Optional

from pipeline.common.downloads import compress_file
from pipeline.common.downloads import count_lines, read_lines, write_lines
from pipeline.common.logging import get_logger

logger = get_logger(__file__)


def split_file(mono_path: str, output_dir: str, num_parts: int, output_suffix: str = ""):
"""
Split a file into fixed number of chunks.
For instance with:
mono_path = "corpus.en.zst"
output_dir = "artifacts"
num_parts = 20
output_suffix = ".ref"
Outputs:
.
├── corpus.en.zst
└── artifacts
├── file.1.ref.zst
├── file.2.ref.zst
├── file.3.ref.zst
├── ...
└── file.20.ref.zst
"""
os.makedirs(output_dir, exist_ok=True)

# Initialize the decompression command
decompress_cmd = f"zstdmt -dc {mono_path}"

# Use ExitStack to manage the cleanup of file handlers
with ExitStack() as stack:
decompressed = stack.enter_context(
subprocess.Popen(decompress_cmd, shell=True, stdout=subprocess.PIPE)
)
total_lines = sum(1 for _ in decompressed.stdout)
lines_per_part = (total_lines + num_parts - 1) // num_parts

print(f"Splitting {mono_path} to {num_parts} chunks x {total_lines} lines")

# Reset the decompression for actual processing
decompressed = stack.enter_context(
subprocess.Popen(decompress_cmd, shell=True, stdout=subprocess.PIPE)
)
current_file = None
current_name = None
current_line_count = 0
file_index = 1

for line in decompressed.stdout:
# If the current file is full or doesn't exist, start a new one
if current_line_count == 0 or current_line_count >= lines_per_part:
if current_file is not None:
current_file.close()
compress_file(current_name, keep_original=False)

current_name = f"{output_dir}/file.{file_index}{output_suffix}"
current_file = stack.enter_context(open(current_name, "w"))
print(f"A new file {current_name} created")
file_index += 1
current_line_count = 0

current_file.write(line.decode())
current_line_count += 1

# Compress the last file after closing.
compress_file(current_name, keep_original=False)

print("Done")
total_lines = count_lines(mono_path)
lines_per_part = (total_lines + num_parts - 1) // num_parts
logger.info(f"Splitting {mono_path} to {num_parts} chunks x {total_lines:,} lines")

line_writer = None
line_count = 0
file_index = 1

with read_lines(mono_path) as lines:
with ExitStack() as chunk_stack:
for line in lines:
if not line_writer or line_count >= lines_per_part:
# The current file is full or doesn't exist, start a new one.
if line_writer:
chunk_stack.close()

chunk_name = f"{output_dir}/file.{file_index}{output_suffix}.zst"
logger.info(f"Writing to file chunk: {chunk_name}")
line_writer = chunk_stack.enter_context(write_lines(chunk_name))
file_index += 1
line_count = 0

line_writer.write(line)
line_count += 1

logger.info("Done writing to files.")


def main(args: Optional[list[str]] = None) -> None:
Expand Down
4 changes: 2 additions & 2 deletions taskcluster/kinds/split-corpus/kind.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ tasks:
python3 $VCS_PATH/pipeline/translate/splitter.py
--output_dir=$TASK_WORKDIR/artifacts
--num_parts={split_chunks}
fetches/corpus.{src_locale}.zst &&
$TASK_WORKDIR/fetches/corpus.{src_locale}.zst &&
python3 $VCS_PATH/pipeline/translate/splitter.py
--output_dir=$TASK_WORKDIR/artifacts
--num_parts={split_chunks}
--output_suffix=.ref
fetches/corpus.{trg_locale}.zst
$TASK_WORKDIR/fetches/corpus.{trg_locale}.zst
dependencies:
merge-corpus: merge-corpus-{src_locale}-{trg_locale}
Expand Down
161 changes: 91 additions & 70 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,78 +156,82 @@ def run_task(
if not fetches_dir:
fetches_dir = self.path

if extra_args:
command_parts.extend(extra_args)

final_env = {
**os.environ,
**task_env,
"TASK_WORKDIR": work_dir,
"MOZ_FETCHES_DIR": fetches_dir,
"VCS_PATH": root_path,
**env,
}

# Expand out environment variables in environment, for instance MARIAN=$MOZ_FETCHES_DIR
# and FETCHES=./fetches will be expanded to MARIAN=./fetches
for key, value in final_env.items():
if not isinstance(value, str):
continue
expanded_value = final_env.get(value[1:])
if value and value[0] == "$" and expanded_value:
final_env[key] = expanded_value

# Ensure the environment variables are sorted so that the longer variables get replaced first.
sorted_env = sorted(final_env.items(), key=lambda kv: kv[0])
sorted_env.reverse()

for index, p in enumerate(command_parts):
part = (
p.replace("$TASK_WORKDIR/$VCS_PATH", root_path)
.replace("$VCS_PATH", root_path)
.replace("$TASK_WORKDIR", work_dir)
.replace("$MOZ_FETCHES_DIR", fetches_dir)
for command_parts_split in split_on_ampersands_operator(command_parts):
if extra_args:
command_parts_split.extend(extra_args)

final_env = {
**os.environ,
**task_env,
"TASK_WORKDIR": work_dir,
"MOZ_FETCHES_DIR": fetches_dir,
"VCS_PATH": root_path,
**env,
}

# Expand out environment variables in environment, for instance MARIAN=$MOZ_FETCHES_DIR
# and FETCHES=./fetches will be expanded to MARIAN=./fetches
for key, value in final_env.items():
if not isinstance(value, str):
continue
expanded_value = final_env.get(value[1:])
if value and value[0] == "$" and expanded_value:
final_env[key] = expanded_value

# Ensure the environment variables are sorted so that the longer variables get replaced first.
sorted_env = sorted(final_env.items(), key=lambda kv: kv[0])
sorted_env.reverse()

for index, p in enumerate(command_parts_split):
part = (
p.replace("$TASK_WORKDIR/$VCS_PATH", root_path)
.replace("$VCS_PATH", root_path)
.replace("$TASK_WORKDIR", work_dir)
.replace("$MOZ_FETCHES_DIR", fetches_dir)
)

# Apply the task environment.
for key, value in sorted_env:
env_var = f"${key}"
if env_var in part:
part = part.replace(env_var, value)

command_parts_split[index] = part

# If using a venv, prepend the binary directory to the path so it is used.
python_bin_dir, venv_dir = get_python_dirs(requirements)
if python_bin_dir:
final_env = {**final_env, "PATH": f'{python_bin_dir}:{os.environ.get("PATH", "")}'}
if command_parts_split[0].endswith(".py"):
# This script is relying on a shebang, add the python3 from the executable instead.
command_parts_split.insert(0, os.path.join(python_bin_dir, "python3"))
elif command_parts_split[0].endswith(".py"):
# This script does not require a virtual environment.
command_parts_split.insert(0, "python3")

# We have to set the path to the C++ lib before the process is started
# https://github.com/Helsinki-NLP/opus-fast-mosestokenizer/issues/6
with open(requirements) as f:
reqs_txt = f.read()
if "opus-fast-mosestokenizer" in reqs_txt:
lib_path = os.path.join(
venv_dir, "lib/python3.10/site-packages/mosestokenizer/lib"
)
print(f"Setting LD_LIBRARY_PATH to {lib_path}")
final_env["LD_LIBRARY_PATH"] = lib_path

print("┌──────────────────────────────────────────────────────────")
print("│ run_task:", " ".join(command_parts_split))
print("└──────────────────────────────────────────────────────────")

result = subprocess.run(
command_parts_split,
env=final_env,
cwd=root_path,
check=False,
)

# Apply the task environment.
for key, value in sorted_env:
env_var = f"${key}"
if env_var in part:
part = part.replace(env_var, value)

command_parts[index] = part

# If using a venv, prepend the binary directory to the path so it is used.
python_bin_dir, venv_dir = get_python_dirs(requirements)
if python_bin_dir:
final_env = {**final_env, "PATH": f'{python_bin_dir}:{os.environ.get("PATH", "")}'}
if command_parts[0].endswith(".py"):
# This script is relying on a shebang, add the python3 from the executable instead.
command_parts.insert(0, os.path.join(python_bin_dir, "python3"))
elif command_parts[0].endswith(".py"):
# This script does not require a virtual environment.
command_parts.insert(0, "python3")

# We have to set the path to the C++ lib before the process is started
# https://github.com/Helsinki-NLP/opus-fast-mosestokenizer/issues/6
with open(requirements) as f:
reqs_txt = f.read()
if "opus-fast-mosestokenizer" in reqs_txt:
lib_path = os.path.join(venv_dir, "lib/python3.10/site-packages/mosestokenizer/lib")
print(f"Setting LD_LIBRARY_PATH to {lib_path}")
final_env["LD_LIBRARY_PATH"] = lib_path

print("┌──────────────────────────────────────────────────────────")
print("│ run_task:", " ".join(command_parts))
print("└──────────────────────────────────────────────────────────")

result = subprocess.run(
command_parts,
env=final_env,
cwd=root_path,
check=False,
)
fail_on_error(result)
fail_on_error(result)

def print_tree(self):
"""
Expand Down Expand Up @@ -256,6 +260,23 @@ def print_tree(self):
print(f"└{span}┘")


def split_on_ampersands_operator(command_parts: list[str]) -> list[list[str]]:
"""Splits a command with the bash && operator into multiple lists of commands."""
multiple_command_parts: list[list[str]] = []
sublist: list[str] = []
for part in command_parts:
if part.strip().startswith("&&"):
command_part = part.replace("&&", "").strip()
if len(command_part):
sublist.append(command_part)
multiple_command_parts.append(sublist)
sublist = []
else:
sublist.append(part)
multiple_command_parts.append(sublist)
return multiple_command_parts


def fail_on_error(result: CompletedProcess[bytes]):
"""When a process fails, surface the stderr."""
if not result.returncode == 0:
Expand Down
57 changes: 57 additions & 0 deletions tests/test_split_corpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from pathlib import Path

import pytest
from fixtures import DataDir

from pipeline.common.downloads import read_lines

# With 47 lines, there will be 5 lines per file, except the last file which should have 2.
corpus_line_count = 47


@pytest.mark.parametrize("task", ["src-en", "trg-ru"])
def test_split_mono(task: str):
_side, locale = task.split("-")
data_dir = DataDir("test_split_mono")
data_dir.create_zst(
f"mono.{locale}.zst", "\n".join([str(i) for i in range(corpus_line_count)]) + "\n"
)
data_dir.run_task(f"split-mono-{task}")
data_dir.print_tree()

for i in range(10):
Path(data_dir.join(f"artifacts/file.{i+1}.zst")).exists()

with read_lines(data_dir.join("artifacts/file.9.zst")) as lines:
assert list(lines) == ["40\n", "41\n", "42\n", "43\n", "44\n"]

with read_lines(data_dir.join("artifacts/file.10.zst")) as lines:
assert list(lines) == ["45\n", "46\n"], "The last file has a partial chunk"


def test_split_corpus():
data_dir = DataDir("test_split_corpus")
data_dir.mkdir("fetches")
data_dir.create_zst(
"fetches/corpus.en.zst", "\n".join([f"en-{i}" for i in range(corpus_line_count)]) + "\n"
)
data_dir.create_zst(
"fetches/corpus.ru.zst", "\n".join([f"ru-{i}" for i in range(corpus_line_count)]) + "\n"
)
data_dir.run_task("split-corpus-en-ru")
data_dir.print_tree()

for i in range(10):
Path(data_dir.join(f"artifacts/file.{i + 1}.zst")).exists()

with read_lines(data_dir.join("artifacts/file.9.zst")) as lines:
assert list(lines) == ["en-40\n", "en-41\n", "en-42\n", "en-43\n", "en-44\n"]

with read_lines(data_dir.join("artifacts/file.10.zst")) as lines:
assert list(lines) == ["en-45\n", "en-46\n"], "The last file has a partial chunk"

with read_lines(data_dir.join("artifacts/file.9.ref.zst")) as lines:
assert list(lines) == ["ru-40\n", "ru-41\n", "ru-42\n", "ru-43\n", "ru-44\n"]

with read_lines(data_dir.join("artifacts/file.10.ref.zst")) as lines:
assert list(lines) == ["ru-45\n", "ru-46\n"], "The last file has a partial chunk"

0 comments on commit 98f8f1c

Please sign in to comment.