-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMSLT-extract.py
49 lines (42 loc) · 2.2 KB
/
MSLT-extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#!/usr/bin/env python -*- coding: utf-8 -*-
import argparse
import tarfile
import re
def extract_text(filein):
return filein.read().decode("utf-16").replace("\r","").encode("utf-8")
if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-f', '--file', required=True, help='input repacked tgz file')
parser.add_argument('-s', '--source', required=True, help='source language (e.g. fr)')
parser.add_argument('-t', '--target', required=False, help='target language (e.g. en)')
parser.add_argument('-c', '--category', required=False, default="dev", help='dev or test?')
parser.add_argument('-o', '--output', required=False, help='output file (used for parallel data)')
args = parser.parse_args()
parallel_data = {}
source_pattern = "(.*/mslt_"+args.category+"_"+args.source+"_\d+/.+?).t2..*"
with tarfile.open(args.file, "r:gz") as tar:
for tarinfo in tar.getmembers():
match = re.match(source_pattern, tarinfo.name.lower())
if match != None:
text = extract_text(tar.extractfile(tarinfo))
parallel_data[match.group(1)] = {"src":text}
if args.target:
target_pattern = "(.*/mslt_"+args.category+"_"+args.source+"_\d+/.+?).t[3-9]."+args.target+".*"
with tarfile.open(args.file, "r:gz") as tar:
for tarinfo in tar.getmembers():
match = re.match(target_pattern, tarinfo.name.lower())
if match != None:
text = extract_text(tar.extractfile(tarinfo))
parallel_data[match.group(1)]["tgt"] = text
output_src = open(args.output+"."+args.source, "wb")
output_tgt = open(args.output+"."+args.target, "wb")
for item in parallel_data.items():
assert len(item[1]) == 2
if len(item[1]["src"].strip()) > 0 and len(item[1]["tgt"].strip()) > 0:
output_src.write(item[1]["src"])
output_tgt.write(item[1]["tgt"])
output_src.close()
output_tgt.close()
else:
for value in parallel_data.itervalues():
print(value["src"].strip())