forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathann_test.py
231 lines (204 loc) · 8.65 KB
/
ann_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import lax
from jax._src import test_util as jtu
from jax._src.util import prod
from jax.config import config
config.parse_flags_with_absl()
ignore_jit_of_pmap_warning = partial(
jtu.ignore_warning,message=".*jit-of-pmap.*")
def compute_recall(result_neighbors, ground_truth_neighbors) -> float:
"""Computes the recall of an approximate nearest neighbor search.
Args:
result_neighbors: int32 numpy array of the shape [num_queries,
neighbors_per_query] where the values are the indices of the dataset.
ground_truth_neighbors: int32 numpy array of with shape [num_queries,
ground_truth_neighbors_per_query] where the values are the indices of the
dataset.
Returns:
The recall.
"""
assert len(
result_neighbors.shape) == 2, "shape = [num_queries, neighbors_per_query]"
assert len(ground_truth_neighbors.shape
) == 2, "shape = [num_queries, ground_truth_neighbors_per_query]"
assert result_neighbors.shape[0] == ground_truth_neighbors.shape[0]
gt_sets = [set(np.asarray(x)) for x in ground_truth_neighbors]
hits = sum(
len(list(x
for x in nn_per_q
if x.item() in gt_sets[q]))
for q, nn_per_q in enumerate(result_neighbors))
return hits / ground_truth_neighbors.size
class AnnTest(jtu.JaxTestCase):
@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name":
"_qy={}_db={}_k={}_recall={}".format(
jtu.format_shape_dtype_string(qy_shape, dtype),
jtu.format_shape_dtype_string(db_shape, dtype), k, recall),
"qy_shape": qy_shape, "db_shape": db_shape, "dtype": dtype,
"k": k, "recall": recall }
for qy_shape in [(200, 128), (128, 128)]
for db_shape in [(128, 500), (128, 3000)]
for dtype in jtu.dtypes.all_floating
for k in [1, 10, 50] for recall in [0.9, 0.95]))
def test_approx_max_k(self, qy_shape, db_shape, dtype, k, recall):
rng = jtu.rand_default(self.rng())
qy = rng(qy_shape, dtype)
db = rng(db_shape, dtype)
scores = lax.dot(qy, db)
_, gt_args = lax.top_k(scores, k)
_, ann_args = lax.approx_max_k(scores, k, recall_target=recall)
self.assertEqual(k, len(ann_args[0]))
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name":
"_qy={}_db={}_k={}_recall={}".format(
jtu.format_shape_dtype_string(qy_shape, dtype),
jtu.format_shape_dtype_string(db_shape, dtype), k, recall),
"qy_shape": qy_shape, "db_shape": db_shape, "dtype": dtype,
"k": k, "recall": recall }
for qy_shape in [(200, 128), (128, 128)]
for db_shape in [(128, 500), (128, 3000)]
for dtype in jtu.dtypes.all_floating
for k in [1, 10, 50] for recall in [0.9, 0.95]))
def test_approx_min_k(self, qy_shape, db_shape, dtype, k, recall):
rng = jtu.rand_default(self.rng())
qy = rng(qy_shape, dtype)
db = rng(db_shape, dtype)
scores = lax.dot(qy, db)
_, gt_args = lax.top_k(-scores, k)
_, ann_args = lax.approx_min_k(scores, k, recall_target=recall)
self.assertEqual(k, len(ann_args[0]))
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name":
"_shape={}_k={}_max_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k, is_max_k),
"shape": shape, "dtype": dtype, "k": k, "is_max_k": is_max_k }
for dtype in [np.float32]
for shape in [(4,), (5, 5), (2, 1, 4)]
for k in [1, 3]
for is_max_k in [True, False]))
def test_autodiff(self, shape, dtype, k, is_max_k):
vals = np.arange(prod(shape), dtype=dtype)
vals = self.rng().permutation(vals).reshape(shape)
if is_max_k:
fn = lambda vs: lax.approx_max_k(vs, k=k)[0]
else:
fn = lambda vs: lax.approx_min_k(vs, k=k)[0]
jtu.check_grads(fn, (vals,), 2, ["fwd", "rev"], eps=1e-2)
@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name":
"_qy={}_db={}_k={}_recall={}".format(
jtu.format_shape_dtype_string(qy_shape, dtype),
jtu.format_shape_dtype_string(db_shape, dtype), k, recall),
"qy_shape": qy_shape, "db_shape": db_shape, "dtype": dtype,
"k": k, "recall": recall }
for qy_shape in [(200, 128), (128, 128)]
for db_shape in [(2048, 128)]
for dtype in jtu.dtypes.all_floating
for k in [1, 10] for recall in [0.9, 0.95]))
def test_pmap(self, qy_shape, db_shape, dtype, k, recall):
num_devices = jax.device_count()
rng = jtu.rand_default(self.rng())
qy = rng(qy_shape, dtype)
db = rng(db_shape, dtype)
db_size = db.shape[0]
gt_scores = lax.dot_general(qy, db, (([1], [1]), ([], [])))
_, gt_args = lax.top_k(-gt_scores, k) # negate the score to get min-k
db_per_device = db_size//num_devices
sharded_db = db.reshape(num_devices, db_per_device, 128)
db_offsets = np.arange(num_devices, dtype=np.int32) * db_per_device
def parallel_topk(qy, db, db_offset):
scores = lax.dot_general(qy, db, (([1],[1]),([],[])))
ann_vals, ann_args = lax.approx_min_k(
scores,
k=k,
reduction_dimension=1,
recall_target=recall,
reduction_input_size_override=db_size,
aggregate_to_topk=False)
return (ann_vals, ann_args + db_offset)
# shape = qy_size, num_devices, approx_dp
ann_vals, ann_args = jax.pmap(
parallel_topk,
in_axes=(None, 0, 0),
out_axes=(1, 1))(qy, sharded_db, db_offsets)
# collapse num_devices and approx_dp
ann_vals = lax.collapse(ann_vals, 1, 3)
ann_args = lax.collapse(ann_args, 1, 3)
ann_vals, ann_args = lax.sort_key_val(ann_vals, ann_args, dimension=1)
ann_args = lax.slice_in_dim(ann_args, start_index=0, limit_index=k, axis=1)
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
def test_vmap_before(self):
batch = 4
qy_size = 128
db_size = 1024
feature_dim = 32
k = 10
rng = jtu.rand_default(self.rng())
qy = rng([batch, qy_size, feature_dim], np.float32)
db = rng([batch, db_size, feature_dim], np.float32)
recall = 0.95
# Create ground truth
gt_scores = lax.dot_general(qy, db, (([2], [2]), ([0], [0])))
_, gt_args = lax.top_k(gt_scores, k)
gt_args = lax.reshape(gt_args, [qy_size * batch, k])
# test target
def approx_max_k(qy, db):
scores = qy @ db.transpose()
return lax.approx_max_k(scores, k)
_, ann_args = jax.vmap(approx_max_k, (0, 0))(qy, db)
ann_args = lax.reshape(ann_args, [qy_size * batch, k])
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
def test_vmap_after(self):
batch = 4
qy_size = 128
db_size = 1024
feature_dim = 32
k = 10
rng = jtu.rand_default(self.rng())
qy = rng([qy_size, feature_dim, batch], np.float32)
db = rng([db_size, feature_dim, batch], np.float32)
recall = 0.95
# Create ground truth
gt_scores = lax.dot_general(qy, db, (([1], [1]), ([2], [2])))
_, gt_args = lax.top_k(gt_scores, k)
gt_args = lax.transpose(gt_args, [2, 0, 1])
gt_args = lax.reshape(gt_args, [qy_size * batch, k])
# test target
def approx_max_k(qy, db):
scores = qy @ db.transpose()
return lax.approx_max_k(scores, k)
_, ann_args = jax.vmap(approx_max_k, (2, 2))(qy, db)
ann_args = lax.transpose(ann_args, [2, 0, 1])
ann_args = lax.reshape(ann_args, [qy_size * batch, k])
ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
self.assertGreater(ann_recall, recall)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())