forked from Morizeyao/GPT2-Chinese
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
137 lines (115 loc) · 5.46 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn.functional as F
import pytorch_transformers
import os
from my_chinese_tokenizer import tokenization_bert
from tqdm import trange
from pytorch_transformers import GPT2LMHeadModel
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 此处设置程序使用哪些显卡
def top_k_logits(logits, k):
"""
Masks everything but the k top entries as -infinity (1e10).
Used to mask logits such that e^-infinity -> 0 won't contribute to the
sum of the denominator.
"""
if k == 0:
return logits
else:
values = torch.topk(logits, k)[0]
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits)
def top_filtering(logits, top_k=5, top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k: <=0: no filtering, >0: keep only top k tokens with highest probability.
top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset
whose total probability mass is greater than or equal to the threshold top_p.
In practice, we select the highest probability tokens whose cumulative probability mass exceeds
the threshold top_p.
threshold: a minimal threshold to keep logits
"""
assert logits.dim() == 1 # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
top_k = min(top_k, logits.size(-1))
if top_k > 0:
# Remove all tokens with a probability less than the last token in the top-k tokens
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Compute cumulative probabilities of sorted tokens
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probabilities > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Back to unsorted indices and set them to -infinity
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
indices_to_remove = logits < threshold
logits[indices_to_remove] = filter_value
return logits
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1.0, top_k=0,
device='cuda', sample=True):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
else:
assert context is None, 'Specify exactly one of start_token and context!'
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
prev = context
output = context
past = None
with torch.no_grad():
for _ in trange(length):
logits, past = model(prev, past=past)
logits = logits[:, -1, :] / temperature
logits = logits.squeeze(0)
logits = top_filtering(logits)
log_probs = F.softmax(logits, dim=-1)
# log_probs[100] = 0
if sample:
prev = torch.multinomial(log_probs, num_samples=1)
prev = prev.unsqueeze(dim=-1)
else:
_, prev = torch.topk(log_probs, k=1, dim=-1)
output = torch.cat((output, prev), dim=1)
return output
def main():
length = -1
batch_size = 1
nsamples = 18
temperature = 1
topk = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = tokenization_bert.BertTokenizer(vocab_file='cache/vocab_small.txt')
model_config = pytorch_transformers.GPT2Config.from_json_file('config/model_config_small.json')
model = GPT2LMHeadModel(config=model_config).from_pretrained('model/final_model')
model.to(device)
model.eval()
if length == -1:
length = model.config.n_ctx // 2
elif length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
while True:
raw_text = '萧炎'
context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text))
generated = 0
for _ in range(nsamples // batch_size):
out = sample_sequence(
model=model, length=length,
context=context_tokens,
start_token=None,
batch_size=batch_size,
temperature=temperature, top_k=topk, device=device
)
out = out[:, len(context_tokens):].tolist()
for i in range(batch_size):
generated += 1
text = tokenizer.convert_ids_to_tokens(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(''.join(text))
print("=" * 80)
if __name__ == '__main__':
main()