Skip to content

Commit

Permalink
一次比较大的更新 😁
Browse files Browse the repository at this point in the history
1. 加上文件头
2. NER 可选装 CRF
3. SRL/NER 蒸馏 Loss 修改
4. 生成可部署模型时自动裁剪模型权重
5. DEP/SDP 返回结果可选格式
6. 修复了 DEP/SDP 解码的小错误
7. 恢复 SDP 的混合解码功能
  • Loading branch information
AlongWY committed Dec 25, 2020
1 parent b6782fa commit 6c481b3
Show file tree
Hide file tree
Showing 46 changed files with 389 additions and 93 deletions.
1 change: 1 addition & 0 deletions ltp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

__version__ = '4.1.3'

from . import const
Expand Down
4 changes: 4 additions & 0 deletions ltp/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

from .eisner import eisner
from .sent_split import split_sentence
from .maximum_forward_matching import Trie
4 changes: 4 additions & 0 deletions ltp/data/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

from .bio import Bio
from .conllu import Conllu
from .srl import Srl
Expand Down
4 changes: 4 additions & 0 deletions ltp/data/dataset/bio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

import logging

import datasets
Expand Down
4 changes: 4 additions & 0 deletions ltp/data/dataset/conllu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

import logging

import os
Expand Down
4 changes: 4 additions & 0 deletions ltp/data/dataset/srl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

import logging

import datasets
Expand Down
4 changes: 4 additions & 0 deletions ltp/data/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

from .collate import collate
from .iterator import iter_raw_lines, iter_lines, iter_blocks
from .multitask_dataloader import MultiTaskDataloader
4 changes: 4 additions & 0 deletions ltp/data/utils/batched_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

from torch.utils.data import Dataset, DataLoader


Expand Down
4 changes: 4 additions & 0 deletions ltp/data/utils/collate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

import torch
from torch._six import int_classes, string_classes, container_abcs
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
Expand Down
4 changes: 4 additions & 0 deletions ltp/data/utils/iterator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

import codecs


Expand Down
1 change: 1 addition & 0 deletions ltp/data/utils/multitask_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

import numpy as np


Expand Down
70 changes: 47 additions & 23 deletions ltp/frontend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

import os
import torch
import itertools
Expand Down Expand Up @@ -280,9 +281,13 @@ def seg(self, inputs: Union[List[str], List[List[str]]], truncation: bool = True

word_input = torch.gather(hidden, dim=1, index=word_idx) # 每个word第一个char的向量

word_cls_input = torch.cat([cls, word_input], dim=1)
word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1)
word_cls_mask[:, 0] = False
if len(self.dep_vocab) + len(self.sdp_vocab) > 0:
word_cls_input = torch.cat([cls, word_input], dim=1)
word_cls_mask = length_to_mask(torch.as_tensor(word_length, device=self.device) + 1)
word_cls_mask[:, 0] = False
else:
word_cls_input, word_cls_mask = None, None

