Skip to content

Commit

Permalink
update training code
Browse files Browse the repository at this point in the history
  • Loading branch information
huyva2 committed Jul 16, 2024
1 parent ba0d86b commit b440b24
Showing 19 changed files with 1,270 additions and 28 deletions.
35 changes: 35 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Setup SEMIKONG

This documentation dedicated to instruct on how to setup the environment for training, evaluation and inference SEMIKONG model.

## Requirement Hardware
- CUDA Version: >= 10.x (ideally 11.x)
~~~
1. SEMIKONG 8B Chat
- CPU: Expected to be around 4 cores
- GPU: Any NVIDIA GPU model with at least 16GB VRAM (A100, A30, RTX 3090, etc.)
- Disk Memory: At least 10GB disk space
- RAM Memory: At least 16GB
2. SEMIKONG 70B Chat
- CPU: Expected to be around 8 cores
- GPU: Any NVIDIA GPU model with at least 150GB VRAM. Recommend high-end GPU such as A100 or H100 or > RTX 3000
- Disk Memory: At least 20GB disk space
- RAM Memory: At least 64GB
~~~

## Environment Setup

- Using `conda` or `poetry` or `venv` to setup the virtual environment
~~~
conda create --name semikong-env python=3.11
conda activate semikong-env
pip install -r requirements.txt
~~~

## Training
__TBA__

## Inference
1. Using OpenAI Client
2. Using vLLM API
230 changes: 209 additions & 21 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,21 +1,209 @@
transformers>=4.41.2
datasets>=2.16.0
accelerate>=0.30.1
peft>=0.11.1
trl>=0.8.6
gradio>=4.0.0
pandas>=2.0.0
scipy
einops
sentencepiece
tiktoken
protobuf
uvicorn
pydantic
fastapi
sse-starlette
matplotlib>=3.7.0
fire
packaging
pyyaml
numpy<2.0.0
accelerate==0.31.0
aiofiles==23.2.1
aiohttp==3.9.5
aiosignal==1.3.1
altair==5.3.0
annotated-types==0.7.0
anyio==4.4.0
astroid==3.2.2
async-lru==2.0.4
attrs==23.2.0
backports.tarfile==1.2.0
bitsandbytes==0.43.1
black==24.4.2
certifi==2024.6.2
cffi==1.16.0
cfgv==3.4.0
charset-normalizer==3.3.2
click==8.1.7
cmake==3.29.5.1
coloredlogs==15.0.1
contourpy==1.2.1
cryptography==42.0.8
cycler==0.12.1
dataclasses-json==0.6.7
datasets==2.20.0
dill==0.3.8
diskcache==5.6.3
distlib==0.3.8
distro==1.9.0
docstring_parser==0.16
einops==0.8.0
fastapi==0.111.0
fastapi-cli==0.0.4
ffmpy==0.3.2
filelock==3.14.0
fire==0.6.0
flash_attn==2.5.9.post1
fonttools==4.53.0
frozenlist==1.4.1
fsspec==2024.3.1
gekko==1.1.3
gitdb==4.0.11
GitPython==3.1.43
gradio==4.36.1
gradio_client==1.0.1
greenlet==3.0.3
h11==0.14.0
hjson==3.1.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.23.3
humanfriendly==10.0
identify==2.5.36
idna==3.7
importlib_metadata==7.1.0
importlib_resources==6.4.0
interegular==0.3.3
isodate==0.6.1
isort==5.13.2
jaraco.classes==3.4.0
jaraco.context==5.3.0
jaraco.functools==4.0.1
jeepney==0.8.0
jieba==0.42.1
Jinja2==3.1.4
joblib==1.4.2
jsonpatch==1.33
jsonpointer==3.0.0
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
keyring==25.2.1
kiwisolver==1.4.5
langchain==0.2.5
langchain-community==0.2.5
langchain-core==0.2.9
langchain-text-splitters==0.2.1
langsmith==0.1.80
lark==1.1.9
llvmlite==0.43.0
lm-format-enforcer==0.10.1
loguru==0.7.2
markdown-it-py==3.0.0
MarkupSafe==2.1.5
marshmallow==3.21.3
matplotlib==3.9.0
mccabe==0.7.0
mdurl==0.1.2
more-itertools==10.3.0
mpmath==1.3.0
msal==1.28.1
msal-extensions==1.1.0
msgpack==1.0.8
multidict==6.0.5
multiprocess==0.70.16
mypy-extensions==1.0.0
nest-asyncio==1.6.0
networkx==3.3
ninja==1.11.1.1
nltk==3.8.1
nodeenv==1.9.1
numba==0.60.0
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-ml-py==12.555.43
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.40
nvidia-nvtx-cu12==12.1.105
openai==1.33.0
optimum==1.20.0
orjson==3.10.4
outlines==0.0.45
packaging==24.1
pandas==2.2.2
pathspec==0.12.1
peft==0.11.1
pillow==10.3.0
platformdirs==4.2.2
portalocker==2.8.2
pre-commit==3.7.1
prometheus-fastapi-instrumentator==7.0.0
prometheus_client==0.20.0
protobuf==5.27.1
psutil==5.9.8
py-cpuinfo==9.0.0
pyairports==2.1.1
pyarrow==16.1.0
pyarrow-hotfix==0.6
pycountry==24.6.1
pycparser==2.22
pydantic==2.7.3
pydantic_core==2.18.4
pydub==0.25.1
Pygments==2.18.0
PyJWT==2.8.0
pylint==3.2.3
pynvml==11.5.0
pyparsing==3.1.2
PyPDF2==3.0.1
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-multipart==0.0.9
pytz==2024.1
PyYAML==6.0.1
ray==2.24.0
referencing==0.35.1
regex==2024.5.15
requests==2.32.3
rich==13.7.1
rouge==1.0.1
rouge-chinese==1.0.3
rpds-py==0.18.1
ruff==0.4.8
safetensors==0.4.3
scipy==1.13.1
SecretStorage==3.3.3
semantic-version==2.10.0
sentencepiece==0.2.0
sentry-sdk==2.5.1
setproctitle==1.3.3
shellingham==1.5.4
shtab==1.7.1
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
SQLAlchemy==2.0.31
sse-starlette==2.1.0
starlette==0.37.2
sympy==1.12.1
tenacity==8.4.1
termcolor==2.4.0
tiktoken==0.7.0
tokenizers==0.19.1
tomlkit==0.12.0
toolz==0.12.1
torch==2.3.0
torchaudio==2.3.1
torchvision==0.18.1
tqdm==4.66.4
transformers==4.42.3
triton==2.3.0
trl==0.9.4
typer==0.12.3
typing-inspect==0.9.0
typing_extensions==4.12.2
tyro==0.8.4
tzdata==2024.1
ujson==5.10.0
urllib3==2.2.1
uvicorn==0.30.1
uvloop==0.19.0
virtualenv==20.26.2
vllm==0.5.0.post1
vllm-flash-attn==2.5.9
wandb==0.17.1
watchfiles==0.22.0
websockets==11.0.3
xformers==0.0.26.post1
xxhash==3.4.1
yarl==1.9.4
zipp==3.19.2
69 changes: 62 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,64 @@
import os
import re

