forked from HIT-SCIR/ltp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. 加上文件头 2. NER 可选装 CRF 3. SRL/NER 蒸馏 Loss 修改 4. 生成可部署模型时自动裁剪模型权重 5. DEP/SDP 返回结果可选格式 6. 修复了 DEP/SDP 解码的小错误 7. 恢复 SDP 的混合解码功能
- Loading branch information
Showing
46 changed files
with
389 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
__version__ = '4.1.3' | ||
|
||
from . import const | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
from .eisner import eisner | ||
from .sent_split import split_sentence | ||
from .maximum_forward_matching import Trie |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
from .bio import Bio | ||
from .conllu import Conllu | ||
from .srl import Srl | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
import logging | ||
|
||
import datasets | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
import logging | ||
|
||
import os | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
import logging | ||
|
||
import datasets | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
from .collate import collate | ||
from .iterator import iter_raw_lines, iter_lines, iter_blocks | ||
from .multitask_dataloader import MultiTaskDataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
from torch.utils.data import Dataset, DataLoader | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
import torch | ||
from torch._six import int_classes, string_classes, container_abcs | ||
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
import codecs | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
import numpy as np | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
import os | ||
import torch | ||
import itertools | ||
|
@@ -280,9 +281,13 @@ def seg(self, inputs: Union[List[str], List[List[str]]], truncation: bool = True | |
|
||
word_input = torch.gather(hidden, dim=1, index=word_idx) # 每个word第一个char的向量 | ||
|
||
word_cls_input = torch.cat([cls, word_input], dim=1) | ||
word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1) | ||
word_cls_mask[:, 0] = False | ||
if len(self.dep_vocab) + len(self.sdp_vocab) > 0: | ||
word_cls_input = torch.cat([cls, word_input], dim=1) | ||
word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1) | ||
word_cls_mask[:, 0] = False | ||
else: | ||
word_cls_input, word_cls_mask = None, None | ||
|
||
return sentences, { | ||
'word_cls': cls, 'word_input': word_input, 'word_length': word_length, | ||
'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask | ||
|
@@ -298,27 +303,32 @@ def pos(self, hidden: dict): | |
Returns: | ||
pos: 词性标注结果 | ||
""" | ||
if len(self.pos_vocab) == 0: | ||
return [] | ||
postagger_output = self.model.pos_classifier(hidden['word_input']).logits | ||
postagger_output = torch.argmax(postagger_output, dim=-1).cpu().numpy() | ||
postagger_output = convert_idx_to_name(postagger_output, hidden['word_length'], self.pos_vocab) | ||
return postagger_output | ||
|
||
@no_gard | ||
def ner(self, hidden: dict): | ||
def ner(self, hidden: dict, as_entities=True): | ||
""" | ||
命名实体识别 | ||
Args: | ||
hidden: 分词时所得到的中间表示 | ||
as_entities: 是否以 Entity(Type, Start, End) 的形式返回 | ||
Returns: | ||
pos: 命名实体识别结果 | ||
""" | ||
if len(self.ner_vocab) == 0: | ||
return [] | ||
ner_output = self.model.ner_classifier.forward( | ||
hidden['word_input'], word_attention_mask=hidden['word_cls_mask'][:, 1:] | ||
).logits | ||
ner_output = torch.argmax(ner_output, dim=-1).cpu().numpy() | ||
) | ||
ner_output = ner_output.decoded or torch.argmax(ner_output.logits, dim=-1).cpu().numpy() | ||
ner_output = convert_idx_to_name(ner_output, hidden['word_length'], self.ner_vocab) | ||
return [get_entities(ner) for ner in ner_output] | ||
return [get_entities(ner) for ner in ner_output] if as_entities else ner_output | ||
|
||
@no_gard | ||
def srl(self, hidden: dict, keep_empty=True): | ||
|
@@ -330,6 +340,8 @@ def srl(self, hidden: dict, keep_empty=True): | |
Returns: | ||
pos: 语义角色标注结果 | ||
""" | ||
if len(self.srl_vocab) == 0: | ||
return [] | ||
srl_output = self.model.srl_classifier.forward( | ||
input=hidden['word_input'], | ||
word_attention_mask=hidden['word_cls_mask'][:, 1:] | ||
|
@@ -350,24 +362,27 @@ def srl(self, hidden: dict, keep_empty=True): | |
return srl_labels_res | ||
|
||
@no_gard | ||
def dep(self, hidden: dict, fast=True): | ||
def dep(self, hidden: dict, fast=True, as_tuple=True): | ||
""" | ||
依存句法树 | ||
Args: | ||
hidden: 分词时所得到的中间表示 | ||
fast: 启用 fast 模式时,减少对结果的约束,速度更快,相应的精度会降低 | ||
as_tuple: 返回的结果是否为 (idx, head, rel) 的格式,否则返回 heads, rels | ||
Returns: | ||
依存句法树结果 | ||
""" | ||
if len(self.dep_vocab) == 0: | ||
return [] | ||
word_attention_mask = hidden['word_cls_mask'] | ||
result = self.model.dep_classifier.forward( | ||
input=hidden['word_cls_input'], | ||
word_attention_mask=word_attention_mask[:, 1:] | ||
) | ||
dep_arc, dep_label = result.arc_logits, result.rel_logits | ||
dep_arc[:, 0, 1:] = float('-inf') | ||
dep_arc.diagonal(0, 1, 2)[1:].fill_(float('-inf')) | ||
dep_arc.diagonal(0, 1, 2).fill_(float('-inf')) | ||
dep_arc = dep_arc.argmax(dim=-1) if fast else eisner(dep_arc, word_attention_mask) | ||
|
||
dep_label = torch.argmax(dep_label, dim=-1) | ||
|
@@ -376,49 +391,58 @@ def dep(self, hidden: dict, fast=True): | |
dep_arc[~word_attention_mask] = -1 | ||
dep_label[~word_attention_mask] = -1 | ||
|
||
arc_pred = [ | ||
head_pred = [ | ||
[item for item in arcs if item != -1] | ||
for arcs in dep_arc[:, 1:].cpu().numpy().tolist() | ||
] | ||
rel_pred = [ | ||
[self.dep_vocab[item] for item in rels if item != -1] | ||
for rels in dep_label[:, 1:].cpu().numpy().tolist() | ||
] | ||
|
||
if not as_tuple: | ||
return head_pred, rel_pred | ||
return [ | ||
[(idx + 1, arc, rel) for idx, (arc, rel) in enumerate(zip(arcs, rels))] | ||
for arcs, rels in zip(arc_pred, rel_pred) | ||
[(idx + 1, head, rel) for idx, (head, rel) in enumerate(zip(heads, rels))] | ||
for heads, rels in zip(head_pred, rel_pred) | ||
] | ||
|
||
@no_gard | ||
def sdp(self, hidden: dict, graph=True): | ||
def sdp(self, hidden: dict, mode: str = 'graph'): | ||
""" | ||
语义依存图(树) | ||
Args: | ||
hidden: 分词时所得到的中间表示 | ||
graph: 选择是语义依存图还是语义依存树结果 | ||
mode: ['tree', 'graph', 'mix'] | ||
Returns: | ||
语义依存图(树)结果 | ||
""" | ||
if len(self.sdp_vocab) == 0: | ||
return [] | ||
|
||
word_attention_mask = hidden['word_cls_mask'] | ||
result = self.model.sdp_classifier( | ||
input=hidden['word_cls_input'], | ||
word_attention_mask=word_attention_mask[:, 1:] | ||
) | ||
sdp_arc, sdp_label = result.arc_logits, result.rel_logits | ||
sdp_arc[:, 0, 1:] = float('-inf') | ||
sdp_arc.diagonal(0, 1, 2)[1:].fill_(float('-inf')) # 避免自指 | ||
sdp_arc.diagonal(0, 1, 2).fill_(float('-inf')) # 避免自指 | ||
sdp_label = torch.argmax(sdp_label, dim=-1) | ||
|
||
if graph: | ||
# 语义依存图 | ||
sdp_arc = torch.sigmoid_(sdp_arc) > 0.5 | ||
else: | ||
if mode == 'tree': | ||
# 语义依存树 | ||
sdp_arc_idx = eisner(sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc) | ||
sdp_arc = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_(-1, sdp_arc_idx, True) | ||
sdp_arc[~word_attention_mask] = False | ||
sdp_label = get_graph_entities(sdp_arc, sdp_label, self.sdp_vocab) | ||
sdp_arc_res = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_(-1, sdp_arc_idx, True) | ||
elif mode == 'mix': | ||
# 混合解码 | ||
sdp_arc_idx = eisner(sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc) | ||
sdp_arc_res = (sdp_arc.sigmoid_() > 0.5).scatter_(-1, sdp_arc_idx, True) | ||
else: | ||
# 语义依存图 | ||
sdp_arc_res = torch.sigmoid_(sdp_arc) > 0.5 | ||
|
||
sdp_arc_res[~word_attention_mask] = False | ||
sdp_label = get_graph_entities(sdp_arc_res, sdp_label, self.sdp_vocab) | ||
|
||
return sdp_label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
import os | ||
import types | ||
from argparse import ArgumentParser | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,10 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
import os | ||
import numpy | ||
import numpy as np | ||
import types | ||
import numpy as np | ||
from argparse import ArgumentParser | ||
|
||
import torch | ||
|
@@ -81,8 +84,27 @@ def distill_linear(batch, result, target, temperature_scheduler, model: Model = | |
logits_mask = batch['attention_mask'][:, 2:] | ||
active_logits = result.logits[logits_mask] | ||
active_target_logits = target[logits_mask] | ||
temperature = temperature_scheduler(active_logits, active_target_logits) | ||
return kd_ce_loss(active_logits, active_target_logits, temperature=temperature) | ||
|
||
if result.decoded is not None and extra is not None: | ||
start_transitions = torch.as_tensor(extra['start_transitions'], device=model.device) | ||
transitions = torch.as_tensor(extra['transitions'], device=model.device) | ||
end_transitions = torch.as_tensor(extra['end_transitions'], device=model.device) | ||
|
||
temperature = temperature_scheduler(active_logits, active_target_logits) | ||
kd_loss = kd_mse_loss(active_logits, active_target_logits, temperature) | ||
|
||
transitions_temp = temperature_scheduler(model.srl_classifier.crf.transitions, transitions) | ||
s_transitions_temp = temperature_scheduler(model.srl_classifier.crf.start_transitions, start_transitions) | ||
e_transitions_temp = temperature_scheduler(model.srl_classifier.crf.end_transitions, end_transitions) | ||
|
||
crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions, transitions_temp) + \ | ||
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions, s_transitions_temp) + \ | ||
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions, e_transitions_temp) | ||
|
||
return kd_loss + crf_loss | ||
else: | ||
temperature = temperature_scheduler(active_logits, active_target_logits) | ||
return kd_ce_loss(active_logits, active_target_logits, temperature=temperature) | ||
|
||
|
||
def distill_matrix_dep(batch, result, target, temperature_scheduler, model: Model = None, extra=None) -> torch.Tensor: | ||
|
@@ -149,16 +171,24 @@ def distill_matrix_crf(batch, result, target, temperature_scheduler, model: Mode | |
s_rel, labels = result.arc_logits, result.labels | ||
t_rel = target | ||
|
||
logits_loss = kd_mse_loss(s_rel[logits_mask], t_rel[logits_mask]) | ||
active_logits = s_rel[logits_mask] | ||
active_target_logits = t_rel[logits_mask] | ||
|
||
temperature = temperature_scheduler(active_logits, active_target_logits) | ||
kd_loss = kd_mse_loss(active_logits, active_target_logits, temperature) | ||
|
||
start_transitions = torch.as_tensor(extra['start_transitions'], device=model.device) | ||
transitions = torch.as_tensor(extra['transitions'], device=model.device) | ||
end_transitions = torch.as_tensor(extra['end_transitions'], device=model.device) | ||
|
||
crf_loss = kd_mse_loss(transitions, model.srl_classifier.rel_crf.transitions) + \ | ||
kd_mse_loss(start_transitions, model.srl_classifier.rel_crf.start_transitions) + \ | ||
kd_mse_loss(end_transitions, model.srl_classifier.rel_crf.end_transitions) | ||
return (logits_loss + crf_loss) / 2 | ||
transitions_temp = temperature_scheduler(model.srl_classifier.crf.transitions, transitions) | ||
s_transitions_temp = temperature_scheduler(model.srl_classifier.crf.start_transitions, start_transitions) | ||
e_transitions_temp = temperature_scheduler(model.srl_classifier.crf.end_transitions, end_transitions) | ||
|
||
crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions, transitions_temp) + \ | ||
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions, s_transitions_temp) + \ | ||
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions, e_transitions_temp) | ||
return kd_loss + crf_loss | ||
|
||
|
||
distill_loss_map = { | ||
|
@@ -179,7 +209,7 @@ def build_dataset(model, **kwargs): | |
|
||
for task, task_data_dir in kwargs.items(): | ||
task_distill_path = os.path.join(task_data_dir, task, 'output.npz') | ||
task_distill_data = numpy.load(task_distill_path, allow_pickle=True) | ||
task_distill_data = np.load(task_distill_path, allow_pickle=True) | ||
|
||
distill_datasets[task] = task_distill_data['data'].tolist() | ||
distill_datasets_extra[task] = task_distill_data.get('extra', None) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
from .module import BaseModule | ||
|
||
from .mlp import MLP | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
import torch, math | ||
from torch import nn, Tensor | ||
from torch.nn import functional as F | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
|
||
try: | ||
from pytorch_lightning import LightningModule as BaseModule | ||
except Exception as e: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*_ | ||
# Author: Yunlong Feng <[email protected]> | ||
from typing import Optional | ||
# ref: https://github.com/fastnlp/TENER | ||
|
||
import torch, math | ||
from torch import Tensor, nn | ||
import torch.nn.functional as F | ||
|
Oops, something went wrong.