Skip to content

Commit

Permalink
support transformers-lib loading (#134)
Browse files Browse the repository at this point in the history
* support transformers-lib loading

* ppl script
  • Loading branch information
wejoncy authored Aug 26, 2024
1 parent db0430b commit 4c67f07
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 55 deletions.
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ python setup.py install
## Quantize llama2
```bash
# Quantize and Save compressed model, method can be one of [gptq/awq/hqq]
python -m qllm --model=meta-llama/Llama-2-7b-hf --method=gptq --nsamples=64 --wbits=4 --groupsize=128 --save ./Llama-2-7b-4bit
python -m qllm --model=meta-llama/Llama-2-7b-hf --method=awq --dataset=pileval --nsamples=16 --wbits=4 --groupsize=128 --save ./Llama-2-7b-4bit
python -m qllm --model=meta-llama/Llama-2-7b-hf --method=hqq --wbits=4 --groupsize=128 --save ./Llama-2-7b-4bit
python -m qllm --model=meta-llama/Llama-2-7b-hf --quant_method=gptq --nsamples=64 --wbits=4 --groupsize=128 --save ./Llama-2-7b-4bit
python -m qllm --model=meta-llama/Llama-2-7b-hf --quant_method=awq --dataset=pileval --nsamples=16 --wbits=4 --groupsize=128 --save ./Llama-2-7b-4bit
python -m qllm --model=meta-llama/Llama-2-7b-hf --quant_method=hqq --wbits=4 --groupsize=128 --save ./Llama-2-7b-4bit
```

## Convert to onnx model
use `--export_onnx ./onnx_model` to export and save onnx model
```
python -m qllm --model meta-llama/Llama-2-7b-chat-hf --method=gptq --dataset=pileval --nsamples=16 --save ./Llama-2-7b-chat-hf_awq_q4/ --export_onnx ./Llama-2-7b-chat-hf_awq_q4_onnx/
python -m qllm --model meta-llama/Llama-2-7b-chat-hf --quant_method=gptq --dataset=pileval --nsamples=16 --save ./Llama-2-7b-chat-hf_awq_q4/ --export_onnx ./Llama-2-7b-chat-hf_awq_q4_onnx/
```
or you can convert a existing model in HF Hub
```
Expand All @@ -85,7 +85,7 @@ python -m qllm --load TheBloke/Llama-2-7B-Chat-GPTQ --export_onnx=./onnx
## (NEW) Quantize model with mix bits/groupsize for higher precision (PPL)
```bash
# Quantize and Save compressed model
python -m qllm --model=meta-llama/Llama-2-7b-hf --method=gptq --save ./Llama-2-7b-4bit --allow_mix_bits --true-sequential
python -m qllm --model=meta-llama/Llama-2-7b-hf --quant_method=gptq --save ./Llama-2-7b-4bit --allow_mix_bits --true-sequential
```
### NOTE:
1. only support GPTQ
Expand All @@ -97,15 +97,15 @@ python -m qllm --model=meta-llama/Llama-2-7b-hf --method=gptq --save ./Llama-2-7
## Quantize model for vLLM
Due to the zereos diff, we need to set a env variable if you set pack_mode to GPTQ whenver the method is awq or gptq
```bash
COMPATIBLE_WITH_AUTOGPTQ=1 python -m qllm --model=meta-llama/Llama-2-7b-hf --method=gptq --save ./Llama-2-7b-4bit --pack_mode=GPTQ
COMPATIBLE_WITH_AUTOGPTQ=1 python -m qllm --model=meta-llama/Llama-2-7b-hf --quant_method=gptq --save ./Llama-2-7b-4bit --pack_mode=GPTQ
```
If you use GEMM pack_mode, then you don't have to set the var
```bash
python -m qllm --model=meta-llama/Llama-2-7b-hf --method=gptq --save ./Llama-2-7b-4bit --pack_mode=GEMM
python -m qllm --model=meta-llama/Llama-2-7b-hf --quant_method=gptq --save ./Llama-2-7b-4bit --pack_mode=GEMM
```

```bash
python -m qllm --model=meta-llama/Llama-2-7b-hf --method=awq --save ./Llama-2-7b-4bit --pack_mode=GEMM
python -m qllm --model=meta-llama/Llama-2-7b-hf --quant_method=awq --save ./Llama-2-7b-4bit --pack_mode=GEMM
```
## Conversion among AWQ, GPTQ and MarLin
```bash
Expand Down Expand Up @@ -174,10 +174,10 @@ pip install fschat accelerate
use `--use_plugin` to enable a chatbot plugin

```
python -m qllm --model meta-llama/Llama-2-7b-chat-hf --method=awq --dataset=pileval --nsamples=16 --use_plugin --save ./Llama-2-7b-chat-hf_awq_q4/
python -m qllm --model meta-llama/Llama-2-7b-chat-hf --quant_method=awq --dataset=pileval --nsamples=16 --use_plugin --save ./Llama-2-7b-chat-hf_awq_q4/
or
python -m qllm --model meta-llama/Llama-2-7b-chat-hf --method=gptq --dataset=pileval --nsamples=16 --use_plugin --save ./Llama-2-7b-chat-hf_gptq_q4/
python -m qllm --model meta-llama/Llama-2-7b-chat-hf --quant_method=gptq --dataset=pileval --nsamples=16 --use_plugin --save ./Llama-2-7b-chat-hf_gptq_q4/
```

## use QLLM with API
Expand Down
2 changes: 1 addition & 1 deletion qllm/args_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
class FakeArgs:
def __init__(self, **entries):
self.method = "gptq"
self.quant_method = "gptq"
self.dataset = "wikitext2"
self.seed = 0
self.nsamples = 128
Expand Down
27 changes: 19 additions & 8 deletions qllm/auto_model_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ def eval_model(self, model, pack_mode, dev):
"compared with awq, gptq is", return_tensors="pt").to(model.device)
out = model.generate(**inputs, max_length=50)

