Skip to content

Commit

Permalink
benching
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaryan0404 committed Aug 16, 2024
1 parent 30636c3 commit 7e41031
Show file tree
Hide file tree
Showing 3 changed files with 540 additions and 0 deletions.
28 changes: 28 additions & 0 deletions examples/lolcats_hh/custom_llama_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaForCausalLM
from transformers import LlamaConfig
from torch import nn
from typing import Optional, Tuple, Union, List
import torch

from typing import Optional, Tuple, Union, List

from tk_hedgehog_window_attention import TKHedgehogWindowAttention


class CustomLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = TKHedgehogWindowAttention(config=config, layer_idx=layer_idx)


class CustomLlamaModel(LlamaModel):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[CustomLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)

class CustomLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.model = CustomLlamaModel(config)
226 changes: 226 additions & 0 deletions examples/lolcats_hh/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import fire
import os
import sys
import time
import gradio as gr

import torch
from transformers import AutoTokenizer, AutoConfig

from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
from llama_recipes.inference.model_utils import load_peft_model

from accelerate.utils import is_xpu_available

from custom_llama_model import CustomLlamaForCausalLM
from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig

import matplotlib.pyplot as plt

def load_custom_llama_model(model_name, quantization, use_fast_kernels):
config = AutoConfig.from_pretrained(model_name)
model = CustomLlamaForCausalLM.from_pretrained(model_name, config=config, device_map="auto",
low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)

# if quantization:
# model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

if use_fast_kernels:
print("Using fast kernels")
model = AutoModelForCausalLM.from_pretrained(
model_name,
return_dict=True,
load_in_8bit=quantization,
device_map="auto",
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2" if use_fast_kernels else None,
torch_dtype=torch.bfloat16
)

return model

def main(
model_name,
peft_model: str = None,
quantization: bool = False,
max_new_tokens = 20,
prompt_file: str = None,
seed: int = 42,
do_sample: bool = True,
min_length: int = None,
use_cache: bool = True,
top_p: float = 1.0,
temperature: float = 1.0,
top_k: int = 50,
repetition_penalty: float = 1.0,
length_penalty: int = 1,
enable_azure_content_safety: bool = False,
enable_sensitive_topics: bool = False,
enable_salesforce_content_safety: bool = True,
enable_llamaguard_content_safety: bool = False,
max_padding_length: int = None,
use_fast_kernels: bool = False,
**kwargs
):
def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, batch_size, use_fa, **kwargs):
safety_checker = get_safety_checker(
enable_azure_content_safety,
enable_sensitive_topics,
enable_salesforce_content_safety,
enable_llamaguard_content_safety
)

safety_results = [check(user_prompt) for check in safety_checker]
are_safe = all([r[1] for r in safety_results])
if are_safe:
print("User prompt deemed safe.")
# print(f"User prompt:\n{user_prompt}")
else:
print("User prompt deemed unsafe.")
for method, is_safe, report in safety_results:
if not is_safe:
print(method)
print(report)
print("Skipping the inference as the prompt is not safe.")
sys.exit(1) # Exit the program with an error status

if is_xpu_available():
torch.xpu.manual_seed(seed)
else:
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

model = load_custom_llama_model(model_name, quantization, use_fa)
if peft_model:
model = load_peft_model(model, peft_model)

model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

user_prompt = [str(user_prompt)] * batch_size

batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt", is_split_into_words=False)
if is_xpu_available():
batch = {k: v.to("xpu") for k, v in batch.items()}
else:
batch = {k: v.to("cuda") for k, v in batch.items()}

start = time.perf_counter()
with torch.no_grad():
outputs = model.generate(
**batch,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_p=top_p,
temperature=temperature,
min_length=min_length,
use_cache=use_cache,
top_k=top_k,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
**kwargs
)
e2e_inference_time = (time.perf_counter() - start) * 1000
print(f"The inference time is {e2e_inference_time} ms: batch size = {batch_size} and sequence length = {len(user_prompt[0])}")
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

safety_results = [check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) for check in safety_checker]
are_safe = all([r[1] for r in safety_results])
if are_safe:
print("User input and model output deemed safe.")
# print(f"Model output:\n{output_text}")
else:
print("Model output deemed unsafe.")
for method, is_safe, report in safety_results:
if not is_safe:
print(method)
print(report)

return e2e_inference_time, len(user_prompt[0])
# return output_text

