Skip to content

Commit

Permalink
🎨 保留句子当中的任何空格
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Sep 11, 2022
1 parent 7846ee1 commit 6ff5a57
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions python/interface/ltp/nerual.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,21 +242,22 @@ def _cws_post(
logits = result.logits
attention_mask = result.attention_mask

text = []
char_idx = []
for raw_text, encodings in zip(inputs, tokenized.encodings):
char_pos = []
for sentence, encodings in zip(inputs, tokenized.encodings):
last = None
text.append([])
char_idx.append([])
for idx, current in enumerate(encodings.offsets[1:-1]):
if current == (0, 0):
char_pos.append([])
for idx, (start, end) in enumerate(encodings.offsets[1:-1]):
if start == 0 and end == 0:
break
elif current[0] == current[1]:
elif start == end:
continue
elif current != last:
text[-1].append(raw_text[current[0] : current[1]])
elif (start, end) != last:
char_idx[-1].append(idx)
last = current
char_pos[-1].append(start)
last = (start, end)
char_pos[-1].append(len(sentence))

if crf is None:
decoded = logits.argmax(dim=-1)
Expand All @@ -272,28 +273,27 @@ def _cws_post(
decoded = crf.decode(logits, attention_mask)
decoded = [[self.cws_vocab[tag] for tag in tags] for tags in decoded]
entities = [get_entities([d[i] for i in idx]) for d, idx in zip(decoded, char_idx)]
entities = [[(e[1], e[2]) for e in se] for se in entities]
# t: tag, s: start, e: end
entities = [[(s, e) for (t, s, e) in tse] for tse in entities]

words = [
["".join(sent[e[0] : e[1] + 1]) for e in sent_entities]
for sent, sent_entities in zip(text, entities)
[sent[pos[s] : pos[e + 1]] for s, e in sent_entities]
for sent, pos, sent_entities in zip(inputs, char_pos, entities)
]

if len(self.hook):
words = [self.hook.hook("".join(t), w) for t, w in zip(text, words)]
entities = []

char_len_cumsum = [np.cumsum([len(c) for c in s]) for s in text]
words = [self.hook.hook(t, w) for t, w in zip(inputs, words)]
words_len_cumsum = [np.cumsum([len(w) for w in s]) for s in words]

for char_end, word_end in zip(char_len_cumsum, words_len_cumsum):
entities = []
for char_end, word_end in zip(char_pos, words_len_cumsum):
entities.append([])
char_index = {cl: idx for idx, cl in enumerate(char_end)}
length2index = {cl: idx for idx, cl in enumerate(char_end[1:])}
for i, e in enumerate(word_end):
if i == 0:
entities[-1].append((0, char_index[e]))
entities[-1].append((0, length2index[e]))
else:
entities[-1].append((char_index[word_end[i - 1]] + 1, char_index[e]))
entities[-1].append((length2index[word_end[i - 1]] + 1, length2index[e]))

words_idx = torch.nn.utils.rnn.pad_sequence(
[
Expand Down

0 comments on commit 6ff5a57

Please sign in to comment.