Skip to content

Commit

Permalink
add llama2 chat template (InternLM#140)
Browse files Browse the repository at this point in the history
* add llama2 template

* update readme and fix lint

* update readme

* add bos

* add bos

* remove bos

* Update model.py

---------

Co-authored-by: grimoire <[email protected]>
  • Loading branch information
grimoire and grimoire authored Jul 20, 2023
1 parent 8ba2d7c commit 406f8c9
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 4 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ English | [简体中文](README_zh-CN.md)

______________________________________________________________________

## News
## News 🎉

\[2023/07\] TurboMind supports tensor-parallel inference of InternLM.
- \[2023/07\] TurboMind supports tensor-parallel inference of InternLM.
- \[2023/07\] TurboMind supports llama2 7b/13b.

______________________________________________________________________

Expand Down
5 changes: 3 additions & 2 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

______________________________________________________________________

## 更新
## 更新 🎉

\[2023/07\] TurboMind 支持 InternLM 的 Tensor Parallel 推理
- \[2023/07\] TurboMind 支持 InternLM 的 Tensor Parallel 推理
- \[2023/07\] TurboMind 支持 Llama2 7b/13b 模型

______________________________________________________________________

Expand Down
44 changes: 44 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,50 @@ def stop_words(self):
return [45623]


@MODELS.register_module(name='llama2')
class Llama2:
"""Chat template of LLaMA2 model."""

def __init__(self):

B_INST, E_INST = '[INST]', '[/INST]'
B_SYS, E_SYS = '<<SYS>>\n', '\n<</SYS>>\n\n'

DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" # noqa: E501

self.b_inst = B_INST
self.e_inst = E_INST
self.b_sys = B_SYS
self.e_sys = E_SYS
self.default_sys_prompt = DEFAULT_SYSTEM_PROMPT

def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
prompt (str): user's input prompt
sequence_start (bool): indicator for the first round chat of a
session sequence
Returns:
str: the concatenated prompt
"""
if sequence_start:
return f'<BOS>{self.b_inst} ' \
f'{self.b_sys} {self.default_sys_prompt} {self.e_sys}' \
f'{prompt} {self.e_inst} '

return f'{self.b_inst} {prompt} {self.e_inst} '

@property
def stop_words(self):
"""Return the stop-words' token ids."""
return None


def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
Expand Down

0 comments on commit 406f8c9

Please sign in to comment.