Skip to content

Commit

Permalink
1、更新项目结构
Browse files Browse the repository at this point in the history
2、删除MMI生成方法
  • Loading branch information
yangjianxin1 committed May 26, 2021
1 parent ded94e0 commit ba360c8
Show file tree
Hide file tree
Showing 14 changed files with 21,855 additions and 640 deletions.
307 changes: 74 additions & 233 deletions README.md

Large diffs are not rendered by default.

10 changes: 0 additions & 10 deletions config/model_config_dialogue_small.json

This file was deleted.

100 changes: 100 additions & 0 deletions data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from torch.nn.parallel import DataParallel
import torch
from torch.nn.parallel._functions import Scatter
from torch.nn.parallel.parallel_apply import parallel_apply


def scatter(inputs, target_gpus, chunk_sizes, dim=0):
r"""
Slices tensors into approximately equal chunks and
distributes them across given GPUs. Duplicates
references to objects that are not tensors.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
try:
return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
except:
print('obj', obj.size())
print('dim', dim)
print('chunk_sizes', chunk_sizes)
quit()
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return list(map(list, zip(*map(scatter_map, obj))))
if isinstance(obj, dict) and len(obj) > 0:
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return [obj for targets in target_gpus]

# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None


def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
r"""Scatter with support for kwargs dictionary"""
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs


class BalancedDataParallel(DataParallel):
def __init__(self, gpu0_bsz, *args, **kwargs):
self.gpu0_bsz = gpu0_bsz
super().__init__(*args, **kwargs)

def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
if self.gpu0_bsz == 0:
device_ids = self.device_ids[1:]
else:
device_ids = self.device_ids
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
# print('len(inputs)1: ', str(len(inputs)))
# print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
if self.gpu0_bsz == 0:
replicas = replicas[1:]
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
return self.gather(outputs, self.output_device)

def parallel_apply(self, replicas, device_ids, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])

def scatter(self, inputs, kwargs, device_ids):
bsz = inputs[0].size(self.dim)
num_dev = len(self.device_ids)
gpu0_bsz = self.gpu0_bsz
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
if gpu0_bsz < bsz_unit:
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
delta = bsz - sum(chunk_sizes)
for i in range(delta):
chunk_sizes[i + 1] += 1
if gpu0_bsz == 0:
chunk_sizes = chunk_sizes[1:]
else:
return super().scatter(inputs, kwargs, device_ids)

# print('bsz: ', bsz)
# print('num_dev: ', num_dev)
# print('gpu0_bsz: ', gpu0_bsz)
# print('bsz_unit: ', bsz_unit)
# print('chunk_sizes: ', chunk_sizes)
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
12 changes: 7 additions & 5 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ class MyDataset(Dataset):
"""

def __init__(self, data_list):
self.data_list = data_list
def __init__(self, input_list, max_len):
self.input_list = input_list
self.max_len = max_len

def __getitem__(self, index):
input_ids = self.data_list[index].strip()
input_ids = [int(token_id) for token_id in input_ids.split()]
input_ids = self.input_list[index]
input_ids = input_ids[:self.max_len]
input_ids = torch.tensor(input_ids, dtype=torch.long)
return input_ids

def __len__(self):
return len(self.data_list)
return len(self.input_list)
6 changes: 3 additions & 3 deletions generate_dialogue_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def generate_subset():
"""
parser = argparse.ArgumentParser()
parser.add_argument('--raw_data_path', default='data/train.txt', type=str, required=False, help='原始训练语料')
parser.add_argument('--subset_size', default=500000, type=int, required=False, help='要获取的对话数据子集的规模')
parser.add_argument('--subset_size', default=1000000, type=int, required=False, help='要获取的对话数据子集的规模')
parser.add_argument('--subset_data_path', default='data', type=str, required=False,
help='数据子集文件路径,指定文件的父目录')
args = parser.parse_args()
Expand All @@ -23,7 +23,7 @@ def generate_subset():
subset_size = min(len(dialogues), args.subset_size)

with open(join(args.subset_data_path, "train_{}w.txt".format(int(subset_size / 10000))), "w", encoding="utf8") as f:
print("generating subset,please wait a few seconds ")
print("generating subset,please wait a few minutes")
for dialogue_index, dialogue in enumerate(dialogues):
if dialogue_index >= subset_size:
break
Expand Down Expand Up @@ -64,4 +64,4 @@ def compute_dialogue_length():


if __name__ == '__main__':
compute_dialogue_length()
generate_subset()
Binary file removed image/chitchat_demo.png
Binary file not shown.
62 changes: 34 additions & 28 deletions interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,43 @@
from tqdm import tqdm
from torch.nn import DataParallel
import logging
from transformers.modeling_gpt2 import GPT2Config, GPT2LMHeadModel
from transformers import BertTokenizer
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config
from transformers import BertTokenizerFast
# from transformers import BertTokenizer
from os.path import join, exists
from itertools import zip_longest, chain
# from chatbot.model import DialogueGPT2Model
from dataset import MyDataset
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from sklearn.model_selection import train_test_split
from train import create_model
from train_origin import create_model
import torch.nn.functional as F

PAD = '[PAD]'
pad_id = 0


def set_interact_args():
def set_args():
"""
Sets up the training arguments.
Sets up the arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='0', type=str, required=False, help='生成设备')
parser.add_argument('--temperature', default=1, type=float, required=False, help='生成的temperature')
parser.add_argument('--topk', default=8, type=int, required=False, help='最高k选1')
parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率')
parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False,
help='模型参数')
parser.add_argument('--log_path', default='data/interacting.log', type=str, required=False, help='interact日志存放位置')
parser.add_argument('--voca_path', default='vocabulary/vocab_small.txt', type=str, required=False, help='选择词库')
parser.add_argument('--dialogue_model_path', default='dialogue_model_path/', type=str, required=False, help='对话模型路径')
# parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False,
# help='模型参数')
parser.add_argument('--log_path', default='data/interact.log', type=str, required=False, help='interact日志存放位置')
parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False, help='选择词库')
parser.add_argument('--model_path', default='model/epoch40', type=str, required=False, help='对话模型路径')
parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
# parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断')
parser.add_argument('--max_history_len', type=int, default=5, help="dialogue history的最大长度")
parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度")
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
return parser.parse_args()

