forked from Lightning-AI/litgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_adapter.py
117 lines (100 loc) · 4.38 KB
/
generate_adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import sys
import time
import warnings
from pathlib import Path
from typing import Optional
import lightning as L
import torch
from generate import generate
from lit_parrot import Tokenizer
from lit_parrot.adapter import Parrot, Config
from lit_parrot.utils import EmptyInitOnDevice, lazy_load, check_valid_checkpoint_dir
from scripts.prepare_alpaca import generate_prompt
def main(
prompt: str = "What food do lamas eat?",
input: str = "",
adapter_path: Path = Path("out/adapter/alpaca/lit_model_adapter_finetuned.pth"),
checkpoint_dir: Path = Path(f"checkpoints/stabilityai/stablelm-base-alpha-3b"),
quantize: Optional[str] = None,
max_new_tokens: int = 100,
top_k: int = 200,
temperature: float = 0.8,
) -> None:
"""Generates a response based on a given instruction and an optional input.
This script will only work with checkpoints from the instruction-tuned Parrot-Adapter model.
See `finetune_adapter.py`.
Args:
prompt: The prompt/instruction (Alpaca style).
adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
`finetune_adapter.py`.
checkpoint_dir: The path to the checkpoint folder with pretrained Parrot weights.
input: Optional input (Alpaca style).
quantize: Whether to quantize the model and using which method:
``"llm.int8"``: LLM.int8() mode,
``"gptq.int4"``: GPTQ 4-bit mode.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
"""
check_valid_checkpoint_dir(checkpoint_dir)
fabric = L.Fabric(devices=1)
dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
with open(checkpoint_dir / "lit_config.json") as fp:
config = Config(**json.load(fp))
if quantize == "gptq.int4":
model_file = "lit_model_gptq.4bit.pth"
if not (checkpoint_dir / model_file).is_file():
raise ValueError("Please run `python quantize/gptq.py` first")
else:
model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file
print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.time()
with EmptyInitOnDevice(device=fabric.device, dtype=dtype, quantization_mode=quantize):
model = Parrot(config)
with lazy_load(checkpoint_path) as pretrained_checkpoint, lazy_load(
adapter_path
) as adapter_checkpoint:
# 1. Load the pretrained weights
model.load_state_dict(pretrained_checkpoint, strict=False)
# 2. Load the fine-tuned adapter weights
model.load_state_dict(adapter_checkpoint, strict=False)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
model.eval()
model = fabric.setup(model)
tokenizer = Tokenizer(checkpoint_dir / "tokenizer.json", checkpoint_dir / "tokenizer_config.json")
sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, device=model.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
t0 = time.perf_counter()
y = generate(
model,
encoded,
max_returned_tokens,
max_seq_length=max_returned_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id,
)
t = time.perf_counter() - t0
model.reset_cache()
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
print(output)
tokens_generated = y.size(0) - prompt_length
print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
if fabric.device.type == "cuda":
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
if __name__ == "__main__":
from jsonargparse import CLI
torch.set_float32_matmul_precision("high")
warnings.filterwarnings(
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
"ignore",
message="ComplexHalf support is experimental and many operators don't support it yet",
)
CLI(main)