Skip to content

Commit

Permalink
add enable_sft_conversations_dataset_v3
Browse files Browse the repository at this point in the history
Signed-off-by: ftgreat <[email protected]>
  • Loading branch information
ftgreat committed May 30, 2023
1 parent 257dfff commit e767882
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 1 deletion.
164 changes: 163 additions & 1 deletion examples/gpt3_pretrain/llama/train_llama_bmtrain_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,169 @@ def forward_step(data, model, mems=None):
print('*'*20, "model", model, flush=True)

## conversations_dataset
if env_args.enable_sft_conversations_dataset_v2:
if env_args.enable_sft_conversations_dataset_v3:
assert env_args.enable_sft_dataset_dir is not None and \
env_args.enable_sft_dataset_file is not None

cur_dir = env_args.enable_sft_dataset_dir
jsonl_data = os.path.join(cur_dir, env_args.enable_sft_dataset_file)
max_seq_len = 2048

import jsonlines
import numpy as np
def read_file():
conversations = []
with jsonlines.open(jsonl_data) as reader:
for line in reader:
conversations.append(line)
return conversations

from examples.gpt3_pretrain.llama import ym_conversation as conversation_lib
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"
unknown_role = "unknown" # use default unknown role
roles = {
"human": conversation_lib.default_conversation.roles[0], # human role
"gpt": conversation_lib.default_conversation.roles[1], # gpt role
}

def _add_speaker_and_signal(header, source, get_conversation=True):
conversation = header

if "instruction" in source and source["instruction"] is not None and len(source["instruction"]) > 0:
source["instruction"] = (
BEGIN_SIGNAL
+ conversation_lib.default_conversation.roles[2]
+ ": "
+ source["instruction"]
+ END_SIGNAL
)
if get_conversation:
conversation += source["instruction"]
for sentence in source["conversations"]:
sentence_from = sentence["from"].lower()
sentence["value"] = (
BEGIN_SIGNAL
+ roles.get(sentence_from, unknown_role)
+ ": "
+ sentence["value"]
+ END_SIGNAL
)
if get_conversation:
conversation += sentence["value"]
return conversation

class ConversationDatasetV3(Dataset):
def __init__(self, conversations, tokenizer, maxlen=512):
super(ConversationDatasetV3, self).__init__()
self.conversations = conversations
self.tokenizer = tokenizer
self.maxlen = maxlen

def __getitem__(self, i):
header = f"{conversation_lib.default_conversation.system}\n\n"
source = self.conversations[i]
_add_speaker_and_signal(header, source)

source["chat_desc"] = header
chat_desc = source['chat_desc']
instruction = source['instruction']
conversations = source['conversations']

# chat_desc
example = self.tokenizer.encode_plus(f"{chat_desc}", None, max_length=None)['input_ids']
EOS_TOKEN = example[-1]
example = example[:-1] # remove eos
# instruction
instruction = self.tokenizer.encode_plus(f"{instruction}", None, max_length=None)['input_ids']
instruction = instruction[1:-1] # remove bos & eos
example += instruction

import copy
labels = copy.deepcopy(example)

for conversation in conversations:
role = conversation['from']
content = conversation['value']

if role == 'gpt':
prefix_gpt = BEGIN_SIGNAL + roles.get(sentence_from, unknown_role) + ": "
content_gpt = content[len(prefix_gpt):]

prefix_gpt = self.tokenizer.encode_plus(f"{prefix_gpt}", None, max_length=None)['input_ids']
prefix_gpt = prefix_gpt[1:-1] # remove bos & eos
example += prefix_gpt
role_labels = [env_args.IGNORE_INDEX] * len(prefix_gpt)

content_gpt = self.tokenizer.encode_plus(f"{content_gpt}", None, max_length=None)['input_ids']
content_gpt = content_gpt[1:-1] # remove bos & eos
example += content_gpt
role_labels = copy.deepcopy(content_gpt)
else:
content = self.tokenizer.encode_plus(f"{content}", None, max_length=None)['input_ids']
content = content[1:-1] # remove bos & eos
example += content
# masking
role_labels = [env_args.IGNORE_INDEX] * len(content)
labels += role_labels

example.append(EOS_TOKEN)
labels.append(EOS_TOKEN)
assert len(example) == len(labels)

## maxlen
example = example[:self.maxlen]
labels = labels[:self.maxlen]

output = {
"input_ids": example,
"labels": labels,
}
return output

def __len__(self):
return len(self.conversations)

@staticmethod
def collate_fn(batch):
def padding(indice, max_length, pad_idx=0):
pad_indice = [
item + [pad_idx] * max(0, max_length - len(item)) for item in indice
]
return torch.tensor(pad_indice)

input_ids = [data["input_ids"] for data in batch]
labels = [data["labels"] for data in batch]
max_length = max_seq_len
input_ids = padding(input_ids, max_length)[:,:max_length]
labels = padding(labels, max_length, pad_idx=env_args.IGNORE_INDEX)[:,:max_length]

data = {
"input_ids": input_ids,
"labels": labels
}
return data

conversations = read_file()
data_len = len(conversations)
#train_size = int(data_len * 0.95)
train_size = data_len
train_conversations = conversations[:train_size]

train_dataset = ConversationDatasetV3(train_conversations,
tokenizer=tokenizer,
maxlen=max_seq_len)
#print(f"train_dataset \n {train_dataset[0]}")

trainer.do_train(
train_dataset=train_dataset,
valid_dataset=None,
collate_fn=ConversationDatasetV3.collate_fn,
optimizer=None,
rank_split=False)

elif env_args.enable_sft_conversations_dataset_v2:
assert env_args.enable_sft_dataset_dir is not None and \
env_args.enable_sft_dataset_file is not None

Expand Down
2 changes: 2 additions & 0 deletions flagai/env_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self,
enable_sft_dataset_jsonl=False,
enable_sft_conversations_dataset=False,
enable_sft_conversations_dataset_v2=False,
enable_sft_conversations_dataset_v3=False,
enable_weighted_dataset_v2=False,

enable_flash_attn_models=False,
Expand Down Expand Up @@ -168,6 +169,7 @@ def __init__(self,
self.parser.add_argument('--enable_sft_dataset_jsonl', default=enable_sft_dataset_jsonl, type=str2bool, help='debug args')
self.parser.add_argument('--enable_sft_conversations_dataset', default=enable_sft_conversations_dataset, type=str2bool, help='debug args')
self.parser.add_argument('--enable_sft_conversations_dataset_v2', default=enable_sft_conversations_dataset_v2, type=str2bool, help='debug args')
self.parser.add_argument('--enable_sft_conversations_dataset_v3', default=enable_sft_conversations_dataset_v3, type=str2bool, help='debug args')
self.parser.add_argument('--enable_weighted_dataset_v2', default=enable_weighted_dataset_v2, type=str2bool, help='debug args')
self.parser.add_argument('--IGNORE_INDEX', default=-100, type=int, help='start training from saved checkpoint')

Expand Down

0 comments on commit e767882

Please sign in to comment.