Skip to content

Commit

Permalink
feat: 支持多语言模型
Browse files Browse the repository at this point in the history
  • Loading branch information
brzhang666 committed May 4, 2023
1 parent 18c4366 commit 37d98b4
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 66 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ release dir 可以直接下载 release 版本,也可以:
- [x] 全局状态管理
- [x] 多轮次对话
- [x] prompt 支持
- [ ] 支持 ChatGlm 等其他模型
- [ ] 一键到出会话支持
- [ ] 支持高级搜索
- [ ] 等你 issue 来支持
Expand Down
33 changes: 33 additions & 0 deletions lib/components/conversation.dart
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,39 @@ class _ConversationWindowState extends State<ConversationWindow> {
const SizedBox(
height: 28,
),
DropdownButtonFormField(
value:
BlocProvider.of<UserSettingCubit>(context).state.llm,
decoration: InputDecoration(
labelText: AppLocalizations.of(context)!.llmHint,
hintText: AppLocalizations.of(context)!.llmHint,
floatingLabelBehavior: FloatingLabelBehavior.auto,
contentPadding: const EdgeInsets.symmetric(
horizontal: 16, vertical: 8),
border: OutlineInputBorder(
borderRadius: BorderRadius.circular(5),
borderSide: BorderSide.none,
),
filled: true,
),
items: <String>['OpenAI', 'ChatGlm', 'IF']
.map<DropdownMenuItem<String>>((String value) {
return DropdownMenuItem<String>(
value: value,
child: Text(
value,
),
);
}).toList(),
onChanged: (String? newValue) {
if (newValue == null) return;
BlocProvider.of<UserSettingCubit>(context)
.setLlm(newValue);
},
),
const SizedBox(
height: 28,
),
TextFormField(
controller: controllerApiKey,
decoration: InputDecoration(
Expand Down
20 changes: 15 additions & 5 deletions lib/cubit/setting_cubit.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ part 'setting_state.dart';
class UserSettingCubit extends Cubit<UserSettingState> with HydratedMixin {
UserSettingCubit()
: super(UserSettingState(lightTheme, const Locale('en'), "",
"https://api.openai-proxy.com", false, "gpt-3.5-turbo")) {
"https://api.openai-proxy.com", false, "OpenAI", "gpt-3.5-turbo")) {
hydrate();
}

Expand All @@ -22,17 +22,18 @@ class UserSettingCubit extends Cubit<UserSettingState> with HydratedMixin {
state.key,
state.baseUrl,
state.useStream,
state.llm,
state.gptModel));
}

void setKey(String key) {
emit(UserSettingState(state.themeData, state.locale, key, state.baseUrl,
state.useStream, state.gptModel));
state.useStream, state.llm, state.gptModel));
}

void setProxyUrl(String baseUrl) {
emit(UserSettingState(state.themeData, state.locale, state.key, baseUrl,
state.useStream, state.gptModel));
state.useStream, state.llm, state.gptModel));
}

void switchLocale() {
Expand All @@ -42,17 +43,23 @@ class UserSettingCubit extends Cubit<UserSettingState> with HydratedMixin {
state.key,
state.baseUrl,
state.useStream,
state.llm,
state.gptModel));
}

void setUseStream(bool useStream) {
emit(UserSettingState(state.themeData, state.locale, state.key,
state.baseUrl, useStream, state.gptModel));
state.baseUrl, useStream, state.llm, state.gptModel));
}

void setGptModel(String value) {
emit(UserSettingState(state.themeData, state.locale, state.key,
state.baseUrl, state.useStream, value));
state.baseUrl, state.useStream, state.llm, value));
}

void setLlm(String newValue) {
emit(UserSettingState(state.themeData, state.locale, state.key,
state.baseUrl, state.useStream, newValue, state.gptModel));
}

