-
Notifications
You must be signed in to change notification settings - Fork 14
/
train.py
97 lines (91 loc) · 4.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from encoder.bert_encoder import BERTEncoder
from models.rifre_sentence import RIFRE_SEN
from models.rifre_triple import RIFRE_TR
from framework.sentence_re import Sentence_RE
from framework.triple_re import Triple_RE
from configs import Config
from utils import count_params
import numpy as np
import torch
import random, argparse
torch.cuda.set_device(0)
def seed_torch(seed=2020):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Model Controller')
parser.add_argument('--train', default=True, type=bool)
parser.add_argument('--dataset', default='webnlg', type=str,
help='specify the dataset from ["nyt","webnlg","semeval"]')
args = parser.parse_args()
dataset = args.dataset
is_train = args.train
config = Config()
if config.seed is not None:
print(config.seed)
seed_torch(config.seed)
if dataset == 'semeval':
print('train--' + dataset)
config.class_nums = config.semeval_class
sentence_encoder = BERTEncoder(pretrain_path=config.bert_base)
model = RIFRE_SEN(sentence_encoder, config)
count_params(model)
framework = Sentence_RE(model,
train_path=config.semeval_train,
val_path=config.semeval_val,
test_path=config.semeval_test,
rel2id=config.semeval_rel2id,
pretrain_path=config.bert_base,
ckpt=config.semeval_ckpt,
batch_size=config.batch_size,
max_epoch=config.epoch,
lr=config.lr)
framework.train_semeval_model()
framework.load_state_dict(config.semeval_ckpt)
print('test:')
framework.eval_semeval(framework.test_loader)
elif dataset == 'webnlg':
print('train--' + dataset + config.webnlg_ckpt)
config.class_nums = config.webnlg_class
sentence_encoder = BERTEncoder(pretrain_path=config.bert_base_case)
model = RIFRE_TR(sentence_encoder, config)
count_params(model)
framework = Triple_RE(model,
train=config.webnlg_train,
val=config.webnlg_val,
test=config.webnlg_test,
rel2id=config.webnlg_rel2id,
pretrain_path=config.bert_base_case,
ckpt=config.webnlg_ckpt,
batch_size=config.batch_size,
max_epoch=config.epoch,
lr=config.lr,
num_workers=4)
framework.train_model()
framework.load_state_dict(config.webnlg_ckpt)
print('test:' + config.webnlg_ckpt)
framework.test_set.metric(framework.model)
elif dataset == 'nyt':
print('train--' + dataset)
config.class_nums = config.nyt_class
sentence_encoder = BERTEncoder(pretrain_path=config.bert_base_case)
model = RIFRE_TR(sentence_encoder, config)
count_params(model)
framework = Triple_RE(model,
train=config.nyt_train,
val=config.nyt_val,
test=config.nyt_test,
rel2id=config.nyt_rel2id,
pretrain_path=config.bert_base_case,
ckpt=config.nyt_ckpt,
batch_size=config.batch_size,
max_epoch=config.epoch,
lr=config.lr)
output_path = 'save_result/nyt_result.json'
framework.train_model()
framework.load_state_dict(config.nyt_ckpt)
print('test:' + config.nyt_ckpt)
framework.test_set.metric(framework.model, output_path=output_path)
else:
print('unkonw dataset')