Skip to content

Commit

Permalink
merge branch from remote
Browse files Browse the repository at this point in the history
  • Loading branch information
wildkid1024 committed Jun 26, 2023
1 parent 56ced61 commit f94ecdb
Show file tree
Hide file tree
Showing 28 changed files with 1,010 additions and 562 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ endif()
message(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS})
set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/model.cpp src/executor.cpp
src/devices/cpu/cpudevice.cpp
src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/baichuan.cpp)
src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/basellm.cpp)

include_directories(include)
include_directories(include/utils)
Expand Down
9 changes: 5 additions & 4 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import platform
import logging
import argparse
sys.path.append('./build-py')
import pyfastllm # 或fastllm

logging.info(f"python gcc version:{platform.python_compiler()}")

sys.path.append('./build-py')

def args_parser():
parser = argparse.ArgumentParser(description='pyfastllm')
Expand All @@ -19,11 +19,12 @@ def args_parser():
return args

LLM_TYPE = ""
def print_back(idx:int, content: str):
def print_back(idx:int, content: bytearray):
content = content.decode(encoding="utf-8", errors="replace")
if idx == 0:
print(f"{LLM_TYPE}:{content}", end='')
elif idx > 0:
print(f"{content}", end='')
print(f"{content}", end='\n')
elif idx == -1:
print()

Expand Down Expand Up @@ -52,4 +53,4 @@ def main(args):

if __name__ == "__main__":
args = args_parser()
main(args)
main(args)
93 changes: 24 additions & 69 deletions example/benchmark/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,11 @@
// Created by huangyuyang on 6/9/23.
//

#include "factoryllm.h"
#include "model.h"
#include "utils.h"
#include "fstream"

static factoryllm fllm;
static int modeltype = 0;
static char* modelpath = NULL;
static fastllm::basellm* chatGlm = fllm.createllm(LLM_TYPE_CHATGLM);
static fastllm::basellm* moss = fllm.createllm(LLM_TYPE_MOSS);
static fastllm::basellm* vicuna = fllm.createllm(LLM_TYPE_VICUNA);
static int sRound = 0;
static std::string history;

std::map <std::string, int> modelDict = {
{"chatglm", 0}, {"moss", 1}, {"vicuna", 2}
};

