Skip to content

Commit

Permalink
updated loading code
Browse files Browse the repository at this point in the history
Signed-off-by: ftgreat <[email protected]>
  • Loading branch information
ftgreat committed Jun 30, 2023
1 parent 1089616 commit feaef08
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
19 changes: 11 additions & 8 deletions examples/Aquila/Aquila-chat/aquila_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,29 @@
flush=True)

checkpoints = env_args.pre_load_dir

model_name = env_args.model_name

print('*' * 20, "model_name", model_name, flush=True)

cache_dir = os.path.join(checkpoints, model_name)
print('*' * 20, "cache_dir", cache_dir)
tokenizer = Tokenizer.from_pretrained(model_name, cache_dir=cache_dir)
print('*' * 20, "tokenizer", tokenizer)

# avoid sync loading models in case of Mem OOM
if env_args.bmt_async_load:
import time
time.sleep(10 * 60 * (trainer.local_rank % 4))

config_file = os.path.join(cache_dir, 'config.json')
from flagai.model.aquila_model import AQUILAModel

model = AQUILAModel.init_from_json(config_file=config_file)
# print('*'*20, "model", model)
# from flagai.model.aquila_model import AQUILAModel

loader = AutoLoader("lm",
model_dir='./checkpoints_in/',
model_name=model_name,
use_cache=False,
fp16=True)
model = loader.get_model()
tokenizer = loader.get_tokenizer()
# print('*' * 20, "tokenizer", tokenizer)
print('*'*20, "model", model)

#lora
if env_args.lora:
Expand Down
2 changes: 1 addition & 1 deletion examples/Aquila/Aquila-chat/hostfile
Original file line number Diff line number Diff line change
@@ -1 +1 @@
192.168.21.4 slots=4
192.168.21.2 slots=4
2 changes: 1 addition & 1 deletion flagai/env_trainer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def do_train(self,
save_optim=self.save_optim,
save_dir=self.save_dir,
save_rng=self.save_rng,
iteration_in_epoch=iteration_)
iteration_in_epoch=99999999999)

# Evaluation #todo add train_args
if ((self.epochs == 0) or (self.eval_interval and
Expand Down

0 comments on commit feaef08

Please sign in to comment.