forked from shibing624/pycorrector
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add deep context model for error correct.
- Loading branch information
xuming06
committed
Nov 15, 2018
1 parent
85653d2
commit 9a0d2ec
Showing
11 changed files
with
10,906 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author:XuMing([email protected]) | ||
@description: network configuration | ||
""" | ||
import os | ||
|
||
output_dir = 'output' | ||
|
||
# chinese corpus | ||
raw_train_paths = [ | ||
# '../data/cn/CGED/CGED18_HSK_TrainingSet.xml', | ||
# '../data/cn/CGED/CGED17_HSK_TrainingSet.xml', | ||
'../data/cn/CGED/CGED16_HSK_TrainingSet.xml', | ||
'../data/cn/CGED/sample_HSK_TrainingSet.xml', | ||
] | ||
|
||
# Training data path. | ||
train_path = os.path.join(output_dir, 'train.txt') | ||
# Validation data path. | ||
test_path = os.path.join(output_dir, 'test.txt') | ||
|
||
emb_path = os.path.join(output_dir, 'emb.vec') | ||
model_path = os.path.join(output_dir, 'deep_context.model') | ||
|
||
# nets | ||
word_embed_size = 200 | ||
hidden_size = 200 | ||
n_layers = 1 | ||
use_mlp = True | ||
dropout = 0.0 | ||
|
||
# train | ||
maxlen = 64 | ||
epochs = 2 | ||
batch_size = 64 | ||
min_freq = 3 | ||
ns_power = 0.75 | ||
learning_rate = 1e-3 | ||
gpu_id = 0 | ||
|
||
# evaluate with mscc data set | ||
question_file = 'YOUR_DATASET_DIR/Holmes.machine_format.questions.txt' | ||
answer_file = 'YOUR_DATASET_DIR/Holmes.machine_format.answers.txt' | ||
|
||
if not os.path.exists(output_dir): | ||
os.makedirs(output_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author:XuMing([email protected]) | ||
@description: read data and generate vocab, config | ||
""" | ||
from codecs import open | ||
import json | ||
|
||
|
||
def write_embedding(id2word, nn_embedding, use_cuda, filename): | ||
with open(filename, mode='w', encoding='utf-8') as f: | ||
f.write('{} {}\n'.format(nn_embedding.num_embeddings, nn_embedding.embedding_dim)) | ||
if use_cuda: | ||
embeddings = nn_embedding.weight.data.cpu().numpy() | ||
else: | ||
embeddings = nn_embedding.weight.data.numpy() | ||
|
||
for word_id, vec in enumerate(embeddings): | ||
word = id2word[word_id] | ||
vec = ' '.join(list(map(str, vec))) | ||
f.write('{} {}\n'.format(word, vec)) | ||
|
||
|
||
def load_vocab(filename): | ||
with open(filename, mode='r', encoding='utf-8') as f: | ||
f.readline() | ||
itos = [str(field.split(' ', 1)[0]) for field in f] | ||
stoi = {token: i for i, token in enumerate(itos)} | ||
return itos, stoi | ||
|
||
|
||
def write_config(filename, **kwargs): | ||
with open(filename, mode='w', encoding='utf-8') as f: | ||
json.dump(kwargs, f) | ||
|
||
|
||
def read_config(filename): | ||
with open(filename, mode='r', encoding='utf-8') as f: | ||
return json.load(f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author:XuMing([email protected]) | ||
@description: evaluate with mscc data set | ||
The function create_mscc_dataset is Copyright 2016 Oren Melamud | ||
Modifications copyright (C) 2018 Tatsuya Aoki | ||
This code is based on https://github.com/orenmel/context2vec/blob/master/context2vec/eval/mscc_text_tokenize.py | ||
Used to convert the Microsoft Sentence Completion Challnege (MSCC) learning corpus into a one-sentence-per-line format. | ||
""" | ||
|
||
import glob | ||
import os | ||
import sys | ||
from codecs import open | ||
|
||
import numpy | ||
import torch | ||
from nltk.tokenize import word_tokenize, sent_tokenize | ||
|
||
from pycorrector.deep_context import config | ||
from pycorrector.deep_context.data_util import load_vocab | ||
from pycorrector.deep_context.infer import read_model | ||
|
||
|
||
def create_mscc_dataset(input_dir, output_filename, lowercase=True): | ||
def write_paragraph_lines(paragraph_lines, file_obj): | ||
paragraph_str = ' '.join(paragraph_lines) | ||
for sent in sent_tokenize(paragraph_str): | ||
if lowercase: | ||
sent = sent.lower() | ||
file_obj.write(' '.join(word_tokenize(sent)) + '\n') | ||
|
||
if input_dir[-1] != '/': | ||
input_dir += '/' | ||
|
||
if not os.path.isdir(input_dir): | ||
raise NotADirectoryError | ||
|
||
print('Read files from', input_dir) | ||
print('Creating dataset to', output_filename) | ||
files = glob.glob(input_dir + '*.TXT') | ||
with open(output_filename, mode='w', encoding='utf-8') as output_file: | ||
for file in files: | ||
with open(file, mode='r', errors='ignore', encoding='utf-8') as input_file: | ||
paragraph_lines = [] | ||
count = 0 | ||
for i, line in enumerate(input_file): | ||
if len(line.strip()) == 0 and len(paragraph_lines) > 0: | ||
write_paragraph_lines(paragraph_lines, output_file) | ||
paragraph_lines = [] | ||
else: | ||
paragraph_lines.append(line) | ||
count += 1 | ||
if len(paragraph_lines) > 0: | ||
write_paragraph_lines(paragraph_lines, output_file) | ||
print('Read {} lines'.format(count)) | ||
|
||
|
||
def read_mscc_questions(input_file, lower=True): | ||
with open(input_file, mode='r', encoding='utf-8') as f: | ||
questions = [] | ||
for line in f: | ||
q_id, text = line.split(' ', 1) | ||
if lower: | ||
text = text.lower() | ||
text = text.strip().split() | ||
target_word = '' | ||
for index, token in enumerate(text): | ||
if token.startswith('[') and token.endswith(']'): | ||
target_word = token[1:-1] | ||
target_pos = index | ||
if not target_word: | ||
raise SyntaxError | ||
questions.append([text, q_id, target_word, target_pos]) | ||
return questions | ||
|
||
|
||
def print_mscc_score(gold_q_id: list, q_id_and_sim: list): | ||
assert len(q_id_and_sim) % 5 == 0 | ||
|
||
gold = numpy.array(gold_q_id) | ||
answer = numpy.array([sorted(q_id_and_sim[5 * i:5 * (i + 1)], key=lambda x: x[1], reverse=True) | ||
for i in range(int(len(q_id_and_sim) / 5))])[:, 0, 0] | ||
correct_or_not = (gold == answer) | ||
mid = int(len(correct_or_not) / 2) | ||
dev = correct_or_not[:mid] | ||
test = correct_or_not[mid:] | ||
|
||
print('Overall', float(sum(correct_or_not)) / len(correct_or_not)) | ||
print('dev', float(sum(dev)) / len(dev)) | ||
print('test', float(sum(test)) / len(test)) | ||
|
||
|
||
def mscc_evaluation(question_file, | ||
answer_file, | ||
output_file, | ||
model, | ||
stoi, | ||
unk_token, | ||
bos_token, | ||
eos_token, | ||
device): | ||
questions = read_mscc_questions(question_file) | ||
q_id_and_sim = [] | ||
with open(question_file, mode='r', encoding='utf-8') as f, open(output_file, mode='w', encoding='utf-8') as w: | ||
for question, input_line in zip(questions, f): | ||
tokens, q_id, target_word, target_pos = question | ||
tokens[target_pos] = target_word | ||
tokens = [bos_token] + tokens + [eos_token] | ||
indexed_sentence = [stoi[token] if token in stoi else stoi[unk_token] for token in tokens] | ||
input_tokens = torch.tensor(indexed_sentence, dtype=torch.long, device=device).unsqueeze(0) | ||
indexed_target_word = input_tokens[0, target_pos + 1] | ||
similarity = model.run_inference(input_tokens, indexed_target_word, target_pos) | ||
q_id_and_sim.append((q_id, similarity)) | ||
w.write(input_line.strip() + '\t' + str(similarity) + '\n') | ||
|
||
with open(answer_file, mode='r', encoding='utf-8') as f: | ||
gold_q_id = [line.split(' ', 1)[0] for line in f] | ||
|
||
print_mscc_score(gold_q_id, q_id_and_sim) | ||
|
||
|
||
if __name__ == '__main__': | ||
if len(sys.argv) < 2: | ||
print('Please specify your input directory that contains MSCC dataset.') | ||
print('(Most of the case the name of the directory might be `Holmes_Training_Data`.)') | ||
print('sample usage: python src/eval/mscc.py ~/dataset/Holmes_Training_Data/') | ||
quit() | ||
create_mscc_dataset(sys.argv[1], 'dataset/mscc_train.txt') | ||
|
||
gpu_id = config.gpu_id | ||
model_path = config.model_path | ||
emb_path = config.emb_path | ||
# device | ||
use_cuda = torch.cuda.is_available() and gpu_id > -1 | ||
if use_cuda: | ||
device = torch.device('cuda:{}'.format(gpu_id)) | ||
torch.cuda.set_device(gpu_id) | ||
else: | ||
device = torch.device('cpu') | ||
|
||
# load model | ||
model, config_dict = read_model(model_path, device) | ||
unk_token = config_dict['unk_token'] | ||
bos_token = config_dict['bos_token'] | ||
eos_token = config_dict['eos_token'] | ||
|
||
# read vocab from word_emb path | ||
itos, stoi = load_vocab(emb_path) | ||
|
||
mscc_evaluation(config.question_file, | ||
config.answer_file, | ||
'mscc.result', | ||
model, | ||
stoi, | ||
unk_token=unk_token, | ||
bos_token=bos_token, | ||
eos_token=eos_token, | ||
device=device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author:XuMing([email protected]) | ||
@description: Inference | ||
""" | ||
|
||
import torch | ||
from torch import optim | ||
|
||
from pycorrector.deep_context import config | ||
from pycorrector.deep_context.data_util import read_config, load_vocab | ||
from pycorrector.deep_context.network import Context2vec | ||
|
||
|
||
def inference(model_path, | ||
emb_path, | ||
gpu_id): | ||
def return_split_sentence(sentence): | ||
if ' ' not in sentence: | ||
print('sentence should contain white space to split it into tokens') | ||
raise SyntaxError | ||
elif '[]' not in sentence: | ||
print('sentence should contain `[]` that notes the target') | ||
raise SyntaxError | ||
else: | ||
tokens = sentence.lower().strip().split() | ||
target_pos = tokens.index('[]') | ||
return tokens, target_pos | ||
|
||
# device | ||
use_cuda = torch.cuda.is_available() and gpu_id > -1 | ||
if use_cuda: | ||
device = torch.device('cuda:{}'.format(gpu_id)) | ||
torch.cuda.set_device(gpu_id) | ||
else: | ||
device = torch.device('cpu') | ||
|
||
# load model | ||
model, config_dict = read_model(model_path, device) | ||
unk_token = config_dict['unk_token'] | ||
bos_token = config_dict['bos_token'] | ||
eos_token = config_dict['eos_token'] | ||
|
||
# read vocab from word_emb path | ||
itos, stoi = load_vocab(emb_path) | ||
|
||
# norm weight | ||
model.norm_embedding_weight(model.criterion.W) | ||
|
||
while True: | ||
sentence = input('>> ') | ||
try: | ||
tokens, target_pos = return_split_sentence(sentence) | ||
except SyntaxError: | ||
continue | ||
tokens[target_pos] = unk_token | ||
tokens = [bos_token] + tokens + [eos_token] | ||
indexed_sentence = [stoi[token] if token in stoi else stoi[unk_token] for token in tokens] | ||
input_tokens = torch.tensor(indexed_sentence, dtype=torch.long, device=device).unsqueeze(0) | ||
topv, topi = model.run_inference(input_tokens, target=None, target_pos=target_pos) | ||
for value, key in zip(topv, topi): | ||
print(value.item(), itos[key.item()]) | ||
|
||
|
||
def read_model(model_path, device): | ||
config_file = model_path + '.config.json' | ||
config_dict = read_config(config_file) | ||
model = Context2vec(vocab_size=config_dict['vocab_size'], | ||
counter=[1] * config_dict['vocab_size'], | ||
word_embed_size=config_dict['word_embed_size'], | ||
hidden_size=config_dict['hidden_size'], | ||
n_layers=config_dict['n_layers'], | ||
bidirectional=config_dict['bidirectional'], | ||
use_mlp=config_dict['use_mlp'], | ||
dropout=config_dict['dropout'], | ||
pad_index=config_dict['pad_index'], | ||
device=device, | ||
inference=True).to(device) | ||
model.load_state_dict(torch.load(model_path)) | ||
optimizer = optim.Adam(model.parameters(), lr=config_dict['learning_rate']) | ||
optimizer.load_state_dict(torch.load(model_path + '.optim')) | ||
model.eval() | ||
return model, config_dict | ||
|
||
|
||
if __name__ == "__main__": | ||
inference(config.model_path, | ||
config.emb_path, | ||
config.gpu_id) |
Oops, something went wrong.