forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmerge_datasets.py
93 lines (69 loc) · 2.34 KB
/
merge_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import sys
import json
import argparse
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
)
from megatron.core.datasets.indexed_dataset import (
IndexedDataset,
IndexedDatasetBuilder,
get_bin_path,
get_idx_path,
)
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title="input data")
group.add_argument(
"--input",
type=str,
required=True,
help="Path to directory containing all document files to merge",
)
group = parser.add_argument_group(title="output data")
group.add_argument(
"--output-prefix",
type=str,
required=True,
help="Path to binary output file without suffix",
)
group = parser.add_argument_group(title="miscellaneous")
group.add_argument(
"--multimodal",
action="store_true",
help="Whether the datasets are assumed to be multimodal"
)
args = parser.parse_args()
assert os.path.isdir(
args.input
), f"ERROR: {args.input} is not a directory or does not exist"
assert os.path.isdir(
os.path.dirname(args.output_prefix)
), f"ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist"
return args
def main():
args = get_args()
prefixes = set()
for basename in os.listdir(args.input):
prefix, ext = os.path.splitext(basename)
if prefix in prefixes:
continue
if not os.path.isfile(os.path.join(args.input, basename)):
continue
ext_pair = ".bin" if ext == ".idx" else ".idx"
assert os.path.isfile(
os.path.join(args.input, prefix) + ext_pair
), f"ERROR: {ext_pair} file not provided for {os.path.join(args.input, prefix)}"
prefixes.add(prefix)
builder = None
for prefix in sorted(prefixes):
if builder is None:
dataset = IndexedDataset(os.path.join(args.input, prefix), multimodal=args.multimodal)
builder = IndexedDatasetBuilder(
get_bin_path(args.output_prefix), dtype=dataset.index.dtype, multimodal=args.multimodal
)
del dataset
builder.add_index(os.path.join(args.input, prefix))
builder.finalize(get_idx_path(args.output_prefix))
if __name__ == '__main__':
main()