Skip to content

Commit

Permalink
Update finetune.py
Browse files Browse the repository at this point in the history
  • Loading branch information
pratyushmaini authored Feb 22, 2024
1 parent fe281c1 commit 164167d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def main(cfg):
)

model = AutoModelForCausalLM.from_pretrained(model_id, use_flash_attention_2=model_cfg["flash_attention2"]=="true", torch_dtype=torch.bfloat16, trust_remote_code = True)

# Hot fix for https://discuss.huggingface.co/t/help-with-llama-2-finetuning-setup/50035
model.generation_config.do_sample = True

if model_cfg["gradient_checkpointing"] == "true":
model.gradient_checkpointing_enable()

Expand Down Expand Up @@ -127,7 +131,7 @@ def main(cfg):
if cfg.LoRA.r != 0:
model = model.merge_and_unload()


model.save_pretrained(cfg.save_dir)
tokenizer.save_pretrained(cfg.save_dir)

Expand Down

0 comments on commit 164167d

Please sign in to comment.