Skip to content

Commit

Permalink
Replace vector_l2sq_dist with overloaded l2sq_dist (lanterndata#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
ezra-varady authored and Ngalstyan4 committed Sep 25, 2023
1 parent d511abf commit d3ba5b5
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
16 changes: 3 additions & 13 deletions sql/lantern.sql
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ BEGIN

IF pgvector_exists THEN
-- taken from pgvector so our index can work with pgvector types
CREATE FUNCTION vector_l2sq_dist(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
CREATE FUNCTION l2sq_dist(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME', 'vector_l2sq_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE OPERATOR CLASS dist_vec_l2sq_ops
DEFAULT FOR TYPE vector USING lantern_hnsw AS
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_l2sq_dist(vector, vector);
FUNCTION 1 l2sq_dist(vector, vector);
END IF;


Expand All @@ -46,16 +46,6 @@ BEGIN
-- create access method
CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnsw_handler;
COMMENT ON ACCESS METHOD hnsw IS 'LanternDB access method for vector embeddings, based on the hnsw algorithm';
IF pgvector_exists THEN
-- An older version of pgvector exists, which does not have hnsw yet
-- So, there is no naming conflict. We still add a compatibility operator class
-- so pgvector vector type can be used with our index
-- taken from pgvector so our index can work with pgvector types
CREATE OPERATOR CLASS dist_vec_l2sq_ops
DEFAULT FOR TYPE vector USING hnsw AS
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 vector_l2sq_dist(vector, vector);
END IF;
END IF;
END;
$BODY$
Expand Down
15 changes: 15 additions & 0 deletions sql/updates/0.0.4-latest.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- drop the old implementation
DROP FUNCTION IF EXISTS vector_l2sq_dist(vector, vector);

-- replace is with overloaded version
-- Check if the vector type from pgvector exists
SELECT EXISTS (
SELECT 1
FROM pg_type
WHERE typname = 'vector'
) INTO pgvector_exists;

IF pgvector_exists THEN
CREATE FUNCTION l2sq_dist(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME', 'vector_l2sq_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
END IF
6 changes: 3 additions & 3 deletions test/expected/hnsw_vector.out
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7;
Order By: (v <-> '[0,1,0]'::vector)
(3 rows)

SELECT ROUND(vector_l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist
SELECT ROUND(l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist
FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7;
dist
------
Expand All @@ -111,7 +111,7 @@ FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7;
2.00
(7 rows)

EXPLAIN (COSTS FALSE) SELECT ROUND(vector_l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist
EXPLAIN (COSTS FALSE) SELECT ROUND(l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist
FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7;
QUERY PLAN
---------------------------------------------------------
Expand Down Expand Up @@ -171,5 +171,5 @@ ERROR: Operator <-> has no standalone meaning and is reserved for use in vector
-- Expect error due to mismatching vector dimensions
SELECT 1 FROM small_world ORDER BY v <-> '[0,1,0,1]' LIMIT 1;
ERROR: Expected vector with dimension 3, got 4
SELECT vector_l2sq_dist('[1,1]'::vector, '[0,1,0]'::vector);
SELECT l2sq_dist('[1,1]'::vector, '[0,1,0]'::vector);
ERROR: expected equally sized vectors but got vectors with dimensions 2 and 3
6 changes: 3 additions & 3 deletions test/sql/hnsw_vector.sql
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7;
EXPLAIN (COSTS FALSE) SELECT ROUND(l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist
FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7;

SELECT ROUND(vector_l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist
SELECT ROUND(l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist
FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7;
EXPLAIN (COSTS FALSE) SELECT ROUND(vector_l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist
EXPLAIN (COSTS FALSE) SELECT ROUND(l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist
FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7;

-- Verify that index creation on a large vector produces an error
Expand Down Expand Up @@ -77,4 +77,4 @@ SELECT ARRAY[1,2,3] <-> ARRAY[3,2,1];

-- Expect error due to mismatching vector dimensions
SELECT 1 FROM small_world ORDER BY v <-> '[0,1,0,1]' LIMIT 1;
SELECT vector_l2sq_dist('[1,1]'::vector, '[0,1,0]'::vector);
SELECT l2sq_dist('[1,1]'::vector, '[0,1,0]'::vector);

0 comments on commit d3ba5b5

Please sign in to comment.