Skip to content

Commit

Permalink
搭建int4模型基本架构与构造函数
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaosiyuan1098 committed Nov 10, 2024
1 parent 6b0cdad commit 906875c
Show file tree
Hide file tree
Showing 18 changed files with 269 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ build/
# MacOS Cache
.DS_Store

vscode/
.vscode/
models/

File renamed without changes.
51 changes: 51 additions & 0 deletions include/modules/llamaAttention_int4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include <utility>
#include <cstdlib>
#include <string>
#include <vector>

#include "common.h"
#include "operators.h"
#include "utlis.h"

struct Int4llamaAttention_output {
Matrix3D<float> attn_output;
Matrix3D<float> attn_probs_reshaped;
std::pair<Matrix3D<float>, Matrix3D<float>> past_key_value;
};
struct Int4llamaAttention_input {
Matrix3D<float> hidden_states;
Matrix3D<float> attention_mask;
Matrix3D<float> past_key, past_value;
bool has_past_key_value = false;
int layer_idx;

Int4llamaAttention_input(Matrix3D<float> hidden_states_, Matrix3D<float> attention_mask_, int layer_idx_)
: hidden_states(hidden_states_), attention_mask(attention_mask_), layer_idx(layer_idx_) {}

Int4llamaAttention_input(Matrix3D<float> hidden_states_, Matrix3D<float> attention_mask_, Matrix3D<float> past_key_,
Matrix3D<float> past_value_, bool has_past_key_value_, int layer_idx_)
: hidden_states(hidden_states_),
attention_mask(attention_mask_),
past_key(past_key_),
past_value(past_value_),
has_past_key_value(has_past_key_value_),
layer_idx(layer_idx_) {}
};


class Int4llamaAttention {
public:
Int4llamaAttention(std::string param_path, const struct model_config config);
Int4llamaAttention() {}
// static void initialized_memory(const struct model_config config);
// struct Int4llamaAttention_output forward(const struct Int4llamaAttention_input &input);

private:
void unshape(Matrix3D<float> shaped, Matrix3D<float> unshape, int sqlen);
void shape(Matrix3D<float> unshape, Matrix3D<float> shaped, int sqlen);
int embed_dim, num_heads, head_dim;
// Linear_FP_int4 k_proj, v_proj, q_proj, o_proj;
// RotaryPosEmb rotary_pos_emb;
// BMM_F32T qk_bmm, pv_bmm;
std::string profile_name = "Int4llamaAttention";
};
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// #include <cstdlib>
// #include <string>
// #include <vector>
#include <cstdlib>
#include <string>
#include <vector>

// #include "llamaDecoderLayer.h"
// #include "common.h"
// #include "operators.h"
#include "llamaDecoderLayer_fp32.h"
#include "common.h"
#include "operators.h"

// struct Fp32llamaDecoder_output {
// Matrix3D<float> last_hidden_state;
Expand Down
37 changes: 37 additions & 0 deletions include/modules/llamaDecoder_int4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "llamaDecoderlayer_int4.h"

struct Int4llamaDecoder_output {
Matrix3D<float> last_hidden_state;
std::vector<Matrix3D<float>> past_keys, past_values;
};
struct Int4llamaDecoder_input {
Matrix3D<int> input_ids;
std::vector<Matrix3D<float>> past_keys, past_values;
bool has_past_keys_values;

Int4llamaDecoder_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) { has_past_keys_values = false; }
Int4llamaDecoder_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
std::vector<Matrix3D<float>> past_values_)
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
has_past_keys_values = true;
}
};

class Int4llamaDecoder {
public:
Int4llamaDecoder(std::string param_path, const struct model_config config);
Int4llamaDecoder(){};
// Matrix3D<float> prepare_decoder_attention_mask(int length, int past_length);
// struct Int4llamaDecoder_output forward(const struct Int4llamaDecoder_input& input);
// Embedding embed_tokens;
// LlamaRMSNorm norm;
int voc_size, embed_dim, padding_idx, hidden_dim, num_heads;
std::vector<Int4llamaDecoderLayer> layers;
std::string profile_name = "Int4llamaDecoder";

private:
float* attention_mask_buf;
float* pos_embeds_buf;
float* last_hidden_states_buf;
float* hidden_states_buf;
};
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// #include "llamaAttention.h"
// #include "common.h"
// #include "operators.h"
#include "llamaAttention_fp32.h"


// struct Fp32llamaDecoderLayer_output {
// Matrix3D<float> hidden_states;
Expand Down
51 changes: 51 additions & 0 deletions include/modules/llamaDecoderlayer_int4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "llamaAttention_int4.h"