# from .plugin import perplexity_utils
# ppl = perplexity_utils.Perplexity(
# model,
# self.tokenizer,
# "wikitext",
# None,
# "test",
# "text",
# )
# ppl.calculate_perplexity(512, 512)

model.to('cpu')
print(self.tokenizer.decode(out[0]))

Expand All @@ -74,10 +85,10 @@ def pack_model(self, model, quantizers, pack_mode):
quant_config_by_layer = {key: {
"wbits": value[-2], "groupsize": value[-1]} for key, value in quantizers.items()}
meta_info = model.quant_config.to_meta
wbits, method = meta_info.bits, meta_info.method
quant_config_by_layer["method"] = model.quant_config.method
wbits, quant_method = meta_info.bits, meta_info.quant_method
quant_config_by_layer["quant_method"] = model.quant_config.quant_method

target_layer = select_quant_linear(pack_mode, wbits, method)
target_layer = select_quant_linear(pack_mode, wbits, quant_method)

make_mixbits_quant_linear(model, quantizers, quant_config_by_layer, target_layer=target_layer)
qlayers = find_layers(model, [target_layer])
Expand Down Expand Up @@ -106,8 +117,8 @@ def repack_to_new_mode(self, model, new_pack_mode):
return model
meta_info = model.quant_config.to_meta
bits = meta_info.bits
source_layer = select_quant_linear(old_pack_mode, bits, meta_info.method)
target_layer = select_quant_linear(new_pack_mode, bits, meta_info.method)
source_layer = select_quant_linear(old_pack_mode, bits, meta_info.quant_method)
target_layer = select_quant_linear(new_pack_mode, bits, meta_info.quant_method)
if source_layer == target_layer:
return model
model.quant_config.version = new_pack_mode
Expand Down Expand Up @@ -174,14 +185,14 @@ def run(self, args):
set_seed(args.seed)

if args.pack_mode == "AUTO" and args.allow_mix_bits:
assert args.method == "gptq", "only gptq support allow_mix_bits mode"
assert args.quant_method == "gptq", "only gptq support allow_mix_bits mode"
args.pack_mode = "GPTQ"
if args.allow_mix_bits and args.pack_mode != "GPTQ":
raise ValueError("allow_mix_bits only support GPTQ packing mode")
if not isinstance(args.load, str):
args.load = args.load.as_posix()

if args.method == "awq" and args.nsamples > 64:
if args.quant_method == "awq" and args.nsamples > 64:
logger.warning("as the memory blast, AWQ will limit to 32 samples for quantization")
args.nsamples = 64

Expand All @@ -202,7 +213,7 @@ def run(self, args):
Please run with `-h` to refer the usage.")

if not args.load and args.wbits < 16:
if args.method == "hqq":
if args.quant_method == "hqq":
inputs_dataloader = None
else:
inputs_dataloader = self.get_datasets(args.tokenizer, args.dataset, args.nsamples, args.seed)
Expand Down
43 changes: 32 additions & 11 deletions qllm/modeling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def no_init_weights(attrs: list = None):
setattr(torch.Tensor, attr, old_attr[idx])

