Skip to content

Commit

Permalink
Merge pull request mit-han-lab#33 from abhinavkulkarni/dev/more_models
Browse files Browse the repository at this point in the history
  • Loading branch information
Sakits authored Jul 10, 2023
2 parents ab536fb + d2a10bd commit 3b9f287
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
4 changes: 2 additions & 2 deletions awq/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def build_model_and_enc(model_path):
# all hf model
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if "mpt" in config.__class__.__name__.lower():
enc = AutoTokenizer.from_pretrained(config.tokenizer_name)
enc = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
else:
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False)
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)

if args.load_quant: # directly load quantized weights
print("Loading pre-computed quantized weights...")
Expand Down
7 changes: 6 additions & 1 deletion awq/quantize/auto_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,14 @@ def w_quantize_func(p): return p
def _search_module_scale(block, linears2scale: list, x, kwargs={}):
# w: co, ci
# x: n, ci
x = x.to(next(block.parameters()).device)
weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
w_max = get_weight_scale(
weight, q_group_size=q_config.get("q_group_size", -1))
# Clear GPU memory
del weight
torch.cuda.empty_cache()

x = x.to(next(block.parameters()).device)
with torch.no_grad():
org_out = block(x, **kwargs)
if isinstance(org_out, tuple):
Expand All @@ -126,6 +129,8 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}):
n_grid = 20
history = []

# Clear GPU memory
torch.cuda.empty_cache()
org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
for ratio in range(n_grid):
ratio = ratio * 1 / n_grid
Expand Down
6 changes: 6 additions & 0 deletions awq/quantize/pre_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def cache_input_hook(m, x, y, name, feat_dict):
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

# Clear GPU memory
torch.cuda.empty_cache()

if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block(
layer, layer_kwargs,
Expand All @@ -145,6 +148,9 @@ def cache_input_hook(m, x, y, name, feat_dict):
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")

# Clear GPU memory
torch.cuda.empty_cache()

if mse_range:
clip_list = auto_clip_block(layer,
Expand Down

0 comments on commit 3b9f287

Please sign in to comment.