Skip to content

Commit

Permalink
Update by commit
Browse files Browse the repository at this point in the history
  • Loading branch information
threelittlemonkeys committed Apr 9, 2019
1 parent 6673af7 commit ae7ace7
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

UNIT = "word" # unit of tokenization (char, word)
UNIT = "char" # unit of tokenization (char, word)
RNN_TYPE = "LSTM" # LSTM or GRU
NUM_DIRS = 2 # unidirectional: 1, bidirectional: 2
NUM_LAYERS = 2
Expand Down
2 changes: 1 addition & 1 deletion pos-tagging/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def load_data():
data.append(x + y)
data.sort(key = lambda x: -len(x))
fo.close()
return data, cti, wti, tti
return data, cti, wti, tti

if __name__ == "__main__":
if len(sys.argv) != 2:
Expand Down
12 changes: 8 additions & 4 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@ def run_model(model, itt, batch):
xc, xw = batchify(*zip(*[(x[2], x[3]) for x in batch]), sos = False, eos = False)
result = model.decode(xc, xw)
for i in range(batch_size):
batch[i].append(tuple(itt[j] for j in result[i]))
batch[i].append([itt[j] for j in result[i]])
return [(x[1], x[4], x[5]) for x in sorted(batch[:batch_size])]

def predict(filename, model, cti, wti, itt):
def predict(filename, model, cti, wti, itt, iob = False):
data = []
fo = open(filename)
for idx, line in enumerate(fo):
line = line.strip()
if re.match("(\S+/\S+( |$))+", line):
x, y = zip(*[re.split("/(?=[^/]+$)", x) for x in line.split()])
if iob:
x, y = tokenize(line, UNIT), []
for w in line.split(" "):
y.extend(["B"] + ["I"] * (len(w) - 1))
elif re.match("(\S+/\S+( |$))+", line):
x, y = zip(*[re.split("/(?=[^/]+$)", x) for x in line.split(" ")])
x = [normalize(x) for x in x]
else:
x, y = tokenize(line, UNIT), None
Expand Down
11 changes: 5 additions & 6 deletions word-segmentation/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@ def load_data():
for line in fo:
line = line.strip()
tokens = line.split(" ")
seq = []
tags = []
x = []
y = []
for word in tokens:
if not KEEP_IDX:
for c in word:
if c not in cti:
cti[c] = len(cti)
ctags = ["B" if i == 0 else "I" for i in range(len(word))]
seq.extend(["%d:%d" % (cti[c], cti[c]) if c in cti else str(UNK_IDX) for c in word])
tags.extend([str(tti[t]) for t in ctags])
data.append(seq + tags)
x.extend(["%d:%d" % (cti[c], cti[c]) if c in cti else str(UNK_IDX) for c in word])
y.extend([str(tti["B"])] + [str(tti["I"])] * (len(word) - 1))
data.append(x + y)
data.sort(key = lambda x: -len(x))
fo.close()
return data, cti, tti
Expand Down

0 comments on commit ae7ace7

Please sign in to comment.