forked from crankshaw-google/lit-gpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadapter_v2.py
135 lines (115 loc) · 5.36 KB
/
adapter_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import json
import sys
import time
import warnings
from pathlib import Path
from typing import Literal, Optional
import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from generate.base import generate
from lit_gpt import Tokenizer
from lit_gpt.adapter_v2 import GPT, Block, Config
from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, lazy_load, quantization
from scripts.prepare_alpaca import generate_prompt
def main(
prompt: str = "What food do lamas eat?",
input: str = "",
adapter_path: Path = Path("out/adapter_v2/alpaca/lit_model_adapter_finetuned.pth"),
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None,
max_new_tokens: int = 100,
top_k: int = 200,
temperature: float = 0.8,
strategy: str = "auto",
devices: int = 1,
precision: Optional[str] = None,
) -> None:
"""Generates a response based on a given instruction and an optional input.
This script will only work with checkpoints from the instruction-tuned GPT-AdapterV2 model.
See `finetune/adapter_v2.py`.
Args:
prompt: The prompt/instruction (Alpaca style).
input: Optional input (Alpaca style).
adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
`finetune/adapter_v2.py`.
checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
- gptq.int4: 4-bit quantization from GPTQ
for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
strategy: Indicates the Fabric strategy setting to use.
devices: How many devices to use.
precision: Indicates the Fabric precision setting to use.
"""
precision = precision or get_default_supported_precision(training=False)
if strategy == "fsdp":
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch()
check_valid_checkpoint_dir(checkpoint_dir)
with open(checkpoint_dir / "lit_config.json") as fp:
config = Config(**json.load(fp))
if quantize is not None and devices > 1:
raise NotImplementedError
if quantize == "gptq.int4":
model_file = "lit_model_gptq.4bit.pth"
if not (checkpoint_dir / model_file).is_file():
raise ValueError("Please run `python quantize/gptq.py` first")
else:
model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True), quantization(quantize):
model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
t0 = time.perf_counter()
with lazy_load(checkpoint_path) as checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
model.load_state_dict(checkpoint, strict=quantize is None)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
model.eval()
model = fabric.setup(model)
tokenizer = Tokenizer(checkpoint_dir)
sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, device=model.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
t0 = time.perf_counter()
y = generate(
model,
encoded,
max_returned_tokens,
max_seq_length=max_returned_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id,
)
t = time.perf_counter() - t0
model.reset_cache()
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
fabric.print(output)
tokens_generated = y.size(0) - prompt_length
fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
if __name__ == "__main__":
from jsonargparse import CLI
torch.set_float32_matmul_precision("high")
warnings.filterwarnings(
# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
"ignore",
message="ComplexHalf support is experimental and many operators don't support it yet",
)
CLI(main)