Skip to content

Commit

Permalink
new autogptq config format && parallel load (#110)
Browse files Browse the repository at this point in the history
* new autogptq config format && parallel load

* fix
  • Loading branch information
wejoncy authored Mar 25, 2024
1 parent 408c79f commit bb63f8b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
10 changes: 7 additions & 3 deletions qllm/modeling/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import torch
import transformers
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -85,7 +86,7 @@ def _get_resolved_weight_or_index_file(model_name_or_path):
for possible_weight_file in ["model.safetensors", "pytorch_model.bin"]:
weight_or_index_file = cached_file(model_name_or_path, possible_weight_file)
if weight_or_index_file:break
return weight_or_index_file
return str(weight_or_index_file)


def _load_check_point(model, model_name_or_path, get_keys_only: bool = False):
Expand All @@ -99,7 +100,10 @@ def _load_check_point(model, model_name_or_path, get_keys_only: bool = False):
if "weight_map" in index:
index = index["weight_map"]
checkpoint_files = sorted(list(set(index.values())))
checkpoint_files = [cached_file(model_name_or_path, f) for f in checkpoint_files]
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_checkpoint_files = {executor.submit(cached_file, model_name_or_path, f): f for f in checkpoint_files}
checkpoint_files = [future.result() for future in concurrent.futures.as_completed(future_to_checkpoint_files)]
#checkpoint_files = [cached_file(model_name_or_path, f) for f in checkpoint_files]
else:
checkpoint_files = [weight_or_index_file]

Expand Down Expand Up @@ -284,7 +288,7 @@ def save_pretrained(model, tokenizer, save_directory: Union[str, Path], pack_mod
quant_config_by_layer, quant_config = model.quant_config_by_layer, model.quant_config
if pack_mode != quant_config.version and pack_mode != "AUTO":
repack_func()

model.config.quantization_config = model.quant_config.quant_config
model.save_pretrained(save_directory, save_serialization=save_serialization)
tokenizer is not None and tokenizer.save_pretrained(save_directory)

Expand Down
17 changes: 14 additions & 3 deletions qllm/modeling/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,24 @@ def load_quant_op_config(self, model_name_or_path):

def load_quant_config(self, model_name_or_path):
config_file = self.get_resolved_base_dir(model_name_or_path, "quant_config.json")
quant_config = None
if config_file is None:
# GPTQ-for-llama/AutoGPTQ
config_file = self.get_resolved_base_dir(model_name_or_path, "quantize_config.json")
if config_file is not None:
with open(config_file) as fp:
quant_config = json.load(fp)

if config_file is None:
config_file = self.get_resolved_base_dir(model_name_or_path, "config.json")
if config_file is not None:
with open(config_file) as fp:
quant_config = json.load(fp)
quant_config = quant_config.get("quantization_config", None)
assert quant_config.get('use_exllama', False) == False, "use_exllama is not supported yet"

assert quant_config is not None, ("quant_config.json/quantize_config.json not found in checkpoint directory")

assert config_file is not None, ("quant_config.json/quantize_config.json not found in checkpoint directory")
with open(config_file) as fp:
quant_config = json.load(fp)
wbits = quant_config.get("w_bit", quant_config.get("bits", None))
groupsize = quant_config.get("q_group_size", quant_config.get("group_size", None))
assert wbits is not None and groupsize is not None
Expand Down

0 comments on commit bb63f8b

Please sign in to comment.