from setuptools import find_packages, setup

setup(
name='semikong_finetune',
packages=find_packages(include=['ultils']),
version='0.1.0',
description='Fine tuning Large Language Model for semiconductor industry',
author='aic'
)

def get_version():
with open(os.path.join("src", "semikong", "extras", "env.py"), "r", encoding="utf-8") as f:
file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content)
return version


def get_requires():
with open("requirements.txt", "r", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines


extra_require = {
"torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"vllm": ["vllm>=0.4.3"],
"modelscope": ["modelscope"],
"dev": ["ruff", "pytest"],
}


def main():
setup(
name="llamafactory",
version=get_version(),
author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords=["SEMIKONG", "Llama", "LLM", "transformer", "pytorch", "deep learning"],
license="Apache 2.0 License",
url="https://github.com/aitomatic/semikong",
package_dir={"": "src"},
packages=find_packages("src"),
python_requires=">=3.10.0",
install_requires=get_requires(),
extras_require=extra_require,
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions src/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from configs.peft import lora_config, llama_adapter_config, prefix_config
from configs.fsdp import fsdp_config
from configs.training import train_config
from configs.wandb import wandb_config
from configs.quantization import quantization_configs
16 changes: 16 additions & 0 deletions src/configs/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass

@dataclass
class alpaca_dataset:
dataset: str = "semikong_dataset"
train_split: str = "train"
test_split: str = "val"
data_path: str = "src/llama_recipes/datasets/alpaca_data.json"


@dataclass
class custom_dataset:
dataset: str = "custom_dataset"
file: str = "custom_dataset.py"
train_split: str = "train"
test_split: str = "validation"
18 changes: 18 additions & 0 deletions src/configs/fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass

from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

@dataclass
class fsdp_config:
mixed_precision: bool=True
use_fp16: bool=False
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD # HYBRID_SHARD "Full Shard within a node DDP cross Nodes", SHARD_GRAD_OP "Shard only Gradients and Optimizer States", NO_SHARD "Similar to DDP".
hsdp : bool =False # Require HYBRID_SHARD to be set. This flag can extend the HYBRID_SHARD by allowing sharding a model on customized number of GPUs (Sharding_group) and Replicas over Sharding_group.
sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model.
replica_group_size: int=0 #requires hsdp to be set. This specifies the replica group size, which is world_size/sharding_group_size.
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool=True
fsdp_cpu_offload: bool=False
pure_bf16: bool = False
optimizer: str= "AdamW"
Loading

0 comments on commit b440b24

Please sign in to comment.