Skip to content

Commit

Permalink
调整以适应fastllm新的提交
Browse files Browse the repository at this point in the history
  • Loading branch information
wildkid1024 committed Jun 18, 2023
1 parent 5f9d3f7 commit 854776a
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 26 deletions.
31 changes: 19 additions & 12 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,6 @@ else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread --std=c++17 -O2 -march=native")
endif()

if (PY_API)
# add_compile_definitions(PY_API)
add_definitions(-DPY_API)
set(PYBIND third_party/pybind11)
add_subdirectory(${PYBIND})
pybind11_add_module(pyfastllm src/pybinding.cpp src/chatglm.cpp src/moss.cpp src/vicuna.cpp src/fastllm.cpp src/fastllm.cpp include/basellm.h include/chatglm.h include/factoryllm.h include/fastllm.h include/utils.h)
endif()

find_package(Python3 REQUIRED)
include_directories(include third_party/pybind11/include)

message(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS})
set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/devices/cpu/cpudevice.cpp src/executor.cpp src/chatglm.cpp src/moss.cpp src/vicuna.cpp)
Expand All @@ -39,9 +29,24 @@ if (USE_CUDA)
add_compile_definitions(USE_CUDA)
set(FASTLLM_CUDA_SOURCES src/fastllm-cuda.cu src/devices/cuda/cudadevice.cpp)
set(FASTLLM_LINKED_LIBS ${FASTLLM_LINKED_LIBS} cublas)
#set(CMAKE_CUDA_ARCHITECTURES "70")
set(CMAKE_CUDA_ARCHITECTURES "70")
endif()

if (PY_API)
set(PYBIND third_party/pybind11)
add_subdirectory(${PYBIND})
add_compile_definitions(PY_API)

set(Python3_ROOT_DIR "/usr/local/python3.10.6/bin/")
find_package(Python3 REQUIRED)

include_directories(include third_party/pybind11/include)
file(GLOB FASTLLM_CXX_HEADERS include/*.h)
add_library(pyfastllm MODULE src/pybinding.cpp ${FASTLLM_CXX_SOURCES} ${FASTLLM_CXX_HEADERS} ${FASTLLM_CUDA_SOURCES})
target_link_libraries(pyfastllm PUBLIC pybind11::module ${FASTLLM_LINKED_LIBS})
pybind11_extension(pyfastllm)
else()

add_library(fastllm OBJECT
${FASTLLM_CXX_SOURCES}
${FASTLLM_CUDA_SOURCES}
Expand All @@ -58,4 +63,6 @@ add_executable(webui example/webui/webui.cpp)
target_link_libraries(webui fastllm)

add_executable(benchmark example/benchmark/benchmark.cpp)
target_link_libraries(benchmark fastllm)
target_link_libraries(benchmark fastllm)

endif()
14 changes: 9 additions & 5 deletions cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import sys
import platform
import logging
Expand All @@ -6,7 +7,7 @@

logging.info(f"python gcc version:{platform.python_compiler()}")

sys.path.append('./build')
sys.path.append('./build-py')

def args_parser():
parser = argparse.ArgumentParser(description='pyfastllm')
Expand All @@ -28,14 +29,17 @@ def print_back(idx:int, content: str):
sys.stdout.flush()

def main(args):
prompt = "写一篇500字关于我的妈妈的作文"
# prompt = "hello"
model_path = "/home/wildkid1024/Code/Cpp/fastllm/build/chatglm-6b-int8.bin"
model_path = args.path
model = pyfastllm.ChatGLMModel()
model.load_weights(model_path)
model.warmup()
model.response(prompt, print_back)

prompt = ""
while prompt != "exit":
prompt = input("User: ")
model.response(prompt, print_back)
print()
sys.stdout.flush()

if __name__ == "__main__":
args = args_parser()
Expand Down
10 changes: 4 additions & 6 deletions include/basellm.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
#pragma once
#include "fastllm.h"

<<<<<<< HEAD
// typedef void(*RuntimeResult) (int index, const char* content);//实时生成的内容回调 index: 0开始回复,-1本次回复结束

// typedef void(*RuntimeResult) (int index, const char* content); //实时生成的内容回调 index: 0开始回复,-1本次回复结束
// typedef void(*RuntimeResultBatch) (int index, std::vector <std::string> &contents); //实时生成的内容回调 index: 0开始回复,-1本次回复结束

using RuntimeResult = std::function<void(int index, const char* content)>;
=======
typedef void(*RuntimeResult) (int index, const char* content); //实时生成的内容回调 index: 0开始回复,-1本次回复结束
typedef void(*RuntimeResultBatch) (int index, std::vector <std::string> &contents); //实时生成的内容回调 index: 0开始回复,-1本次回复结束
>>>>>>> dev/master
using RuntimeResultBatch = std::function<void(int index, std::vector <std::string> &contents)>;

namespace fastllm {
class basellm {
Expand Down
18 changes: 15 additions & 3 deletions src/pybinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@ using namespace pybind11::literals;
#endif


const std::string VERSION_INFO = "0.0.1";

#ifdef PY_API

PYBIND11_MODULE(pyfastllm, m) {
m.doc() = "fastllm python bindings";

m.def("set_threads", &fastllm::SetThreads)
.def("get_threads", &fastllm::GetThreads)
.def("set_low_memory", &fastllm::SetLowMemMode)
.def("get_low_memory", &fastllm::GetLowMemMode);


py::class_<fastllm::ChatGLMModel>(m, "ChatGLMModel")
.def(py::init<>())
Expand All @@ -27,17 +35,21 @@ PYBIND11_MODULE(pyfastllm, m) {
.def(py::init<>())
.def("load_weights", &fastllm::MOSSModel::LoadFromFile)
.def("response", &fastllm::MOSSModel::Response)
.def("causal_mask", &fastllm::MOSSModel::WarmUp)
.def("warmup", &fastllm::MOSSModel::WarmUp)
.def("save_lowbit_model", &fastllm::MOSSModel::SaveLowBitModel);

py::class_<fastllm::VicunaModel>(m, "VicunaModel")
.def(py::init<>())
.def("load_weights", &fastllm::VicunaModel::LoadFromFile)
.def("response", &fastllm::VicunaModel::Response)
.def("causal_mask", &fastllm::VicunaModel::WarmUp)
.def("warmup", &fastllm::VicunaModel::WarmUp)
.def("save_lowbit_model", &fastllm::VicunaModel::SaveLowBitModel);

#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
#endif

}

#endif

0 comments on commit 854776a

Please sign in to comment.