Skip to content

Commit

Permalink
fix: map eos,bos hopefully stops
Browse files Browse the repository at this point in the history
  • Loading branch information
zanussbaum committed Mar 25, 2023
1 parent f51c5c8 commit cb43f53
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def generate(tokenizer, prompt, model, config):

outputs = model.generate(input_ids=input_ids, max_new_tokens=config["max_new_tokens"], temperature=config["temperature"])

print(outputs)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

return decoded[len(prompt):]
Expand All @@ -19,6 +20,7 @@ def generate(tokenizer, prompt, model, config):
def setup_model(config):
model = AutoModelForCausalLM.from_pretrained(config["model_name"], device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>"})

if config["lora"]:
model = PeftModelForCausalLM.from_pretrained(model, config["lora_path"], device_map="auto", torch_dtype=torch.float16)
Expand All @@ -33,17 +35,22 @@ def setup_model(config):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--prompt", type=str)

args = parser.parse_args()

config = read_config(args.config)

print("setting up model")
if config["prompt"] is None and args.prompt is None:
raise ValueError("Prompt is required either in config or as argument")

prompt = config["prompt"] if args.prompt is None else args.prompt

print("Setting up model")
model, tokenizer = setup_model(config)

print("generating")
print("Generating")
start = time.time()
generation = generate(tokenizer, args.prompt, model, config)
print(f"done in {time.time() - start:.2f}s")
generation = generate(tokenizer, prompt, model, config)
print(f"Done in {time.time() - start:.2f}s")
print(generation)

0 comments on commit cb43f53

Please sign in to comment.