Skip to content

Commit

Permalink
complete trainer RLHF config
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Sep 4, 2023
1 parent 60fd6e5 commit 395bfd4
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 89 deletions.
179 changes: 94 additions & 85 deletions EasyDel/rlhf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@
import pathlib
import pprint

import datasets
from fjutils import StreamingCheckpointer
from jax import numpy as jnp
from flax import linen as nn

import flax
import jax
from fjutils import optimizers

import collections
from typing import Union, Optional, List, Dict, OrderedDict, NamedTuple, Callable, Any, Sequence, assert_type
from typing import Union, Optional, Callable

from jax._src.maps import Mesh
from jax.sharding import Mesh
from jax.experimental.mesh_utils import create_device_mesh
from transformers import PretrainedConfig
from .utils import AVAILABLE_MODELS_FOR_RLHF, AVAILABLE_MODELS_CONFIG_FOR_RLHF
from .utils import AVAILABLE_MODELS_FOR_RLHF
from .reward import RewardModel
from .ppo import ActorCritic
from datasets import DatasetDict, Dataset, IterableDatasetDict, IterableDataset
import fjutils

import wandb


Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self,
super().__init__(**kwargs)
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.precision = precision if not isinstance(precision, str) else jax.lax.Precision(precision)
self.scheduler = scheduler
self.extra_optimizer_kwargs = extra_optimizer_kwargs
self.actor_lr = actor_lr
Expand Down Expand Up @@ -131,110 +131,119 @@ def __repr__(self):
def get_mesh_names():
return 'dp', 'fsdp', 'mp'

def get_optimizer_and_scheduler(self, steps=None):
steps = self.max_steps or steps
assert steps is not None, 'if you haven\'t pass max steps to init you should pass init in func'
@staticmethod
def get_optimizer_and_scheduler(
steps,
optimizer: str = 'adam',
scheduler: str = 'cosine',

learning_rate: float = 5e-5,
learning_rate_end: float = 6e-5,
gradient_accumulation_steps: int = 8,
weight_decay: float = 1e-2,
**kwargs
):

