Skip to content

Commit

Permalink
[Feature] Support Qwen-7B, dynamic NTK scaling and logN scaling in tu…
Browse files Browse the repository at this point in the history
…rbomind (InternLM#230)

* qwen support

* dynamic ntk & logn attn

* fix ntk & add chat template

* fix ntk scaling & stop words

* fix lint

* add tiktoken to requirements.txt

* fix tokenizer, set model format automatically

* update model.py

* update readme

* fix lint
  • Loading branch information
lzhangzz authored Aug 18, 2023
1 parent 62b60db commit 4a60b45
Show file tree
Hide file tree
Showing 27 changed files with 619 additions and 1,939 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ______________________________________________________________________

## News 🎉

- \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling
- \[2023/08\] TurboMind supports Windows (tp=1)
- \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation🚀. Check [this](./docs/en/w4a16.md) guide for detailed info
- \[2023/08\] LMDeploy has launched on the [HuggingFace Hub](https://huggingface.co/lmdeploy), providing ready-to-use 4-bit models.
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ______________________________________________________________________

## 更新 🎉

- \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放
- \[2023/08\] TurboMind 支持 Windows (tp=1)
- \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现🚀。部署方式请看[这里](./docs/zh_cn/w4a16.md)
- \[2023/08\] LMDeploy 开通了 [HuggingFace Hub](https://huggingface.co/lmdeploy) ,提供开箱即用的 4-bit 模型
Expand Down
3 changes: 2 additions & 1 deletion examples/cpp/llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __init__(self, model_file: str):
self.pad_id = self.model.pad_id()
else:
from transformers import AutoTokenizer
self.model = AutoTokenizer.from_pretrained(model_file)
self.model = AutoTokenizer.from_pretrained(model_file,
trust_remote_code=True)
self.vocab_size = self.model.vocab_size
self.start_id = self.model.bos_token_id
self.end_id = self.model.eos_token_id
Expand Down
30 changes: 30 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,36 @@ def get_prompt(self, prompt, sequence_start=True):
return f'{self.b_inst} {prompt} {self.e_inst} '


@MODELS.register_module(name='qwen-7b')
class Qwen7BChat(BaseModel):
"""Chat template for Qwen-7B-Chat."""

def __init__(self):
super().__init__()
self.session_len = 8192
self.top_p = 0.5
self.top_k = 40
self.temperature = 1.0

self.im_start = '<|im_start|>'
self.im_end = '<|im_end|>'
self.system = 'You are a helpful assistant.'

def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'{self.im_start}system\n{self.system}{self.im_end}' \
f'\n{self.im_start}user\n{prompt}{self.im_end}' \
f'\n{self.im_start}assistant\n'

return f'\n{self.im_start}user\n{prompt}{self.im_end}' \
f'\n{self.im_start}assistant\n'

@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [151645] # <|im_end|>


def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
Expand Down
161 changes: 155 additions & 6 deletions lmdeploy/serve/turbomind/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import lmdeploy
from lmdeploy.model import MODELS

supported_formats = ['llama', 'hf', 'awq']
supported_formats = ['llama', 'hf', 'awq', 'qwen']


def get_package_root_path():
Expand Down Expand Up @@ -84,7 +84,7 @@ def copy_triton_model_templates(_path: str):
return None


def tokenizer_info(model_path: str):
def tokenizer_info_sp(model_path: str):
"""Return the vocabulary size, bos token id and eos token id.
Args:
Expand All @@ -101,6 +101,13 @@ def tokenizer_info(model_path: str):
return n_words, bos_id, eos_id


def tokenizer_info_qwen(model_dir: str):
n_words = 151851
bos_id = 0
eos_id = 151643
return n_words, bos_id, eos_id


def export(model_name: str,
num_layer: int,
norm_eps: float,
Expand All @@ -111,7 +118,11 @@ def export(model_name: str,
tp: int,
size_per_head: int = 128,
group_size: int = 0,
weight_type: str = 'fp16'):
weight_type: str = 'fp16',
max_position_embeddings: int = 0,
use_dynamic_ntk: int = 0,
use_logn_attn: int = 0,
tokenizer_info=tokenizer_info_sp):
"""Export deploying information to a config file.
Args:
Expand Down Expand Up @@ -191,7 +202,7 @@ def save_bin(param: torch.Tensor, name):
head_num=head_num,
kv_head_num=kv_head_num,
size_per_head=size_per_head,
vocab_size=vocab_size,
vocab_size=_vocab_size,
num_layer=num_layer,
rotary_embedding=size_per_head,
inter_size=inter_size,
Expand All @@ -210,7 +221,11 @@ def save_bin(param: torch.Tensor, name):
cache_chunk_size=1,
use_context_fmha=1,
quant_policy=0,
tensor_para_size=tp))
tensor_para_size=tp,
# extra attention params
max_position_embeddings=max_position_embeddings,
use_dynamic_ntk=int(use_dynamic_ntk),
use_logn_attn=int(use_logn_attn)))

