-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
45,986 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ git clone [email protected]:YaoJiayi/CacheBlend.git | |
cd CacheBlend/vllm_blend | ||
pip install -e . | ||
cd .. | ||
pip install -r requirements.txt | ||
``` | ||
|
||
|
||
|
@@ -24,4 +25,5 @@ python example/blend.py | |
``` | ||
python example/blend_musique.py | ||
``` | ||
To run datasets other than musique, please replace `musique` with `samsum` or `wikimqa` in the above command. | ||
## References |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from vllm import LLM, SamplingParams | ||
import torch | ||
import json | ||
import numpy as np | ||
from transformers import AutoTokenizer | ||
from utils import load_dataset, normalize_question, build_fewshot_prompt, compute_rl | ||
from pathlib import Path | ||
from itertools import chain | ||
|
||
eval_dataset = load_dataset("inputs/samsum.json") | ||
|
||
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", gpu_memory_utilization=0.5) | ||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") | ||
llm.set_tokenizer(tokenizer) | ||
|
||
prefix_prompt = "Summarize the dialogue into a few short sentences. The following are some examples.\n\n" | ||
|
||
ttft_blend = [] | ||
ttft_full = [] | ||
rl_blend = [] | ||
rl_full = [] | ||
|
||
max_ctx_len = 3400 | ||
#TODO (Jiayi): fix filler tokens at the begining or pass in tokenizer | ||
for sample_idx, ex in enumerate(eval_dataset): | ||
answers = ex["answers"] | ||
doc_prompts, q_prompt = build_fewshot_prompt(ex) | ||
doc_chunk_ids = [tokenizer.encode(doc)[1:] for doc in doc_prompts] | ||
q_ids = tokenizer.encode(q_prompt)[1:] | ||
|
||
|
||
# drop last few-shot examples if exceeding max_ctx_len | ||
while len(list(chain.from_iterable(doc_chunk_ids))) > max_ctx_len: | ||
del_idx = int(len(doc_chunk_ids)/2) | ||
del doc_chunk_ids[del_idx] | ||
|
||
# skip if all ctxs are dropped | ||
if len(doc_chunk_ids)==0: | ||
continue | ||
|
||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0, max_tokens=1) | ||
|
||
# Create an tokenizer and LLM. | ||
cache_fuse_metadata = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.cache_fuse_metadata | ||
cache_fuse_metadata['collect'] = False | ||
cache_fuse_metadata['check'] = False | ||
cache_fuse_metadata['attn_bias'] = None | ||
|
||
s_start_full = tokenizer.encode(prefix_prompt)[1:] | ||
s_start_len = len(s_start_full) + 1 | ||
|
||
s_start = [] | ||
s_start_1_len = len(s_start) + 1 | ||
|
||
s_end = [] | ||
s_end_len = len(s_end) | ||
|
||
doc_chunk_ids = [s_start+chunk_ids for chunk_ids in doc_chunk_ids] | ||
doc_chunk_ids = [s_start_full] + doc_chunk_ids | ||
doc_chunk_ids = doc_chunk_ids + [s_start+q_ids+s_end] | ||
|
||
last_len = len(q_ids+s_end) | ||
|
||
cache_fuse_metadata['collect'] = True | ||
cache_fuse_metadata["check"] = False | ||
num_layer = 32 | ||
chunk_past_key_values = [] | ||
shift = 0 | ||
# Concatenate old KVs | ||
for i in range(len(doc_chunk_ids)): | ||
prompts = [tokenizer.decode(doc_chunk_ids[i])] | ||
llm.generate(prompts, sampling_params) | ||
shift += len(doc_chunk_ids[i]) | ||
llm_layers = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.layers | ||
for j in range(num_layer): | ||
past_key_values = llm_layers[j].self_attn.hack_kv | ||
if i == 0: | ||
temp_k = past_key_values[0][:s_start_len].clone() # do not chage with s_start_1 | ||
temp_v = past_key_values[1][:s_start_len].clone() | ||
else: | ||
temp_k = past_key_values[0][s_start_1_len:len(doc_chunk_ids[i])+1].clone() | ||
temp_v = past_key_values[1][s_start_1_len:len(doc_chunk_ids[i])+1].clone() | ||
|
||
if i == 0: | ||
chunk_past_key_values.append([temp_k, temp_v]) | ||
else: | ||
#pdb.set_trace() | ||
chunk_past_key_values[j][0] = torch.cat((chunk_past_key_values[j][0],temp_k), dim=0) | ||
chunk_past_key_values[j][1] = torch.cat((chunk_past_key_values[j][1],temp_v), dim=0) | ||
llm_layers[j].self_attn.hack_kv = None | ||
llm.llm_engine.model_executor.driver_worker.model_runner.model.model.old_kvs = chunk_past_key_values | ||
|
||
input_ids = [] | ||
|
||
for i in range(len(doc_chunk_ids)): | ||
if i == 0: | ||
temp_ids = doc_chunk_ids[i] | ||
else: | ||
temp_ids = doc_chunk_ids[i][s_start_1_len-1:] | ||
input_ids += temp_ids | ||
|
||
input_prompt = tokenizer.decode(input_ids) | ||
|
||
sampling_params = SamplingParams(temperature=0, max_tokens=128) | ||
cache_fuse_metadata["check"] = True | ||
cache_fuse_metadata['collect'] = False | ||
cache_fuse_metadata['recomp_ratio'] = 0.18 | ||
cache_fuse_metadata['fast_attention'] = True | ||
cache_fuse_metadata['suffix_len'] = last_len | ||
|
||
print(f"Sample idx: {sample_idx}") | ||
output = llm.generate([input_prompt], sampling_params) | ||
res = output[0].outputs[0].text | ||
# TODO(Jiayi): please move this to utils | ||
res = res.lstrip('\n').split('\n')[0] | ||
print(f"Cached generation: {res}") | ||
ttft = output[0].metrics.first_token_time-output[0].metrics.first_scheduled_time | ||
print(f"TTFT with cache: {ttft}") | ||
ttft_blend.append(ttft) | ||
rl = max([compute_rl(res, answer) for answer in answers]) | ||
rl_blend.append(rl) | ||
|
||
|
||
sampling_params = SamplingParams(temperature=0, max_tokens=128) | ||
cache_fuse_metadata["check"] = False | ||
cache_fuse_metadata['collect'] = False | ||
output = llm.generate([input_prompt], sampling_params) | ||
res = output[0].outputs[0].text | ||
res = res.lstrip('\n').split('\n')[0] | ||
print(f"Normal generation: {res}") | ||
ttft = output[0].metrics.first_token_time-output[0].metrics.first_scheduled_time | ||
print(f"TTFT with full prefill: {ttft}") | ||
ttft_full.append(ttft) | ||
rl = max([compute_rl(res, answer) for answer in answers]) | ||
rl_full.append(rl) | ||
print("------------") | ||
|
||
|
||
print("---------------Result Summary---------------------") | ||
print(f"TTFT with cache: {np.mean(ttft_blend)}") | ||
print(f"TTFT with full prefill: {np.mean(ttft_full)}") | ||
print(f"rl with cache: {np.mean(rl_blend)}") | ||
print(f"rl with full prefill: {np.mean(rl_full)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
rouge_score |