Skip to content

Commit

Permalink
support qwen1.5-1.8b
Browse files Browse the repository at this point in the history
  • Loading branch information
Isaachhh committed Mar 16, 2024
1 parent 4f630f3 commit f0b85e6
Show file tree
Hide file tree
Showing 9 changed files with 2,235 additions and 2 deletions.
1 change: 1 addition & 0 deletions bunny/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .language_model.bunny_phi import BunnyPhiForCausalLM, BunnyPhiConfig
from .language_model.bunny_stablelm import BunnyStableLMForCausalLM, BunnyStableLMConfig
from .language_model.bunny_qwen import BunnyQwenForCausalLM, BunnyQwenConfig
13 changes: 12 additions & 1 deletion bunny/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def load_pretrained_model(model_path, model_base, model_name, model_type, load_8bit=False, load_4bit=False,
device_map="auto", device="cuda", **kwargs):
if model_type not in {'phi-1.5', 'phi-2', 'stablelm-2'}:
if model_type not in {'phi-1.5', 'phi-2', 'stablelm-2', 'qwen1.5-1.8b'}:
raise ValueError(f"Unknown Model Type {model_type}")

kwargs = {"device_map": device_map, **kwargs}
Expand Down Expand Up @@ -48,6 +48,10 @@ def load_pretrained_model(model_path, model_base, model_name, model_type, load_8
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True, trust_remote_code=True)
model = BunnyStableLMForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
config=lora_cfg_pretrained, **kwargs)
elif model_type == 'qwen1.5-1.8b':
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
model = BunnyQwenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained,
**kwargs)

token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
if model.lm_head.weight.shape[0] != token_num:
Expand Down Expand Up @@ -97,6 +101,10 @@ def load_from_hf(repo_id, filename, subfolder=None):
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True, trust_remote_code=True)
model = BunnyStableLMForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
config=cfg_pretrained, **kwargs)
elif model_type == 'qwen1.5-1.8b':
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
model = BunnyQwenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
**kwargs)

mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
Expand All @@ -108,6 +116,9 @@ def load_from_hf(repo_id, filename, subfolder=None):
elif model_type == 'stablelm-2':
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
model = BunnyStableLMForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
elif model_type == 'qwen1.5-1.8b':
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = BunnyQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)

model.resize_token_embeddings(len(tokenizer))

Expand Down
100 changes: 100 additions & 0 deletions bunny/model/language_model/bunny_qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM

from .qwen2 import Qwen2Model, Qwen2Config, Qwen2ForCausalLM

from transformers.modeling_outputs import CausalLMOutputWithPast

from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM


class BunnyQwenConfig(Qwen2Config):
model_type = "bunny-qwen"


class BunnyQwenModel(BunnyMetaModel, Qwen2Model):
config_class = BunnyQwenConfig

def __init__(self, config: Qwen2Config):
super(BunnyQwenModel, self).__init__(config)


class BunnyQwenForCausalLM(Qwen2ForCausalLM, BunnyMetaForCausalLM):
config_class = BunnyQwenConfig

def __init__(self, config):
super(Qwen2ForCausalLM, self).__init__(config)
self.model = BunnyQwenModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_model(self):
return self.model

def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:

if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images
)

return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
**kwargs):
images = kwargs.pop("images", None)

_inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
**kwargs
)

if images is not None:
_inputs['images'] = images
return _inputs


AutoConfig.register("bunny-qwen", BunnyQwenConfig)
AutoModelForCausalLM.register(BunnyQwenConfig, BunnyQwenForCausalLM)
80 changes: 80 additions & 0 deletions bunny/model/language_model/qwen2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from transformers.utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tokenizers_available,
is_torch_available,
)


_import_structure = {
"configuration_qwen2": ["QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Qwen2Config"],
"tokenization_qwen2": ["Qwen2Tokenizer"],
}

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_qwen2"] = [
"Qwen2ForCausalLM",
"Qwen2Model",
"Qwen2PreTrainedModel",
"Qwen2ForSequenceClassification",
]


if TYPE_CHECKING:
from .configuration_qwen2 import QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP, Qwen2Config
from .tokenization_qwen2 import Qwen2Tokenizer

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_qwen2_fast import Qwen2TokenizerFast

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_qwen2 import (
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
Qwen2PreTrainedModel,
)


else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
144 changes: 144 additions & 0 deletions bunny/model/language_model/qwen2/configuration_qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Qwen2 model configuration"""

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)

QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json",
}


class Qwen2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen2Model, Qwen2Config
>>> # Initializing a Qwen2 style configuration
>>> configuration = Qwen2Config()
>>> # Initializing a model from the Qwen2-7B style configuration
>>> model = Qwen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "qwen2"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout

super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading

0 comments on commit f0b85e6

Please sign in to comment.