Skip to content

Commit

Permalink
python自定义模型提供文件导入的功能
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 22, 2024
1 parent 506fccf commit cb9ff7a
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 42 deletions.
37 changes: 1 addition & 36 deletions example/python/custom_model.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,10 @@
from ftllm import llm
from ftllm.llm import ComputeGraph
from qwen2 import Qwen2Model
import os
import math

root_path = "/mnt/hfmodels/"
model_path = os.path.join(root_path, "Qwen/Qwen2-7B-Instruct")

class Qwen2Model(ComputeGraph):
def build(self):
weight, data, config = self.weight, self.data, self.config
head_dim = config["hidden_size"] // config["num_attention_heads"]
self.Embedding(data["inputIds"], weight["model.embed_tokens.weight"], data["hiddenStates"]);
self.DataTypeAs(data["hiddenStates"], data["atype"])
for i in range(config["num_hidden_layers"]):
pastKey = data["pastKey."][i]
pastValue = data["pastValue."][i]
layer = weight["model.layers."][i]
self.RMSNorm(data["hiddenStates"], layer[".input_layernorm.weight"], config["rms_norm_eps"], data["attenInput"])
self.Linear(data["attenInput"], layer[".self_attn.q_proj.weight"], layer[".self_attn.q_proj.bias"], data["q"])
self.Linear(data["attenInput"], layer[".self_attn.k_proj.weight"], layer[".self_attn.k_proj.bias"], data["k"])
self.Linear(data["attenInput"], layer[".self_attn.v_proj.weight"], layer[".self_attn.v_proj.bias"], data["v"])
self.ExpandHead(data["q"], head_dim)
self.ExpandHead(data["k"], head_dim)
self.ExpandHead(data["v"], head_dim)
self.LlamaRotatePosition2D(data["q"], data["positionIds"], data["sin"], data["cos"], head_dim // 2)
self.LlamaRotatePosition2D(data["k"], data["positionIds"], data["sin"], data["cos"], head_dim // 2)
self.FusedAttention(data["q"], pastKey, pastValue, data["k"], data["v"], data["attenInput"],
data["attentionMask"], data["attenOutput"], data["seqLens"], 1.0 / math.sqrt(head_dim))
self.Linear(data["attenOutput"], layer[".self_attn.o_proj.weight"], layer[".self_attn.o_proj.bias"], data["attenLastOutput"]);
self.AddTo(data["hiddenStates"], data["attenLastOutput"]);
self.RMSNorm(data["hiddenStates"], layer[".post_attention_layernorm.weight"], config["rms_norm_eps"], data["attenInput"])
self.Linear(data["attenInput"], layer[".mlp.gate_proj.weight"], layer[".mlp.gate_proj.bias"], data["w1"])
self.Linear(data["attenInput"], layer[".mlp.up_proj.weight"], layer[".mlp.up_proj.bias"], data["w3"])
self.Silu(data["w1"], data["w1"])
self.MulTo(data["w1"], data["w3"])
self.Linear(data["w1"], layer[".mlp.down_proj.weight"], layer[".mlp.down_proj.bias"], data["w2"])
self.AddTo(data["hiddenStates"], data["w2"])
self.SplitLastTokenStates(data["hiddenStates"], data["seqLens"], data["lastTokensStates"])
self.RMSNorm(data["lastTokensStates"], weight["model.norm.weight"], config["rms_norm_eps"], data["lastTokensStates"])
self.Linear(data["lastTokensStates"], weight["lm_head.weight"], weight["lm_head.bias"], data["logits"])

model = llm.model(model_path, graph = Qwen2Model)
prompt = "北京有什么景点?"
messages = [
Expand Down
40 changes: 40 additions & 0 deletions example/python/qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from ftllm.llm import ComputeGraph
import math

class Qwen2Model(ComputeGraph):
def build(self):
weight, data, config = self.weight, self.data, self.config
config["max_positions"] = 128000

head_dim = config["hidden_size"] // config["num_attention_heads"]
self.Embedding(data["inputIds"], weight["model.embed_tokens.weight"], data["hiddenStates"]);
self.DataTypeAs(data["hiddenStates"], data["atype"])
for i in range(config["num_hidden_layers"]):
pastKey = data["pastKey."][i]
pastValue = data["pastValue."][i]
layer = weight["model.layers."][i]
self.RMSNorm(data["hiddenStates"], layer[".input_layernorm.weight"], config["rms_norm_eps"], data["attenInput"])
self.Linear(data["attenInput"], layer[".self_attn.q_proj.weight"], layer[".self_attn.q_proj.bias"], data["q"])
self.Linear(data["attenInput"], layer[".self_attn.k_proj.weight"], layer[".self_attn.k_proj.bias"], data["k"])
self.Linear(data["attenInput"], layer[".self_attn.v_proj.weight"], layer[".self_attn.v_proj.bias"], data["v"])
self.ExpandHead(data["q"], head_dim)
self.ExpandHead(data["k"], head_dim)
self.ExpandHead(data["v"], head_dim)
self.LlamaRotatePosition2D(data["q"], data["positionIds"], data["sin"], data["cos"], head_dim // 2)
self.LlamaRotatePosition2D(data["k"], data["positionIds"], data["sin"], data["cos"], head_dim // 2)
self.FusedAttention(data["q"], pastKey, pastValue, data["k"], data["v"], data["attenInput"],
data["attentionMask"], data["attenOutput"], data["seqLens"], 1.0 / math.sqrt(head_dim))
self.Linear(data["attenOutput"], layer[".self_attn.o_proj.weight"], layer[".self_attn.o_proj.bias"], data["attenLastOutput"]);
self.AddTo(data["hiddenStates"], data["attenLastOutput"]);
self.RMSNorm(data["hiddenStates"], layer[".post_attention_layernorm.weight"], config["rms_norm_eps"], data["attenInput"])
self.Linear(data["attenInput"], layer[".mlp.gate_proj.weight"], layer[".mlp.gate_proj.bias"], data["w1"])
self.Linear(data["attenInput"], layer[".mlp.up_proj.weight"], layer[".mlp.up_proj.bias"], data["w3"])
self.Silu(data["w1"], data["w1"])
self.MulTo(data["w1"], data["w3"])
self.Linear(data["w1"], layer[".mlp.down_proj.weight"], layer[".mlp.down_proj.bias"], data["w2"])
self.AddTo(data["hiddenStates"], data["w2"])
self.SplitLastTokenStates(data["hiddenStates"], data["seqLens"], data["lastTokensStates"])
self.RMSNorm(data["lastTokensStates"], weight["model.norm.weight"], config["rms_norm_eps"], data["lastTokensStates"])
self.Linear(data["lastTokensStates"], weight["lm_head.weight"], weight["lm_head.bias"], data["logits"])

__model__ = Qwen2Model
34 changes: 30 additions & 4 deletions src/models/graph/fastllmjson.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,47 @@
namespace fastllm {
class FastllmJsonModelConfig : GraphLLMModelConfig {
public:
json11::Json config;
json11::Json json, graphJson, configJson, tokenizerConfigJson, generationConfigJson;

void Init(const std::string &configString) {
std::string error;
config = json11::Json::parse(configString, error);
json = json11::Json::parse(configString, error);
graphJson = json["graph"];
configJson = json["config"];
tokenizerConfigJson = json["tokenizer_config"];
generationConfigJson = json["generation_config"];
}

void InitParams(GraphLLMModel *model) {
if (configJson["max_positions"].is_number()) {
model->max_positions = configJson["max_positions"].int_value();
}
if (configJson["rope_base"].is_number()) {
model->rope_base = configJson["rope_base"].number_value();
}
if (configJson["rope_factor"].is_number()) {
model->rope_factor = configJson["rope_factor"].number_value();
}

if (configJson["pre_prompt"].is_string()) {
model->pre_prompt = configJson["pre_prompt"].string_value();
}
if (configJson["user_role"].is_string()) {
model->user_role = configJson["user_role"].string_value();
}
if (configJson["bot_role"].is_string()) {
model->bot_role = configJson["bot_role"].string_value();
}
if (configJson["history_sep"].is_string()) {
model->history_sep = configJson["history_sep"].string_value();
}
}

std::map <std::string, std::vector <std::pair <std::string, DataType> > >
GetTensorMap(GraphLLMModel *model, const std::vector <std::string> &tensorNames) {
std::string embeddingName = "";
std::map <std::string, std::vector <std::pair <std::string, DataType> > > ret;
for (auto &op : config.array_items()) {
for (auto &op : graphJson.array_items()) {
std::string type = op["type"].string_value();
std::map <std::string, std::string> weights;
for (auto &it : op["nodes"].object_items()) {
Expand Down Expand Up @@ -54,7 +80,7 @@ namespace fastllm {
wNodes[it.first] = ComputeGraphNode(it.first);
}

for (auto &op : config.array_items()) {
for (auto &op : graphJson.array_items()) {
std::string type = op["type"].string_value();
std::map <std::string, std::string> datas;
std::map <std::string, float> floatParams;
Expand Down
9 changes: 8 additions & 1 deletion tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,14 @@ def __init__ (self):
self.graph = []

def __str__(self):
return json.dumps(self.graph, indent = 4, default = lambda x: x.to_json())
output = {"graph": self.graph}
if (hasattr(self, "config")):
output["config"] = self.config
if (hasattr(self, "tokenizer_config")):
output["tokenizer_config"] = self.tokenizer_config
if (hasattr(self, "generation_config")):
output["generation_config"] = self.generation_config
return json.dumps(output, indent = 4, default = lambda x: x.to_json())

def Print(self, input):
self.graph.append({"type": "Print",
Expand Down
13 changes: 12 additions & 1 deletion tools/fastllm_pytools/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def make_normal_parser(des: str) -> argparse.ArgumentParser:
parser.add_argument('--kv_cache_limit', type = str, default = "auto", help = 'kv缓存最大使用量')
parser.add_argument('--max_batch', type = int, default = -1, help = '每次最多同时推理的询问数量')
parser.add_argument('--device', type = str, help = '使用的设备')
parser.add_argument('--custom', type = str, default = "", help = '指定描述自定义模型的python文件')
return parser

def make_normal_llm_model(args):
Expand All @@ -29,7 +30,17 @@ def make_normal_llm_model(args):
llm.set_cpu_low_mem(args.low)
if (args.cuda_embedding):
llm.set_cuda_embedding(True)
model = llm.model(args.path, dtype = args.dtype, tokenizer_type = "auto")
graph = None
if (args.custom != ""):
import importlib.util
spec = importlib.util.spec_from_file_location("custom_module", args.custom)
if spec is None:
raise ImportError(f"Cannot load module at {args.custom}")
custom_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(custom_module)
if (hasattr(custom_module, "__model__")):
graph = getattr(custom_module, "__model__")
model = llm.model(args.path, dtype = args.dtype, graph = graph, tokenizer_type = "auto")
model.set_atype(args.atype)
if (args.max_batch > 0):
model.set_max_batch(args.max_batch)
Expand Down

0 comments on commit cb9ff7a

Please sign in to comment.