Skip to content

Commit

Permalink
Merge pull request FlagAI-Open#194 from Anhforth/master
Browse files Browse the repository at this point in the history
localize safety checker
  • Loading branch information
ftgreat authored Jan 6, 2023
2 parents d095b87 + 11199af commit f0ee4a4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/gpt2_text_writting/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
if __name__ == '__main__':
loader = AutoLoader("seq2seq",
"GPT2-base-ch",
model_dir="./state_dict/")
model_dir="./checkpoints/")
model = loader.get_model()
tokenizer = loader.get_tokenizer()
predictor = Predictor(model, tokenizer)
Expand Down
28 changes: 24 additions & 4 deletions flagai/model/predictor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,38 @@
from itertools import islice
from transformers import AutoFeatureExtractor
import math

from flagai.model.file_utils import _get_model_id, _get_checkpoint_path, _get_vocab_path, _get_model_files
join = os.path.join

def download(model_name, download_path):
try:
model_id = _get_model_id(model_name)
except:
print("Model hub is not reachable!")
# prepare the download path
# downloading the files
if model_id and model_id != "null":
model_files = eval(_get_model_files(model_name))
print("model files:" + str(model_files))
for file_name in model_files:
if not file_name.endswith("bin"):
_get_vocab_path(os.path.join(download_path, model_name), file_name, model_id)
else :
_get_checkpoint_path(os.path.join(download_path, model_name), file_name, model_id)
return


def get_safety_checker():
# load safety model
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
safety_model_id = "CompVis/stable-diffusion-safety-checker"
path = os.getcwd() + "/checkpoints/"
if not os.path.exists(path+"SafetyChecker"):
download("SafetyChecker", path)
# safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
safety_model_id)
path+"SafetyChecker")
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id)
path+"SafetyChecker")
return safety_checker, safety_feature_extractor


Expand Down

0 comments on commit f0ee4a4

Please sign in to comment.