Skip to content

Commit

Permalink
Fix training.
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu committed Feb 1, 2024
1 parent c878cc3 commit 879ef31
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 32 deletions.
39 changes: 15 additions & 24 deletions llava/train/llava_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import torch
import torch.nn as nn

from torch.utils.data import Sampler

Expand All @@ -9,7 +10,6 @@
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
ShardedDDPOption,
logger,
)
from typing import List, Optional
Expand Down Expand Up @@ -156,8 +156,6 @@ def create_optimizer(self):
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
return super().create_optimizer()

opt_model = self.model

Expand Down Expand Up @@ -212,27 +210,20 @@ def create_optimizer(self):

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")

return self.optimizer

Expand Down
48 changes: 42 additions & 6 deletions llava/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch

import transformers
import tokenizers

from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from torch.utils.data import Dataset
Expand All @@ -45,6 +46,10 @@ def rank0_print(*args):
print(*args)


from packaging import version
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')


@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
Expand All @@ -57,6 +62,7 @@ class ModelArguments:
mm_projector_type: Optional[str] = field(default='linear')
mm_use_im_start_end: bool = field(default=False)
mm_use_im_patch_token: bool = field(default=True)
mm_patch_merge_type: Optional[str] = field(default='flat')
mm_vision_select_feature: Optional[str] = field(default="patch")


Expand Down Expand Up @@ -468,6 +474,10 @@ def preprocess_v1(
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2

if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
round_len -= 1
instruction_len -= 1

target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

cur_len += round_len
Expand All @@ -490,6 +500,7 @@ def preprocess_v1(
def preprocess_mpt(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
Expand All @@ -509,7 +520,18 @@ def preprocess_mpt(
conversations.append(conv.get_prompt())

# Tokenize conversations
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)

if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids

targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT

Expand All @@ -532,8 +554,18 @@ def preprocess_mpt(
if len(parts) != 2:
break
parts[0] += sep
round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))

if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 1

if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
round_len += 1
instruction_len += 1

target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

cur_len += round_len
Expand Down Expand Up @@ -594,7 +626,7 @@ def preprocess(
if conversation_lib.default_conversation.version.startswith("v1"):
return preprocess_v1(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "mpt":
return preprocess_mpt(sources, tokenizer)
return preprocess_mpt(sources, tokenizer, has_image=has_image)
# add end signal and concatenate together
conversations = []
for source in sources:
Expand Down Expand Up @@ -753,7 +785,7 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_collator=data_collator)


def train():
def train(attn_implementation=None):
global local_rank

parser = transformers.HfArgumentParser(
Expand Down Expand Up @@ -785,7 +817,7 @@ def train():
if 'mpt' in model_args.model_name_or_path:
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
config.attn_config['attn_impl'] = training_args.mpt_attn_impl
model = LlavaMPTForCausalLM.from_pretrained(
model = LlavaMptForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
Expand All @@ -795,12 +827,16 @@ def train():
model = LlavaLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
else:
model = transformers.LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
model.config.use_cache = False
Expand Down
2 changes: 1 addition & 1 deletion llava/train/train_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from llava.train.train import train

if __name__ == "__main__":
train()
train(attn_implementation="flash_attention_2")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "llava"
version = "1.1.3"
version = "1.2.0"
description = "Towards GPT-4 like large language and visual assistant."
readme = "README.md"
requires-python = ">=3.8"
Expand Down

0 comments on commit 879ef31

Please sign in to comment.