forked from princeton-nlp/SimCSE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
584 lines (521 loc) · 23.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
import logging
import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional, Union, List, Dict, Tuple
import torch
import collections
import random
from datasets import load_dataset
import transformers
from transformers import (
CONFIG_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig,
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorForLanguageModeling,
DataCollatorWithPadding,
HfArgumentParser,
Trainer,
TrainingArguments,
default_data_collator,
set_seed,
EvalPrediction,
BertModel,
BertForPreTraining,
RobertaModel
)
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase
from transformers.trainer_utils import is_main_process
from transformers.data.data_collator import DataCollatorForLanguageModeling
from transformers.file_utils import cached_property, torch_required, is_torch_available, is_torch_tpu_available
from simcse.models import RobertaForCL, BertForCL
from simcse.trainers import CLTrainer
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
# Huggingface's original arguments
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
# SimCSE's arguments
temp: float = field(
default=0.05,
metadata={
"help": "Temperature for softmax."
}
)
pooler_type: str = field(
default="cls",
metadata={
"help": "What kind of pooler to use (cls, cls_before_pooler, avg, avg_top2, avg_first_last)."
}
)
hard_negative_weight: float = field(
default=0,
metadata={
"help": "The **logit** of weight for hard negatives (only effective if hard negatives are used)."
}
)
do_mlm: bool = field(
default=False,
metadata={
"help": "Whether to use MLM auxiliary objective."
}
)
mlm_weight: float = field(
default=0.1,
metadata={
"help": "Weight for MLM auxiliary objective (only effective if --do_mlm)."
}
)
mlp_only_train: bool = field(
default=False,
metadata={
"help": "Use MLP only during training"
}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
# Huggingface's original arguments.
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
# SimCSE's arguments
train_file: Optional[str] = field(
default=None,
metadata={"help": "The training data file (.txt or .csv)."}
)
max_seq_length: Optional[int] = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated."
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
mlm_probability: float = field(
default=0.15,
metadata={"help": "Ratio of tokens to mask for MLM (only effective if --do_mlm)"}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
@dataclass
class OurTrainingArguments(TrainingArguments):
# Evaluation
## By default, we evaluate STS (dev) during training (for selecting best checkpoints) and evaluate
## both STS and transfer tasks (dev) at the end of training. Using --eval_transfer will allow evaluating
## both STS and transfer tasks (dev) during training.
eval_transfer: bool = field(
default=False,
metadata={"help": "Evaluate transfer task dev sets (in validation)."}
)
@cached_property
@torch_required
def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices")
if self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0
elif is_torch_tpu_available():
device = xm.xla_device()
self._n_gpu = 0
elif self.local_rank == -1:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
# trigger an error that a device index is missing. Index 0 takes into account the
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value.
self._n_gpu = torch.cuda.device_count()
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
#
# deepspeed performs its own DDP internally, and requires the program to be started with:
# deepspeed ./program.py
# rather than:
# python -m torch.distributed.launch --nproc_per_node=2 ./program.py
if self.deepspeed:
from .integrations import is_deepspeed_available
if not is_deepspeed_available():
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
import deepspeed
deepspeed.init_distributed()
else:
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
if device.type == "cuda":
torch.cuda.set_device(device)
return device
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, OurTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
)
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model.
set_seed(training_args.seed)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub
#
# For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this
# behavior (see below)
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
if extension == "csv":
datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/", delimiter="\t" if "tsv" in data_args.train_file else ",")
else:
datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/")
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
if model_args.model_name_or_path:
if 'roberta' in model_args.model_name_or_path:
model = RobertaForCL.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
model_args=model_args
)
elif 'bert' in model_args.model_name_or_path:
model = BertForCL.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
model_args=model_args
)
if model_args.do_mlm:
pretrained_model = BertForPreTraining.from_pretrained(model_args.model_name_or_path)
model.lm_head.load_state_dict(pretrained_model.cls.predictions.state_dict())
else:
raise NotImplementedError
else:
raise NotImplementedError
logger.info("Training new model from scratch")
model = AutoModelForMaskedLM.from_config(config)
model.resize_token_embeddings(len(tokenizer))
# Prepare features
column_names = datasets["train"].column_names
sent2_cname = None
if len(column_names) == 2:
# Pair datasets
sent0_cname = column_names[0]
sent1_cname = column_names[1]
elif len(column_names) == 3:
# Pair datasets with hard negatives
sent0_cname = column_names[0]
sent1_cname = column_names[1]
sent2_cname = column_names[2]
elif len(column_names) == 1:
# Unsupervised datasets
sent0_cname = column_names[0]
sent1_cname = column_names[0]
else:
raise NotImplementedError
def prepare_features(examples):
# padding = longest (default)
# If no sentence in the batch exceed the max length, then use
# the max sentence length in the batch, otherwise use the
# max sentence length in the argument and truncate those that
# exceed the max length.
# padding = max_length (when pad_to_max_length, for pressure test)
# All sentences are padded/truncated to data_args.max_seq_length.
total = len(examples[sent0_cname])
# Avoid "None" fields
for idx in range(total):
if examples[sent0_cname][idx] is None:
examples[sent0_cname][idx] = " "
if examples[sent1_cname][idx] is None:
examples[sent1_cname][idx] = " "
sentences = examples[sent0_cname] + examples[sent1_cname]
# If hard negative exists
if sent2_cname is not None:
for idx in range(total):
if examples[sent2_cname][idx] is None:
examples[sent2_cname][idx] = " "
sentences += examples[sent2_cname]
sent_features = tokenizer(
sentences,
max_length=data_args.max_seq_length,
truncation=True,
padding="max_length" if data_args.pad_to_max_length else False,
)
features = {}
if sent2_cname is not None:
for key in sent_features:
features[key] = [[sent_features[key][i], sent_features[key][i+total], sent_features[key][i+total*2]] for i in range(total)]
else:
for key in sent_features:
features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)]
return features
if training_args.do_train:
train_dataset = datasets["train"].map(
prepare_features,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
# Data collator
@dataclass
class OurDataCollatorWithPadding:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
mlm: bool = True
mlm_probability: float = data_args.mlm_probability
def __call__(self, features: List[Dict[str, Union[List[int], List[List[int]], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
special_keys = ['input_ids', 'attention_mask', 'token_type_ids', 'mlm_input_ids', 'mlm_labels']
bs = len(features)
if bs > 0:
num_sent = len(features[0]['input_ids'])
else:
return
flat_features = []
for feature in features:
for i in range(num_sent):
flat_features.append({k: feature[k][i] if k in special_keys else feature[k] for k in feature})
batch = self.tokenizer.pad(
flat_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
if model_args.do_mlm:
batch["mlm_input_ids"], batch["mlm_labels"] = self.mask_tokens(batch["input_ids"])
batch = {k: batch[k].view(bs, num_sent, -1) if k in special_keys else batch[k].view(bs, num_sent, -1)[:, 0] for k in batch}
if "label" in batch:
batch["labels"] = batch["label"]
del batch["label"]
if "label_ids" in batch:
batch["labels"] = batch["label_ids"]
del batch["label_ids"]
return batch
def mask_tokens(
self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
data_collator = default_data_collator if data_args.pad_to_max_length else OurDataCollatorWithPadding(tokenizer)
trainer = CLTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.model_args = model_args
# Training
if training_args.do_train:
model_path = (
model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
)
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
if trainer.is_world_process_zero():
with open(output_train_file, "w") as writer:
logger.info("***** Train results *****")
for key, value in sorted(train_result.metrics.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
results = trainer.evaluate(eval_senteval_transfer=True)
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in sorted(results.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()