Skip to content

Commit

Permalink
修复由于 huggingface/transformers 更新模型结构带来的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Sep 3, 2020
1 parent 82059cf commit 8f1335f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ltp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>
__version__ = '4.0.8'
__version__ = '4.0.9'

from .core import Registrable
from .data import Dataset
Expand Down
6 changes: 5 additions & 1 deletion ltp/ltp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import torch
import itertools
import regex as re
import transformers
from typing import Union, List

from packaging import version
from transformers import AutoTokenizer, cached_path, TensorType, BatchEncoding
from transformers.file_utils import is_remote_url

transformers_version = version.parse(transformers.__version__)

from ltp.models import Model
from ltp.utils import length_to_mask, eisner, split_sentence
from ltp.utils import get_entities
Expand Down Expand Up @@ -118,7 +122,7 @@ def __init__(self, path: str = 'small', device=None, **kwargs):
ckpt['model_config']['init'].pop('pretrained')
self.cache_dir = path
self.model = Model.from_params(ckpt['model_config'], config=ckpt['pretrained_config']).to(self.device)
self.model.load_state_dict(ckpt['model'])
self.model.load_state_dict(ckpt['model'], strict=transformers_version < version.parse("3.1.0"))
self.model.eval()
# todo fp16
self.max_length = self.model.pretrained.config.max_position_embeddings
Expand Down
14 changes: 14 additions & 0 deletions ltp/nn/bilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,28 @@ def reset_parameters(self):
bound = 1 / math.sqrt(self.weight.size(1))
nn.init.uniform_(self.weight, -bound, bound)

def onnx_forward(self, x1: Tensor, x2: Tensor):
if self.bias_x:
x1 = torch.cat((x1, torch.ones_like(x1[..., :1])), -1)
if self.bias_y:
x2 = torch.cat((x2, torch.ones_like(x2[..., :1])), -1)
x1 = x1.unsqueeze(1)
x2 = x2.unsqueeze(1)
s: Tensor = x1 @ self.weight @ x2.transpose(-1, -2)
if s.size(1) == 1:
s = s.squeeze(1)
return s

def forward(self, x1: Tensor, x2: Tensor):
res = self.onnx_forward(x1, x2)
if self.bias_x:
x1 = torch.cat((x1, torch.ones_like(x1[..., :1])), -1)
if self.bias_y:
x2 = torch.cat((x2, torch.ones_like(x2[..., :1])), -1)
if self.expand:
# [batch_size, n_out, seq_len, seq_len]
s = torch.einsum('bxi,oij,byj->boxy', x1, self.weight, x2)
test = torch.sum(s - res)
return s
# [batch_size, n_out, seq_len]
return F.bilinear(x1, x2, self.weight, None)
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
install_requires=[
'torch>=1.2.0',
'torchtext==0.5.0',
'transformers>=3.0',
'transformers>=3.0, <3.2.*',
'pygtrie==2.3.3',
'tqdm',
'toml',
'fire',
'fire'
],
classifiers=[
'Development Status :: 1 - Planning',
Expand Down

0 comments on commit 8f1335f

Please sign in to comment.