forked from PaddlePaddle/PaddleMIX
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support deepseek-vl2 infer (tiny and small) (PaddlePaddle#995)
- Loading branch information
Showing
14 changed files
with
3,937 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Deepseek-VL2 | ||
|
||
## 1. 模型介绍 | ||
[DeepSeek-VL2](https://github.com/deepseek-ai/DeepSeek-VL2)是一种基于大型混合专家(Mixture-of-Experts,MoE)视觉语言模型,相较于其前身DeepSeek-VL有了显著提升。DeepSeek-VL2在各种任务中展现出了卓越的能力,包括但不限于视觉问答、光学字符识别、文档/表格/图表理解以及视觉定位。我们的模型系列包含三种变体:DeepSeek-VL2-Tiny、DeepSeek-VL2-Small和DeepSeek-VL2,分别拥有10亿、28亿和45亿个激活参数。与现有的开源密集型和基于MoE的模型相比,DeepSeek-VL2在激活参数相似或更少的情况下,实现了具有竞争力甚至最先进的性能。 | ||
![Overview of DeepSeek-VL2](https://github.com/user-attachments/assets/926928a3-bad2-4c5b-8f45-d0c9a7661f34) | ||
注:以上为 DeepSeek-VL2 的整体架构图引用自论文。 | ||
|
||
**本仓库支持的模型权重:** | ||
|
||
| Model | | ||
|--------------------| | ||
| deepseek-ai/deepseek-vl2-tiny | | ||
| deepseek-ai/deepseek-vl2-small | | ||
|
||
注意:与huggingface权重同名,但权重为paddle框架的Tensor,使用`xxx.from_pretrained("deepseek-ai/deepseek-vl2-tiny")`即可自动下载该权重文件夹到缓存目录。 | ||
|
||
## 2 环境准备 | ||
|
||
1)[安装 PaddleMIX 环境依赖包](https://github.com/PaddlePaddle/PaddleMIX/tree/develop?tab=readme-ov-file#%E5%AE%89%E8%A3%85) | ||
|
||
2)pip install pillow tqdm paddlenlp==3.0.0b3 | ||
|
||
注意:Python版本最好为3.10及以上版本。 | ||
|
||
## 3 快速开始 | ||
|
||
### 推理 | ||
> 注:在V100上运行以下代码需要指定dtype="float16", 如果需要使用deepseek-vl2-small模型,需要修改model_path为"deepseek-ai/deepseek-vl2-small" | ||
```bash | ||
# Deepseek-vl2-tiny single image understanding | ||
python paddlemix/examples/deepseek_vl2/single_image_infer.py \ | ||
--model_path="deepseek-ai/deepseek-vl2-tiny" \ | ||
--image_file="paddlemix/demo_images/examples_image2.jpg" \ | ||
--question="The Panda" \ | ||
--dtype="bfloat16" | ||
|
||
# Deepseek-vl2-tiny multi image understanding | ||
python paddlemix/examples/deepseek_vl2/multi_image_infer.py \ | ||
--image_file_1="paddlemix/demo_images/examples_image1.jpg" \ | ||
--image_file_2="paddlemix/demo_images/examples_image2.jpg" \ | ||
--image_file_3="paddlemix/demo_images/twitter3.jpeg" \ | ||
--question="Can you tell me what are in the images?" \ | ||
--dtype="bfloat16" | ||
|
||
# Deepseek-vl2-tiny increment prefilling kv cache inference | ||
python paddlemix/examples/deepseek_vl2/increment_prefilling_infer.py \ | ||
--image_file_1="paddlemix/demo_images/examples_image1.jpg" \ | ||
--image_file_2="paddlemix/demo_images/examples_image2.jpg" \ | ||
--image_file_3="paddlemix/demo_images/twitter3.jpeg" \ | ||
--question="Can you tell me what are in the images?" \ | ||
--dtype="bfloat16" | ||
``` | ||
|
||
### 结果展示 | ||
1) DeepSeek-VL2-tiny Single Image Understanding | ||
|
||
![panda](https://github.com/user-attachments/assets/6f66021c-c2fe-4231-a466-6b3747c26f7c) | ||
``` | ||
<|User|>: <image> | ||
<|ref|>The Panda<|/ref|>. | ||
<|Assistant|>: <|ref|>The Panda<|/ref|><|det|>[[100, 192, 998, 998]]<|/det|><|end▁of▁sentence|> | ||
``` | ||
|
||
2) DeepSeek-VL2-tiny Multi Image Understanding | ||
``` | ||
<|User|>: This is image_1: <image> | ||
This is image_2: <image> | ||
This is image_3: <image> | ||
Can you tell me what are in the images? | ||
<|Assistant|>: The first image shows a red panda resting on a wooden platform. The second image features a giant panda sitting among bamboo plants. The third image captures a rocket launch at night, with the bright trail of the rocket illuminating the sky.<|end▁of▁sentence|> | ||
``` | ||
![mutli-infer](https://github.com/user-attachments/assets/4a1ade41-90ed-4d04-949a-90c3b54bdf78) | ||
## 参考文献 | ||
```BibTeX | ||
@misc{wu2024deepseekvl2mixtureofexpertsvisionlanguagemodels, | ||
title={DeepSeek-VL2: Mixture-of-Experts Vision-Language Models for Advanced Multimodal Understanding}, | ||
author={Zhiyu Wu and Xiaokang Chen and Zizheng Pan and Xingchao Liu and Wen Liu and Damai Dai and Huazuo Gao and Yiyang Ma and Chengyue Wu and Bingxuan Wang and Zhenda Xie and Yu Wu and Kai Hu and Jiawei Wang and Yaofeng Sun and Yukun Li and Yishi Piao and Kang Guan and Aixin Liu and Xin Xie and Yuxiang You and Kai Dong and Xingkai Yu and Haowei Zhang and Liang Zhao and Yisong Wang and Chong Ruan}, | ||
year={2024}, | ||
eprint={2412.10302}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CV}, | ||
url={https://arxiv.org/abs/2412.10302}, | ||
} | ||
``` |
112 changes: 112 additions & 0 deletions
112
paddlemix/examples/deepseek_vl2/increment_prefilling_infer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
import sys | ||
|
||
import paddle | ||
from paddlenlp.generation import GenerationConfig | ||
from paddlenlp.transformers.llama.tokenizer_fast import LlamaTokenizerFast | ||
|
||
from paddlemix.models.deepseek_vl2 import DeepseekVLV2Config, DeepseekVLV2ForCausalLM | ||
from paddlemix.processors.deepseek_vl2_processing import DeepseekVLV2Processor | ||
|
||
sys.path.append(os.path.dirname(__file__)) | ||
from utils import load_pil_images | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_path", type=str, default="deepseek-ai/deepseek-vl2-tiny") | ||
parser.add_argument("--image_file_1", type=str, required=True) | ||
parser.add_argument("--image_file_2", type=str, required=True) | ||
parser.add_argument("--image_file_3", type=str, required=True) | ||
parser.add_argument("--question", type=str, default="What is shown in this image?") | ||
parser.add_argument("--dtype", type=str, default="bfloat16") | ||
|
||
args = parser.parse_args() | ||
|
||
model_path = args.model_path | ||
tokenizer = LlamaTokenizerFast.from_pretrained(model_path) | ||
config = DeepseekVLV2Config.from_pretrained(model_path) | ||
|
||
candidate_resolutions = config["candidate_resolutions"] | ||
patch_size = config.vision_config["patch_size"] | ||
downsample_ratio = config["downsample_ratio"] | ||
vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor( | ||
tokenizer=tokenizer, | ||
candidate_resolutions=candidate_resolutions, | ||
patch_size=patch_size, | ||
downsample_ratio=downsample_ratio, | ||
) | ||
|
||
vl_gpt: DeepseekVLV2ForCausalLM = DeepseekVLV2ForCausalLM.from_pretrained(model_path, dtype=args.dtype).eval() | ||
|
||
# multiple images/interleaved image-text | ||
conversation = [ | ||
{ | ||
"role": "<|User|>", | ||
"content": "This is image_1: <image>\n" | ||
"This is image_2: <image>\n" | ||
f"This is image_3: <image>\n {args.question}", | ||
"images": [ | ||
f"{args.image_file_1}", | ||
f"{args.image_file_2}", | ||
f"{args.image_file_3}", | ||
], | ||
}, | ||
{"role": "<|Assistant|>", "content": ""}, | ||
] | ||
|
||
|
||
pil_images = load_pil_images(conversation) | ||
prepare_inputs = vl_chat_processor( | ||
conversations=conversation, images=pil_images, force_batchify=True, system_prompt="" | ||
) | ||
prepare_inputs.images = prepare_inputs.images.astype(args.dtype) | ||
|
||
with paddle.no_grad(): | ||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) | ||
inputs_embeds, past_key_values = vl_gpt.incremental_prefilling( | ||
input_ids=prepare_inputs.input_ids, | ||
images=prepare_inputs.images, | ||
images_seq_mask=prepare_inputs.images_seq_mask, | ||
images_spatial_crop=prepare_inputs.images_spatial_crop, | ||
attention_mask=prepare_inputs.attention_mask, | ||
chunk_size=512, | ||
) | ||
|
||
generation_config = GenerationConfig( | ||
pad_token_id=tokenizer.pad_token_id, | ||
bos_token_id=tokenizer.bos_token_id, | ||
eos_token_id=tokenizer.eos_token_id, | ||
max_new_tokens=512, | ||
do_sample=False, | ||
trunc_input=True, | ||
output_attentions=True, | ||
use_cache=True, # must true | ||
return_dict=True, | ||
) | ||
outputs = vl_gpt.generate( | ||
inputs_embeds=inputs_embeds, | ||
input_ids=prepare_inputs.input_ids, | ||
images=prepare_inputs.images, | ||
images_seq_mask=prepare_inputs.images_seq_mask, | ||
images_spatial_crop=prepare_inputs.images_spatial_crop, | ||
attention_mask=prepare_inputs.attention_mask, | ||
past_key_values=past_key_values, | ||
generation_config=generation_config, | ||
full_input_ids=prepare_inputs.input_ids, | ||
) | ||
answer = tokenizer.decode(outputs[0][0].cpu().tolist(), skip_special_tokens=False) | ||
print(f"{prepare_inputs['sft_format'][0]}", answer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
import sys | ||
|
||
from paddlenlp.generation import GenerationConfig | ||
from paddlenlp.transformers.llama.tokenizer_fast import LlamaTokenizerFast | ||
|
||
from paddlemix.models.deepseek_vl2 import DeepseekVLV2Config, DeepseekVLV2ForCausalLM | ||
from paddlemix.processors.deepseek_vl2_processing import DeepseekVLV2Processor | ||
|
||
sys.path.append(os.path.dirname(__file__)) | ||
from utils import load_pil_images | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_path", type=str, default="deepseek-ai/deepseek-vl2-tiny") | ||
parser.add_argument("--image_file_1", type=str, required=True) | ||
parser.add_argument("--image_file_2", type=str, required=True) | ||
parser.add_argument("--image_file_3", type=str, required=True) | ||
parser.add_argument("--question", type=str, default="What is shown in this image?") | ||
parser.add_argument("--dtype", type=str, default="bfloat16") | ||
|
||
args = parser.parse_args() | ||
|
||
model_path = args.model_path | ||
|
||
tokenizer = LlamaTokenizerFast.from_pretrained(model_path) | ||
config = DeepseekVLV2Config.from_pretrained(model_path) | ||
|
||
candidate_resolutions = config["candidate_resolutions"] | ||
patch_size = config.vision_config["patch_size"] | ||
downsample_ratio = config["downsample_ratio"] | ||
|
||
vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor( | ||
tokenizer=tokenizer, | ||
candidate_resolutions=candidate_resolutions, | ||
patch_size=patch_size, | ||
downsample_ratio=downsample_ratio, | ||
) | ||
|
||
tokenizer = vl_chat_processor.tokenizer | ||
vl_gpt: DeepseekVLV2ForCausalLM = DeepseekVLV2ForCausalLM.from_pretrained(model_path, dtype=args.dtype).eval() | ||
|
||
|
||
conversation = [ | ||
{ | ||
"role": "<|User|>", | ||
"content": "This is image_1: <image>\n" | ||
"This is image_2: <image>\n" | ||
f"This is image_3: <image>\n {args.question}", | ||
"images": [ | ||
f"{args.image_file_1}", | ||
f"{args.image_file_2}", | ||
f"{args.image_file_3}", | ||
], | ||
}, | ||
{"role": "<|Assistant|>", "content": ""}, | ||
] | ||
|
||
|
||
pil_images = load_pil_images(conversation) | ||
prepare_inputs = vl_chat_processor( | ||
conversations=conversation, images=pil_images, force_batchify=True, system_prompt="" | ||
) | ||
prepare_inputs.images = prepare_inputs.images.astype(args.dtype) | ||
|
||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) | ||
|
||
generation_config = GenerationConfig( | ||
pad_token_id=tokenizer.pad_token_id, | ||
bos_token_id=tokenizer.bos_token_id, | ||
eos_token_id=tokenizer.eos_token_id, | ||
max_new_tokens=512, | ||
do_sample=False, | ||
trunc_input=True, | ||
output_attentions=True, | ||
use_cache=True, # must true | ||
return_dict=True, | ||
) | ||
outputs = vl_gpt.language.generate( | ||
generation_config=generation_config, | ||
inputs_embeds=inputs_embeds.astype(args.dtype), | ||
attention_mask=prepare_inputs.attention_mask, | ||
) | ||
|
||
answer = tokenizer.decode(outputs[0][0].cpu().tolist(), skip_special_tokens=False) | ||
print(f"{prepare_inputs['sft_format'][0]}", answer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
import sys | ||
|
||
from paddlenlp.generation import GenerationConfig | ||
from paddlenlp.transformers.llama.tokenizer_fast import LlamaTokenizerFast | ||
|
||
from paddlemix.models.deepseek_vl2 import DeepseekVLV2Config, DeepseekVLV2ForCausalLM | ||
from paddlemix.processors.deepseek_vl2_processing import DeepseekVLV2Processor | ||
|
||
sys.path.append(os.path.dirname(__file__)) | ||
from utils import load_pil_images | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_path", type=str, default="deepseek-ai/deepseek-vl2-tiny") | ||
parser.add_argument("--image_file", type=str, required=True) | ||
parser.add_argument("--question", type=str, default="What is shown in this image?") | ||
parser.add_argument("--dtype", type=str, default="bfloat16") | ||
|
||
args = parser.parse_args() | ||
|
||
model_path = args.model_path | ||
tokenizer = LlamaTokenizerFast.from_pretrained(model_path) | ||
config = DeepseekVLV2Config.from_pretrained(model_path) | ||
|
||
candidate_resolutions = config["candidate_resolutions"] | ||
patch_size = config.vision_config["patch_size"] | ||
downsample_ratio = config["downsample_ratio"] | ||
vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor( | ||
tokenizer=tokenizer, | ||
candidate_resolutions=candidate_resolutions, | ||
patch_size=patch_size, | ||
downsample_ratio=downsample_ratio, | ||
) | ||
tokenizer = vl_chat_processor.tokenizer | ||
vl_gpt: DeepseekVLV2ForCausalLM = DeepseekVLV2ForCausalLM.from_pretrained(model_path, dtype=args.dtype).eval() | ||
|
||
conversation = [ | ||
{ | ||
"role": "<|User|>", | ||
"content": f"<image>\n<|ref|>{args.question}<|/ref|>.", | ||
"images": [args.image_file], | ||
}, | ||
{"role": "<|Assistant|>", "content": ""}, | ||
] | ||
|
||
pil_images = load_pil_images(conversation) | ||
prepare_inputs = vl_chat_processor( | ||
conversations=conversation, images=pil_images, force_batchify=True, system_prompt="" | ||
) | ||
prepare_inputs.images = prepare_inputs.images.astype(args.dtype) | ||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) | ||
|
||
generation_config = GenerationConfig( | ||
pad_token_id=tokenizer.pad_token_id, | ||
bos_token_id=tokenizer.bos_token_id, | ||
eos_token_id=tokenizer.eos_token_id, | ||
max_new_tokens=512, | ||
do_sample=False, | ||
trunc_input=True, | ||
output_attentions=True, | ||
use_cache=True, # must true for infer | ||
return_dict=True, | ||
) | ||
outputs = vl_gpt.language.generate( | ||
generation_config=generation_config, | ||
inputs_embeds=inputs_embeds.astype(args.dtype), | ||
attention_mask=prepare_inputs.attention_mask, | ||
) | ||
answer = tokenizer.decode(outputs[0][0].cpu().tolist()) | ||
print(f"{prepare_inputs['sft_format'][0]}", answer) |
Oops, something went wrong.