struct Int4llamaDecoderLayer_output
{
Matrix3D<float> hidden_states;
Matrix3D<float> attentions;
std::pair<Matrix3D<float>, Matrix3D<float>> past_key_value;

Int4llamaDecoderLayer_output(Matrix3D<float> hidden_states_, Matrix3D<float> attentions_,
std::pair<Matrix3D<float>, Matrix3D<float>> past_key_value_)
{
hidden_states = hidden_states_;
attentions = attentions_;
past_key_value = past_key_value_;
};
};
struct Int4llamaDecoderLayer_input
{
Matrix3D<float> hidden_states;
Matrix3D<float> attention_mask;
Matrix3D<float> past_key, past_value;
bool has_past_key_value = false;

Int4llamaDecoderLayer_input(Matrix3D<float> &hidden_states_, Matrix3D<float> &attention_mask_)
{
hidden_states = hidden_states_;
attention_mask = attention_mask_;
has_past_key_value = false;
}
Int4llamaDecoderLayer_input(Matrix3D<float> &hidden_states_, Matrix3D<float> &attention_mask_,
Matrix3D<float> past_key_, Matrix3D<float> past_value_)
{
hidden_states = hidden_states_;
attention_mask = attention_mask_;
past_key = past_key_;
past_value = past_value_;
has_past_key_value = true;
}
};

class Int4llamaDecoderLayer {
public:
Int4llamaDecoderLayer(std::string param_path, const struct model_config config, int layer_idx);
struct Int4llamaDecoderLayer_output forward(const struct Int4llamaDecoderLayer_input &input);

int embed_dim, num_attention_heads, hidden_dim, layer_idx;
// LlamaRMSNorm input_layernorm, post_attention_layernorm; // from torch_int.nn
// Linear_FP_int4 gate_proj, down_proj, up_proj;
Int4llamaAttention attn;
std::string profile_name = "Int4llamaDecoderLayer";
};
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// #include "llamaDecoder.h"
#include "llamaDecoder_fp32.h"

// struct Fp32LlamaForCausalLM_output {
// Matrix3D<float> logits;
Expand Down Expand Up @@ -31,3 +31,35 @@
// float* logits_output;
// float* lm_head_weight;
// };


struct Int4LlamaForCausalLM_output {
Matrix3D<float> logits;
std::vector<Matrix3D<float>> past_keys, past_values;
};
struct Int4LlamaForCausalLM_input {
Matrix3D<int> input_ids;
std::vector<Matrix3D<float>> past_keys, past_values;
bool has_past_keys_values;

Int4LlamaForCausalLM_input() {}
Int4LlamaForCausalLM_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) { has_past_keys_values = false; }
Int4LlamaForCausalLM_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
std::vector<Matrix3D<float>> past_values_)
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
has_past_keys_values = true;
}
};

class Int4LlamaForCausalLM {
public:
Int4LlamaForCausalLM(std::string param_path, const struct model_config config);
struct Int4LlamaForCausalLM_output forward(const struct Int4LlamaForCausalLM_input& input);

private:
// Int4llamaDecoder decoder;
// Linear_FP_int4 lm_head;
std::string profile_name = "Int4LlamaForCausalLM";
float* logits_output;
uint8_t* lm_head_weight;
};
32 changes: 32 additions & 0 deletions include/modules/llamaForCausalLM_int4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "llamaDecoder_int4.h"

struct Int4LlamaForCausalLM_output {
Matrix3D<float> logits;
std::vector<Matrix3D<float>> past_keys, past_values;
};
struct Int4LlamaForCausalLM_input {
Matrix3D<int> input_ids;
std::vector<Matrix3D<float>> past_keys, past_values;
bool has_past_keys_values;

Int4LlamaForCausalLM_input() {}
Int4LlamaForCausalLM_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) { has_past_keys_values = false; }
Int4LlamaForCausalLM_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
std::vector<Matrix3D<float>> past_values_)
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
has_past_keys_values = true;
}
};

class Int4LlamaForCausalLM {
public:
Int4LlamaForCausalLM(std::string param_path, const struct model_config config);
struct Int4LlamaForCausalLM_output forward(const struct Int4LlamaForCausalLM_input& input);

private:
Int4llamaDecoder decoder;
// Linear_FP_int4 lm_head;
std::string profile_name = "Int4LlamaForCausalLM";
float* logits_output;
uint8_t* lm_head_weight;
};
6 changes: 3 additions & 3 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "model.h"
#include "utlis.h"
#include "opt_params.h"
#include "llamaAttention.h"
#include "llamaForCausalLM_int4.h"

