Skip to content

Commit

Permalink
[Minor] Merge model initilization
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiaming Tang committed Jul 4, 2023
1 parent e04d0ec commit 6371c3a
Showing 1 changed file with 4 additions and 17 deletions.
21 changes: 4 additions & 17 deletions awq/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,14 @@ def build_model_and_enc(model_path):
)
else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)

if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq"

# Init model on CPU
def skip(*args, **kwargs):
pass

torch.nn.init.kaiming_normal_ = skip
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)


awq_results = run_awq(
model, enc,
w_bit=args.w_bit, q_config=q_config,
Expand All @@ -121,11 +113,6 @@ def skip(*args, **kwargs):
print("AWQ results saved at", args.dump_awq)

exit(0)
else:
# Inference with fake quant
# Init model on CPU:
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)

if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq)
Expand Down

0 comments on commit 6371c3a

Please sign in to comment.