@override
Expand All @@ -62,6 +69,7 @@ class UserSettingCubit extends Cubit<UserSettingState> with HydratedMixin {
String key = json['user_key_value'] as String;
String baseUrl = json['user_proxy_url_value'] as String;
bool useStream = json['user_use_stream_value'] as bool;
String llm = json['user_llm_value'] as String;
String gptModel = json['user_gpt_model_value'] as String;

return UserSettingState(
Expand All @@ -70,6 +78,7 @@ class UserSettingCubit extends Cubit<UserSettingState> with HydratedMixin {
key,
baseUrl,
useStream,
llm.isEmpty ? "OpenAI" : llm,
gptModel.isEmpty ? "gpt-3.5-turbo" : gptModel);
}

Expand All @@ -96,6 +105,7 @@ class UserSettingCubit extends Cubit<UserSettingState> with HydratedMixin {
'user_key_value': state.key,
'user_proxy_url_value': state.baseUrl,
'user_use_stream_value': state.useStream,
'user_llm_value': state.llm,
'user_gpt_model_value': state.gptModel
};
}
Expand Down
3 changes: 2 additions & 1 deletion lib/cubit/setting_state.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class UserSettingState {
final String baseUrl;
final bool useStream;
final String gptModel;
final String llm; //大语言模型,OpenAi,ChatGlm,IF
const UserSettingState(this.themeData, this.locale, this.key, this.baseUrl,
this.useStream, this.gptModel);
this.useStream, this.llm, this.gptModel);
}
15 changes: 15 additions & 0 deletions lib/data/glm.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import 'package:flutter/foundation.dart';
import 'package:flutter_chatgpt/data/llm.dart';
import 'package:flutter_chatgpt/repository/conversation.dart';

class ChatGlM extends LLM {
@override
getResponse(
List<Message> messages,
ValueChanged<Message> onResponse,
ValueChanged<Message> errorCallback,
ValueChanged<Message> onSuccess) async {
// TODO: implement getResponse
throw UnimplementedError();
}
}
15 changes: 15 additions & 0 deletions lib/data/if.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import 'package:flutter/foundation.dart';
import 'package:flutter_chatgpt/data/llm.dart';
import 'package:flutter_chatgpt/repository/conversation.dart';

class ChatIF extends LLM {
@override
getResponse(
List<Message> messages,
ValueChanged<Message> onResponse,
ValueChanged<Message> errorCallback,
ValueChanged<Message> onSuccess) async {
// TODO: implement getResponse
throw UnimplementedError();
}
}
76 changes: 76 additions & 0 deletions lib/data/llm.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import 'package:dart_openai/openai.dart';
import 'package:flutter/foundation.dart';
import 'package:flutter_chatgpt/cubit/setting_cubit.dart';
import 'package:flutter_chatgpt/repository/conversation.dart';
import 'package:get_it/get_it.dart';

abstract class LLM {
getResponse(List<Message> messages, ValueChanged<Message> onResponse,
ValueChanged<Message> errorCallback, ValueChanged<Message> onSuccess);
}

