-
Notifications
You must be signed in to change notification settings - Fork 1
/
nearest.py
76 lines (58 loc) · 1.83 KB
/
nearest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python
#
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple tool for inspecting nearest neighbors and analogies."""
from __future__ import print_function
import re
import sys
from getopt import GetoptError, getopt
from vecs import Vecs
try:
opts, args = getopt(sys.argv[1:], 'v:e:', ['vocab=', 'embeddings='])
except GetoptError as e:
print(e, file=sys.stderr)
sys.exit(2)
opt_vocab = 'vocab.txt'
opt_embeddings = None
for o, a in opts:
if o in ('-v', '--vocab'):
opt_vocab = a
if o in ('-e', '--embeddings'):
opt_embeddings = a
vecs = Vecs(opt_vocab, opt_embeddings)
while True:
sys.stdout.write('query> ')
sys.stdout.flush()
query = sys.stdin.readline().strip()
if not query:
break
parts = re.split(r'\s+', query)
if len(parts) == 1:
res = vecs.neighbors(parts[0])
elif len(parts) == 3:
vs = [vecs.lookup(w) for w in parts]
if any(v is None for v in vs):
print('not in vocabulary: %s' % (
', '.join(tok for tok, v in zip(parts, vs) if v is None)))
continue
res = vecs.neighbors(vs[2] - vs[0] + vs[1])
else:
print('use a single word to query neighbors, or three words for analogy')
continue
if not res:
continue
for word, sim in res[:20]:
print('%0.4f: %s' % (sim, word))
print()