config = configparser.ConfigParser()
for section, key_values in cfg.items():
Expand Down Expand Up @@ -725,6 +740,134 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size=group_size)


def deploy_qwen(model_name: str, model_path: str, tokenizer_path: str,
triton_models_path: str, tp: int):
"""Deploy a model with huggingface transformers' format.
Args:
model_name (str): the name of the to-be-deployed model
model_path (str): the path of the directory where the model weight
files are
tokenizer_path (str): the path of the tokenizer model path
triton_models_path (str): the path of the exported triton models
tp (int): the number of tensor parallelism
quant_path (str): path of the quantized model, which can be None
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
"""

if osp.exists(model_path):
shutil.copy(osp.join(model_path, 'qwen.tiktoken'),
osp.join(triton_models_path, 'tokenizer'))
for _file in os.listdir(model_path):
if _file.endswith('.json') or _file.endswith('.py'):
json_path = osp.join(model_path, _file)
shutil.copy(json_path,
osp.join(triton_models_path, 'tokenizer', _file))
with get_package_root_path() as root_path:
shutil.copy(osp.join(root_path, 'turbomind/tokenizer.py'),
osp.join(triton_models_path, 'tokenizer'))
else:
print(f'tokenizer model {tokenizer_path} does not exist')
exit(-1)

# read model arguments from params.json
try:
params_path = osp.join(model_path, 'config.json')
with open(params_path) as f:
config = json.load(f)
num_layer = config['num_hidden_layers']
norm_eps = config['layer_norm_epsilon']
if 'num_key_value_heads' in config:
kv_head_num = config['num_key_value_heads']
else:
kv_head_num = config['num_attention_heads']
seq_length = config['seq_length']
use_dynamic_ntk = config['use_dynamic_ntk']
use_logn_attn = config['use_logn_attn']
except Exception as e:
print(f'get "num_hidden_layers" and "layer_norm_epsilon" from '
f'{params_path} failed: {e}')
return False

# convert weights from hf to turbomind
model_params = {}

_files = [file for file in os.listdir(model_path) if file.endswith('.bin')]
_files = sorted(_files)
print(_files)

_params = {}
for _file in _files:
_tmp = torch.load(osp.join(model_path, _file), map_location='cpu')
_params.update(_tmp)

def get_tensor(name, trans=True):
"""return a transposed tensor according its name."""
if trans:
return _params[name].cuda().t()
else:
return _params[name].cuda()

for i in range(num_layer):
print(i)

