Skip to content

Commit

Permalink
Add interfaces to the pipeline to obtain logits and ppl (InternLM#1652)
Browse files Browse the repository at this point in the history
* pipeline ppl

* turbomind decode support embeddings input

* pipeline get_logtis support embeddings input

* add prepare_inputs

* update docs

* fix long session ppl

* fix lint

* fix unequal session_len of turbomind and pipeline

* reduce memory

* fix pytorch engine crush

* pytorch engine decode embeddings

* remove do_preprocess

* Revert "fix unequal session_len of turbomind and pipeline"

This reverts commit 0b0508a.

* fix template

* fix size

* fix

* update docs

* fix steps

* remove convert to numpy

* update docs
  • Loading branch information
irexyc authored Jun 25, 2024
1 parent a5aeee3 commit c59a704
Show file tree
Hide file tree
Showing 15 changed files with 433 additions and 75 deletions.
15 changes: 7 additions & 8 deletions docs/en/advance/long_context.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,19 @@ passkey_retrival(session_len, 5)
The following codes demonstrate how to use LMDeploy to calculate perplexity.

```python
from datasets import load_dataset
from lmdeploy import TurbomindEngineConfig
from lmdeploy.turbomind import TurboMind
from transformers import AutoTokenizer
from lmdeploy import TurbomindEngineConfig, pipeline
import numpy as np

# load model and tokenizer
engine_config = TurbomindEngineConfig(rope_scaling_factor=2.0, session_len=160000)
engine = TurboMind.from_pretrained('internlm/internlm2-chat-7b', engine_config)
tokenizer = engine.tokenizer
generator = engine.create_instance()
model_repoid_or_path = 'internlm/internlm2-chat-7b'
backend_config = TurbomindEngineConfig(rope_scaling_factor=2.0, session_len=160000)
pipe = pipeline(model_repoid_or_path, backend_config=backend_config)
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)

# get perplexity
text = 'Use a long prompt to replace this sentence'
input_ids = tokenizer.encode(text)
loss = generator.get_ppl(input_ids)[0]
loss = pipe.get_ppl(input_ids)[0]
ppl = np.exp(loss)
```
20 changes: 20 additions & 0 deletions docs/en/inference/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,26 @@ for item in pipe.stream_infer(prompts, gen_config=gen_config):
print(item)
```

- **An example to cauculate logits & ppl:**

```python
from transformers import AutoTokenizer
from lmdeploy import pipeline
model_repoid_or_path='internlm/internlm2-chat-7b'
pipe = pipeline(model_repoid_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)

# logits
messages = [
{"role": "user", "content": "Hello, how are you?"},
]
input_ids = tokenizer.apply_chat_template(messages)
logits = pipe.get_logits(input_ids)

# ppl
ppl = pipe.get_ppl(input_ids)
```

- **Below is an example for pytorch backend. Please install triton first.**

```shell
Expand Down
18 changes: 18 additions & 0 deletions docs/en/inference/vl_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,24 @@ response = pipe(('describe this image', image))
print(response)
```

### Calculate logits

We provide support for custom inputs. Users can utilize 'prepare_inputs' to understand how the inputs are organized.

```python
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image
pipe = pipeline('internlm/internlm-xcomposer2-7b', backend_config=TurbomindEngineConfig(cache_max_entry_count=0.5))

# logits
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
inputs = pipe.prepare_inputs(('describe this image', image))
input_ids = inputs['input_ids']
embeddings = inputs['input_embeddings']
embedding_ranges = inputs['input_embedding_ranges']
logits = pipe.get_logits(input_ids, embeddings, embedding_ranges)
```

## Multi-images inference

When dealing with multiple images, you can put them all in one list. Keep in mind that multiple images will lead to a higher number of input tokens, and as a result, the size of the [context window](#set-context-window-size) typically needs to be increased.
Expand Down
17 changes: 8 additions & 9 deletions docs/zh_cn/advance/long_context.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,19 @@ passkey_retrival(session_len, 5)
下面展示使用 LMDeploy 计算困惑度的用法

```python
from datasets import load_dataset
from lmdeploy import TurbomindEngineConfig
from lmdeploy.turbomind import TurboMind
from transformers import AutoTokenizer
from lmdeploy import TurbomindEngineConfig, pipeline
import numpy as np

# load model and tokenizer
engine_config = TurbomindEngineConfig(rope_scaling_factor=2.0, session_len=160000)
engine = TurboMind.from_pretrained('internlm/internlm2-chat-7b', engine_config)
tokenizer = engine.tokenizer
generator = engine.create_instance()
model_repoid_or_path = 'internlm/internlm2-chat-7b'
backend_config = TurbomindEngineConfig(rope_scaling_factor=2.0, session_len=160000)
pipe = pipeline(model_repoid_or_path, backend_config=backend_config)
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)

