forked from SciSharp/LLamaSharp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLLamaWeights.cs
218 lines (187 loc) · 8.46 KB
/
LLamaWeights.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Exceptions;
using LLama.Extensions;
using LLama.Native;
using Microsoft.Extensions.Logging;
namespace LLama
{
/// <summary>
/// A set of model weights, loaded into memory.
/// </summary>
public sealed class LLamaWeights
: IDisposable
{
/// <summary>
/// The native handle, which is used in the native APIs
/// </summary>
/// <remarks>Be careful how you use this!</remarks>
public SafeLlamaModelHandle NativeHandle { get; }
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount => NativeHandle.VocabCount;
/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize => NativeHandle.ContextSize;
/// <summary>
/// Get the size of this model in bytes
/// </summary>
public ulong SizeInBytes => NativeHandle.SizeInBytes;
/// <summary>
/// Get the number of parameters in this model
/// </summary>
public ulong ParameterCount => NativeHandle.ParameterCount;
/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => NativeHandle.EmbeddingSize;
/// <summary>
/// Get the special tokens of this model
/// </summary>
public SafeLlamaModelHandle.ModelTokens Tokens => NativeHandle.Tokens;
/// <summary>
/// All metadata keys in this model
/// </summary>
public IReadOnlyDictionary<string, string> Metadata { get; set; }
private LLamaWeights(SafeLlamaModelHandle weights)
{
NativeHandle = weights;
Metadata = weights.ReadMetadata();
}
/// <summary>
/// Load weights into memory
/// </summary>
/// <param name="params"></param>
/// <returns></returns>
public static LLamaWeights LoadFromFile(IModelParams @params)
{
using var pin = @params.ToLlamaModelParams(out var lparams);
var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
foreach (var adapter in @params.LoraAdapters)
{
if (string.IsNullOrEmpty(adapter.Path))
continue;
if (adapter.Scale <= 0)
continue;
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase);
}
return new LLamaWeights(weights);
}
/// <summary>
/// Load weights into memory
/// </summary>
/// <param name="params">Parameters to use to load the model</param>
/// <param name="token">A cancellation token that can interrupt model loading</param>
/// <param name="progressReporter">Receives progress updates as the model loads (0 to 1)</param>
/// <returns></returns>
/// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception>
/// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception>
public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress<float>? progressReporter = null)
{
// don't touch the @params object inside the task, it might be changed
// externally! Save a copy of everything that we need later.
var modelPath = @params.ModelPath;
var loraBase = @params.LoraBase;
var loraAdapters = @params.LoraAdapters.ToArray();
// Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
// slightly smaller range to allow some space for reporting LoRA loading too.
var modelLoadProgressRange = 1f;
if (loraAdapters.Length > 0)
modelLoadProgressRange = 0.9f;
using (@params.ToLlamaModelParams(out var lparams))
{
#if !NETSTANDARD2_0
// Overwrite the progress callback with one which polls the cancellation token and updates the progress object
if (token.CanBeCanceled || progressReporter != null)
{
var internalCallback = lparams.progress_callback;
lparams.progress_callback = (progress, ctx) =>
{
// Update the progress reporter (remapping the value into the smaller range).
progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange);
// If the user set a callback in the model params, call that and see if we should cancel
if (internalCallback != null && !internalCallback(progress, ctx))
return false;
// Check the cancellation token
if (token.IsCancellationRequested)
return false;
return true;
};
}
#endif
var model = await Task.Run(() =>
{
try
{
// Load the model
var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams);
// Apply the LoRA adapters
for (var i = 0; i < loraAdapters.Length; i++)
{
// Interrupt applying LoRAs if the token is cancelled
if (token.IsCancellationRequested)
{
weights.Dispose();
token.ThrowIfCancellationRequested();
}
// Don't apply invalid adapters
var adapter = loraAdapters[i];
if (string.IsNullOrEmpty(adapter.Path))
continue;
if (adapter.Scale <= 0)
continue;
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase);
// Report progress. Model loading reported progress from 0 -> 0.9, use
// the last 0.1 to represent all of the LoRA adapters being applied.
progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1));
}
// Update progress reporter to indicate completion
progressReporter?.Report(1);
return new LLamaWeights(weights);
}
catch (LoadWeightsFailedException)
{
// Convert a LoadWeightsFailedException into a cancellation exception if possible.
token.ThrowIfCancellationRequested();
// Ok the weights failed to load for some reason other than cancellation.
throw;
}
}, token);
return model;
}
}
/// <inheritdoc />
public void Dispose()
{
NativeHandle.Dispose();
}
/// <summary>
/// Create a llama_context using this model
/// </summary>
/// <param name="params"></param>
/// <param name="logger"></param>
/// <returns></returns>
public LLamaContext CreateContext(IContextParams @params, ILogger? logger = null)
{
return new LLamaContext(this, @params, logger);
}
/// <summary>
/// Convert a string of text into tokens
/// </summary>
/// <param name="text"></param>
/// <param name="add_bos"></param>
/// <param name="encoding"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns>
public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{
return NativeHandle.Tokenize(text, add_bos, special, encoding);
}
}
}