-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsample.py
70 lines (59 loc) · 1.8 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import torch
import tiktoken
from contextlib import nullcontext
from model import GPT
init_from = "resume"
out_dir = "out"
start = "\n"
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = (
0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
)
top_k = (
200 # retain only the top_k most likely tokens, clamp others to have 0 probability
)
device = "cuda"
device_type = "cuda" if "cuda" in device else "cpu"
compile = False
dtype = (
"bfloat16"
if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
else "float16"
) # 'float32' or 'bfloat16' or 'float16'
ptdtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[dtype]
ctx = (
nullcontext()
if device_type == "cpu"
else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
)
if init_from == "resume":
checkpoint_path = os.path.join(out_dir, "model_15000.pt")
checkpoint = torch.load(checkpoint_path)
config = checkpoint["config"]
model = GPT(config)
stat_dict = checkpoint["model"]
model.load_state_dict(stat_dict)
elif init_from.startswith("gpt2"):
model = GPT.from_pretrained(init_from)
model.eval()
model.to(device_type)
if compile:
model = torch.compile(model)
enc = tiktoken.get_encoding("gpt2")
if start.startswith("FILE:"):
with open(start[5:], "r") as f:
start = f.read()
ids = enc.encode(start, allowed_special={"<|endoftext|>"})
x = torch.tensor(ids, dtype=torch.long, device=device)[None, ...]
with torch.no_grad():
with ctx:
for k in range(num_samples):
y = model.generate(x, max_new_tokens, temperature, top_k)
print(enc.decode(y[0].tolist()))
print("---------------")