Skip to content

Commit

Permalink
Merge pull request FlagAI-Open#424 from Anhforth/fix_opt
Browse files Browse the repository at this point in the history
Fix issue461
  • Loading branch information
ftgreat authored Jun 19, 2023
2 parents da66998 + 7cd8ace commit dc93209
Show file tree
Hide file tree
Showing 14 changed files with 39 additions and 8 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions


name: Python application

on:
Expand All @@ -19,10 +20,10 @@ jobs:

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
- name: Set up Python 3.9
uses: actions/setup-python@v3
with:
python-version: "3.8.8"
python-version: "3.9"
- name: Install dependencies
run: |
python -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
Expand Down
3 changes: 3 additions & 0 deletions examples/opt/generate_opt_1.3b.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
from flagai.model.predictor.predictor import Predictor
from flagai.auto_model.auto_loader import AutoLoader

Expand Down
3 changes: 3 additions & 0 deletions examples/opt/generate_opt_125m.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
from flagai.model.predictor.predictor import Predictor
from flagai.auto_model.auto_loader import AutoLoader
import torch
Expand Down
3 changes: 3 additions & 0 deletions examples/opt/generate_opt_13b.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
from flagai.model.predictor.predictor import Predictor
from flagai.auto_model.auto_loader import AutoLoader
import torch
Expand Down
3 changes: 3 additions & 0 deletions examples/opt/generate_opt_2.7b.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
from flagai.model.predictor.predictor import Predictor
from flagai.auto_model.auto_loader import AutoLoader

Expand Down
3 changes: 3 additions & 0 deletions examples/opt/generate_opt_30b.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
from flagai.model.predictor.predictor import Predictor
from flagai.auto_model.auto_loader import AutoLoader
import torch
Expand Down
3 changes: 3 additions & 0 deletions examples/opt/generate_opt_350m.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
from flagai.model.predictor.predictor import Predictor
from flagai.auto_model.auto_loader import AutoLoader
import torch
Expand Down
3 changes: 3 additions & 0 deletions examples/opt/generate_opt_6.7b.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
from flagai.model.predictor.predictor import Predictor
from flagai.auto_model.auto_loader import AutoLoader
import torch
Expand Down
3 changes: 3 additions & 0 deletions examples/opt/generate_opt_66b.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
from flagai.model.predictor.predictor import Predictor
from flagai.auto_model.auto_loader import AutoLoader
import torch
Expand Down
3 changes: 3 additions & 0 deletions examples/opt/opt_30b_en_mutigpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import torch
import os
import argparse
Expand Down
3 changes: 3 additions & 0 deletions examples/opt/opt_66b_en_mutigpu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,2"
import torch
import os
Expand Down
2 changes: 1 addition & 1 deletion flagai/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def init_from_json(cls, config_file='./config.json', **kwargs):
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = cls(change_json_to_cls(args), **kwargs)
torch.set_default_tensor_type(torch.FloatTensor)
else :
else:
model = cls(change_json_to_cls(args), **kwargs)

return model
Expand Down
4 changes: 1 addition & 3 deletions flagai/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,8 @@ class GPT2Model(BaseModel):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.config = config.json_config

# TODO Global Config
if type(config) is dict:
if type(self.config) is dict:
init_method_std = self.config.get("initializer_range", 0.002)
init_method = unscaled_init_method(init_method_std)
output_layer_init_method = None
Expand Down Expand Up @@ -294,7 +293,6 @@ def __init__(self, config, **kwargs):
self.config_gpt = config_gpt

self.parallel_output = True

self.transformer = GPT2Stack(self.config_gpt)
self.lm_head = nn.Linear(config_gpt.n_embd,
config_gpt.vocab_size,
Expand Down
6 changes: 4 additions & 2 deletions flagai/model/opt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flagai.model.layers.activations import ACT2FN
from flagai.model.gpt2_model import GPT2Model, GPT2Stack, GPT2Config
from torch.utils.checkpoint import checkpoint
from flagai.model.base_model import change_json_to_cls

OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/opt-125m",
Expand Down Expand Up @@ -96,15 +97,16 @@ def trans_opt_to_gpt_config(opt_config_json):
"word_embed_proj_dim": "n_embd",
"do_layer_norm_before": "do_layer_norm_before",
}
for k, v in opt_config_json.items():
for k, v in opt_config_json.json_config.items():
if k in trans_key:
trans_config_json[trans_key[k]] = v

return trans_config_json
return change_json_to_cls(trans_config_json)

class OPTModel(GPT2Model):

def __init__(self, config, **kwargs):

config = trans_opt_to_gpt_config(config)
super(OPTModel, self).__init__(config, **kwargs)
self.transformer = OPTStack(self.config_gpt)
Expand Down

0 comments on commit dc93209

Please sign in to comment.