return sentences, {
'word_cls': cls, 'word_input': word_input, 'word_length': word_length,
'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask
Expand All @@ -298,27 +303,32 @@ def pos(self, hidden: dict):
Returns:
pos: 词性标注结果
"""
if len(self.pos_vocab) == 0:
return []
postagger_output = self.model.pos_classifier(hidden['word_input']).logits
postagger_output = torch.argmax(postagger_output, dim=-1).cpu().numpy()
postagger_output = convert_idx_to_name(postagger_output, hidden['word_length'], self.pos_vocab)
return postagger_output

@no_gard
def ner(self, hidden: dict):
def ner(self, hidden: dict, as_entities=True):
"""
命名实体识别
Args:
hidden: 分词时所得到的中间表示
as_entities: 是否以 Entity(Type, Start, End) 的形式返回
Returns:
pos: 命名实体识别结果
"""
if len(self.ner_vocab) == 0:
return []
ner_output = self.model.ner_classifier.forward(
hidden['word_input'], word_attention_mask=hidden['word_cls_mask'][:, 1:]
).logits
ner_output = torch.argmax(ner_output, dim=-1).cpu().numpy()
)
ner_output = ner_output.decoded or torch.argmax(ner_output.logits, dim=-1).cpu().numpy()
ner_output = convert_idx_to_name(ner_output, hidden['word_length'], self.ner_vocab)
return [get_entities(ner) for ner in ner_output]
return [get_entities(ner) for ner in ner_output] if as_entities else ner_output

@no_gard
def srl(self, hidden: dict, keep_empty=True):
Expand All @@ -330,6 +340,8 @@ def srl(self, hidden: dict, keep_empty=True):
Returns:
pos: 语义角色标注结果
"""
if len(self.srl_vocab) == 0:
return []
srl_output = self.model.srl_classifier.forward(
input=hidden['word_input'],
word_attention_mask=hidden['word_cls_mask'][:, 1:]
Expand All @@ -350,24 +362,27 @@ def srl(self, hidden: dict, keep_empty=True):
return srl_labels_res

@no_gard
def dep(self, hidden: dict, fast=True):
def dep(self, hidden: dict, fast=True, as_tuple=True):
"""
依存句法树
Args:
hidden: 分词时所得到的中间表示
fast: 启用 fast 模式时,减少对结果的约束,速度更快,相应的精度会降低
as_tuple: 返回的结果是否为 (idx, head, rel) 的格式,否则返回 heads, rels
Returns:
依存句法树结果
"""
if len(self.dep_vocab) == 0:
return []
word_attention_mask = hidden['word_cls_mask']
result = self.model.dep_classifier.forward(
input=hidden['word_cls_input'],
word_attention_mask=word_attention_mask[:, 1:]
)
dep_arc, dep_label = result.arc_logits, result.rel_logits
dep_arc[:, 0, 1:] = float('-inf')
dep_arc.diagonal(0, 1, 2)[1:].fill_(float('-inf'))
dep_arc.diagonal(0, 1, 2).fill_(float('-inf'))
dep_arc = dep_arc.argmax(dim=-1) if fast else eisner(dep_arc, word_attention_mask)

dep_label = torch.argmax(dep_label, dim=-1)
Expand All @@ -376,49 +391,58 @@ def dep(self, hidden: dict, fast=True):
dep_arc[~word_attention_mask] = -1
dep_label[~word_attention_mask] = -1

arc_pred = [
head_pred = [
[item for item in arcs if item != -1]
for arcs in dep_arc[:, 1:].cpu().numpy().tolist()
]
rel_pred = [
[self.dep_vocab[item] for item in rels if item != -1]
for rels in dep_label[:, 1:].cpu().numpy().tolist()
]

if not as_tuple:
return head_pred, rel_pred
return [
[(idx + 1, arc, rel) for idx, (arc, rel) in enumerate(zip(arcs, rels))]
for arcs, rels in zip(arc_pred, rel_pred)
[(idx + 1, head, rel) for idx, (head, rel) in enumerate(zip(heads, rels))]
for heads, rels in zip(head_pred, rel_pred)
]

@no_gard
def sdp(self, hidden: dict, graph=True):
def sdp(self, hidden: dict, mode: str = 'graph'):
"""
语义依存图(树)
Args:
hidden: 分词时所得到的中间表示
graph: 选择是语义依存图还是语义依存树结果
mode: ['tree', 'graph', 'mix']
Returns:
语义依存图(树)结果
"""
if len(self.sdp_vocab) == 0:
return []

word_attention_mask = hidden['word_cls_mask']
result = self.model.sdp_classifier(
input=hidden['word_cls_input'],
word_attention_mask=word_attention_mask[:, 1:]
)
sdp_arc, sdp_label = result.arc_logits, result.rel_logits
sdp_arc[:, 0, 1:] = float('-inf')
sdp_arc.diagonal(0, 1, 2)[1:].fill_(float('-inf')) # 避免自指
sdp_arc.diagonal(0, 1, 2).fill_(float('-inf')) # 避免自指
sdp_label = torch.argmax(sdp_label, dim=-1)

if graph:
# 语义依存图
sdp_arc = torch.sigmoid_(sdp_arc) > 0.5
else:
if mode == 'tree':
# 语义依存树
sdp_arc_idx = eisner(sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc)
sdp_arc = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_(-1, sdp_arc_idx, True)
sdp_arc[~word_attention_mask] = False
sdp_label = get_graph_entities(sdp_arc, sdp_label, self.sdp_vocab)
sdp_arc_res = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_(-1, sdp_arc_idx, True)
elif mode == 'mix':
# 混合解码
sdp_arc_idx = eisner(sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc)
sdp_arc_res = (sdp_arc.sigmoid_() > 0.5).scatter_(-1, sdp_arc_idx, True)
else:
# 语义依存图
sdp_arc_res = torch.sigmoid_(sdp_arc) > 0.5

sdp_arc_res[~word_attention_mask] = False
sdp_label = get_graph_entities(sdp_arc_res, sdp_label, self.sdp_vocab)

return sdp_label
4 changes: 4 additions & 0 deletions ltp/multitask.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

import os
import types
from argparse import ArgumentParser
Expand Down
50 changes: 40 additions & 10 deletions ltp/multitask_distill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

import os
import numpy
import numpy as np
import types
import numpy as np
from argparse import ArgumentParser

import torch
Expand Down Expand Up @@ -81,8 +84,27 @@ def distill_linear(batch, result, target, temperature_scheduler, model: Model =
logits_mask = batch['attention_mask'][:, 2:]
active_logits = result.logits[logits_mask]
active_target_logits = target[logits_mask]
temperature = temperature_scheduler(active_logits, active_target_logits)
return kd_ce_loss(active_logits, active_target_logits, temperature=temperature)

if result.decoded is not None and extra is not None:
start_transitions = torch.as_tensor(extra['start_transitions'], device=model.device)
transitions = torch.as_tensor(extra['transitions'], device=model.device)
end_transitions = torch.as_tensor(extra['end_transitions'], device=model.device)

temperature = temperature_scheduler(active_logits, active_target_logits)
kd_loss = kd_mse_loss(active_logits, active_target_logits, temperature)

transitions_temp = temperature_scheduler(model.srl_classifier.crf.transitions, transitions)
s_transitions_temp = temperature_scheduler(model.srl_classifier.crf.start_transitions, start_transitions)
e_transitions_temp = temperature_scheduler(model.srl_classifier.crf.end_transitions, end_transitions)

crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions, transitions_temp) + \
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions, s_transitions_temp) + \
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions, e_transitions_temp)

return kd_loss + crf_loss
else:
temperature = temperature_scheduler(active_logits, active_target_logits)
return kd_ce_loss(active_logits, active_target_logits, temperature=temperature)


def distill_matrix_dep(batch, result, target, temperature_scheduler, model: Model = None, extra=None) -> torch.Tensor:
Expand Down Expand Up @@ -149,16 +171,24 @@ def distill_matrix_crf(batch, result, target, temperature_scheduler, model: Mode
s_rel, labels = result.arc_logits, result.labels
t_rel = target

logits_loss = kd_mse_loss(s_rel[logits_mask], t_rel[logits_mask])
active_logits = s_rel[logits_mask]
active_target_logits = t_rel[logits_mask]

temperature = temperature_scheduler(active_logits, active_target_logits)
kd_loss = kd_mse_loss(active_logits, active_target_logits, temperature)

start_transitions = torch.as_tensor(extra['start_transitions'], device=model.device)
transitions = torch.as_tensor(extra['transitions'], device=model.device)
end_transitions = torch.as_tensor(extra['end_transitions'], device=model.device)

crf_loss = kd_mse_loss(transitions, model.srl_classifier.rel_crf.transitions) + \
kd_mse_loss(start_transitions, model.srl_classifier.rel_crf.start_transitions) + \
kd_mse_loss(end_transitions, model.srl_classifier.rel_crf.end_transitions)
return (logits_loss + crf_loss) / 2
transitions_temp = temperature_scheduler(model.srl_classifier.crf.transitions, transitions)
s_transitions_temp = temperature_scheduler(model.srl_classifier.crf.start_transitions, start_transitions)
e_transitions_temp = temperature_scheduler(model.srl_classifier.crf.end_transitions, end_transitions)

crf_loss = kd_mse_loss(transitions, model.srl_classifier.crf.transitions, transitions_temp) + \
kd_mse_loss(start_transitions, model.srl_classifier.crf.start_transitions, s_transitions_temp) + \
kd_mse_loss(end_transitions, model.srl_classifier.crf.end_transitions, e_transitions_temp)
return kd_loss + crf_loss


distill_loss_map = {
Expand All @@ -179,7 +209,7 @@ def build_dataset(model, **kwargs):

for task, task_data_dir in kwargs.items():
task_distill_path = os.path.join(task_data_dir, task, 'output.npz')
task_distill_data = numpy.load(task_distill_path, allow_pickle=True)
task_distill_data = np.load(task_distill_path, allow_pickle=True)

distill_datasets[task] = task_distill_data['data'].tolist()
distill_datasets_extra[task] = task_distill_data.get('extra', None)
Expand Down
4 changes: 4 additions & 0 deletions ltp/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

from .module import BaseModule

from .mlp import MLP
Expand Down
1 change: 1 addition & 0 deletions ltp/nn/bilinear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

import torch, math
from torch import nn, Tensor
from torch.nn import functional as F
Expand Down
4 changes: 4 additions & 0 deletions ltp/nn/module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

try:
from pytorch_lightning import LightningModule as BaseModule
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion ltp/nn/relative_transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>
from typing import Optional
# ref: https://github.com/fastnlp/TENER

import torch, math
from torch import Tensor, nn
import torch.nn.functional as F
Expand Down
Loading

0 comments on commit 6c481b3

Please sign in to comment.