Skip to content

Commit

Permalink
修改了部分代码,填加了一个例子
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Nov 23, 2020
1 parent 7440a17 commit 4c97410
Show file tree
Hide file tree
Showing 10 changed files with 634 additions and 335 deletions.
Binary file modified .DS_Store
Binary file not shown.
57 changes: 36 additions & 21 deletions bert_seq2seq/bert_relation_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ def extrac_subject(self, output, subject_ids):
## 抽取subject的向量表征
batch_size = output.shape[0]
hidden_size = output.shape[-1]
start = torch.gather(output, index=subject_ids[:, :1].unsqueeze(1).expand((batch_size, 1, hidden_size)), dim=1)
end = torch.gather(output, index=subject_ids[:, 1: ].unsqueeze(1).expand((batch_size, 1, hidden_size)), dim=1)
subject = torch.cat((start, end), dim=-1)
return subject[:, 0]
# print("out shape is " + str(output.shape))
start_end = torch.gather(output, index=subject_ids.unsqueeze(-1).expand((batch_size, 2, hidden_size)), dim=1)
# print("start_end shape is " + str(start_end.shape))
# start = torch.gather(output, index=subject_ids[:, :1].unsqueeze(1).expand((batch_size, 1, hidden_size)), dim=1)
# end = torch.gather(output, index=subject_ids[:, 1: ].unsqueeze(1).expand((batch_size, 1, hidden_size)), dim=1)
subject = torch.cat((start_end[:, 0], start_end[:, 1]), dim=-1)
# print("subject shape is " + str(subject.shape))
return subject

def forward(self, text, subject_ids, position_enc=None, subject_labels=None, object_labels=None, use_layer_num=-1, device="cpu"):
if use_layer_num != -1:
Expand All @@ -71,15 +75,23 @@ def forward(self, text, subject_ids, position_enc=None, subject_labels=None, obj
enc_layers, _ = self.bert(text,
output_all_encoded_layers=True)
squence_out = enc_layers[use_layer_num]
sub_out = enc_layers[-2]
# print(squence_out.shape)

transform_out = self.layer_norm(squence_out)
subject_pred_out = self.subject_pred(transform_out)
# transform_out = self.layer_norm(squence_out)
subject_pred_out = self.subject_pred(squence_out)

subject_pred_act = self.activation(subject_pred_out)
subject_vec = self.extrac_subject(squence_out, subject_ids)
object_layer_norm = self.layer_norm_cond([squence_out, subject_vec])

subject_pred_act = subject_pred_act**2

subject_vec = self.extrac_subject(sub_out, subject_ids)
object_layer_norm = self.layer_norm_cond([sub_out, subject_vec])
object_pred_out = self.object_pred(object_layer_norm)
object_pred_act = self.activation(object_pred_out)

object_pred_act = object_pred_act**4

batch_size, seq_len, target_size = object_pred_act.shape

object_pred_act = object_pred_act.reshape((batch_size, seq_len, int(target_size/2), 2))
Expand All @@ -94,7 +106,7 @@ def forward(self, text, subject_ids, position_enc=None, subject_labels=None, obj
else :
return predictions

def predict_subject(self, text, position_enc=None, use_layer_num=-1, device="cpu"):
def predict_subject(self, text,use_layer_num=-1, device="cpu"):
if use_layer_num != -1:
if use_layer_num < 0 or use_layer_num > 7:
# 越界
Expand All @@ -105,15 +117,17 @@ def predict_subject(self, text, position_enc=None, use_layer_num=-1, device="cpu
self.target_mask = (text > 0).float()
enc_layers, _ = self.bert(text, output_all_encoded_layers=True)
squence_out = enc_layers[use_layer_num]

transform_out = self.layer_norm(squence_out)
subject_pred_out = self.subject_pred(transform_out)

sub_out = enc_layers[-2]
# transform_out = self.layer_norm(squence_out)
subject_pred_out = self.subject_pred(squence_out)
subject_pred_act = self.activation(subject_pred_out)

subject_pred_act = subject_pred_act**2

# subject_pred_act = (subject_pred_act > 0.5).long()
return subject_pred_act

def predict_object_predicate(self, text, subject_ids, position_enc=None, subject_labels=None, object_labels=None, use_layer_num=-1, device="cpu"):
def predict_object_predicate(self, text, subject_ids, use_layer_num=-1, device="cpu"):
if use_layer_num != -1:
if use_layer_num < 0 or use_layer_num > 7:
# 越界
Expand All @@ -122,17 +136,18 @@ def predict_object_predicate(self, text, subject_ids, position_enc=None, subject
text = text.to(device)
subject_ids = subject_ids.to(device)

self.target_mask = (text > 0).float()
enc_layers, _ = self.bert(text, output_all_encoded_layers=True)
squence_out = enc_layers[use_layer_num]

subject_vec = self.extrac_subject(squence_out, subject_ids)
object_layer_norm = self.layer_norm_cond([squence_out, subject_vec])
sub_out = enc_layers[-2]
subject_vec = self.extrac_subject(sub_out, subject_ids)
object_layer_norm = self.layer_norm_cond([sub_out, subject_vec])
object_pred_out = self.object_pred(object_layer_norm)
object_pred_act = self.activation(object_pred_out)
batch_size, seq_len, target_size = object_pred_act.shape

object_pred_act = object_pred_act.reshape((batch_size, seq_len, int(target_size/2), 2))
object_pred_act = object_pred_act**4

batch_size, seq_len, target_size = object_pred_act.shape
object_pred_act = object_pred_act.view((batch_size, seq_len, int(target_size/2), 2))
# print(object_pred_act.shape)
predictions = object_pred_act
return predictions
return predictions
4 changes: 2 additions & 2 deletions bert_seq2seq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def load_model_params(model, pretrain_model_path, keep_tokens=None):
torch.cuda.empty_cache()
print("{} loaded!".format(pretrain_model_path))

def load_recent_model(model, recent_model_path):
checkpoint = torch.load(recent_model_path)
def load_recent_model(model, recent_model_path, device="cuda"):
checkpoint = torch.load(recent_model_path, map_location=device)
model.load_state_dict(checkpoint)
torch.cuda.empty_cache()
print(str(recent_model_path) + " loaded!")
Expand Down
Loading

0 comments on commit 4c97410

Please sign in to comment.