Skip to content

Commit

Permalink
Merge pull request FlagAI-Open#165 from 920232796/master
Browse files Browse the repository at this point in the history
load model no network
  • Loading branch information
BAAI-OpenPlatform authored Dec 8, 2022
2 parents 7eceacb + 9926625 commit f139aaf
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions flagai/auto_model/auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,8 @@ def __init__(self,
download_path = os.path.join(model_dir, raw_model_name)
print("*" * 20, task_name, model_name)

model_id = _get_model_id(f"{raw_model_name}-{task_name}")
if model_id != 'null':
model_name_ = f"{raw_model_name}-{task_name}"
else:
model_name_ = raw_model_name
model_name_ = self.is_exist_finetuned_model(model_name, task_name)

self.model = getattr(LazyImport(self.model_name[0]),
self.model_name[1]).from_pretrain(
download_path=model_dir,
Expand Down Expand Up @@ -236,6 +233,19 @@ def __init__(self,
self.tokenizer = None
self.transform = None

def is_exist_finetuned_model(self, raw_model_name, task_name):
try:
model_id = _get_model_id(f"{raw_model_name}-{task_name}")
if model_id != 'null':
model_name_ = f"{raw_model_name}-{task_name}"
return model_name_
else :
return raw_model_name

except:
print("Model hub is not reachable.")
return raw_model_name

def get_task_name(self, brief_model_name):
all_model_task = list(ALL_TASK.keys())
model_tasks = [t for t in all_model_task if brief_model_name in t]
Expand Down

0 comments on commit f139aaf

Please sign in to comment.