Skip to content

Commit

Permalink
Update to MiniCPM-V 2.6
Browse files Browse the repository at this point in the history
  • Loading branch information
yiranyyu committed Aug 6, 2024
1 parent 1cb882d commit b1a1529
Show file tree
Hide file tree
Showing 28 changed files with 3,692 additions and 191 deletions.
1,002 changes: 948 additions & 54 deletions README.md

Large diffs are not rendered by default.

1,002 changes: 948 additions & 54 deletions README_en.md

Large diffs are not rendered by default.

997 changes: 947 additions & 50 deletions README_zh.md

Large diffs are not rendered by default.

Binary file added assets/gif_cases/ai.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/gif_cases/beer.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/gif_cases/mb.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/gif_cases/rabbit.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/gif_cases/ticket.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/gif_cases/wfh.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/gif_cases/zoo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/minicpmv2_6/ICL-Mem.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/minicpmv2_6/ICL-elec.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/minicpmv2_6/multi_img-bike.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/minicpmv2_6/multi_img-code.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/minicpmv2_6/multi_img-menu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/minicpmv2_6/multiling-medal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/minicpmv2_6/multiling-olympic.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/radar_final.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 75 additions & 1 deletion chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,87 @@ def chat(self, input):
)
return answer

class MiniCPMV2_6:
def __init__(self, model_path, multi_gpus=False) -> None:

print('torch_version:', torch.__version__)
if multi_gpus: # inference on multi-gpus
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
with init_empty_weights():
model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16)

device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
device_id = device_map["llm.model.embed_tokens"]
device_map["llm.lm_head"] = device_id # first and last layer of llm should be in the same device
device_map["vpm"] = device_id
device_map["resampler"] = device_id
device_id2 = device_map["llm.model.layers.26"]
device_map["llm.model.layers.8"] = device_id2
device_map["llm.model.layers.9"] = device_id2
device_map["llm.model.layers.10"] = device_id2
device_map["llm.model.layers.11"] = device_id2
device_map["llm.model.layers.12"] = device_id2
device_map["llm.model.layers.13"] = device_id2
device_map["llm.model.layers.14"] = device_id2
device_map["llm.model.layers.15"] = device_id2
device_map["llm.model.layers.16"] = device_id2
print(device_map)

self.model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
self.model.eval()
else:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
self.model.eval().cuda()

self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

def chat(self, input):
image = None
if "image" in input and len(input["image"]) > 10: # legacy API
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"

msgs = json.loads(input["question"])

for msg in msgs:
contents = msg.pop('content') # support str or List[Dict]
if isinstance(contents, str):
contents = [contents]

new_cnts = []
for c in contents:
if isinstance(c, dict):
if c['type'] == 'text':
c = c['pairs']
elif c['type'] == 'image':
c = Image.open(io.BytesIO(base64.b64decode(c["pairs"]))).convert('RGB')
else:
raise ValueError("content type only support text and image.")
new_cnts.append(c)
msg['content'] = new_cnts
print(f'msgs: {str(msgs)}')

answer = self.model.chat(
image=image,
msgs=msgs,
tokenizer=self.tokenizer,
)
return answer


class MiniCPMVChat:
def __init__(self, model_path) -> None:
def __init__(self, model_path, multi_gpus=False) -> None:
if '12B' in model_path:
self.model = OmniLMM12B(model_path)
elif 'MiniCPM-Llama3-V' in model_path:
self.model = MiniCPMV2_5(model_path)
elif 'MiniCPM-V-2_6' in model_path:
self.model = MiniCPMV2_6(model_path, multi_gpus)
else:
self.model = MiniCPMV(model_path)

Expand Down
30 changes: 30 additions & 0 deletions docs/faqs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
### FAQs

<details>
<summary>Q: How to choose between sampling or beam search for inference </summary>

In various scenarios, the quality of results obtained from beam search and sampling decoding strategies can vary. You can determine your decoding strategy based on the following aspects:

If you have the following needs, consider using sampling decoding:

1. You require faster inference speed.
2. You wish for a streaming generation approach.
3. Your task necessitates some open-ended responses.

If your task is about providing deterministic answers, you might want to experiment with beam search to see if it can achieve better outcomes.
</details>


<details>
<summary>Q: How to ensure that the model generates results of sufficient length</summary>

We've observed that during multi-language inference on MiniCPM-V 2.6, the generation sometimes ends prematurely. You can improve the results by passing a `min_new_tokens` parameter.
```python
res = model.chat(
image=None,
msgs=msgs,
tokenizer=tokenizer,
min_new_tokens=100
)
```
</details>
87 changes: 75 additions & 12 deletions finetune/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def trim_and_pad(seq, batch_first, padding_value):
}


def conversation_to_ids(conversation, tokenizer, llm_type=None):
def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
Expand All @@ -115,6 +115,10 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
input_ids, context, raw_msg = conversation_to_ids_llama3(
conversation, tokenizer
)
elif llm_type == "qwen2":
input_ids, context, raw_msg = conversation_to_ids_qwen2(
conversation, tokenizer
)
else:
input_ids, context, raw_msg = conversation_to_ids_minicpm(
conversation, tokenizer
Expand All @@ -125,6 +129,7 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):

# build target
target = torch.full_like(ids, -100, dtype=torch.int32)

