forked from SciSharp/LLamaSharp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLLamaEmbedder.cs
138 lines (116 loc) · 4.62 KB
/
LLamaEmbedder.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
using LLama.Native;
using System;
using LLama.Exceptions;
using LLama.Abstractions;
using Microsoft.Extensions.Logging;
using System.Threading;
using System.Threading.Tasks;
namespace LLama
{
/// <summary>
/// The embedder for LLama, which supports getting embeddings from text.
/// </summary>
public sealed class LLamaEmbedder
: IDisposable
{
/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => Context.EmbeddingSize;
/// <summary>
/// LLama Context
/// </summary>
public LLamaContext Context { get; }
/// <summary>
/// Create a new embedder, using the given LLamaWeights
/// </summary>
/// <param name="weights"></param>
/// <param name="params"></param>
/// <param name="logger"></param>
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
if ([email protected])
throw new ArgumentException("EmbeddingMode must be true", nameof(@params));
Context = weights.CreateContext(@params, logger);
}
/// <summary>
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public Task<float[]> GetEmbeddings(string text, CancellationToken cancellationToken = default)
{
return GetEmbeddings(text, true, cancellationToken);
}
/// <summary>
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <param name="addBos">Add bos to the text.</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public async Task<float[]> GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default)
{
var tokens = Context.Tokenize(text, addBos);
if (tokens.Length > Context.ContextSize)
throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text));
// Evaluate prompt in batch-size chunks
var n_past = 0;
var batch = new LLamaBatch();
var batchSize = (int)Context.Params.BatchSize;
for (var i = 0; i < tokens.Length; i += batchSize)
{
var n_eval = tokens.Length - i;
if (n_eval > batchSize)
n_eval = batchSize;
batch.Clear();
batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, true);
n_past += n_eval;
var returnCode = await Context.DecodeAsync(batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}
var embeddings = GetEmbeddingsArray();
// Remove everything we just evaluated from the context cache
Context.NativeHandle.KvCacheClear();
// Normalize the embeddings vector
// https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92
Normalize(embeddings);
return embeddings;
}
private float[] GetEmbeddingsArray()
{
unsafe
{
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero);
if (embeddings == null)
return Array.Empty<float>();
return new Span<float>(embeddings, Context.EmbeddingSize).ToArray();
}
}
private static void Normalize(Span<float> embeddings)
{
// Calculate length
var lengthSqr = 0.0;
foreach (var value in embeddings)
lengthSqr += value * value;
var length = (float)Math.Sqrt(lengthSqr);
// Do not divide by length if it is zero
if (length <= float.Epsilon)
return;
// Normalize
for (var i = 0; i < embeddings.Length; i++)
embeddings[i] /= length;
}
/// <inheritdoc />
public void Dispose()
{
Context.Dispose();
}
}
}