Expand Down Expand Up @@ -79,7 +80,7 @@ def create_logger(args):
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
logits: logits distribution shape (vocab size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Expand Down Expand Up @@ -110,45 +111,50 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')


def main():
args = set_interact_args()
args = set_args()
logger = create_logger(args)
# 当用户使用GPU,并且GPU可用时
args.cuda = torch.cuda.is_available() and not args.no_cuda
device = 'cuda' if args.cuda else 'cpu'
logger.info('using device:{}'.format(device))
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
tokenizer = BertTokenizer(vocab_file=args.voca_path)
model = GPT2LMHeadModel.from_pretrained(args.dialogue_model_path)
model.to(device)
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
# tokenizer = BertTokenizer(vocab_file=args.voca_path)
model = GPT2LMHeadModel.from_pretrained(args.model_path)
model = model.to(device)
model.eval()
if args.save_samples_path:
if not os.path.exists(args.save_samples_path):
os.makedirs(args.save_samples_path)
samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8')
samples_file.write("聊天记录{}:\n".format(datetime.now()))
# 存储聊天记录,每个utterance以token的id的形式进行存储
# 存储聊天记录,每个utterance以token的id的形式进行存储
history = []
print('开始和chatbot聊天,输入CTRL + Z以退出')

while True:
try:
text = input("user:")
# text = "你好"
if args.save_samples_path:
samples_file.write("user:{}\n".format(text))
history.append(tokenizer.encode(text))
text_ids = tokenizer.encode(text, add_special_tokens=False)
history.append(text_ids)
input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头

for history_id, history_utr in enumerate(history[-args.max_history_len:]):
input_ids.extend(history_utr)
input_ids.append(tokenizer.sep_token_id)
curr_input_tensor = torch.tensor(input_ids).long().to(device)
generated = []
input_ids = torch.tensor(input_ids).long().to(device)
input_ids = input_ids.unsqueeze(0)
response = [] # 根据context,生成的response
# 最多生成max_len个token
for _ in range(args.max_len):
outputs = model(input_ids=curr_input_tensor)
next_token_logits = outputs[0][-1, :]
outputs = model(input_ids=input_ids)
logits = outputs.logits
next_token_logits = logits[0, -1, :]
# 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
for id in set(generated):
for id in set(response):
next_token_logits[id] /= args.repetition_penalty
next_token_logits = next_token_logits / args.temperature
# 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
Expand All @@ -158,12 +164,12 @@ def main():
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束
break
generated.append(next_token.item())
curr_input_tensor = torch.cat((curr_input_tensor, next_token), dim=0)
response.append(next_token.item())
input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)
# his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist())
# print("his_text:{}".format(his_text))
history.append(generated)
text = tokenizer.convert_ids_to_tokens(generated)
history.append(response)
text = tokenizer.convert_ids_to_tokens(response)
print("chatbot:" + "".join(text))
if args.save_samples_path:
samples_file.write("chatbot:{}\n".format("".join(text)))
Expand Down
6 changes: 3 additions & 3 deletions interact_mmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from sklearn.model_selection import train_test_split
from train import create_model
from train_origin import create_model
import torch.nn.functional as F
import copy

Expand All @@ -39,7 +39,7 @@ def set_interact_args():
help='模型参数')
parser.add_argument('--log_path', default='data/interacting_mmi.log', type=str, required=False,
help='interact_mmi日志存放位置')
parser.add_argument('--voca_path', default='vocabulary/vocab_small.txt', type=str, required=False, help='选择词库')
parser.add_argument('--voca_path', default='vocab/vocab_small.txt', type=str, required=False, help='选择词库')
parser.add_argument('--dialogue_model_path', default='dialogue_model/', type=str, required=False,
help='dialogue_model路径')
parser.add_argument('--mmi_model_path', default='mmi_model/', type=str, required=False,
Expand Down Expand Up @@ -85,7 +85,7 @@ def create_logger(args):
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
logits: logits distribution shape (vocab size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Expand Down
Loading

0 comments on commit ba360c8

Please sign in to comment.