# qkv weights
qkv_w = get_tensor(f'transformer.h.{i}.attn.c_attn.weight')
q_w, k_w, v_w = torch.split(qkv_w, qkv_w.size(-1) // 3, dim=-1)
q_w, k_w = permute(q_w), permute(k_w)
qkv_w = merge_qkv(q_w, k_w, v_w, tp, dim=2)
model_params[f'layers.{i}.attention.w_qkv.weight'] = qkv_w

# qkv bias
qkv_b = get_tensor(f'transformer.h.{i}.attn.c_attn.bias')
q_b, k_b, v_b = torch.split(qkv_b, qkv_b.size(-1) // 3)
q_b, k_b = permute(q_b), permute(k_b)
qkv_b = merge_qkv(q_b, k_b, v_b, tp, dim=1)
model_params[f'layers.{i}.attention.w_qkv.bias'] = qkv_b

# o weights
o_w = get_tensor(f'transformer.h.{i}.attn.c_proj.weight')
model_params[f'layers.{i}.attention.wo.weight'] = o_w
model_params[f'layers.{i}.attention.wo.bias'] = torch.zeros_like(q_b)

# ffn weights
# ours: w2(silu(w1(x)) * w3(x))
# qwen: c_proj(w1(x) * silu(w2(x)))
w1 = get_tensor(f'transformer.h.{i}.mlp.w2.weight')
w3 = get_tensor(f'transformer.h.{i}.mlp.w1.weight')
w2 = get_tensor(f'transformer.h.{i}.mlp.c_proj.weight')
model_params[f'layers.{i}.feed_forward.w1.weight'] = w1
model_params[f'layers.{i}.feed_forward.w2.weight'] = w2
model_params[f'layers.{i}.feed_forward.w3.weight'] = w3

# norm weights
attn_norm = get_tensor(f'transformer.h.{i}.ln_1.weight')
ffn_norm = get_tensor(f'transformer.h.{i}.ln_2.weight')

model_params[f'layers.{i}.attention_norm.weight'] = attn_norm
model_params[f'layers.{i}.ffn_norm.weight'] = ffn_norm

other = [('tok_embeddings.weight', 'transformer.wte.weight'),
('norm.weight', 'transformer.ln_f.weight'),
('output.weight', 'lm_head.weight')]
for ft, hf in other:
model_params[ft] = get_tensor(hf, trans=False)

return export(model_name,
num_layer,
norm_eps,
kv_head_num,
model_params,
model_path,
triton_models_path,
tp,
max_position_embeddings=seq_length,
use_dynamic_ntk=use_dynamic_ntk,
use_logn_attn=use_logn_attn,
tokenizer_info=tokenizer_info_qwen)


def pack_model_repository(workspace_path: str):
"""package the model repository.
Expand Down Expand Up @@ -752,7 +895,7 @@ def pack_model_repository(workspace_path: str):

def main(model_name: str,
model_path: str,
model_format: str = 'hf',
model_format: str = None,
tokenizer_path: str = None,
dst_path: str = './workspace',
tp: int = 1,
Expand All @@ -777,6 +920,9 @@ def main(model_name: str,
f"'{model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}'

if model_format is None:
model_format = 'qwen' if model_name == 'qwen-7b' else 'hf'

if model_format not in supported_formats:
print(f'the model format "{model_format}" is not supported. '
f'The supported format are: {supported_formats}')
Expand All @@ -803,6 +949,9 @@ def main(model_name: str,
elif model_format == 'awq':
res = deploy_awq(model_name, model_path, tokenizer_path,
triton_models_path, tp, quant_path, group_size)
elif model_format == 'qwen':
res = deploy_qwen(model_name, model_path, tokenizer_path,
triton_models_path, tp)

# update `tensor_para_size` in `triton_models/interactive/config.pbtxt`
with open(osp.join(triton_models_path, 'interactive/config.pbtxt'),
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/turbomind/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from typing import Sequence, Union

Expand Down Expand Up @@ -99,6 +100,13 @@ def __init__(self, model_dir: str):
if hasattr(self.model, 'backend_tokenizer'):
self.model.backend_tokenizer.save(backend_tokenizer_file)

if self.model.eos_token_id is None:
generation_config_file = osp.join(model_dir,
'generation_config.json')
with open(generation_config_file, 'r') as f:
cfg = json.load(f)
self.model.eos_token_id = cfg['eos_token_id']

@property
def vocab_size(self):
"""vocabulary size."""
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pybind11
safetensors
sentencepiece
setuptools
tiktoken
torch
transformers
tritonclient[all]
31 changes: 9 additions & 22 deletions src/turbomind/kernels/decoder_masked_multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,6 @@ struct Multihead_attention_params_base {
// The input Vs and the associated bias. Dimensions B x D and D, resp.
const T *v = nullptr, *v_bias = nullptr;

// The cache for the Ks. The size must be at least B x L x D.
T* k_cache = nullptr;
// The cache for the Vs. The size must be at least B x L x D.
T* v_cache = nullptr;
// The indirections to use for cache when beam sampling.
const int* cache_indir = nullptr;

// scales
const float* query_weight_output_scale = nullptr;
const float* attention_qk_scale = nullptr;
Expand Down Expand Up @@ -108,10 +101,6 @@ struct Multihead_attention_params_base {
// The slope per head of linear position bias to attention score (H).
const T* linear_bias_slopes = nullptr;

const T* ia3_key_weights = nullptr;
const T* ia3_value_weights = nullptr;
const int* ia3_tasks = nullptr;

const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr;
int int8_mode = 0;
Expand All @@ -123,17 +112,15 @@ struct Multihead_attention_params_base {

template<typename T>
struct Multihead_attention_params: public Multihead_attention_params_base<T> {
// allows to exist attention eary
bool* finished = nullptr;

// required in case of masked attention with different length
const int* length_per_sample = nullptr;

T** k_cache_per_sample = nullptr;
T** v_cache_per_sample = nullptr;
size_t kv_cache_per_sample_offset = 0;
bool k_cache_interleaved = true;
int num_kv_heads = 0;
bool* finished = nullptr;
const int* length_per_sample = nullptr;
T** k_cache_per_sample = nullptr;
T** v_cache_per_sample = nullptr;
size_t kv_cache_per_sample_offset = 0;
int num_kv_heads = 0;
int max_position_embeddings = 0;
bool use_dynamic_ntk = false;
bool use_logn_attn = false;
};

template<class T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,12 @@

////////////////////////////////////////////////////////////////////////////////////////////////////

// !!! Specialize the launcher for Cross attention
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value;
// constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
const int tlength = params.timestep;

FT_CHECK(params.cache_indir == nullptr);
const int tlength = params.timestep;

if (params.int8_mode == 4) {
if (tlength < 32) {
Expand Down
Loading

0 comments on commit 4a60b45

Please sign in to comment.