-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathAppData.cs
282 lines (251 loc) · 11.1 KB
/
AppData.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
using System.Data;
using AiServer.ServiceInterface.Generation;
using Microsoft.Extensions.Logging;
using ServiceStack;
using ServiceStack.OrmLite;
using AiServer.ServiceModel;
using AiServer.ServiceModel.Types;
using Microsoft.Extensions.Hosting;
using ServiceStack.Model;
namespace AiServer.ServiceInterface;
public class AppData(ILogger<AppData> log,
AiProviderFactory aiFactory,
MediaProviderFactory mediaProviderFactory,
IHostEnvironment env)
{
public static AppData Instance { get; set; }
// OpenAI/standard-specific properties
public PocoDataSource<AiModel> AiModels { get; set; } = new([]);
public PocoDataSource<AiType> AiTypes { get; set; } = new([]);
public AiProvider[] AiProviders { get; set; } = [];
public MediaProvider[] MediaProviders { get; set; } = [];
public PocoDataSource<MediaType> MediaTypes { get; set; } = new([]);
public PocoDataSource<TextToSpeechVoice> TextToSpeechVoices { get; set; } = new([]);
public MediaModel[] MediaModels { get; set; } = [];
public Dictionary<string, MediaModel> MediaModelsMap { get; set; } = [];
public PocoDataSource<Prompt> Prompts { get; set; } = new([]);
// Shared properties
private CancellationTokenSource? cts;
public CancellationToken Token => cts?.Token ?? CancellationToken.None;
public DateTime? StoppedAt { get; private set; }
public bool IsStopped => StoppedAt != null;
public AiProvider AssertAiProvider(string name) => AiProviders.FirstOrDefault(x => x.Name == name)
?? throw new NotSupportedException($"AI Provider {name} not found");
public MediaProvider AssertMediaProvider(string name) => MediaProviders.FirstOrDefault(x => x.Name == name)
?? throw new NotSupportedException($"Media Provider {name} not found");
public MediaProvider AssertComfyProvider(string name) => ComfyProviders.FirstOrDefault(x => x.Name == name)
?? throw new NotSupportedException($"Comfy Provider {name} not found");
public MediaProvider[] ComfyProviders => MediaProviders
.Where(x => x.MediaType.Provider == AiServiceProvider.Comfy)
.ToArray();
string? ReadTextFile(string path)
{
var fullPath = Path.Combine(env.ContentRootPath, path);
return File.Exists(fullPath)
? File.ReadAllText(fullPath)
: null;
}
T[] LoadModels<T>(string name) where T : IHasId<string>
{
var models = ReadTextFile($"wwwroot/lib/data/{name}").FromJson<List<T>>();
var overrideJson = ReadTextFile($"App_Data/overrides/{name}");
if (overrideJson != null)
{
var insertModels = new List<T>();
var overrideModels = overrideJson.FromJson<List<T>>();
foreach (var model in overrideModels)
{
var index = models.FindIndex(x => x.Id == model.Id);
if (index >= 0)
{
models[index] = model;
}
else
{
insertModels.Add(model);
}
}
models.InsertRange(0, insertModels);
}
return models.ToArray();
}
public void Reload(IDbConnection db)
{
Prompts = PocoDataSource.Create(LoadModels<Prompt>("prompts.json")
.Where(x => !string.IsNullOrEmpty(x.Value)).ToList());
AiModels = PocoDataSource.Create(LoadModels<AiModel>("ai-models.json"));
AiTypes = PocoDataSource.Create(LoadModels<AiType>("ai-types.json"));
MediaTypes = PocoDataSource.Create(LoadModels<MediaType>("media-types.json"));
TextToSpeechVoices = PocoDataSource.Create(LoadModels<TextToSpeechVoice>("tts-voices.json"));
ResetAiProviders(db);
ResetMediaProviders(db);
LoadModelDefaults();
LogWorkerInfo(AiProviders, "API");
}
public void ResetAiProviders(IDbConnection db)
{
AiProviders = db.Select<AiProvider>()
.OrderByDescending(x => x.Priority)
.ThenBy(x => x.Id)
.ToArray();
AiProviders.Each(x => x.AiType = AiTypes.GetAll().FirstOrDefault(t => t.Id == x.AiTypeId)
?? throw new NotSupportedException($"Could not found AiType {x.AiTypeId}"));
}
public void ResetMediaProviders(IDbConnection db)
{
MediaProviders = db.LoadSelect<MediaProvider>()
.OrderByDescending(x => x.Priority)
.ThenBy(x => x.Id)
.ToArray();
MediaProviders.Each(x => x.MediaType = MediaTypes.GetAll().FirstOrDefault(t => t.Id == x.MediaTypeId)
?? throw new NotSupportedException($"Could not found MediaType {x.MediaTypeId}"));
}
private void LoadModelDefaults()
{
MediaModels = LoadModels<MediaModel>("media-models.json")
.Where(x => x is { ApiModels.Keys.Count: > 0, ModelType:
ModelType.TextToImage or
ModelType.TextToSpeech or
ModelType.SpeechToText or
ModelType.ImageUpscale or
ModelType.TextToAudio or
ModelType.TextEncoder or
ModelType.ImageToImage or
ModelType.ImageWithMask or
ModelType.ImageToText
})
.ToArray();
MediaModelsMap = MediaModels.ToDictionary(x => x.Id);
}
public string? GetMediaApiModel(MediaProvider provider, string modelId)
{
if (!MediaModelsMap.TryGetValue(modelId, out var mediaModel))
throw HttpError.NotFound("Model does not exist: " + modelId.SafeVarName());
return mediaModel.ApiModels.GetValueOrDefault(provider.MediaType!.Id);
}
public MediaModel GetMediaModelByApiModel(MediaProvider provider, string apiModel)
{
ArgumentNullException.ThrowIfNull(apiModel);
var providerType = provider?.MediaType?.Id ?? throw new ArgumentNullException(nameof(provider.MediaType));
foreach (var mediaModel in MediaModelsMap.Values)
{
foreach (var entry in mediaModel.ApiModels)
{
if (entry.Key == providerType && entry.Value == apiModel)
return mediaModel;
}
}
throw HttpError.NotFound($"{apiModel} is not a supported model for {provider.Name} ({provider.MediaType?.Id})");
}
public string? GetQualifiedMediaModel(ModelType modelType, string apiModel)
{
foreach (var mediaModel in MediaModels
.Where(x => x.ModelType == modelType))
{
foreach (var entry in mediaModel.ApiModels)
{
if (entry.Value == apiModel)
return mediaModel.Id;
}
}
return null;
}
public string? GetDefaultMediaApiModel(MediaProvider provider, AiTaskType taskType)
{
ArgumentNullException.ThrowIfNull(provider);
var supportedTaskModels = MediaModelsMap.Values
.Where(x => x.ModelType == GetModelTypeByAiTaskType(taskType))
.Where(x => x.ApiModels.ContainsKey(provider.MediaType!.Id) &&
provider.Models != null &&
provider.Models.Contains(x.ApiModels[provider.MediaType!.Id]));
var defaultSupportedModel = supportedTaskModels.FirstOrDefault()?.Id;
if(defaultSupportedModel == null)
throw HttpError.NotFound($"No supported models found for {provider.Name} ({provider.MediaType?.Id})");
return MediaModelsMap[defaultSupportedModel!].ApiModels[provider.MediaType!.Id];
}
public List<string> GetSupportedModels(AiTaskType taskType)
{
return MediaModelsMap
.Where(x => x.Value.ModelType == GetModelTypeByAiTaskType(taskType))
.Select(x => x.Key)
.ToList();
}
public bool ProviderHasModelForTask(MediaProvider provider, AiTaskType taskType)
{
ArgumentNullException.ThrowIfNull(provider);
return MediaModelsMap.Values
.Any(x => x.ModelType == GetModelTypeByAiTaskType(taskType) &&
x.ApiModels.ContainsKey(provider.MediaType!.Id) &&
provider.Models != null &&
provider.Models.Contains(x.ApiModels[provider.MediaType!.Id]));
}
public bool ModelSupportsTask(string modelId, AiTaskType taskType)
{
return MediaModelsMap.TryGetValue(modelId, out var modelSettings) &&
modelSettings.ModelType == GetModelTypeByAiTaskType(taskType);
}
private ModelType GetModelTypeByAiTaskType(AiTaskType taskType)
{
return taskType switch
{
AiTaskType.TextToImage => ModelType.TextToImage,
AiTaskType.TextToSpeech => ModelType.TextToSpeech,
AiTaskType.SpeechToText => ModelType.SpeechToText,
AiTaskType.ImageToImage => ModelType.ImageToImage,
AiTaskType.ImageUpscale => ModelType.ImageUpscale,
AiTaskType.ImageWithMask => ModelType.ImageWithMask,
AiTaskType.ImageToText => ModelType.ImageToText,
_ => throw new NotSupportedException($"Unsupported task type: {taskType}")
};
}
private void LogWorkerInfo(AiProvider[] apiProviders, string workerType)
{
foreach (var worker in apiProviders.Where(x => x.Enabled))
{
log.LogInformation(
"""
[{Type}] [{Name}] is {Enabled}, currently {Online} at concurrency {Concurrency}, accepting models:
{Models}
""",
workerType,
worker.Name,
worker.Enabled ? "Enabled" : "Disabled",
worker.OfflineDate != null ? "Offline" : "Online",
worker.Concurrency,
string.Join("\n ", worker.Models.Select(x => x.Model)));
}
}
public IOpenAiProvider GetOpenAiProvider(AiProvider aiProvider) =>
aiFactory.GetOpenAiProvider(aiProvider.AiType.Provider);
public IAiProvider GetGenerationProvider(MediaProvider apiProvider) =>
mediaProviderFactory.GetProvider(apiProvider.MediaType.Provider);
/// <summary>
/// For ollama models with 'latest' tag, returns:
/// - model:tag if tag exists
/// - model:${latest} when tag is 'latest' or unspecified
/// For models without tags, returns ${model} when exists
/// </summary>
public string? GetQualifiedModel(string model)
{
if (model.IndexOf(':') == -1)
{
var aiModel = AiModels.GetAll().FirstOrDefault(x => x.Id == model);
if (aiModel == null)
return null;
return aiModel.Id + (aiModel.Latest != null ? ":" + aiModel.Latest : "");
}
else
{
var modelGroup = model.LeftPart(':');
var modelTag = model.RightPart(':');
var aiModel = AiModels.GetAll().FirstOrDefault(x => x.Id == modelGroup);
if (aiModel == null)
return null;
if (modelTag == "latest")
return aiModel.Id + (aiModel.Latest != null ? ":" + aiModel.Latest : "");
if (!aiModel.Tags.Contains(modelTag))
return null;
return aiModel.Id + ":" + modelTag;
}
}
}