-
Notifications
You must be signed in to change notification settings - Fork 2
/
create_made_data.py
34 lines (24 loc) · 987 Bytes
/
create_made_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import pickle
import os
import argparse
import numpy as np
def sequences_to_nhot(seqs, vocab_size):
labels = np.zeros((len(seqs), vocab_size), dtype=np.float32)
for bid, seq in enumerate(seqs):
for word in seq:
labels[bid][word] = 1
return labels
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('train_data')
parser.add_argument('valid_data')
parser.add_argument('vocab_file')
parser.add_argument('save_path')
args = parser.parse_args()
_, _, token_vocab, _ = pickle.load(open(args.vocab_file, 'rb'))
attr_vocab_size = len(token_vocab)
train_data, _, _, _, _ = pickle.load(open(args.train_data, 'rb'))
valid_data, _, _, _, _ = pickle.load(open(args.valid_data, 'rb'))
train_nhot = sequences_to_nhot(train_data, attr_vocab_size)
valid_nhot = sequences_to_nhot(valid_data, attr_vocab_size)
np.savez(args.save_path, train_data=train_nhot, valid_data=valid_nhot)