Skip to content

Commit

Permalink
update time performance experiments and visualization modules
Browse files Browse the repository at this point in the history
  • Loading branch information
huangzhengxiang committed Nov 13, 2024
1 parent 028f09a commit a6c6298
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 9 deletions.
83 changes: 83 additions & 0 deletions transformers/llm/datasets/visualization/time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
from matplotlib import cbook
from matplotlib.axes import Axes
from typing import List, Dict, Tuple
import pandas as pd
import numpy as np
import argparse
import os
import re
from io import StringIO

def split_by_turns(id: str, content: str) -> List[pd.DataFrame]:
pattern = "<{id}>\n(.*?)</{id}>\n".format(id=id)
return [pd.read_csv(StringIO(item)) for item in re.findall(pattern, content, flags=re.DOTALL)]
def preprocess(file_path: str) -> Tuple[List[pd.DataFrame], List[pd.DataFrame]]:
content = open(file_path, "rt").read()
return split_by_turns("prefill", content), split_by_turns("decode", content)
def get_max_turn(no_reuse_prefill_record):
return max(10, max([len(record) for record in no_reuse_prefill_record]))
def draw_history_len(ax: Axes, no_reuse_prefill_record: List[pd.DataFrame]):
max_round = get_max_turn(no_reuse_prefill_record)
history_len = [0 for _ in range(0, max_round)]
for turn in range(0, max_round):
history_len[turn] = np.median([record["input_token"][turn] - record["prompt_token"][turn]
for record in no_reuse_prefill_record if len(record)>=turn+1]).item()
plt.plot(np.arange(1, max_round+1), history_len, label="median history len", marker=".", markersize=8)
return
def draw_prefill_bar_chat(ax: Axes, no_reuse, reuse):
offset = 0.2
max_round = len(no_reuse)
no_reuse_med = [np.median(turn) for turn in no_reuse]
rects = ax.bar(np.arange(1,max_round+1) + offset, no_reuse_med, offset*2, label="no reuse kv", color="tomato")
ax.bar_label(rects, fmt="{:.2f}", padding=4, fontsize=6)
reuse_med = [np.median(turn) for turn in reuse]
rects = ax.bar(np.arange(1,max_round+1) - offset, reuse_med, offset*2, label="reuse kv", color="springgreen")
ax.bar_label(rects, fmt="{:.2f}", padding=4, fontsize=6)
return
def compare_prefill_reuse_kv(no_reuse_prefill_record: List[pd.DataFrame],
reuse_prefill_record: List[pd.DataFrame]):
plt.close()
_,ax1 = plt.subplots()
ax2 = ax1.twinx()
# plot history_len
draw_history_len(ax2, no_reuse_prefill_record)
# calculate per turn
max_round = get_max_turn(no_reuse_prefill_record)
no_reuse = [[] for _ in range(0, max_round)]
for turn in range(0, max_round):
no_reuse[turn] = [record["response_speed"][turn] for record in no_reuse_prefill_record if len(record)>=turn+1]
reuse = [[] for _ in range(0, max_round)]
for turn in range(0, max_round):
reuse[turn] = [record["response_speed"][turn] for record in reuse_prefill_record if len(record)>=turn+1]
# plot the bar chat (with error bar)
draw_prefill_bar_chat(ax1, no_reuse, reuse)
ax1.set_xticks(np.arange(1,max_round+1),np.arange(1,max_round+1),fontsize=9)
ax1.set_ylim(0,100)
ax2.set_ylim(0,1000)
ax1.legend(loc='upper left', title="prefill response speed")
ax2.legend(loc='upper right')
ax1.set_ylabel("prefill\nresponse\nspeed", rotation=0, labelpad=12)
ax2.set_ylabel("history\nlen", rotation=0, labelpad=8)
ax1.set_xlabel("round")
plt.title("KV cache reuse for multi-turn chat\neffects on ShareGPT")
plt.tight_layout()
plt.savefig("./pic/fig.png",dpi=1200)
plt.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="./data")
parser.add_argument("--no_reuse", type=str, default="shareGPT_common_en_70k_noreuse.txt")
parser.add_argument("--reuse", type=str, default="shareGPT_common_en_70k_reuse.txt")
args = parser.parse_args()

no_reuse_file_path = os.path.join(args.root, args.no_reuse)
reuse_file_path = os.path.join(args.root, args.reuse)
no_reuse_prefill_record, no_reuse_decode_record = preprocess(no_reuse_file_path)
reuse_prefill_record, reuse_decode_record = preprocess(reuse_file_path)
# visualize prefill
compare_prefill_reuse_kv(no_reuse_prefill_record, reuse_prefill_record)
20 changes: 12 additions & 8 deletions transformers/llm/engine/src/LlmSessionInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,29 @@ int LlmSessionInfo::getTotalDecodeLen() {
return mTimePerformance.decode_record_.size();
}
void LlmSessionInfo::print_speed(std::ostream* os) {
(*os) << "prefill " << mTimePerformance.prefill_record_.size() << std::endl;
// prefill statistics
(*os) << "<prefill>" << std::endl;
if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) {
(*os) << "prev_token input_token speed(token/s)" << std::endl;
(*os) << "prev_token,input_token,response_speed" << std::endl;
for (auto record : mTimePerformance.prefill_record_) {
(*os) << record.prefill_prev_token_ << " " << record.prefill_token_ << " " << record.prefill_token_/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
(*os) << record.prefill_prev_token_ << "," << record.prefill_token_ << "," << record.prefill_token_/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
}
} else {
(*os) << "prev_token input_token prompt_token response_speed(token/s)" << std::endl;
(*os) << "prev_token,input_token,prompt_token,response_speed" << std::endl;
for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) {
auto record = mTimePerformance.prefill_record_[r];
auto prompt_len = mTimePerformance.prompt_record_[r];
(*os) << record.prefill_prev_token_ << " " << record.prefill_token_ << " " << prompt_len << " " << prompt_len/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
(*os) << record.prefill_prev_token_ << "," << record.prefill_token_ << "," << prompt_len << "," << prompt_len/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
}
}
(*os) << "decode " << mTimePerformance.decode_record_.size() << std::endl;
(*os) << "prev_token speed(token/s)" << std::endl;
(*os) << "</prefill>" << std::endl;
// decode statistics
(*os) << "<decode>" << std::endl;
(*os) << "prev_token,response_speed" << std::endl;
for (auto record : mTimePerformance.decode_record_) {
(*os) << record.decode_prev_token_ << " " << 1./(((float)record.decode_us_)*MICRO_TO_SEC) << std::endl;
(*os) << record.decode_prev_token_ << "," << 1./(((float)record.decode_us_)*MICRO_TO_SEC) << std::endl;
}
(*os) << "</decode>" << std::endl;
}

} // Transformer
Expand Down
1 change: 0 additions & 1 deletion transformers/llm/engine/src/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ float ChatPPLMeasurer::perplexity_one(const std::vector<std::vector<PromptItem>>

// record time performance to file
if (perfOS != nullptr) {
(*perfOS) << "<chat>" << std::endl;
mLlm->mLlmSessionInfos[0].print_speed(perfOS);
}

Expand Down

0 comments on commit a6c6298

Please sign in to comment.