Skip to content

Commit

Permalink
add samsum & update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoJiayi committed Sep 9, 2024
1 parent 499adb4 commit 0d617d9
Show file tree
Hide file tree
Showing 5 changed files with 45,986 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ git clone [email protected]:YaoJiayi/CacheBlend.git
cd CacheBlend/vllm_blend
pip install -e .
cd ..
pip install -r requirements.txt
```


Expand All @@ -24,4 +25,5 @@ python example/blend.py
```
python example/blend_musique.py
```
To run datasets other than musique, please replace `musique` with `samsum` or `wikimqa` in the above command.
## References
144 changes: 144 additions & 0 deletions example/blend_samsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from vllm import LLM, SamplingParams
import torch
import json
import numpy as np
from transformers import AutoTokenizer
from utils import load_dataset, normalize_question, build_fewshot_prompt, compute_rl
from pathlib import Path
from itertools import chain

eval_dataset = load_dataset("inputs/samsum.json")

llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", gpu_memory_utilization=0.5)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
llm.set_tokenizer(tokenizer)

prefix_prompt = "Summarize the dialogue into a few short sentences. The following are some examples.\n\n"

ttft_blend = []
ttft_full = []
rl_blend = []
rl_full = []

max_ctx_len = 3400
#TODO (Jiayi): fix filler tokens at the begining or pass in tokenizer
for sample_idx, ex in enumerate(eval_dataset):
answers = ex["answers"]
doc_prompts, q_prompt = build_fewshot_prompt(ex)
doc_chunk_ids = [tokenizer.encode(doc)[1:] for doc in doc_prompts]
q_ids = tokenizer.encode(q_prompt)[1:]


# drop last few-shot examples if exceeding max_ctx_len
while len(list(chain.from_iterable(doc_chunk_ids))) > max_ctx_len:
del_idx = int(len(doc_chunk_ids)/2)
del doc_chunk_ids[del_idx]

# skip if all ctxs are dropped
if len(doc_chunk_ids)==0:
continue

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, max_tokens=1)

# Create an tokenizer and LLM.
cache_fuse_metadata = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.cache_fuse_metadata
cache_fuse_metadata['collect'] = False
cache_fuse_metadata['check'] = False
cache_fuse_metadata['attn_bias'] = None

s_start_full = tokenizer.encode(prefix_prompt)[1:]
s_start_len = len(s_start_full) + 1

s_start = []
s_start_1_len = len(s_start) + 1

s_end = []
s_end_len = len(s_end)

doc_chunk_ids = [s_start+chunk_ids for chunk_ids in doc_chunk_ids]
doc_chunk_ids = [s_start_full] + doc_chunk_ids
doc_chunk_ids = doc_chunk_ids + [s_start+q_ids+s_end]

last_len = len(q_ids+s_end)

cache_fuse_metadata['collect'] = True
cache_fuse_metadata["check"] = False
num_layer = 32
chunk_past_key_values = []
shift = 0
# Concatenate old KVs
for i in range(len(doc_chunk_ids)):
prompts = [tokenizer.decode(doc_chunk_ids[i])]
llm.generate(prompts, sampling_params)
shift += len(doc_chunk_ids[i])
llm_layers = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.layers
for j in range(num_layer):
past_key_values = llm_layers[j].self_attn.hack_kv
if i == 0:
temp_k = past_key_values[0][:s_start_len].clone() # do not chage with s_start_1
temp_v = past_key_values[1][:s_start_len].clone()
else:
temp_k = past_key_values[0][s_start_1_len:len(doc_chunk_ids[i])+1].clone()
temp_v = past_key_values[1][s_start_1_len:len(doc_chunk_ids[i])+1].clone()

if i == 0:
chunk_past_key_values.append([temp_k, temp_v])
else:
#pdb.set_trace()
chunk_past_key_values[j][0] = torch.cat((chunk_past_key_values[j][0],temp_k), dim=0)
chunk_past_key_values[j][1] = torch.cat((chunk_past_key_values[j][1],temp_v), dim=0)
llm_layers[j].self_attn.hack_kv = None
llm.llm_engine.model_executor.driver_worker.model_runner.model.model.old_kvs = chunk_past_key_values

input_ids = []

for i in range(len(doc_chunk_ids)):
if i == 0:
temp_ids = doc_chunk_ids[i]
else:
temp_ids = doc_chunk_ids[i][s_start_1_len-1:]
input_ids += temp_ids

input_prompt = tokenizer.decode(input_ids)

sampling_params = SamplingParams(temperature=0, max_tokens=128)
cache_fuse_metadata["check"] = True
cache_fuse_metadata['collect'] = False
cache_fuse_metadata['recomp_ratio'] = 0.18
cache_fuse_metadata['fast_attention'] = True
cache_fuse_metadata['suffix_len'] = last_len

print(f"Sample idx: {sample_idx}")
output = llm.generate([input_prompt], sampling_params)
res = output[0].outputs[0].text
# TODO(Jiayi): please move this to utils
res = res.lstrip('\n').split('\n')[0]
print(f"Cached generation: {res}")
ttft = output[0].metrics.first_token_time-output[0].metrics.first_scheduled_time
print(f"TTFT with cache: {ttft}")
ttft_blend.append(ttft)
rl = max([compute_rl(res, answer) for answer in answers])
rl_blend.append(rl)


sampling_params = SamplingParams(temperature=0, max_tokens=128)
cache_fuse_metadata["check"] = False
cache_fuse_metadata['collect'] = False
output = llm.generate([input_prompt], sampling_params)
res = output[0].outputs[0].text
res = res.lstrip('\n').split('\n')[0]
print(f"Normal generation: {res}")
ttft = output[0].metrics.first_token_time-output[0].metrics.first_scheduled_time
print(f"TTFT with full prefill: {ttft}")
ttft_full.append(ttft)
rl = max([compute_rl(res, answer) for answer in answers])
rl_full.append(rl)
print("------------")


print("---------------Result Summary---------------------")
print(f"TTFT with cache: {np.mean(ttft_blend)}")
print(f"TTFT with full prefill: {np.mean(ttft_full)}")
print(f"rl with cache: {np.mean(rl_blend)}")
print(f"rl with full prefill: {np.mean(rl_full)}")
14 changes: 13 additions & 1 deletion example/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import collections
import string
import re
from rouge_score import rouge_scorer

def load_dataset(dataset_path):
print("Loading dataset:", dataset_path)
Expand Down Expand Up @@ -47,6 +48,12 @@ def build_qa_prompt(example, query_prompt):
q_prompt = f"{query_prompt}{q}\nAnswer:"
return doc_prompts, q_prompt

def build_fewshot_prompt(example):
q = "\n\n"+example["question"]
doc_prompts = [f"{ctx['text']}" for ctx in example["ctxs"]]
q_prompt = f"{q}"
return doc_prompts, q_prompt

def compute_f1(a_pred, a_gold, tokenizer):
a_pred = parse_generation(a_pred)
gold_toks = tokenizer.encode(normalize_answer(a_gold))[1:]
Expand All @@ -64,4 +71,9 @@ def compute_f1(a_pred, a_gold, tokenizer):
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
return f1

def compute_rl(pred, gold):
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
rougeL = scorer.score(gold, pred)['rougeL'].fmeasure
return rougeL
45,826 changes: 45,826 additions & 0 deletions inputs/samsum.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
rouge_score

0 comments on commit 0d617d9

Please sign in to comment.