-
Notifications
You must be signed in to change notification settings - Fork 17
/
ptq.py
81 lines (69 loc) · 2.64 KB
/
ptq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# coding=utf-8
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import datetime
from logging import Logger
import torch
import torch.distributed as dist
from transformers import LlamaTokenizerFast
import transformers
from eval_utils.main import ptq_model
from eval_utils.modeling_llama import LlamaForCausalLM
from utils import data_utils, eval_utils, utils
from utils.process_args import process_args_ptq
log: Logger = utils.get_logger("spinquant")
def train() -> None:
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8))
model_args, training_args, ptq_args = process_args_ptq()
local_rank = utils.get_local_rank()
log.info("the rank is {}".format(local_rank))
torch.distributed.barrier()
config = transformers.AutoConfig.from_pretrained(
model_args.input_model, token=model_args.access_token
)
# Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens
process_word_embeddings = False
if config.tie_word_embeddings:
config.tie_word_embeddings = False
process_word_embeddings = True
dtype = torch.bfloat16 if training_args.bf16 else torch.float16
model = LlamaForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_args.input_model,
config=config,
torch_dtype=dtype,
token=model_args.access_token,
)
if process_word_embeddings:
model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
model.cuda()
model = ptq_model(ptq_args, model, model_args)
model.seqlen = training_args.model_max_length
if local_rank == 0:
log.info("Model PTQ completed {}".format(model))
log.info("Start to load tokenizer...")
tokenizer = LlamaTokenizerFast.from_pretrained(
pretrained_model_name_or_path=model_args.input_model,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=True,
add_eos_token=False,
add_bos_token=False,
token=model_args.access_token,
)
log.info("Complete tokenizer loading...")
model.config.use_cache = False
testloader = data_utils.get_wikitext2(
seed=ptq_args.seed,
seqlen=2048,
tokenizer=tokenizer,
eval_mode=True,
)
dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, ptq_args)
log.info("wiki2 ppl is: {}".format(dataset_ppl))
dist.barrier()
if __name__ == "__main__":
train()