-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathexample_chat_hf.py
49 lines (44 loc) · 1.01 KB
/
example_chat_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# -*- encoding: utf-8 -*-
# File: example_chat_hf.py
# Description: None
import torch
from transformers import AutoModelForCausalLM
path = "/mnt/algorithm/user_dir/zhoudong/workspace/models/megrez-o" # Change this to the path of the model.
model = (
AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
.eval()
.cuda()
)
prompt = "hi" * (128 - 1)
# Chat with text and image
messages = [
{
"role": "user",
"content": {
"text": prompt,
"image": "./data/sample_image.jpg",
},
},
]
# Chat with audio and image
# messages = [
# {
# "role": "user",
# "content": {
# "image": "./data/sample_image.jpg",
# "audio": "./data/sample_audio.m4a",
# },
# },
# ]
MAX_NEW_TOKENS = 100
response = model.chat(
messages,
sampling=False,
max_new_tokens=MAX_NEW_TOKENS,
)
print(response)