Skip to content

Commit

Permalink
update generate.py
Browse files Browse the repository at this point in the history
Signed-off-by: ftgreat <[email protected]>
  • Loading branch information
ftgreat committed Jun 13, 2023
1 parent 6244f96 commit a421f03
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 18 deletions.
2 changes: 1 addition & 1 deletion examples/Aquila/Aquila-chat/generate_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
model_dir=state_dict,
model_name=model_name,
use_cache=True,
use_fp16=True)
fp16=True)
model = loader.get_model()
tokenizer = loader.get_tokenizer()
cache_dir = os.path.join(state_dict, model_name)
Expand Down
3 changes: 2 additions & 1 deletion examples/Aquila/Aquila-code/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import os
from flagai import mpu
import sys;sys.path.append("/data2/yzd/workspace/FlagAI")
import sys
from flagai.auto_model.auto_loader import AutoLoader
import random
import numpy as np
Expand All @@ -18,6 +18,7 @@
print(f"building model...")
loader = AutoLoader("lm", model_name="aquilacode-7b-nv",
use_cache=True,
fp16=True,
model_dir=model_dir)

model = loader.get_model()
Expand Down
12 changes: 5 additions & 7 deletions examples/Aquila/Aquila-pretrain/generate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
from flagai.data.tokenizer import Tokenizer
from flagai.model.predictor.predictor import Predictor
import bminf

state_dict = "./checkpoints_in/"
Expand All @@ -13,21 +13,19 @@
model_dir=state_dict,
model_name=model_name,
use_cache=True,
use_fp16=True)
fp16=True)
model = loader.get_model()
tokenizer = loader.get_tokenizer()

model.eval()
model.half()

model.cuda()

predictor = Predictor(model, tokenizer)


texts = [
"汽车EDR是什么",
]

predictor = Predictor(model, tokenizer)

for text in texts:
print('-'*80)
text = f'{text}'
Expand Down
9 changes: 0 additions & 9 deletions flagai/model/predictor/aquila.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,6 @@ def aquila_generate(
return decoded[0]


def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
Expand Down

0 comments on commit a421f03

Please sign in to comment.