forked from nomic-ai/gpt4all
-
Notifications
You must be signed in to change notification settings - Fork 0
/
chat.cpp
465 lines (393 loc) · 13.9 KB
/
chat.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
#include "chat.h"
#include "chatlistmodel.h"
#include "mysettings.h"
#include "network.h"
#include "server.h"
#include <QDataStream>
#include <QDateTime>
#include <QDebug>
#include <QLatin1String>
#include <QMap>
#include <QString>
#include <QStringList>
#include <QTextStream>
#include <Qt>
#include <QtGlobal>
#include <QtLogging>
#include <utility>
Chat::Chat(QObject *parent)
: QObject(parent)
, m_id(Network::globalInstance()->generateUniqueId())
, m_name(tr("New Chat"))
, m_chatModel(new ChatModel(this))
, m_responseState(Chat::ResponseStopped)
, m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new ChatLLM(this))
, m_collectionModel(new LocalDocsCollectionsModel(this))
{
connectLLM();
}
Chat::Chat(bool isServer, QObject *parent)
: QObject(parent)
, m_id(Network::globalInstance()->generateUniqueId())
, m_name(tr("Server Chat"))
, m_chatModel(new ChatModel(this))
, m_responseState(Chat::ResponseStopped)
, m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new Server(this))
, m_isServer(true)
, m_collectionModel(new LocalDocsCollectionsModel(this))
{
connectLLM();
}
Chat::~Chat()
{
delete m_llmodel;
m_llmodel = nullptr;
}
void Chat::connectLLM()
{
// Should be in different threads
connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelChangeRequested, m_llmodel, &ChatLLM::modelChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection);
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::QueuedConnection);
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection);
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection);
connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections);
}
void Chat::reset()
{
stopGenerating();
// Erase our current on disk representation as we're completely resetting the chat along with id
ChatListModel::globalInstance()->removeChatFile(this);
emit resetContextRequested();
m_id = Network::globalInstance()->generateUniqueId();
emit idChanged(m_id);
// NOTE: We deliberately do no reset the name or creation date to indictate that this was originally
// an older chat that was reset for another purpose. Resetting this data will lead to the chat
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
// further down in the list. This might surprise the user. In the future, we me might get rid of
// the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown
// we effectively do a reset context. We *have* to do this right now when switching between different
// types of models. The only way to get rid of that would be a very long recalculate where we rebuild
// the context if we switch between different types of models. Probably the right way to fix this
// is to allow switching models but throwing up a dialog warning users if we switch between types
// of models that a long recalculation will ensue.
m_chatModel->clear();
}
void Chat::processSystemPrompt()
{
emit processSystemPromptRequested();
}
void Chat::resetResponseState()
{
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
return;
m_tokenSpeed = QString();
emit tokenSpeedChanged();
m_responseInProgress = true;
m_responseState = m_collections.empty() ? Chat::PromptProcessing : Chat::LocalDocsRetrieval;
emit responseInProgressChanged();
emit responseStateChanged();
}
void Chat::prompt(const QString &prompt)
{
resetResponseState();
emit promptRequested(m_collections, prompt);
}
void Chat::regenerateResponse()
{
const int index = m_chatModel->count() - 1;
m_chatModel->updateSources(index, QList<ResultInfo>());
emit regenerateResponseRequested();
}
void Chat::stopGenerating()
{
m_llmodel->stopGenerating();
}
QString Chat::response() const
{
return m_response;
}
Chat::ResponseState Chat::responseState() const
{
return m_responseState;
}
void Chat::handleResponseChanged(const QString &response)
{
if (m_responseState != Chat::ResponseGeneration) {
m_responseState = Chat::ResponseGeneration;
emit responseStateChanged();
}
m_response = response;
const int index = m_chatModel->count() - 1;
m_chatModel->updateValue(index, this->response());
emit responseChanged();
}
void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
{
if (m_shouldDeleteLater)
deleteLater();
if (loadingPercentage == m_modelLoadingPercentage)
return;
bool wasLoading = isCurrentlyLoading();
bool wasLoaded = isModelLoaded();
m_modelLoadingPercentage = loadingPercentage;
emit modelLoadingPercentageChanged();
if (isCurrentlyLoading() != wasLoading)
emit isCurrentlyLoadingChanged();
if (isModelLoaded() != wasLoaded)
emit isModelLoadedChanged();
}
void Chat::promptProcessing()
{
m_responseState = !databaseResults().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
emit responseStateChanged();
}
void Chat::responseStopped(qint64 promptResponseMs)
{
m_tokenSpeed = QString();
emit tokenSpeedChanged();
emit responseChanged();
m_responseInProgress = false;
m_responseState = Chat::ResponseStopped;
emit responseInProgressChanged();
emit responseStateChanged();
if (m_generatedName.isEmpty())
emit generateNameRequested();
Network::globalInstance()->trackChatEvent("response_complete", {
{"first", m_firstResponse},
{"message_count", chatModel()->count()},
{"$duration", promptResponseMs / 1000.},
});
m_firstResponse = false;
}
ModelInfo Chat::modelInfo() const
{
return m_modelInfo;
}
void Chat::setModelInfo(const ModelInfo &modelInfo)
{
if (m_modelInfo == modelInfo && isModelLoaded())
return;
m_modelInfo = modelInfo;
emit modelInfoChanged();
emit modelChangeRequested(modelInfo);
}
void Chat::newPromptResponsePair(const QString &prompt)
{
resetResponseState();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt(tr("Prompt: "), prompt);
m_chatModel->appendResponse(tr("Response: "), prompt);
emit resetResponseRequested();
}
void Chat::serverNewPromptResponsePair(const QString &prompt)
{
resetResponseState();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt(tr("Prompt: "), prompt);
m_chatModel->appendResponse(tr("Response: "), prompt);
}
bool Chat::isRecalc() const
{
return m_llmodel->isRecalc();
}
void Chat::unloadAndDeleteLater()
{
if (!isModelLoaded()) {
deleteLater();
return;
}
m_shouldDeleteLater = true;
unloadModel();
}
void Chat::markForDeletion()
{
m_llmodel->setMarkedForDeletion(true);
}
void Chat::unloadModel()
{
stopGenerating();
m_llmodel->setShouldBeLoaded(false);
}
void Chat::reloadModel()
{
m_llmodel->setShouldBeLoaded(true);
}
void Chat::forceUnloadModel()
{
stopGenerating();
m_llmodel->setForceUnloadModel(true);
m_llmodel->setShouldBeLoaded(false);
}
void Chat::forceReloadModel()
{
m_llmodel->setForceUnloadModel(true);
m_llmodel->setShouldBeLoaded(true);
}
void Chat::trySwitchContextOfLoadedModel()
{
m_trySwitchContextInProgress = 1;
emit trySwitchContextInProgressChanged();
m_llmodel->requestTrySwitchContext();
}
void Chat::generatedNameChanged(const QString &name)
{
// Only use the first three words maximum and remove newlines and extra spaces
m_generatedName = name.simplified();
QStringList words = m_generatedName.split(' ', Qt::SkipEmptyParts);
int wordCount = qMin(7, words.size());
m_name = words.mid(0, wordCount).join(' ');
emit nameChanged();
}
void Chat::handleRecalculating()
{
Network::globalInstance()->trackChatEvent("recalc_context", { {"length", m_chatModel->count()} });
emit recalcChanged();
}
void Chat::handleModelLoadingError(const QString &error)
{
if (!error.isEmpty()) {
auto stream = qWarning().noquote() << "ERROR:" << error << "id";
stream.quote() << id();
}
m_modelLoadingError = error;
emit modelLoadingErrorChanged();
}
void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
{
m_tokenSpeed = tokenSpeed;
emit tokenSpeedChanged();
}
QString Chat::deviceBackend() const
{
return m_llmodel->deviceBackend();
}
QString Chat::device() const
{
return m_llmodel->device();
}
QString Chat::fallbackReason() const
{
return m_llmodel->fallbackReason();
}
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
{
m_databaseResults = results;
const int index = m_chatModel->count() - 1;
m_chatModel->updateSources(index, m_databaseResults);
}
void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
{
if (m_modelInfo == modelInfo)
return;
m_modelInfo = modelInfo;
emit modelInfoChanged();
}
void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value)
{
m_trySwitchContextInProgress = value;
emit trySwitchContextInProgressChanged();
}
bool Chat::serialize(QDataStream &stream, int version) const
{
stream << m_creationDate;
stream << m_id;
stream << m_name;
stream << m_userName;
if (version > 4)
stream << m_modelInfo.id();
else
stream << m_modelInfo.filename();
if (version > 2)
stream << m_collections;
const bool serializeKV = MySettings::globalInstance()->saveChatsContext();
if (version > 5)
stream << serializeKV;
if (!m_llmodel->serialize(stream, version, serializeKV))
return false;
if (!m_chatModel->serialize(stream, version))
return false;
return stream.status() == QDataStream::Ok;
}
bool Chat::deserialize(QDataStream &stream, int version)
{
stream >> m_creationDate;
stream >> m_id;
emit idChanged(m_id);
stream >> m_name;
stream >> m_userName;
m_generatedName = QLatin1String("nonempty");
emit nameChanged();
QString modelId;
stream >> modelId;
if (version > 4) {
if (ModelList::globalInstance()->contains(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
} else {
if (ModelList::globalInstance()->containsByFilename(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId);
}
if (!m_modelInfo.id().isEmpty())
emit modelInfoChanged();
bool discardKV = m_modelInfo.id().isEmpty();
// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
// unfortunately, we cannot deserialize these
if (version < 2 && m_modelInfo.filename().contains("gpt4all-j"))
discardKV = true;
if (version > 2) {
stream >> m_collections;
emit collectionListChanged(m_collections);
}
bool deserializeKV = true;
if (version > 5)
stream >> deserializeKV;
m_llmodel->setModelInfo(m_modelInfo);
if (!m_llmodel->deserialize(stream, version, deserializeKV, discardKV))
return false;
if (!m_chatModel->deserialize(stream, version))
return false;
m_llmodel->setStateFromText(m_chatModel->text());
emit chatModelChanged();
return stream.status() == QDataStream::Ok;
}
QList<QString> Chat::collectionList() const
{
return m_collections;
}
bool Chat::hasCollection(const QString &collection) const
{
return m_collections.contains(collection);
}
void Chat::addCollection(const QString &collection)
{
if (hasCollection(collection))
return;
m_collections.append(collection);
emit collectionListChanged(m_collections);
}
void Chat::removeCollection(const QString &collection)
{
if (!hasCollection(collection))
return;
m_collections.removeAll(collection);
emit collectionListChanged(m_collections);
}