Skip to content

Commit

Permalink
修复依存句法解码和语义依存在batch解码时出现无效值的问题 HIT-SCIR#429
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Nov 10, 2020
1 parent aa83d77 commit b998028
Showing 2 changed files with 7 additions and 1 deletion.
4 changes: 4 additions & 0 deletions ltp/ltp.py
Original file line number Diff line number Diff line change
@@ -371,6 +371,8 @@ def dep(self, hidden: dict, fast=False):

word_cls_mask = hidden['word_cls_mask']
word_cls_mask = word_cls_mask.unsqueeze(-1).expand(-1, -1, word_cls_mask.size(1))
word_cls_mask = word_cls_mask & word_cls_mask.transpose(-1, -2)
word_cls_mask[:, :, 0] = True
dep_arc = dep_arc & word_cls_mask
dep_label = get_graph_entities(dep_arc, dep_label, self.dep_vocab)

@@ -412,6 +414,8 @@ def sdp(self, hidden: dict, graph=True):

word_cls_mask = hidden['word_cls_mask']
word_cls_mask = word_cls_mask.unsqueeze(-1).expand(-1, -1, word_cls_mask.size(1))
word_cls_mask = word_cls_mask & word_cls_mask.transpose(-1, -2)
word_cls_mask[:, :, 0] = True
sdp_arc = sdp_arc & word_cls_mask
sdp_label = get_graph_entities(sdp_arc, sdp_label, self.sdp_vocab)

4 changes: 3 additions & 1 deletion utils/mini_run.py
Original file line number Diff line number Diff line change
@@ -86,7 +86,9 @@ def test(self, sentences: List[str] = None):
self.ltp.add_words("DMI与主机通讯中断")
if sentences is None:
sentences = [
"他叫汤姆去拿外衣。"
"我们都是中国人。",
"遇到苦难不要放弃,加油吧!奥利给!",
"乔丹是一位出生在纽约的美国职业篮球运动员。"
]
res = self._predict([sentence.strip() for sentence in sentences])
print(json.dumps(res, indent=2, sort_keys=True, ensure_ascii=False))

0 comments on commit b998028

Please sign in to comment.