# get perplexity
text = 'The grass is green. The sky is blue. The sun is yellow'
text = 'Use a long prompt to replace this sentence'
input_ids = tokenizer.encode(text)
loss = generator.get_ppl(input_ids)[0]
loss = pipe.get_ppl(input_ids)[0]
ppl = np.exp(loss)
```
20 changes: 20 additions & 0 deletions docs/zh_cn/inference/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,26 @@ for item in pipe.stream_infer(prompts, gen_config=gen_config):
print(item)
```

- **计算 logits & ppl:**

```python
from transformers import AutoTokenizer
from lmdeploy import pipeline
model_repoid_or_path='internlm/internlm2-chat-7b'
pipe = pipeline(model_repoid_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)

# logits
messages = [
{"role": "user", "content": "Hello, how are you?"},
]
input_ids = tokenizer.apply_chat_template(messages)
logits = pipe.get_logits(input_ids)

# ppl
ppl = pipe.get_ppl(input_ids)
```

- **使用 pytorch 后端**

需要先安装 triton
Expand Down
18 changes: 18 additions & 0 deletions docs/zh_cn/inference/vl_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,24 @@ response = pipe(('describe this image', image))
print(response)
```

### 计算 logits

LMDeploy 支持用户自定义输入,用户可以调用`prepare_inputs`,了解多模态的输入是如何组织的。

```python
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image
pipe = pipeline('internlm/internlm-xcomposer2-7b', backend_config=TurbomindEngineConfig(cache_max_entry_count=0.5))

# logits
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
inputs = pipe.prepare_inputs(('describe this image', image))
input_ids = inputs['input_ids']
embeddings = inputs['input_embeddings']
embedding_ranges = inputs['input_embedding_ranges']
logits = pipe.get_logits(input_ids, embeddings, embedding_ranges)
```

## 多图推理

对于多图的场景,在推理时,只要把它们放在一个列表中即可。不过,多图意味着输入 token 数更多,所以通常需要[增大推理的上下文长度](#设置上下文长度)
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,10 @@ def decode(self,
input_ids (List[List[int]] | List[np.ndaray]): the batch of input
token ids
steps (List[int]): the offset of the k/v cache
input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]):
embeddings features
input_embedding_ranges: (List[List[Tuple[int, int]]]):
the begin/end offsets of input_embeddings to input_ids
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,10 @@ def decode(self,
Args:
input_ids (numpy.ndarray): the batch of input token ids
steps (List[int]): the offset of the k/v cache
input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]):
embeddings features
input_embedding_ranges: (List[List[Tuple[int, int]]]):
the begin/end offsets of input_embeddings to input_ids
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
Expand All @@ -502,6 +506,10 @@ def __add_messages(session_ids, input_ids, adapter_names,
input_embeddings, input_embedding_ranges):
add_msgs = []
sampling_param = SamplingParam(max_new_tokens=0)
batch_size = len(input_ids)
if input_embeddings is None:
input_embeddings = [None] * batch_size
input_embedding_ranges = [None] * batch_size
for (session_id, token_id, adapter_name, input_emb,
input_ranges) in zip(session_ids, input_ids, adapter_names,
input_embeddings,
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def split(self, split_size: int, block_size: int):
if overlap:
block_end += 1

block_offsets = self.block_offsets[:, :block_end]
block_offsets = self.block_offsets
inp = ModelInputs(
input_ids=self.input_ids[:, start:end],
seq_length=input_ids.new_tensor([end - start]),
Expand Down
15 changes: 2 additions & 13 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PytorchEngineConfig, Response,
TurbomindEngineConfig)
from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model
from lmdeploy.serve.utils import LogitsMixin, _get_event_loop
from lmdeploy.tokenizer import DetokenizeState
from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_logger

Expand Down Expand Up @@ -119,19 +120,7 @@ def __repr__(self) -> str:
return res


def _get_event_loop():
"""get event loop."""
try:
event_loop = asyncio.get_event_loop()
except Exception:
logger.warning('Can not found event loop in current thread.'
' Create a new event loop.')
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
return event_loop


class AsyncEngine:
class AsyncEngine(LogitsMixin):
"""Async inference engine. Maintaining a bunch of tm_model instances.
Args:
Expand Down
Loading

0 comments on commit c59a704

Please sign in to comment.