Skip to content

Commit

Permalink
Merge pull request jbellis#104 from jbellis/assemble-and-sum
Browse files Browse the repository at this point in the history
Add simd approach for summing the cached PQ products of each encoded vector
  • Loading branch information
tjake authored Sep 28, 2023
2 parents e624317 + fb445bb commit 351b849
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ protected CompressedDecoder(CompressedVectors cv) {
}

protected static abstract class CachingDecoder extends CompressedDecoder {
protected final float[][] partialSums;
protected final float[] partialSums;

protected CachingDecoder(CompressedVectors cv, float[] query, VectorSimilarityFunction vsf) {
super(cv);
Expand All @@ -24,16 +24,17 @@ protected CachingDecoder(CompressedVectors cv, float[] query, VectorSimilarityFu

float[] center = pq.getCenter();
var centeredQuery = center == null ? query : VectorUtil.sub(query, center);
for (var i = 0; i < partialSums.length; i++) {
for (var i = 0; i < pq.getSubspaceCount(); i++) {
int offset = pq.subvectorSizesAndOffsets[i][1];
int baseOffset = i * ProductQuantization.CLUSTERS;
for (var j = 0; j < ProductQuantization.CLUSTERS; j++) {
int offset = pq.subvectorSizesAndOffsets[i][1];
float[] centroidSubvector = pq.codebooks[i][j];
switch (vsf) {
case DOT_PRODUCT:
partialSums[i][j] = VectorUtil.dotProduct(centroidSubvector, 0, centeredQuery, offset, centroidSubvector.length);
partialSums[baseOffset + j] = VectorUtil.dotProduct(centroidSubvector, 0, centeredQuery, offset, centroidSubvector.length);
break;
case EUCLIDEAN:
partialSums[i][j] = VectorUtil.squareDistance(centroidSubvector, 0, centeredQuery, offset, centroidSubvector.length);
partialSums[baseOffset + j] = VectorUtil.squareDistance(centroidSubvector, 0, centeredQuery, offset, centroidSubvector.length);
break;
default:
throw new UnsupportedOperationException("Unsupported similarity function " + vsf);
Expand All @@ -43,14 +44,7 @@ protected CachingDecoder(CompressedVectors cv, float[] query, VectorSimilarityFu
}

protected float decodedSimilarity(byte[] encoded) {
// combining cached fragments is the same for dot product and euclidean; cosine is handled separately
float sum = 0.0f;
for (int m = 0; m < partialSums.length; ++m) {
int centroidIndex = Byte.toUnsignedInt(encoded[m]);
var cachedValue = partialSums[m][centroidIndex];
sum += cachedValue;
}
return sum;
return VectorUtil.assembleAndSum(partialSums, ProductQuantization.CLUSTERS, encoded);
}
}

Expand All @@ -77,8 +71,8 @@ public float similarityTo(int node2) {
}

static class CosineDecoder extends CompressedDecoder {
protected final float[][] partialSums;
protected final float[][] aMagnitude;
protected final float[] partialSums;
protected final float[] aMagnitude;
protected final float bMagnitude;

public CosineDecoder(CompressedVectors cv, float[] query) {
Expand All @@ -95,11 +89,10 @@ public CosineDecoder(CompressedVectors cv, float[] query) {

for (int m = 0; m < pq.getSubspaceCount(); ++m) {
int offset = pq.subvectorSizesAndOffsets[m][1];

for (int j = 0; j < ProductQuantization.CLUSTERS; ++j) {
float[] centroidSubvector = pq.codebooks[m][j];
partialSums[m][j] = VectorUtil.dotProduct(centroidSubvector, 0, centeredQuery, offset, centroidSubvector.length);
aMagnitude[m][j] = VectorUtil.dotProduct(centroidSubvector, 0, centroidSubvector, 0, centroidSubvector.length);
partialSums[(m * ProductQuantization.CLUSTERS) + j] = VectorUtil.dotProduct(centroidSubvector, 0, centeredQuery, offset, centroidSubvector.length);
aMagnitude[(m * ProductQuantization.CLUSTERS) + j] = VectorUtil.dotProduct(centroidSubvector, 0, centroidSubvector, 0, centroidSubvector.length);
}

bMagSum += VectorUtil.dotProduct(centeredQuery, offset, centeredQuery, offset, pq.subvectorSizesAndOffsets[m][0]);
Expand All @@ -121,8 +114,8 @@ protected float decodedCosine(int node2) {

for (int m = 0; m < partialSums.length; ++m) {
int centroidIndex = Byte.toUnsignedInt(encoded[m]);
sum += partialSums[m][centroidIndex];
aMag += aMagnitude[m][centroidIndex];
sum += partialSums[(m * ProductQuantization.CLUSTERS) + centroidIndex];
aMag += aMagnitude[(m * ProductQuantization.CLUSTERS) + centroidIndex];
}

return (float) (sum / Math.sqrt(aMag * bMagnitude));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.NeighborSimilarity;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;

import java.io.DataOutput;
Expand All @@ -31,24 +32,16 @@
public class CompressedVectors
{
final ProductQuantization pq;
private final List<byte[]> compressedVectors;
private final ThreadLocal<float[][]> partialSums; // for dot product, euclidean, and cosine
private final ThreadLocal<float[][]> partialMagnitudes; // for cosine
private final byte[][] compressedVectors;
private final ThreadLocal<float[]> partialSums; // for dot product, euclidean, and cosine
private final ThreadLocal<float[]> partialMagnitudes; // for cosine

public CompressedVectors(ProductQuantization pq, List<byte[]> compressedVectors)
public CompressedVectors(ProductQuantization pq, byte[][] compressedVectors)
{
this.pq = pq;
this.compressedVectors = compressedVectors;
this.partialSums = ThreadLocal.withInitial(() -> initFloatFragments(pq));
this.partialMagnitudes = ThreadLocal.withInitial(() -> initFloatFragments(pq));
}

private static float[][] initFloatFragments(ProductQuantization pq) {
float[][] a = new float[pq.getSubspaceCount()][];
for (int i = 0; i < a.length; i++) {
a[i] = new float[ProductQuantization.CLUSTERS];
}
return a;
this.partialSums = ThreadLocal.withInitial(() -> new float[pq.getSubspaceCount() * ProductQuantization.CLUSTERS]);
this.partialMagnitudes = ThreadLocal.withInitial(() -> new float[pq.getSubspaceCount() * ProductQuantization.CLUSTERS]);
}

public void write(DataOutput out) throws IOException
Expand All @@ -57,7 +50,7 @@ public void write(DataOutput out) throws IOException
pq.write(out);

// compressed vectors
out.writeInt(compressedVectors.size());
out.writeInt(compressedVectors.length);
out.writeInt(pq.getSubspaceCount());
for (var v : compressedVectors) {
out.write(v);
Expand All @@ -73,13 +66,13 @@ public static CompressedVectors load(RandomAccessReader in, long offset) throws

// read the vectors
int size = in.readInt();
var compressedVectors = new ArrayList<byte[]>(size);
var compressedVectors = new byte[size][];
int compressedDimension = in.readInt();
for (int i = 0; i < size; i++)
{
byte[] vector = new byte[compressedDimension];
in.readFully(vector);
compressedVectors.add(vector);
compressedVectors[i] = vector;
}

return new CompressedVectors(pq, compressedVectors);
Expand All @@ -92,9 +85,9 @@ public boolean equals(Object o) {

CompressedVectors that = (CompressedVectors) o;
if (!Objects.equals(pq, that.pq)) return false;
if (compressedVectors.size() != that.compressedVectors.size()) return false;
return IntStream.range(0, compressedVectors.size()).allMatch((i) -> {
return Arrays.equals(compressedVectors.get(i), that.compressedVectors.get(i));
if (compressedVectors.length != that.compressedVectors.length) return false;
return IntStream.range(0, compressedVectors.length).allMatch((i) -> {
return Arrays.equals(compressedVectors[i], that.compressedVectors[i]);
});
}

Expand All @@ -117,14 +110,21 @@ public NeighborSimilarity.ApproximateScoreFunction approximateScoreFunctionFor(f
}

public byte[] get(int ordinal) {
return compressedVectors.get(ordinal);
return compressedVectors[ordinal];
}

float[][] reusablePartialSums() {
float[] reusablePartialSums() {
return partialSums.get();
}

float[][] reusablePartialMagnitudes() {
float[] reusablePartialMagnitudes() {
return partialMagnitudes.get();
}

public long memorySize() {
long size = pq.memorySize();
long bsize = RamUsageEstimator.sizeOf(compressedVectors[0]);

return size + (bsize * compressedVectors.length);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.github.jbellis.jvector.disk.Io;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.VectorUtil;

import java.io.DataOutput;
Expand Down Expand Up @@ -100,8 +101,8 @@ public static ProductQuantization compute(RandomAccessVectorValues<float[]> ravv
/**
* Encodes the given vectors in parallel using the PQ codebooks.
*/
public List<byte[]> encodeAll(List<float[]> vectors) {
return vectors.stream().parallel().map(this::encode).collect(Collectors.toList());
public byte[][] encodeAll(List<float[]> vectors) {
return vectors.stream().parallel().map(this::encode).toArray(byte[][]::new);
}

/**
Expand Down Expand Up @@ -344,4 +345,13 @@ public int hashCode() {
public float[] getCenter() {
return globalCentroid;
}

public long memorySize() {
long size = 0;
for (int i = 0; i < codebooks.length; i++)
for (int j = 0; j < codebooks[i].length; j++)
size += RamUsageEstimator.sizeOf(codebooks[i][j]);

return size;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,14 @@ public float[] sub(float[] lhs, float[] rhs) {
}
return result;
}

@Override
public float assembleAndSum(float[] data, int dataBase, byte[] baseOffsets)
{
float sum = 0f;
for (int i = 0; i < baseOffsets.length; i++) {
sum += data[dataBase * i + Byte.toUnsignedInt(baseOffsets[i])];
}
return sum;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
package io.github.jbellis.jvector.vector;

/** Default provider returning scalar implementations. */
final class DefaultVectorizationProvider extends VectorizationProvider {
final public class DefaultVectorizationProvider extends VectorizationProvider {

private final VectorUtilSupport vectorUtilSupport;

DefaultVectorizationProvider() {
public DefaultVectorizationProvider() {
vectorUtilSupport = new DefaultVectorUtilSupport();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,7 @@ public static void addInPlace(float[] v1, float[] v2) {
public static float[] sub(float[] lhs, float[] rhs) {
return impl.sub(lhs, rhs);
}
public static float assembleAndSum(float[] data, int dataBase, byte[] dataOffsets) {
return impl.assembleAndSum(data, dataBase, dataOffsets);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,19 @@ public interface VectorUtilSupport {

/** @return lhs - rhs, element-wise */
public float[] sub(float[] lhs, float[] rhs);

/**
* Calculates the sum of sparse points in a vector.
*
* This assumes the data vector is a 2d matrix which has been flattened into 1 dimension
* so rather than data[n][m] it's data[n * m]. With this layout this method can quickly
* assemble the data from this heap and sum it.
*
* @param data the vector of all datapoints
* @param baseIndex the start of the data in the offset table
* (scaled by the index of the lookup table)
* @param baseOffsets bytes that represent offsets from the baseIndex
* @return the sum of the points
*/
public float assembleAndSum(float[] data, int baseIndex, byte[] baseOffsets);
}
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,8 @@ private static void gridSearch(DataSet ds, List<Integer> mGrid, List<Integer> ef

start = System.nanoTime();
var quantizedVectors = pq.encodeAll(ds.baseVectors);
System.out.format("PQ encode %.2fs,%n", (System.nanoTime() - start) / 1_000_000_000.0);

var compressedVectors = new CompressedVectors(pq, quantizedVectors);
System.out.format("PQ encoded %d[%.2f MB] in %.2fs,%n", ds.baseVectors.size(), (compressedVectors.memorySize()/1024f/1024f) , (System.nanoTime() - start) / 1_000_000_000.0);

var testDirectory = Files.createTempDirectory("BenchGraphDir");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.github.jbellis.jvector.microbench;


import io.github.jbellis.jvector.vector.DefaultVectorizationProvider;
import io.github.jbellis.jvector.vector.VectorUtil;
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.infra.Blackhole;
Expand All @@ -26,14 +27,19 @@

@Warmup(iterations = 2, time = 5)
@Measurement(iterations = 3, time = 10)
@Fork(warmups = 1, value = 1)
@Fork(warmups = 1, value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector"})
public class SimilarityBench {

static int SIZE = 256;
private static final DefaultVectorizationProvider java = new DefaultVectorizationProvider();

static int SIZE = 1536;
static final float[] q1 = new float[SIZE];
static final float[] q2 = new float[SIZE];

static final float[] q3 = new float[2];
static final float[] q3 = new float[4];

static final byte[] indexes = new byte[384];


static {
for (int i = 0; i < q1.length; i++) {
Expand All @@ -43,6 +49,11 @@ public class SimilarityBench {

q3[0] = ThreadLocalRandom.current().nextFloat();
q3[1] = ThreadLocalRandom.current().nextFloat();

int offsetSize = 4;
for (int i = 0; i < indexes.length; i++) {
indexes[i] = (byte)(i * offsetSize);
}
}

@State(Scope.Benchmark)
Expand All @@ -51,12 +62,28 @@ public static class Parameters {
}

@Benchmark
@BenchmarkMode(Mode.Throughput)
@Threads(8)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void zipAndSumSimd(Blackhole bh, Parameters p) {
bh.consume(VectorUtil.assembleAndSum(q1, 0, indexes));
}

@Benchmark
@BenchmarkMode(Mode.Throughput)
@Threads(8)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void zzipAndSumJava(Blackhole bh, Parameters p) {
bh.consume(java.getVectorUtilSupport().assembleAndSum(q1, 0, indexes));
}

/* @Benchmark
@BenchmarkMode(Mode.Throughput)
@Threads(8)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void dotProduct(Blackhole bh, Parameters p) {
bh.consume(VectorUtil.dotProduct(q3, 0, q1, 22, q3.length));
}
}*/

public static void main(String[] args) throws Exception {
org.openjdk.jmh.Main.main(args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,30 @@ public void testSimilarityMetricsByte() {
Assert.assertEquals(a.getVectorUtilSupport().squareDistance(v1, v2), b.getVectorUtilSupport().squareDistance(v1, v2));
}
}

@Test
public void testAssembleAndSum() {
Assume.assumeTrue(hasSimd);

VectorizationProvider a = new DefaultVectorizationProvider();
VectorizationProvider b = VectorizationProvider.getInstance();

for (int i = 0; i < 1000; i++) {
float[] v2 = GraphIndexTestCase.randomVector(getRandom(), 256);

float[] v3 = new float[32];
byte[] offsets = new byte[32];
int skipSize = 256/32;
//Assemble v3 from bits of v2
for (int j = 0, c = 0; j < 256; j+=skipSize, c++) {
v3[c] = v2[j];
offsets[c] = (byte) (c * skipSize);
}

Assert.assertEquals(a.getVectorUtilSupport().sum(v3), b.getVectorUtilSupport().sum(v3), 0.0001);
Assert.assertEquals(a.getVectorUtilSupport().sum(v3), a.getVectorUtilSupport().assembleAndSum(v2, 0, offsets), 0.0001);
Assert.assertEquals(b.getVectorUtilSupport().sum(v3), b.getVectorUtilSupport().assembleAndSum(v2, 0, offsets), 0.0001);

}
}
}
Loading

0 comments on commit 351b849

Please sign in to comment.