def get_no_split_layer_type_name(model:torch.nn.Module):
for name,mod in model.named_modules():
if '.0' in name and name.count('.0') == 1:
return [mod.__class__.__name__]
try:
return model._get_no_split_modules("auto")
except: # noqa: E722
for name,mod in model.named_modules():
if '.0' in name and name.count('.0') == 1:
return [mod.__class__.__name__]

def _hf_weight_generator(hf_weights_files, is_safetensors:bool):
if is_safetensors:
Expand Down Expand Up @@ -172,11 +175,31 @@ def from_pretrained(
cls.disable_double_init()
trust_remote_code = kwargs.pop("trust_remote_code", False)
attn_implementation = kwargs.pop("attn_implementation", None)
max_memory = kwargs.pop("max_memory", None)

# with accelerate.init_empty_weights():
# auto_conf = transformers.AutoConfig.from_pretrained(
# pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
# model = AutoModelForCausalLM.from_config(auto_conf, trust_remote_code=trust_remote_code)
# llm = accelerate.load_checkpoint_and_dispatch(
# model,
# checkpoint=pretrained_model_name_or_path,
# device_map="auto",
# max_memory={0: 1 * 1024 * 1024 * 1024, "cpu": 5 * 1024 * 1024 * 1024},
# dtype=torch.float16,
# no_split_module_classes=get_no_split_layer_type_name(model),
# offload_folder="/tmp/a2",
# )

llm = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, torch_dtype=torch.float16, trust_remote_code=trust_remote_code,
attn_implementation=attn_implementation)
pretrained_model_name_or_path,
torch_dtype=torch.float16,
trust_remote_code=trust_remote_code,
attn_implementation=attn_implementation,
# device_map="auto",
# low_cpu_mem_usage=True,
# max_memory={0: 1*1024 * 1024 * 1024, "cpu": 5*1024 * 1024 * 1024},
# offload_folder="/tmp/a2"
)
return llm

@classmethod
Expand Down Expand Up @@ -241,24 +264,22 @@ def from_quantized(
del layers[layer_name]

target_layer = utils.modelutils.select_quant_linear(
quant_config.version, quant_config.bits(), quant_config.method)
quant_config.version, quant_config.bits(), quant_config.quant_method)
torch.set_default_device("cuda")
utils.modelutils.make_mixbits_quant_linear(
model, layers, quant_config.quant_config_by_op, target_layer=target_layer)
torch.set_default_device("cpu")
if quant_config.method == "awq":
if quant_config.quant_method == "awq":
from ..quantization.quant_awq import scale_activations
scale_activations(model)
del layers
# if low_cpu_mem_usage:
# model = model.cuda()
model.tie_weights() # works with init_empty_weights and load_checkpoint_and_dispatch
try:
# bias issue
no_split_module_classes = get_no_split_layer_type_name(model)
assert no_split_module_classes is None
if torch.cuda.mem_get_info()[1]/1024/1024/1024 < 8:
model = accelerate.big_modeling.load_checkpoint_and_dispatch(
model = accelerate.load_checkpoint_and_dispatch(
model,
checkpoint=_get_resolved_weight_or_index_file(model_name_or_path, quant_config),
device_map=device_map,
Expand Down
8 changes: 4 additions & 4 deletions qllm/modeling/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BaseQuantizeConfig:
def __init__(self):
self.quant_config = {}
self.quant_config_by_op = {}
self.method = None
self.quant_method = None
self.COMPATIBLE_WITH_AUTOGPTQ = False

def groupsize(self, layer_name: str = None):
Expand All @@ -26,7 +26,7 @@ def bits(self, layer_name: str = None):

@property
def to_meta(self):
return MetaConfig(self.bits(), self.groupsize(), self.method)
return MetaConfig(self.bits(), self.groupsize(), self.quant_method)

@property
def version(self):
Expand Down Expand Up @@ -105,13 +105,13 @@ def load_quant_config(self, model_name_or_path):
if quant_config.get('COMPATIBLE_WITH_AUTOGPTQ', None):
self.COMPATIBLE_WITH_AUTOGPTQ = True
if "version" not in quant_config:
self.method = "gptq"
self.quant_method = "gptq"
quant_config["version"] = "GPTQ"
self.COMPATIBLE_WITH_AUTOGPTQ = True
import os
os.environ["COMPATIBLE_WITH_AUTOGPTQ"] = '1' # FixMe: hacky
else: # FIXME is it correct?
self.method = quant_config.get("method", "awq")
self.quant_method = quant_config.get("quant_method", "awq")
self.quant_config = quant_config

@classmethod
Expand Down
Loading

0 comments on commit 4c67f07

Please sign in to comment.