struct BenchmarkConfig {
int model = LLM_TYPE_CHATGLM; // 模型类型, 0 chatglm,1 moss,2 vicuna
std::string path = "chatglm-6b-int4.bin"; // 模型文件路径
int threads = 4; // 使用的线程数
int limit = -1; // 输出token数限制,如果 < 0 则代表无限制
Expand All @@ -32,7 +18,6 @@ struct BenchmarkConfig {
void Usage() {
std::cout << "Usage:" << std::endl;
std::cout << "[-h|--help]: 显示帮助" << std::endl;
std::cout << "<-m|--model> <args>: 模型类型,默认为0, 可以设置为0(chatglm),1(moss),2(vicuna)" << std::endl;
std::cout << "<-p|--path> <args>: 模型文件的路径" << std::endl;
std::cout << "<-t|--threads> <args>: 使用的线程数量" << std::endl;
std::cout << "<-l|--limit> <args>: 输出token数限制" << std::endl;
Expand All @@ -50,17 +35,9 @@ void ParseArgs(int argc, char **argv, BenchmarkConfig &config) {
Usage();
exit(0);
}
else if (sargv[i] == "-m" || sargv[i] == "--model") {
if (modelDict.find(sargv[i + 1]) != modelDict.end()) {
config.model = modelDict[sargv[++i]];
} else {
config.model = atoi(sargv[++i].c_str());
}
}
else if (sargv[i] == "-p" || sargv[i] == "--path") {
config.path = sargv[++i];
}
else if (sargv[i] == "-t" || sargv[i] == "--threads") {
} else if (sargv[i] == "-t" || sargv[i] == "--threads") {
config.threads = atoi(sargv[++i].c_str());
} else if (sargv[i] == "-l" || sargv[i] == "--limit") {
config.limit = atoi(sargv[++i].c_str());
Expand All @@ -77,47 +54,12 @@ void ParseArgs(int argc, char **argv, BenchmarkConfig &config) {
}
}

int initLLMConf(int model, const char* modelPath, int threads) {
fastllm::SetThreads(threads);
modeltype = model;
//printf("@@init llm:type:%d,path:%s\n", model, modelPath);
if (modeltype == 0) {
chatGlm->LoadFromFile(modelPath);
chatGlm->WarmUp();
}
if (modeltype == 1) {
moss->LoadFromFile(modelPath);
}
if (modeltype == 2) {
vicuna->LoadFromFile(modelPath);
}
return 0;
}


void uninitLLM()
{
if (chatGlm)
{
delete chatGlm;
chatGlm = NULL;
}
if (moss)
{
delete moss;
moss = NULL;
}
if (vicuna) {
delete vicuna;
vicuna = NULL;
}
}

int main(int argc, char **argv) {
BenchmarkConfig config;
ParseArgs(argc, argv, config);
initLLMConf(config.model, config.path.c_str(), config.threads);
chatGlm->output_token_limit = config.limit;
fastllm::SetThreads(config.threads);
auto model = fastllm::CreateLLMModelFromFile(config.path);
model->output_token_limit = config.limit;

std::vector <std::string> inputs;
if (config.file != "") {
Expand All @@ -143,17 +85,30 @@ int main(int argc, char **argv) {
if (inputs.size() > config.batch && config.batch != -1) {
inputs.resize(config.batch);
}
for (int i = 0; i < inputs.size(); i++) {
inputs[i] = model->MakeInput("", 0, inputs[i]);
}

std::vector <std::string> outputs;
static int tokens = 0;
auto st = std::chrono::system_clock::now();
chatGlm->ResponseBatch(inputs, outputs, [](int index, std::vector <std::string> &contents) {
if (index != -1) {
for (int i = 0; i < contents.size(); i++) {
tokens += (contents[i].size() > 0);

if (inputs.size() > 0) {
model->ResponseBatch(inputs, outputs, [](int index, std::vector<std::string> &contents) {
if (index != -1) {
for (int i = 0; i < contents.size(); i++) {
tokens += (contents[i].size() > 0);
}
}
}
});
});
} else {
outputs.push_back(model->Response(inputs[0], [](int index, const char *contents) {
if (index != -1) {
tokens++;
}
}));
}

float spend = fastllm::GetSpan(st, std::chrono::system_clock::now());

if (config.output != "") {
Expand Down
11 changes: 10 additions & 1 deletion example/webui/web/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@
const API_URL = "chat";
let data = [];
let input_str;
let uuid = "";
let loading = false;
const scrollToBottom = () => {
if (messagsEle.scrollHeight - messagsEle.scrollTop - messagsEle.clientHeight < 128) {
Expand Down Expand Up @@ -1004,7 +1005,15 @@
let progressData = "";
let resStr = "";
const reqWord = async (refresh) => {
let headers = {"Content-Type": "application/json"};
if (uuid == "") {
uuid = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function (c) {
var r = Math.random() * 16 | 0,
v = c == 'x' ? r : (r & 0x3 | 0x8);
return v.toString(16);
});
}

let headers = {"Content-Type": "application/json", "uuid" : uuid};
let idx = refresh ? refreshIdx : data.length;
let dataSlice = [data[idx - 1]];
const res = await fetch(API_URL, {
Expand Down
104 changes: 59 additions & 45 deletions example/webui/webui.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,70 +10,84 @@
#include <thread>
#include <stdlib.h>
#include <string>
#include <mutex>

std::string GBKToUTF8(const std::string& strGBK);
struct ChatSession {
std::string history = "";
std::string input = "";
std::string output = "";
int round = 0;
int status = 0; // 0: 空闲 1: 结果生成好了 2: 已经写回了
};

std::map <std::string, ChatSession*> sessions;
std::mutex locker;

int main(int argc, char** argv) {
fastllm::SetThreads(8);
fastllm::ChatGLMModel chatGlm;

std::string type;
std::cout << "Use chatglm-6b-v1.1-int4 or chatglm-6b-v1.1-int8 ? 8/4 (Default = 4) ";
std::getline(std::cin, type);
if (type == "8") {
chatGlm.LoadFromFile("chatglm-6b-v1.1-int8.bin");
} else if (type == "4" || type == "") {
chatGlm.LoadFromFile("chatglm-6b-v1.1-int4.bin");
}

std::string history = "";
int round = 0;
static std::string ss = "";
fastllm::ChatGLMModel model;
model.LoadFromFile(argv[1]);

httplib::Server svr;
std::atomic_bool waiting;
waiting = false;
std::string last_request = "";

auto chat = [&](std::string input) {
auto chat = [&](ChatSession *session, const std::string input) {
if (input == "reset" || input == "stop") {
history = "";
round = 0;
ss = "<eop>\n";
session->history = "";
session->round = 0;
session->output = "<eop>\n";
session->status = 2;
} else {
history += ("[Round " + std::to_string(round++) + "]\n问:" + input);
auto prompt = round > 1 ? history : input;

waiting = true;
std::string ret = chatGlm.Response(prompt, [](int index, const char* content) {
if (index == -1) {
ss += "<eop>\n";
} else {
ss += std::string(content);
}
});
waiting = false;

history += ("答:" + ret + "\n");
session->history += ("[Round " + std::to_string(session->round++) + "]\n问:" + input);
auto prompt = session->round > 1 ? session->history : input;
auto inputs = model.weight.tokenizer.Encode(prompt);
std::vector<int> tokens;
for (int i = 0; i < inputs.Count(0); i++) {
tokens.push_back(((float *) inputs.cpuData)[i]);
}
int handleId = model.LaunchResponseTokens(tokens);
std::vector<float> results;
while (true) {
auto result = model.FetchResponseTokens(handleId);
if (result.first == false) {
break;
} else {
results.clear();
results.push_back(result.second[0]);
session->output += model.weight.tokenizer.Decode(fastllm::Data (fastllm::DataType::FLOAT32, {(int)results.size()}, results));
}
if (session->status == 2) {
break;
}
}
session->history += ("答:" + session->output + "\n");
session->output += "<eop>\n";
session->status = 2;
}
};

svr.Post("/chat", [&](const httplib::Request &req, httplib::Response &res) {
if (req.body == last_request) {
res.set_content(ss, "text/plain");
return;
const std::string uuid = req.get_header_value("uuid");
locker.lock();
if (sessions.find(uuid) == sessions.end()) {
sessions[uuid] = new ChatSession();
}
if (waiting) {
res.set_content(ss, "text/plain");
auto *session = sessions[uuid];
locker.unlock();

if (session->status != 0) {
res.set_content(session->output, "text/plain");
if (session->status == 2) {
session->status = 0;
}
} else {
ss = "";
last_request = req.body;
std::thread chat_thread(chat, last_request);
session->output = "";
session->status = 1;
std::thread chat_thread(chat, session, req.body);
chat_thread.detach();
}
});

svr.set_mount_point("/", "web");
svr.set_mount_point("/", "../example/webui/web");
std::cout << ">>> please open http://127.0.0.1:8081\n";
svr.listen("0.0.0.0", 8081);

Expand Down
4 changes: 4 additions & 0 deletions include/devices/cpu/cpudevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CpuLlamaRotatePosition2DOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CpuRepeatPenaltyOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
Expand Down
17 changes: 17 additions & 0 deletions include/devices/cuda/cudadevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaRMSNormOp : BaseOperator {
bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaLinearOp : BaseOperator {
void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
Expand Down Expand Up @@ -58,6 +63,10 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaSiluOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaMulOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
Expand All @@ -66,6 +75,10 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaMulToOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaAttentionMaskOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
Expand All @@ -83,6 +96,10 @@ namespace fastllm {
class CudaRotatePosition2DOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaLlamaRotatePosition2DOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
}

#endif //FASTLLM_CUDADEVICE_H
Loading

0 comments on commit f94ecdb

Please sign in to comment.