Skip to content

Commit

Permalink
统一使用 NamedTuple 作为结果
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Dec 7, 2020
1 parent 38eff99 commit e79c488
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 161 deletions.
11 changes: 7 additions & 4 deletions ltp/task_dependency_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def tokenize(examples):

def validation_method(metric_func=None, loss_tag='val_loss', metric_tag=f'val_{task_info.metric_name}', log=True):
def step(self: pl.LightningModule, batch, batch_nb):
loss, (parc, prel) = self(**batch)
result = self(**batch)
loss = result.loss
parc = result.arc_logits
prel = result.rel_logits

mask: torch.Tensor = batch['word_attention_mask']

Expand Down Expand Up @@ -140,9 +143,9 @@ def train_dataloader(self):
return res

def training_step(self, batch, batch_nb):
loss, logits = self(**batch)
self.log("loss", loss.item())
return loss
result = self(**batch)
self.log("loss", result.loss.item())
return result.loss

def val_dataloader(self):
return torch.utils.data.DataLoader(
Expand Down
12 changes: 6 additions & 6 deletions ltp/task_named_entity_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,21 @@ def validation_method(metric, loss_tag='val_loss', metric_tag=f'val_{task_info.m
metric_func, label_feature = metric

def step(self: pl.LightningModule, batch, batch_nb):
loss, logits = self(**batch)
result = self(**batch)

mask = batch['word_attention_mask'] == False

# acc
labels = batch['labels']
preds = torch.argmax(logits, dim=-1)
preds = torch.argmax(result.logits, dim=-1)

labels[mask] = -1
preds[mask] = -1

labels = [[label_feature[word] for word in sent if word != -1] for sent in labels.detach().cpu().numpy()]
preds = [[label_feature[word] for word in sent if word != -1] for sent in preds.detach().cpu().numpy()]

return {'loss': loss.item(), 'pred': preds, 'labels': labels}
return {'loss': result.loss.item(), 'pred': preds, 'labels': labels}

def epoch_end(self, outputs):
if isinstance(outputs, dict):
Expand Down Expand Up @@ -133,9 +133,9 @@ def train_dataloader(self):
return res

def training_step(self, batch, batch_nb):
loss, logits = self(**batch)
self.log("loss", loss.item())
return loss
result = self(**batch)
self.log("loss", result.loss.item())
return result.loss

def val_dataloader(self):
return torch.utils.data.DataLoader(
Expand Down
14 changes: 7 additions & 7 deletions ltp/task_part_of_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def build_dataset(model, data_dir):
datasets.Conllu,
data_dir=data_dir,
cache_dir=data_dir,
xpos=os.path.join(data_dir, "pos_labels.txt")
xpos=os.path.join(data_dir, "xpos_labels.txt")
)
dataset.remove_columns_(["id", "lemma", "upos", "feats", "head", "deprel", "deps", "misc"])
dataset.rename_column_('xpos', 'labels')
Expand Down Expand Up @@ -79,14 +79,14 @@ def tokenize(examples):

def validation_method(metric, loss_tag='val_loss', metric_tag=f'val_{task_info.metric_name}', log=True):
def step(self, batch, batch_nb):
loss, logits = self(**batch)
result = self(**batch)

mask = batch['logits_mask']
labels = batch['labels']
preds = torch.argmax(logits, dim=-1)
preds = torch.argmax(result.logits, dim=-1)
preds_true = preds[mask] == labels[mask]
return {
loss_tag: loss.item(),
loss_tag: result.loss.item(),
f'{metric_tag}/true': torch.sum(preds_true, dtype=torch.float).item(),
f'{metric_tag}/all': preds_true.numel()
}
Expand Down Expand Up @@ -123,9 +123,9 @@ def train_dataloader(self):
return res

def training_step(self, batch, batch_nb):
loss, logits = self(**batch)
self.log("loss", loss)
return loss
result = self(**batch)
self.log("loss", result.loss)
return result.loss

def val_dataloader(self):
return torch.utils.data.DataLoader(
Expand Down
12 changes: 6 additions & 6 deletions ltp/task_segmention.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,21 @@ def validation_method(metric, loss_tag='val_loss', metric_tag=f'val_{task_info.m
label_mapper = ['I-W', 'B-W']

def step(self: pl.LightningModule, batch, batch_nb):
loss, logits = self(**batch)
result = self(**batch)

mask = batch['attention_mask'][:, 2:] != 1

# acc
labels = batch['labels']
preds = torch.argmax(logits, dim=-1)
preds = torch.argmax(result.logits, dim=-1)

labels[mask] = -1
preds[mask] = -1

labels = [[label_mapper[word] for word in sent if word != -1] for sent in labels.detach().cpu().numpy()]
preds = [[label_mapper[word] for word in sent if word != -1] for sent in preds.detach().cpu().numpy()]

return {'loss': loss.item(), 'pred': preds, 'labels': labels}
return {'loss': result.loss.item(), 'pred': preds, 'labels': labels}

def epoch_end(self: pl.LightningModule, outputs):
if isinstance(outputs, dict):
Expand Down Expand Up @@ -125,9 +125,9 @@ def train_dataloader(self):
return res

def training_step(self, batch, batch_nb):
loss, logits = self(**batch)
self.log("loss", loss.item())
return loss
result = self(**batch)
self.log("loss", result.loss.item())
return result.loss

def val_dataloader(self):
return torch.utils.data.DataLoader(
Expand Down
15 changes: 10 additions & 5 deletions ltp/task_semantic_dependency_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def tokenize(examples):

def validation_method(metric_func=None, loss_tag='val_loss', metric_tag=f'val_{task_info.metric_name}', log=True):
def step(self: pl.LightningModule, batch, batch_nb):
loss, (parc, prel) = self(**batch)
result = self(**batch)
loss = result.loss
parc = result.arc_logits
prel = result.rel_logits

parc[:, 0, 1:] = float('-inf')
parc.diagonal(0, 1, 2)[1:].fill_(float('-inf')) # 避免自指
Expand Down Expand Up @@ -139,9 +142,11 @@ def epoch_end(self: pl.LightningModule, outputs):
r = correct / true if true > 0 else 0
f = 2 * p * r / (p + r) if (p + r > 0) else 0

prefix, appendix = metric_tag.split('_', maxsplit=1)

if log:
self.log_dict(
dictionary={loss_tag: loss, metric_tag: f},
dictionary={loss_tag: loss, f'{prefix}_p': p, f'{prefix}_r': r, metric_tag: f},
on_step=False, on_epoch=True, prog_bar=True, logger=True
)
else:
Expand All @@ -164,9 +169,9 @@ def train_dataloader(self):
return res

def training_step(self, batch, batch_nb):
loss, logits = self(**batch)
self.log("loss", loss.item())
return loss
result = self(**batch)
self.log("loss", result.loss.item())
return result.loss

def val_dataloader(self):
return torch.utils.data.DataLoader(
Expand Down
14 changes: 7 additions & 7 deletions ltp/task_semantic_role_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ def validation_method(metric, loss_tag='val_loss', metric_tag=f'val_{task_info.m
metric_func, label_feature = metric

def step(self: pl.LightningModule, batch, batch_nb):
(logits, preds, labels), = self(**batch)
result = self(**batch)

sent_length = [len(sent) for sent in preds]
preds = [[label_feature[word] for word in sent] for sent in preds]
sent_length = [len(sent) for sent in result.decoded]
preds = [[label_feature[word] for word in sent] for sent in result.decoded]
labels = [
[label_feature[word] for word in sent[:sent_length[idx]]]
for idx, sent in enumerate(labels.detach().cpu().numpy())
for idx, sent in enumerate(result.labels.detach().cpu().numpy())
]

return {'pred': preds, 'labels': labels}
Expand Down Expand Up @@ -136,9 +136,9 @@ def train_dataloader(self):
return res

def training_step(self, batch, batch_nb):
loss, output = self(**batch)
self.log("loss", loss.item())
return loss
result = self(**batch)
self.log("loss", result.loss.item())
return result.loss

def val_dataloader(self):
return torch.utils.data.DataLoader(
Expand Down
22 changes: 10 additions & 12 deletions ltp/transformer_biaffine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from torch import nn
from transformers import AutoModel, AutoConfig

from collections import namedtuple
from ltp.nn import MLP, Bilinear, BaseModule

GraphResult = namedtuple('GraphResult', ['loss', 'arc_logits', 'rel_logits'])


def dep_loss(model, s_arc, s_rel, head, labels, logits_mask):
head_loss = nn.CrossEntropyLoss()
Expand Down Expand Up @@ -64,8 +67,7 @@ def __init__(self, input_size, label_num, dropout,

self.loss_func = loss_func

def forward(self, input, logits_mask=None, word_index=None,
word_attention_mask=None, head=None, labels=None, hidden_states=None):
def forward(self, input, logits_mask=None, word_index=None, word_attention_mask=None, head=None, labels=None):
if word_index is not None:
input = torch.cat([input[:, :1, :], torch.gather(
input[:, 1:, :], dim=1, index=word_index.unsqueeze(-1).expand(-1, -1, input.size(-1))
Expand All @@ -81,7 +83,6 @@ def forward(self, input, logits_mask=None, word_index=None,
s_rel = self.rel_atten(rel_d, rel_h).permute(0, 2, 3, 1)

loss = None
loss_output = (s_arc, s_rel)
if labels is not None:
if logits_mask is None:
logits_mask = word_attention_mask
Expand All @@ -93,8 +94,7 @@ def forward(self, input, logits_mask=None, word_index=None,
activate_word_mask = activate_word_mask & activate_word_mask.transpose(-1, -2)
s_arc.masked_fill_(~activate_word_mask, float('-inf'))

output = ((loss_output,) + hidden_states[1:]) if hidden_states is not None else (loss_output,)
return ((loss,) + output) if loss is not None else output
return GraphResult(loss=loss, arc_logits=s_arc, rel_logits=s_rel)


class TransformerBiaffine(BaseModule):
Expand Down Expand Up @@ -140,9 +140,7 @@ def forward(
head_mask=None,
inputs_embeds=None,
head=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Expand All @@ -156,8 +154,9 @@ def forward(
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states
output_attentions=False,
output_hidden_states=False,
return_dict=False,
)
sequence_output = hidden_states[0]
sequence_output = sequence_output[:, :-1, :]
Expand All @@ -169,6 +168,5 @@ def forward(
word_index=word_index,
word_attention_mask=word_attention_mask,
head=head,
labels=labels,
hidden_states=hidden_states
labels=labels
)
48 changes: 22 additions & 26 deletions ltp/transformer_biaffine_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from torch.nn import functional as F
from transformers import AutoModel

from collections import namedtuple
from ltp.nn import BaseModule, MLP, Bilinear, CRF

SRLResult = namedtuple('SRLResult', ['loss', 'rel_logits', 'decoded', 'labels', 'transitions'])


class BiaffineCRFClassifier(nn.Module):
def __init__(self, input_size, label_num, dropout, hidden_size=300, eval_transitions=False):
Expand All @@ -27,8 +30,7 @@ def rel_forword(self, input):

return s_rel

def forward(self, input, logits_mask=None, word_index=None,
word_attention_mask=None, labels=None, hidden_states=None):
def forward(self, input, logits_mask=None, word_index=None, word_attention_mask=None, labels=None):
if word_index is not None:
input = torch.gather(input, dim=1, index=word_index.unsqueeze(-1).expand(-1, -1, input.size(-1)))

Expand All @@ -50,24 +52,20 @@ def forward(self, input, logits_mask=None, word_index=None,
if labels is not None:
labels = labels.flatten(end_dim=1)[index]

if self.training:
loss = - self.rel_crf.forward(emissions=crf_rel, tags=labels, mask=mask)
loss_output = (s_rel, None, labels)
elif self.eval_transitions:
loss_output = (
s_rel,
(
self.rel_crf.start_transitions,
self.rel_crf.transitions,
self.rel_crf.end_transitions
),
labels
)
else:
loss_output = (s_rel, self.rel_crf.decode(emissions=crf_rel, mask=mask), labels)
loss, decoded = None, None
loss = - self.rel_crf.forward(emissions=crf_rel, tags=labels, mask=mask)

output = ((loss_output,) + hidden_states[1:]) if hidden_states is not None else (loss_output,)
return ((loss,) + output) if loss is not None else output
if not self.training:
decoded = self.rel_crf.decode(emissions=crf_rel, mask=mask)

return SRLResult(
loss=loss, rel_logits=s_rel, decoded=decoded, labels=labels,
transitions=(
self.rel_crf.start_transitions,
self.rel_crf.transitions,
self.rel_crf.end_transitions
)
)


class TransformerBiaffineCRF(BaseModule):
Expand Down Expand Up @@ -110,9 +108,7 @@ def forward(
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
labels=None
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Expand All @@ -126,8 +122,9 @@ def forward(
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states
output_attentions=False,
output_hidden_states=False,
return_dict=False,
)
sequence_output = hidden_states[0]
sequence_output = sequence_output[:, 1:-1, :]
Expand All @@ -138,6 +135,5 @@ def forward(
logits_mask=logits_mask,
word_index=word_index,
word_attention_mask=word_attention_mask,
labels=labels,
hidden_states=hidden_states
labels=labels
)
Loading

0 comments on commit e79c488

Please sign in to comment.