for i in range(1, len(ids)):
if context[i] == 0:
target[i - 1] = ids[i]
Expand All @@ -133,14 +138,21 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None):
target[i - 1] = tokenizer.eot_id
else:
target[i - 1] = tokenizer.eos_id

# build image bound
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if new_schema:
start_cond = (ids == tokenizer.im_start_id) | (ids == tokenizer.slice_start_id)
end_cond = (ids == tokenizer.im_end_id) | (ids == tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
else:
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print("image start token != image end tokens")

if len(image_start_tokens) > 0:
image_bound = torch.hstack(
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
Expand Down Expand Up @@ -230,6 +242,46 @@ def conversation_to_ids_llama3(conversation, tokenizer):
return input_ids, context, raw_msg


def conversation_to_ids_qwen2(conversation, tokenizer):
raw_msg = ""
chat = []
context = []
for idx, msg in enumerate(conversation):
role = msg["role"]
message = msg["content"]
assert role in ["user", "assistant"]
if role == "user":
prefix = "user"
else:
prefix = "assistant"
chat.append({"role":prefix, "content":message})
raw_msg += prefix + message
assert set([i['role'] for i in chat]) & set(['assistant'])

ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False)
input_ids = np.array(input_ids)

start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0]
assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0]
end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0]

context = np.ones_like(input_ids, dtype=np.int8)

for assistant_idx in assistant_idxs:
if assistant_idx-1 in set(start_idxs):
st = assistant_idx + 1
for end_idx in end_idxs:
if end_idx > st:
context[st: end_idx + 1] = 0
break

input_ids = np.hstack(input_ids)
context = np.hstack(context)
return input_ids, context, raw_msg



def preprocess(
image,
conversation,
Expand All @@ -256,8 +308,14 @@ def preprocess(
default_image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
)
new_schema = False
use_image_id = False
if llm_type=='qwen2':
new_schema = True
use_image_id = True
if slice_config:
images = []
image_id_cnt = 0
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
Expand All @@ -270,9 +328,11 @@ def preprocess(
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])

if use_image_id:
image_placeholder = f'{tokenizer.im_id_start}{idx}{tokenizer.im_id_end}' + image_placeholder
image_id_cnt += 1
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums)
tokenizer, best_grid, query_nums, new_schema = new_schema)
images = [transform(i) for i in images]
else:
images = [transform(image)]
Expand All @@ -286,7 +346,7 @@ def preprocess(
image_placeholder + "\n" + conversation[0]["content"]
)

input_dict = conversation_to_ids(conversation, tokenizer, llm_type)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type, new_schema)

if batch_vision:
tgt_sizes = []
Expand Down Expand Up @@ -424,7 +484,7 @@ def split_to_patches(image, grid):
return patches


def get_grid_placeholder(tokenizer, grid, query_num):
def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
)
Expand All @@ -437,7 +497,10 @@ def get_grid_placeholder(tokenizer, grid, query_num):
for j in range(cols):
lines.append(image_placeholder)
slices.append("".join(lines))
slice_placeholder = tokenizer.slice_start + \
if new_schema:
slice_placeholder = '\n'.join(slices)
else:
slice_placeholder = tokenizer.slice_start + \
"\n".join(slices) + tokenizer.slice_end
return slice_placeholder

Expand All @@ -455,4 +518,4 @@ def reshape_by_patch(image_tensor, patch_size):
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
patches = patches.permute(0, 1, 3, 2).reshape(
image_tensor.size(0), patch_size, -1)
return patches
return patches
17 changes: 16 additions & 1 deletion finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from functools import partial
from typing import Dict, List, Optional, Union, Literal, Tuple
from types import MethodType
from torchvision import transforms

import torch
import transformers
from accelerate.utils import DistributedType
Expand Down Expand Up @@ -130,6 +132,18 @@ def make_supervised_data_module(
)


def build_transform():
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
return transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
),
]
)

def get_parameter_number(model):
trainable_params, all_param = 0, 0
for param in model.parameters():
Expand Down Expand Up @@ -248,10 +262,11 @@ def get_input_embeddings(self):
else:
batch_vision = False

transform_func = build_transform()
data_module = make_supervised_data_module(
tokenizer=tokenizer,
data_args=data_args,
transform=model.transform,
transform=transform_func,
data_collator=data_collator,
slice_config=slice_config,
llm_type=llm_type,
Expand Down
19 changes: 11 additions & 8 deletions finetune/finetune_ds.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=6001

MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
MODEL="openbmb/MiniCPM-V-2_6"
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="path/to/trainging_data"
EVAL_DATA="path/to/test_data"
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"



DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
Expand All @@ -28,10 +31,10 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--remove_unused_columns false \
--label_names "labels" \
--prediction_loss_only false \
--bf16 false \
--bf16_full_eval false \
--fp16 true \
--fp16_full_eval true \
--bf16 true \
--bf16_full_eval true \
--fp16 false \
--fp16_full_eval false \
--do_train \
--do_eval \
--tune_vision true \
Expand All @@ -40,8 +43,8 @@ torchrun $DISTRIBUTED_ARGS finetune.py \
--max_slice_nums 9 \
--max_steps 10000 \
--eval_steps 1000 \
--output_dir output/output_minicpmv2 \
--logging_dir output/output_minicpmv2 \
--output_dir output/output_minicpmv26 \
--logging_dir output/output_minicpmv26 \
--logging_strategy "steps" \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
Expand Down
Loading

0 comments on commit b1a1529

Please sign in to comment.