if self.optimizer == 'adafactor':
if self.scheduler == 'linear':
tx, sc = fjutils.optimizers.get_adafactor_with_linear_scheduler(
learning_rate_start=self.learning_rate,
learning_rate_end=self.learning_rate_end,
gradient_accumulation_steps=self.gradient_accumulation_steps,
if optimizer == 'adafactor':
if scheduler == 'linear':
tx, sc = optimizers.get_adafactor_with_linear_scheduler(
learning_rate_start=learning_rate,
learning_rate_end=learning_rate_end,
gradient_accumulation_steps=gradient_accumulation_steps,
steps=steps,
**self.extra_optimizer_kwargs
**kwargs
)
elif self.scheduler == 'cosine':
tx, sc = fjutils.optimizers.get_adafactor_with_cosine_scheduler(
learning_rate=self.learning_rate,
elif scheduler == 'cosine':
tx, sc = optimizers.get_adafactor_with_cosine_scheduler(
learning_rate=learning_rate,
steps=steps,
gradient_accumulation_steps=self.gradient_accumulation_steps,
weight_decay=self.weight_decay,
**self.extra_optimizer_kwargs
gradient_accumulation_steps=gradient_accumulation_steps,
weight_decay=weight_decay,
**kwargs
)
elif self.scheduler == 'none':
tx, sc = fjutils.optimizers.get_adafactor_with_linear_scheduler(
learning_rate_start=self.learning_rate,
learning_rate_end=self.learning_rate,
elif scheduler == 'none':
tx, sc = optimizers.get_adafactor_with_linear_scheduler(
learning_rate_start=learning_rate,
learning_rate_end=learning_rate,
steps=steps,
gradient_accumulation_steps=self.gradient_accumulation_steps,
**self.extra_optimizer_kwargs
gradient_accumulation_steps=gradient_accumulation_steps,
**kwargs
)
elif self.scheduler == 'warm_up_cosine':
tx, sc = fjutils.optimizers.get_adafactor_with_warm_up_cosine_scheduler(
learning_rate=self.learning_rate,
elif scheduler == 'warm_up_cosine':
tx, sc = optimizers.get_adafactor_with_warm_up_cosine_scheduler(
learning_rate=learning_rate,
steps=steps,
weight_decay=self.weight_decay,
gradient_accumulation_steps=self.gradient_accumulation_steps,
**self.extra_optimizer_kwargs
weight_decay=weight_decay,
gradient_accumulation_steps=gradient_accumulation_steps,
**kwargs
)
else:
raise ValueError('seems like you have choose wrong type or unavailable scheduler')
elif self.optimizer == 'lion':
if self.scheduler == 'linear':
tx, sc = fjutils.optimizers.get_lion_with_linear_scheduler(
learning_rate_start=self.learning_rate,
learning_rate_end=self.learning_rate_end,
elif optimizer == 'lion':
if scheduler == 'linear':
tx, sc = optimizers.get_lion_with_linear_scheduler(
learning_rate_start=learning_rate,
learning_rate_end=learning_rate_end,
steps=steps,
gradient_accumulation_steps=self.gradient_accumulation_steps,
**self.extra_optimizer_kwargs
gradient_accumulation_steps=gradient_accumulation_steps,
**kwargs
)
elif self.scheduler == 'cosine':
tx, sc = fjutils.optimizers.get_lion_with_cosine_scheduler(
learning_rate=self.learning_rate,
gradient_accumulation_steps=self.gradient_accumulation_steps,
elif scheduler == 'cosine':
tx, sc = optimizers.get_lion_with_cosine_scheduler(
learning_rate=learning_rate,
gradient_accumulation_steps=gradient_accumulation_steps,
steps=steps,
**self.extra_optimizer_kwargs
**kwargs
)
elif self.scheduler == 'none':
tx, sc = fjutils.optimizers.get_lion_with_linear_scheduler(
learning_rate_start=self.learning_rate,
learning_rate_end=self.learning_rate,
elif scheduler == 'none':
tx, sc = optimizers.get_lion_with_linear_scheduler(
learning_rate_start=learning_rate,
learning_rate_end=learning_rate,
steps=steps,
gradient_accumulation_steps=self.gradient_accumulation_steps,
**self.extra_optimizer_kwargs
gradient_accumulation_steps=gradient_accumulation_steps,
**kwargs
)
elif self.scheduler == 'warm_up_cosine':
tx, sc = fjutils.optimizers.get_lion_with_warm_up_cosine_scheduler(
learning_rate=self.learning_rate,
elif scheduler == 'warm_up_cosine':
tx, sc = optimizers.get_lion_with_warm_up_cosine_scheduler(
learning_rate=learning_rate,
steps=steps,
gradient_accumulation_steps=self.gradient_accumulation_steps,
**self.extra_optimizer_kwargs
gradient_accumulation_steps=gradient_accumulation_steps,
**kwargs
)
else:
raise ValueError('seems like you have choose wrong type or unavailable scheduler')
elif self.optimizer == 'adamw':
if self.scheduler == 'linear':
tx, sc = fjutils.optimizers.get_adamw_with_linear_scheduler(
learning_rate_start=self.learning_rate,
learning_rate_end=self.learning_rate_end,
elif optimizer == 'adamw':
if scheduler == 'linear':
tx, sc = optimizers.get_adamw_with_linear_scheduler(
learning_rate_start=learning_rate,
learning_rate_end=learning_rate_end,
steps=steps,
gradient_accumulation_steps=self.gradient_accumulation_steps,
**self.extra_optimizer_kwargs
gradient_accumulation_steps=gradient_accumulation_steps,
**kwargs
)
elif self.scheduler == 'cosine':
tx, sc = fjutils.optimizers.get_adamw_with_cosine_scheduler(
learning_rate=self.learning_rate,
gradient_accumulation_steps=self.gradient_accumulation_steps,
elif scheduler == 'cosine':
tx, sc = optimizers.get_adamw_with_cosine_scheduler(
learning_rate=learning_rate,
gradient_accumulation_steps=gradient_accumulation_steps,
steps=steps,
weight_decay=self.weight_decay,
**self.extra_optimizer_kwargs
weight_decay=weight_decay,
**kwargs
)
elif self.scheduler == 'none':
tx, sc = fjutils.optimizers.get_adamw_with_linear_scheduler(
learning_rate_start=self.learning_rate,
learning_rate_end=self.learning_rate,
gradient_accumulation_steps=self.gradient_accumulation_steps,
elif scheduler == 'none':
tx, sc = optimizers.get_adamw_with_linear_scheduler(
learning_rate_start=learning_rate,
learning_rate_end=learning_rate,
gradient_accumulation_steps=gradient_accumulation_steps,
steps=steps,
**self.extra_optimizer_kwargs
**kwargs
)
elif self.scheduler == 'warm_up_cosine':
tx, sc = fjutils.optimizers.get_adamw_with_warm_up_cosine_scheduler(
learning_rate=self.learning_rate,
elif scheduler == 'warm_up_cosine':
tx, sc = optimizers.get_adamw_with_warm_up_cosine_scheduler(
learning_rate=learning_rate,
steps=steps,
weight_decay=self.weight_decay,
gradient_accumulation_steps=self.gradient_accumulation_steps,
**self.extra_optimizer_kwargs
weight_decay=weight_decay,
gradient_accumulation_steps=gradient_accumulation_steps,
**kwargs
)
else:
raise ValueError('seems like you have choose wrong type or unavailable scheduler')
Expand Down Expand Up @@ -270,9 +279,9 @@ def setup(self) -> None:
model=self.model,
critic_model=self.critic_model,
pooled_values=False,
dtype=jnp.float32,
param_dtype=jnp.float32,
precision=jax.lax.Precision('fastest')
dtype=self.config.dtype,
param_dtype=self.config.param_dtype,
precision=self.config.precision

)

Expand Down
12 changes: 8 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
jax>=0.4.9
flax~=0.7.1
fjutils~=0.0.16
numpy
numpy~=1.25.2
typing~=3.7.4.3
transformers>=4.31.0
einops~=0.6.1
Expand All @@ -12,7 +12,11 @@ tqdm==4.65.0
datasets==2.14.3
setuptools~=68.0.0
torch>=2.0.1
fastapi
gradio
fastapi~=0.103.0
gradio~=3.41.2
distrax
rlax
rlax
EasyDeL~=0.0.30
wandb~=0.15.9
uvicorn~=0.23.2
pydantic~=2.3.0

0 comments on commit 395bfd4

Please sign in to comment.