Skip to content

Commit

Permalink
throw exception if metric is broken
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Jul 1, 2017
1 parent 9243892 commit c65f1d5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
35 changes: 16 additions & 19 deletions src/annoymodule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,23 @@ typedef struct {
static PyObject *
py_an_new(PyTypeObject *type, PyObject *args, PyObject *kwds) {
py_annoy *self;

self = (py_annoy *)type->tp_alloc(type, 0);
if (self != NULL) {
self->f = 0;
self->ptr = NULL;
if (self == NULL) {
return NULL;
}
const char *metric = NULL;

if (!PyArg_ParseTuple(args, "i|s", &self->f, &metric))
return NULL;
if (!metric || !strcmp(metric, "angular")) {
self->ptr = new AnnoyIndex<int32_t, float, Angular, Kiss64Random>(self->f);
} else if (!strcmp(metric, "euclidean")) {
self->ptr = new AnnoyIndex<int32_t, float, Euclidean, Kiss64Random>(self->f);
} else if (!strcmp(metric, "manhattan")) {
self->ptr = new AnnoyIndex<int32_t, float, Manhattan, Kiss64Random>(self->f);
} else {
PyErr_SetString(PyExc_ValueError, "No such metric");
return NULL;
}

return (PyObject *)self;
Expand All @@ -59,21 +71,6 @@ py_an_new(PyTypeObject *type, PyObject *args, PyObject *kwds) {

static int
py_an_init(py_annoy *self, PyObject *args, PyObject *kwds) {
const char *metric;

if (!PyArg_ParseTuple(args, "is", &self->f, &metric))
return -1;
switch(metric[0]) {
case 'a':
self->ptr = new AnnoyIndex<int32_t, float, Angular, Kiss64Random>(self->f);
break;
case 'e':
self->ptr = new AnnoyIndex<int32_t, float, Euclidean, Kiss64Random>(self->f);
break;
case 'm':
self->ptr = new AnnoyIndex<int32_t, float, Manhattan, Kiss64Random>(self->f);
break;
}
return 0;
}

Expand Down
3 changes: 3 additions & 0 deletions test/annoy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,9 @@ def test_seed(self):
i.load('test/test.tree')
i.set_seed(42)

def test_unknown_distance(self):
self.assertRaises(Exception, AnnoyIndex, 10, 'banana')


class TypesTest(TestCase):
def test_numpy(self, n_points=1000, n_trees=10):
Expand Down

0 comments on commit c65f1d5

Please sign in to comment.