class ChatGpt extends LLM {
@override
getResponse(
List<Message> messages,
ValueChanged<Message> onResponse,
ValueChanged<Message> errorCallback,
ValueChanged<Message> onSuccess) async {
List<OpenAIChatCompletionChoiceMessageModel> openAIMessages = [];
//将messages反转
messages = messages.reversed.toList();
// 将messages里面的每条消息的内容取出来拼接在一起
String content = "";
for (Message message in messages) {
content = content + message.text;
if (content.length < 1800) {
// 插入到 openAIMessages 第一个位置
openAIMessages.insert(
0,
OpenAIChatCompletionChoiceMessageModel(
content: message.text,
role: message.role.asOpenAIChatMessageRole,
),
);
}
}
var message = Message(
conversationId: messages.first.conversationId,
text: "",
role: Role.assistant); //仅仅第一个返回了角色
if (GetIt.instance.get<UserSettingCubit>().state.useStream) {
Stream<OpenAIStreamChatCompletionModel> chatStream = OpenAI.instance.chat
.createStream(
model: GetIt.instance.get<UserSettingCubit>().state.gptModel,
messages: openAIMessages);
chatStream.listen(
(chatStreamEvent) {
if (chatStreamEvent.choices.first.delta.content != null) {
message.text =
message.text + chatStreamEvent.choices.first.delta.content!;
onResponse(message);
}
},
onError: (error) {
message.text = error.message;
errorCallback(message);
},
onDone: () {
onSuccess(message);
},
);
} else {
try {
var response = await OpenAI.instance.chat.create(
model: GetIt.instance.get<UserSettingCubit>().state.gptModel,
messages: openAIMessages,
);
message.text = response.choices.first.message.content;
onSuccess(message);
} catch (e) {
message.text = e.toString();
errorCallback(message);
}
}
}
}
3 changes: 2 additions & 1 deletion lib/l10n/app_en.arb
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
"ok": "OK",
"cancel": "Cancel",
"useStreamApi": "Use Stream API",
"gptModel": "Select GPT Model"
"gptModel": "Select GPT Model",
"llmHint": "Select LLM Model"
}
3 changes: 2 additions & 1 deletion lib/l10n/app_zh.arb
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
"ok": "确定",
"cancel": "取消",
"useStreamApi": "使用流式(Stream) API",
"gptModel": "选择 GPT Model"
"gptModel": "选择 GPT Model",
"llmHint": "选择大语言模型"
}
75 changes: 17 additions & 58 deletions lib/repository/message.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import 'package:dart_openai/openai.dart';
import 'package:flutter/foundation.dart';
import 'package:flutter_chatgpt/cubit/setting_cubit.dart';
import 'package:flutter_chatgpt/data/glm.dart';
import 'package:flutter_chatgpt/data/if.dart';
import 'package:flutter_chatgpt/data/llm.dart';
import 'package:flutter_chatgpt/repository/conversation.dart';
import 'package:get_it/get_it.dart';

Expand Down Expand Up @@ -32,64 +35,20 @@ class MessageRepository {
ValueChanged<Message> onResponse,
ValueChanged<Message> errorCallback,
ValueChanged<Message> onSuccess) async {
List<OpenAIChatCompletionChoiceMessageModel> openAIMessages = [];
//将messages反转
messages = messages.reversed.toList();
while (true) {
// 将messages里面的每条消息的内容取出来拼接在一起
String content = "";
for (Message message in messages) {
content = content + message.text;
if (content.length < 1800) {
// 插入到 openAIMessages 第一个位置
openAIMessages.insert(
0,
OpenAIChatCompletionChoiceMessageModel(
content: message.text,
role: message.role.asOpenAIChatMessageRole,
),
);
}
}
break;
}
var message = Message(
conversationId: messages.first.conversationId,
text: "",
role: Role.assistant); //仅仅第一个返回了角色
if (GetIt.instance.get<UserSettingCubit>().state.useStream) {
Stream<OpenAIStreamChatCompletionModel> chatStream = OpenAI.instance.chat
.createStream(
model: GetIt.instance.get<UserSettingCubit>().state.gptModel,
messages: openAIMessages);
chatStream.listen(
(chatStreamEvent) {
if (chatStreamEvent.choices.first.delta.content != null) {
message.text =
message.text + chatStreamEvent.choices.first.delta.content!;
onResponse(message);
}
},
onError: (error) {
message.text = error.message;
errorCallback(message);
},
onDone: () {
onSuccess(message);
},
);
} else {
try {
var response = await OpenAI.instance.chat.create(
model: GetIt.instance.get<UserSettingCubit>().state.gptModel,
messages: openAIMessages,
);
message.text = response.choices.first.message.content;
onSuccess(message);
} catch (e) {
message.text = e.toString();
errorCallback(message);
}
String llm = GetIt.instance.get<UserSettingCubit>().state.llm;

switch (llm.toUpperCase()) {
case "OPENAI":
ChatGpt().getResponse(messages, onResponse, errorCallback, onSuccess);
break;
case "CHATGLM":
ChatGlM().getResponse(messages, onResponse, errorCallback, onSuccess);
break;
case "IF":
ChatIF().getResponse(messages, onResponse, errorCallback, onSuccess);
break;
default:
ChatGpt().getResponse(messages, onResponse, errorCallback, onSuccess);
}
}

Expand Down

0 comments on commit 37d98b4

Please sign in to comment.