Skip to content

Commit

Permalink
Update seq2seq.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengyanzhao1997 authored Jan 21, 2021
1 parent 24f96c9 commit 29d9eee
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions model/train/NEZHA/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def predict(self, inputs, output_ids, states):
states[i] += 1 #如果当前

if states[i] > 0:
ngrams = self.get_ngram_set(ides_temp, states[i])
ngrams = self.get_ngram_set(token_ids, states[i])
'''
if satates = 1 :开头
因此 ngrams = 1 所有的token
Expand All @@ -132,7 +132,7 @@ def predict(self, inputs, output_ids, states):
if prefix in ngrams: # 如果确实是适合的ngram
candidates = ngrams[prefix]
else: # 没有的话就退回1gram
ngrams = self.get_ngram_set(ides_temp, 1)
ngrams = self.get_ngram_set(token_ids, 1)
candidates = ngrams[tuple()]
states[i] = 1
candidates = list(candidates)
Expand Down Expand Up @@ -207,6 +207,8 @@ def __iter__(self, random=False):
input_dict = self.tokenizer(source,target,max_length=self.Max_len,truncation=True,padding=True)
len_ = len(input_dict['input_ids'])
token_ids = random_masking(input_dict['input_ids'])
if self.tokenizer.vocab['[SEP]'] not in token_ids:
continue
sep_index = token_ids.index(self.tokenizer.vocab['[SEP]']) + 1
source_labels, target_labels = generate_copy_labels(token_ids[:sep_index],token_ids[sep_index:])
labels = source_labels + target_labels[1:]
Expand Down Expand Up @@ -279,11 +281,11 @@ def main():
config_path = os.path.join(pretrained_path,'config.json')
config = BertConfig.from_json_file(config_path)
config.model_type = 'NEZHA'
MAX_LEN = 1024
batch_size = 8
MAX_LEN = 820
batch_size = 1
data = load_data('sfzy_seq2seq.json')
fold = 0
num_folds = 15
num_folds = 100
train_data = data_split(data, fold, num_folds, 'train')
valid_data = data_split(data, fold, num_folds, 'valid')
train_generator = data_generator(train_data,batch_size,MAX_LEN,tokenizer)
Expand Down

0 comments on commit 29d9eee

Please sign in to comment.