Skip to content

Commit

Permalink
bug fixes in partitioned data preprocessor
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Chrzanowski authored and jaredcasper committed Nov 29, 2022
1 parent 8ce8256 commit 84a43b1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
4 changes: 3 additions & 1 deletion megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def build_tokenizer(args):
print('> building {} tokenizer ...'.format(args.tokenizer_type),
flush=True)

if args.tokenizer_type != 'SentencePieceTokenizer':
assert args.vocab_file is not None

# Select and instantiate the tokenizer.
assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True,
Expand Down
14 changes: 9 additions & 5 deletions tools/preprocess_data_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def process_json_file(self, file_name):
self.print_processing_stats(i, proc_start, total_bytes_processed)

fin.close()
builders[key].finalize(output_idx_files[key])


def get_args():
Expand Down Expand Up @@ -219,9 +220,8 @@ def get_args():
args = parser.parse_args()
args.keep_empty = False

if (args.tokenizer_type.lower().startswith('bert')
if not args.split_sentences:
print("Are you sure you don't want to split sentences?")
if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences:
print("Are you sure you don't want to split sentences?")

# some default/dummy values for the tokenizer
args.rank = 1
Expand Down Expand Up @@ -265,7 +265,11 @@ def main():
if args.partitions == 1:
file_name, extension = os.path.splitext(args.input)
sentence_split_file = file_name + "_ss" + extension
in_ss_out_names.append((args.input, sentence_split_file, args.output_prefix))
file_names = {
'partition': args.input,
'sentence_split': sentence_split_file,
'output_prefix': args.output_prefix}
in_ss_out_names.append(file_names)
else:
in_file_names = glob.glob(args.input)

Expand Down Expand Up @@ -358,7 +362,7 @@ def main():
full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix,
key, level)
builders[key].merge_file_(full_partition_output_prefix)
builder[key].finalize(output_idx_files[key])
builders[key].finalize(output_idx_files[key])


if __name__ == '__main__':
Expand Down

0 comments on commit 84a43b1

Please sign in to comment.