Skip to content

Commit

Permalink
Add argument for precisions
Browse files Browse the repository at this point in the history
  • Loading branch information
SyphonArch committed May 4, 2024
1 parent b70aef3 commit 8d65074
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 23 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,15 @@ We have provided a demo script to showcase the inference capabilities of the qua
To run the demo, execute the following command:

```bash
python demo.py
python demo.py -p 3 4 5 6 16
```

Note that the demo script requires the quantized `Llama-2-7b-chat-hf` model to be present in the cache directory.
Other models can be used by changing the `model_path` and `original_model_path` variables in the script.

The demo script will load the quantized model, and perform inference on a custom prompt, using precisions ranging from
the seed precision to the parent precision. The latency at each precision is measured and displayed.
The demo script will load the quantized model, and perform inference on a custom prompt, using specified precisions.
Include 16 to measure the latency of the original model in fp16.
The latency at each precision will be measured and displayed.

The demo will look like this when run properly:

Expand Down
62 changes: 42 additions & 20 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
import logging
import time
from argparse import ArgumentParser

# Logging with time sans date, level name, and message
logging.basicConfig(level=logging.INFO, format='[%(asctime)s | %(levelname)s] %(message)s', datefmt='%H:%M:%S')

parser = ArgumentParser()
parser.add_argument('-p', '--precisions', nargs='+', type=int, default=None,
help="The precisions to benchmark. If not specified, all available precisions will be benchmarked."
)

args = parser.parse_args()

if __name__ == '__main__':
model_path = './cache/packed/anyprec-(Llama-2-7b-chat-hf)-w8_orig3-gc1-c4_s100_blk512'
original_model_path = 'meta-llama/Llama-2-7b-chat-hf'
Expand All @@ -18,6 +26,18 @@
model = AnyPrecisionForCausalLM.from_quantized(model_path)
model = model.eval().cuda()

# Configure the precisions to benchmark
do_fp16 = True
if args.precisions is not None:
precisions = args.precisions
if 16 in precisions:
precisions.remove(16)
else:
do_fp16 = False
assert all(precision in model.precisions for precision in precisions), "Unsupported precision(s) specified."
else:
precisions = model.precisions

# Warm up CUDA cache for stable performance
print("~~~~~~~ Warming up CUDA cache ~~~~~~~")
input_context = "A CUDA cache warm-up is needed to"
Expand All @@ -36,7 +56,8 @@
input_ids = tokenizer.encode(input_context, return_tensors="pt").cuda()

results = {}
for precision in model.precisions:

for precision in precisions:
print(f"=============== generation with {precision}-bit precision ===============")
torch.cuda.synchronize()
start_time = time.time()
Expand All @@ -63,28 +84,29 @@
del model
torch.cuda.empty_cache()

# Benchmark the original model
print(f"=============== generation with fp16 precision ===============")
model = AutoModelForCausalLM.from_pretrained(original_model_path, torch_dtype=torch.float16).eval().cuda()
torch.cuda.synchronize()
start_time = time.time()
output = model.generate(
input_ids,
max_length=256,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer,
)
torch.cuda.synchronize()
end_time = time.time()
if do_fp16:
# Benchmark the original model
print(f"=============== generation with fp16 precision ===============")
model = AutoModelForCausalLM.from_pretrained(original_model_path, torch_dtype=torch.float16).eval().cuda()
torch.cuda.synchronize()
start_time = time.time()
output = model.generate(
input_ids,
max_length=256,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer,
)
torch.cuda.synchronize()
end_time = time.time()

# Calculate generation speed
token_count = len(output[0]) - len(input_ids[0])
tokens_per_second = token_count / (end_time - start_time)
ms_per_token = 1 / tokens_per_second * 1000
# Calculate generation speed
token_count = len(output[0]) - len(input_ids[0])
tokens_per_second = token_count / (end_time - start_time)
ms_per_token = 1 / tokens_per_second * 1000

results[16] = (tokens_per_second, ms_per_token)
results[16] = (tokens_per_second, ms_per_token)

print(f"\n( Generation speed: {tokens_per_second:.1f} tok/s | Latency: {ms_per_token:.2f} ms/tok )\n")
print(f"\n( Generation speed: {tokens_per_second:.1f} tok/s | Latency: {ms_per_token:.2f} ms/tok )\n")

print("=============== Summary ===============")
print(f"\nModel: {model_path}\n")
Expand Down

0 comments on commit 8d65074

Please sign in to comment.