user_prompt = "In the year 2045, humanity has achieved remarkable technological advancements. Flying cars zoom through the skies of megacities, while underwater colonies thrive in the depths of the oceans. Artificial intelligence has become an integral part of daily life, assisting in everything from healthcare to space exploration. Sarah, a brilliant neuroscientist, has been working on a groundbreaking project to merge human consciousness with AI. Her goal is to create a symbiotic relationship between organic brains and artificial neural networks, potentially unlocking unprecedented cognitive abilities and extending human lifespan. As Sarah prepares to present her findings at the World Science Summit, she encounters an ethical dilemma. Her research has revealed unforeseen consequences that could fundamentally alter the course of human evolution. She must decide whether to proceed with her work or suppress her discoveries for the greater good. Meanwhile, on Mars, the first human colony is facing a crisis. A mysterious illness is spreading among the settlers, and communication with Earth has been disrupted by an intense solar storm. The colony's leader, Commander Chen, must make difficult decisions to ensure the survival of his team and the future of Mars exploration. Back on Earth, deep in the Amazon rainforest, a team of environmentalists has made an astounding discovery. They've found a previously unknown species of plant with extraordinary properties. Initial tests suggest it could revolutionize medicine and potentially solve the global energy crisis. However, they soon realize that harvesting the plant could disrupt the delicate ecosystem and potentially lead to unforeseen ecological disasters. As these events unfold across the solar system, a young journalist named Alex embarks on a dangerous investigation. They've uncovered evidence of a secretive organization that seems to be manipulating global events from the shadows. Alex's pursuit of the truth will take them from the neon-lit streets of Tokyo to the hidden bunkers beneath the Antarctic ice. In this complex web of scientific breakthroughs, ethical challenges, and hidden agendas, the fate of humanity hangs in the balance. The decisions made by Sarah, Commander Chen, the environmentalists, and Alex will shape the future of our species and our place in the universe. As the world stands on the brink of a new era,"

batch_sizes = [1, 2, 4, 8, 16, 32]
seq_lengths = [64, 128, 256, 512, 1024]

flash_batch_size_results = {bs: [] for bs in batch_sizes}
flash_seq_length_results = {sl: [] for sl in seq_lengths}
tk_batch_size_results = {bs: [] for bs in batch_sizes}
tk_seq_length_results = {sl: [] for sl in seq_lengths}

# version without TK (using flash_attention)
for bs in batch_sizes:
for sl in seq_lengths:
inference_time, actual_seq_length = inference(user_prompt[:sl], temperature, top_p, top_k, max_new_tokens, bs, True)
flash_batch_size_results[bs].append(inference_time)
flash_seq_length_results[sl].append(inference_time)

for bs in batch_sizes:
for sl in seq_lengths:
inference_time, actual_seq_length = inference(user_prompt[:sl], temperature, top_p, top_k, max_new_tokens, bs, False)
tk_batch_size_results[bs].append(inference_time)
tk_seq_length_results[sl].append(inference_time)

# Plotting
plt.figure(figsize=(20, 10))

# Batch size plot
plt.subplot(1, 2, 1)
for bs in batch_sizes:
plt.plot(seq_lengths, flash_batch_size_results[bs], marker='o', linestyle='-', label=f'Flash Attn BS {bs}')
plt.plot(seq_lengths, tk_batch_size_results[bs], marker='s', linestyle='--', label=f'TK BS {bs}')
plt.xlabel('Sequence Length')
plt.ylabel('Inference Time (ms)')
plt.title('Inference Time vs Sequence Length')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xscale('log')
plt.yscale('log')

# Sequence length plot
plt.subplot(1, 2, 2)
for sl in seq_lengths:
plt.plot(batch_sizes, flash_seq_length_results[sl], marker='o', linestyle='-', label=f'Flash Attn SL {sl}')
plt.plot(batch_sizes, tk_seq_length_results[sl], marker='s', linestyle='--', label=f'TK SL {sl}')
plt.xlabel('Batch Size')
plt.ylabel('Inference Time (ms)')
plt.title('Inference Time vs Batch Size')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xscale('log')
plt.yscale('log')

plt.tight_layout()
plt.savefig('inference_time_comparison_plots.png', bbox_inches='tight')
plt.close()

print("Plots have been saved as 'inference_time_comparison_plots.png'")

# if prompt_file is not None:
# assert os.path.exists(prompt_file), f"Provided Prompt file does not exist {prompt_file}"
# with open(prompt_file, "r") as f:
# user_prompt = "\n".join(f.readlines())
# inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
# elif not sys.stdin.isatty():
# user_prompt = "\n".join(sys.stdin.readlines())
# inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
# else:
# gr.Interface(
# fn=inference,
# inputs=[
# gr.components.Textbox(lines=9, label="User Prompt", placeholder="none"),
# gr.components.Slider(minimum=0, maximum=1, value=1.0, label="Temperature"),
# gr.components.Slider(minimum=0, maximum=1, value=1.0, label="Top p"),
# gr.components.Slider(minimum=0, maximum=100, step=1, value=50, label="Top k"),
# gr.components.Slider(minimum=1, maximum=2000, step=1, value=200, label="Max tokens"),
# ],
# outputs=[
# gr.components.Textbox(lines=5, label="Output"),
# ],
# title="Meta Llama3 Playground",
# description="https://github.com/facebookresearch/llama-recipes",
# ).queue().launch(server_name="0.0.0.0", share=True)

if __name__ == "__main__":
fire.Fire(main)
Loading

0 comments on commit 7e41031

Please sign in to comment.