-
Notifications
You must be signed in to change notification settings - Fork 12
/
collator.py
77 lines (58 loc) · 2.23 KB
/
collator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import copy
import argparse
from dataclasses import dataclass
import transformers
import math
from torch.utils.data import Sampler
import torch.distributed as dist
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, T5Tokenizer, T5Config, T5ForConditionalGeneration
class Collator(object):
def __init__(self, args, tokenizer):
self.args = args
self.only_train_response = args.only_train_response
self.tokenizer = tokenizer
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
# print(self.tokenizer.model_max_length)
def __call__(self, batch):
input_texts = [d["input_ids"] for d in batch]
full_texts = [d["labels"] + self.tokenizer.eos_token for d in batch]
inputs = self.tokenizer(
text = full_texts,
text_target = input_texts,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_attention_mask=True,
)
labels = copy.deepcopy(inputs["input_ids"])
if self.only_train_response:
# ignore padding
labels[labels == self.tokenizer.pad_token_id] = -100
# ignore input text
labels[torch.where(inputs["labels"] != self.tokenizer.pad_token_id)] = -100
inputs["labels"] = labels
return inputs
class TestCollator(object):
def __init__(self, args, tokenizer):
self.args = args
self.tokenizer = tokenizer
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = 0
if isinstance(self.tokenizer, LlamaTokenizer):
# Allow batched inference
self.tokenizer.padding_side = "left"
def __call__(self, batch):
input_texts = [d["input_ids"] for d in batch]
targets = [d["labels"] for d in batch]
inputs = self.tokenizer(
text=input_texts,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_attention_mask=True,
)
return (inputs, targets)