Skip to content

Commit

Permalink
Merge pull request Vahe1994#14 from Vahe1994/integration
Browse files Browse the repository at this point in the history
Transformers quantizers integration
  • Loading branch information
BlackSamorez authored Feb 22, 2024
2 parents 8f75dcd + 206da1f commit 4ab2475
Show file tree
Hide file tree
Showing 12 changed files with 24 additions and 4,429 deletions.
1 change: 0 additions & 1 deletion benchmark/benchmark_generate_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

torch.set_num_threads(8)
from torch import nn

from transformers import AutoConfig, AutoModelForCausalLM

if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion benchmark/generate_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn as nn
from tqdm import trange

from transformers import AutoConfig, AutoModelForCausalLM

if __name__ == "__main__":
Expand Down
69 changes: 23 additions & 46 deletions convert_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
from tqdm.auto import trange

from transformers import AutoConfig, PretrainedConfig


Expand Down Expand Up @@ -43,8 +42,9 @@ def get_layers_prefix(config) -> str:
raise NotImplementedError(f"Can't get layers prefix for {unknown_type}")


def get_converted_state_dict(config, nbits: int, in_path: os.PathLike) -> dict:
def get_converted_state_dict(config, nbits: int, in_path: os.PathLike) -> [dict, list[str]]:
state_dict = {}
linear_weights_not_to_quantize = []

num_layers = get_num_layers(config)
layers_prefix = get_layers_prefix(config)
Expand All @@ -56,13 +56,17 @@ def get_converted_state_dict(config, nbits: int, in_path: os.PathLike) -> dict:
p.data = p.data.half()
else:
p.data = pack_int_data(p.data, nbits)
name = re.sub("quantized_weight.", "", name)
if "quantized_weight." not in name:
linear_weights_not_to_quantize.append(f"{layers_prefix}.{i}.{name}")
else:
name = re.sub("quantized_weight.", "", name)
state_dict[f"{layers_prefix}.{i}.{name}"] = p.data

for key, value in torch.load(os.path.join(in_path, "not_quantized_weights.pt")).items():
state_dict[key] = value.half()
linear_weights_not_to_quantize.append(key)

return state_dict
return state_dict, linear_weights_not_to_quantize


def get_metadata(in_path: os.PathLike) -> dict:
Expand All @@ -75,44 +79,17 @@ def get_metadata(in_path: os.PathLike) -> dict:
}


def update_config(old_config: PretrainedConfig, aqlm_metadata: dict[str, int]):
old_config_type = type(old_config)
old_model_type = old_config.model_type
new_model_type = f"{old_model_type}_aqlm"

class AqlmConfig(old_config_type):
model_type = new_model_type

def __init__(
self,
aqlm: dict[str, int] = {
"nbits_per_codebook": 16,
"num_codebooks": 1,
"out_group_size": 8,
"in_group_size": 1,
},
**kwargs,
):
super().__init__(**kwargs)
self.aqlm = aqlm

config_dict = old_config.to_dict()
config_dict["auto_map"] = {
"AutoConfig": f"configuration_{new_model_type}.{old_config.__class__.__name__}",
"AutoModelForCausalLM": f"modeling_{new_model_type}.{config_dict['architectures'][0]}",
def update_config(config_dict: dict, aqlm_metadata: dict[str, int], linear_weights_not_to_quantize: list[str]):
config_dict["quantization_config"] = {
"quant_method": "aqlm",
"nbits_per_codebook": aqlm_metadata["nbits_per_codebook"],
"num_codebooks": aqlm_metadata["num_codebooks"],
"out_group_size": aqlm_metadata["out_group_size"],
"in_group_size": aqlm_metadata["in_group_size"],
"linear_weights_not_to_quantize": linear_weights_not_to_quantize,
}
del config_dict["_name_or_path"]

new_config = AqlmConfig(
{
"nbits_per_codebook": aqlm_metadata["nbits_per_codebook"],
"num_codebooks": aqlm_metadata["num_codebooks"],
"out_group_size": aqlm_metadata["out_group_size"],
"in_group_size": aqlm_metadata["in_group_size"],
}
)
new_config.update(config_dict)
return new_config
config_dict["torch_dtype"] = "float16"
return config_dict


def add_inference_code(model_type: str, save_path: os.PathLike):
Expand Down Expand Up @@ -147,11 +124,11 @@ def add_inference_code(model_type: str, save_path: os.PathLike):
old_config = AutoConfig.from_pretrained(args.model)
metadata = get_metadata(args.in_path)

add_inference_code(old_config.model_type, args.out_path)

state_dict = get_converted_state_dict(old_config, metadata["nbits_per_codebook"], args.in_path)
state_dict, linear_weights_not_to_quantize = get_converted_state_dict(
old_config, metadata["nbits_per_codebook"], args.in_path
)
torch.save(state_dict, os.path.join(args.out_path, "pytorch_model.bin"))

new_config = update_config(old_config, metadata)
new_config_dict = update_config(old_config.to_diff_dict(), metadata, linear_weights_not_to_quantize)
with open(os.path.join(args.out_path, "config.json"), "w") as config_file:
json.dump(new_config.to_dict(), config_file, indent=4)
json.dump(new_config_dict, config_file, indent=4)
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn
from tqdm import trange
from tqdm.auto import trange
from transformers import PreTrainedModel

from aq_engine import AQEngine
from src.aq import QuantizedLinear
Expand All @@ -23,7 +24,6 @@
get_sequential_groups,
)
from src.utils import using_tf32
from transformers import PreTrainedModel

try:
import wandb
Expand Down
1 change: 0 additions & 1 deletion src/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from datasets import load_dataset
from packaging import version
from tqdm import trange

from transformers import AutoTokenizer, LlamaTokenizer


Expand Down
1 change: 0 additions & 1 deletion src/modelutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
from tqdm import trange

from transformers import AutoConfig, AutoModelForCausalLM

MODEL_ERROR_MSG = "Unsupported model type {} - only 'llama', 'Yi', 'opt' and 'falcon' are supported"
Expand Down
18 changes: 0 additions & 18 deletions transformers/llama/configuration_llama_aqlm.py

This file was deleted.

Loading

0 comments on commit 4ab2475

Please sign in to comment.