std::map<std::string, int> model_config = {{"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B}, {"LLaMA_7B_AWQ", LLaMA_7B}, {"LLaMA_7B_2_chat", LLaMA_7B}};

Expand Down Expand Up @@ -105,7 +105,7 @@ int main(int argc, char **argv)
case FP32:
{
std::cout << m_path << std::endl;
Fp32llamaAttention a = Fp32llamaAttention(m_path, get_opt_model_config(model_id));
// Fp32llamaAttention a = Fp32llamaAttention(m_path, get_opt_model_config(model_id));
// Fp32LlamaForCausalLM model = Fp32LlamaForCausalLM(m_path, get_opt_model_config(model_id));
std::cout << "Finished!" << std::endl;

Expand All @@ -124,7 +124,7 @@ int main(int argc, char **argv)
case INT4:
{
m_path = "../models/INT4/" + m_path;
Fp32llamaAttention a = Fp32llamaAttention(m_path, get_opt_model_config(model_id));
Int4LlamaForCausalLM model = Int4LlamaForCausalLM(m_path, get_opt_model_config(model_id));
std::cout << "Finished!" << std::endl;

// Get input from the user
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include"llamaAttention.h"
#include"llamaAttention_fp32.h"



Expand Down
14 changes: 14 additions & 0 deletions src/modules/llamaAttention_int4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "llamaAttention_int4.h"

Int4llamaAttention::Int4llamaAttention(std::string param_path, const struct model_config config)
{
std::cout << param_path << std::endl;
uint8_t *q_weight, *k_weight, *v_weight, *o_weight;
allocate_aligned_memory(q_weight, (config.embed_dim * config.embed_dim * sizeof(uint8_t)) / 2);
allocate_aligned_memory(k_weight, (config.embed_dim * config.embed_dim * sizeof(uint8_t)) / 2);
allocate_aligned_memory(v_weight, (config.embed_dim * config.embed_dim * sizeof(uint8_t)) / 2);
allocate_aligned_memory(o_weight, (config.embed_dim * config.embed_dim * sizeof(uint8_t)) / 2);
// this->q_proj =
// Linear_FP_int4(Matrix3D<uint8_t>(q_weight, 1, config.embed_dim, config.embed_dim / 2), param_path + "/q_proj");
std::cout << "Allocated memory" << std::endl;
}
File renamed without changes.
19 changes: 19 additions & 0 deletions src/modules/llamaDecoder_int4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "llamaDecoder_int4.h"

Int4llamaDecoder::Int4llamaDecoder(std::string param_path, const struct model_config config)
{
allocate_aligned_memory(attention_mask_buf, config.max_sqlen * config.max_sqlen * sizeof(float));
allocate_aligned_memory(pos_embeds_buf, config.max_sqlen * config.embed_dim * sizeof(float));
allocate_aligned_memory(last_hidden_states_buf, config.max_sqlen * config.embed_dim * sizeof(float));
allocate_aligned_memory(hidden_states_buf, config.max_sqlen * config.embed_dim * sizeof(float));

for (int layer_idx = 0; layer_idx < config.num_layers; layer_idx++) {
DEBUG_INS(std::cout << "Start loading layer:" << layer_idx << "..." << std::endl;)

std::string path = param_path + "/layer" + std::to_string(layer_idx);
Int4llamaDecoderLayer layer = Int4llamaDecoderLayer(path, config, layer_idx);

this->layers.push_back(layer);
}
std::cout << "Int4llamaDecoder init finished!" << std::endl;
}
File renamed without changes.
8 changes: 8 additions & 0 deletions src/modules/llamaDecoderlayer_int4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include "llamaDecoderlayer_int4.h"


Int4llamaDecoderLayer::Int4llamaDecoderLayer(std::string param_path, const struct model_config config, int layer_idx) {

this->attn = Int4llamaAttention(param_path + "/self_attn", config);
std::cout << "Int4llamaDecoderLayer init finished! Layer index: " << layer_idx << std::endl;
}
File renamed without changes.
11 changes: 11 additions & 0 deletions src/modules/llamaForCausalLM_int4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "llamaForCausalLM_int4.h"

Int4LlamaForCausalLM::Int4LlamaForCausalLM(std::string param_path, const struct model_config config) {
allocate_aligned_memory(logits_output, config.max_sqlen * config.vocsize * sizeof(float));
allocate_aligned_memory(lm_head_weight, (config.embed_dim * config.vocsize * sizeof(uint8_t)) / 2);

this->decoder = Int4llamaDecoder(param_path + "/decoder", config);
// this->lm_head = Linear_FP_int4(Matrix3D<uint8_t>(lm_head_weight, 1, config.vocsize, config.embed_dim / 2),
// param_path + "/lm_head");
std::cout << "Int4LlamaForCausalLM init finished!" << std::endl;
}

0 comments on commit 906875c

Please sign in to comment.