Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
tsengalb99 committed Dec 19, 2023
1 parent 5cc9e04 commit f729d8a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 82 deletions.
1 change: 0 additions & 1 deletion gen_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--seqlen', default=1, type=int)
parser.add_argument('--samples', default=100, type=int)
parser.add_argument('--max_tokens', default=400, type=int)
parser.add_argument('--no_use_cuda_graph', action='store_true')
parser.add_argument('--no_use_flash_attn', action='store_true')

Expand Down
103 changes: 27 additions & 76 deletions hfize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ def main(args):
tokenizer = AutoTokenizer.from_pretrained(model_config._name_or_path)

model_type = model_config.model_type
fused = model_config.quip_params.get('fused', True)
model_config.quip_params['model_version'] = MODEL_VERSION

if model_type == 'llama' and fused:
if model_type == 'llama':
model_cls = llama_fuse
elif model_type == 'mistral':
model_cls = MistralForCausalLM
Expand All @@ -69,80 +68,32 @@ def main(args):
layer = model.model.layers[ii]
cpu = torch.device('cpu')

if fused:
glog.info(f'loading layer {ii} qkv')
saved_layer = torch.load(f'{args.quantized_path}/{ii}_qkv.pt', map_location=cpu)
layer.self_attn.qkv_proj.fuse_scales[0].copy_(saved_layer['W_q_scale'])
layer.self_attn.qkv_proj.fuse_scales[1].copy_(saved_layer['W_k_scale'])
layer.self_attn.qkv_proj.fuse_scales[2].copy_(saved_layer['W_v_scale'])
layer.self_attn.qkv_proj.Wscale.copy_(saved_layer['Wscale'])
unpack_quip(layer.self_attn.qkv_proj, saved_layer, codebook_id, codesz)

glog.info(f'loading layer {ii} up')
saved_layer = torch.load(f'{args.quantized_path}/{ii}_up.pt', map_location=cpu)
layer.mlp.upgate_proj.fuse_scales[0].copy_(saved_layer['W_up_scale'])
layer.mlp.upgate_proj.fuse_scales[1].copy_(saved_layer['W_gate_scale'])
layer.mlp.upgate_proj.Wscale.copy_(saved_layer['Wscale'])
unpack_quip(layer.mlp.upgate_proj, saved_layer, codebook_id, codesz)

glog.info(f'loading layer {ii} o')
saved_layer = torch.load(f'{args.quantized_path}/{ii}_o.pt', map_location=cpu)
layer.self_attn.o_proj.Wscale.copy_(saved_layer['W_o_scale'] * saved_layer['Wscale'])
unpack_quip(layer.self_attn.o_proj, saved_layer, codebook_id, codesz)

glog.info(f'loading layer {ii} down')
saved_layer = torch.load(f'{args.quantized_path}/{ii}_down.pt', map_location=cpu)
layer.mlp.down_proj.Wscale.copy_(saved_layer['W_down_scale'] * saved_layer['Wscale'])
if model_config.quip_params['outlier_channel_split']:
layer.mlp.down_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds']))
unpack_quip(layer.mlp.down_proj, saved_layer, codebook_id, codesz)

else:
saved_layer = torch.load(f'{args.quantized_path}/{ii}_q.pt', map_location=cpu)
layer.self_attn.q_scale.copy_(saved_layer['W_scale']*saved_layer['Wscale'])
if model_config.quip_params['outlier_channel_split']:
layer.self_attn.q_proj.ocs_dupe_inds.copy_(
torch.tensor(saved_layer['ocs_dupe_inds']))
unpack_quip(layer.self_attn.q_proj, saved_layer, codebook_id, codesz)

saved_layer = torch.load(f'{args.quantized_path}/{ii}_k.pt', map_location=cpu)
layer.self_attn.k_scale.copy_(saved_layer['W_scale']*saved_layer['Wscale'])
if model_config.quip_params['outlier_channel_split']:
layer.self_attn.k_proj.ocs_dupe_inds.copy_(
torch.tensor(saved_layer['ocs_dupe_inds']))
unpack_quip(layer.self_attn.k_proj, saved_layer, codebook_id, codesz)

saved_layer = torch.load(f'{args.quantized_path}/{ii}_v.pt', map_location=cpu)
layer.self_attn.v_scale.copy_(saved_layer['W_scale']*saved_layer['Wscale'])
if model_config.quip_params['outlier_channel_split']:
layer.self_attn.v_proj.ocs_dupe_inds.copy_(
torch.tensor(saved_layer['ocs_dupe_inds']))
unpack_quip(layer.self_attn.v_proj, saved_layer, codebook_id, codesz)

saved_layer = torch.load(f'{args.quantized_path}/{ii}_o.pt', map_location=cpu)
layer.self_attn.o_scale.copy_(saved_layer['W_scale']*saved_layer['Wscale'])
if model_config.quip_params['outlier_channel_split']:
layer.self_attn.o_proj.ocs_dupe_inds.copy_(
torch.tensor(saved_layer['ocs_dupe_inds']))
unpack_quip(layer.self_attn.o_proj, saved_layer, codebook_id, codesz)

saved_layer = torch.load(f'{args.quantized_path}/{ii}_up.pt', map_location=cpu)
layer.mlp.up_scale.copy_(saved_layer['W_scale']*saved_layer['Wscale'])
if model_config.quip_params['outlier_channel_split']:
layer.mlp.up_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds']))
unpack_quip(layer.mlp.up_proj, saved_layer, codebook_id, codesz)

saved_layer = torch.load(f'{args.quantized_path}/{ii}_gate.pt', map_location=cpu)
layer.mlp.gate_scale.copy_(saved_layer['W_scale']*saved_layer['Wscale'])
if model_config.quip_params['outlier_channel_split']:
layer.mlp.gate_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds']))
unpack_quip(layer.mlp.gate_proj, saved_layer, codebook_id, codesz)

saved_layer = torch.load(f'{args.quantized_path}/{ii}_down.pt', map_location=cpu)
layer.mlp.down_scale.copy_(saved_layer['W_scale']*saved_layer['Wscale'])
if model_config.quip_params['outlier_channel_split']:
layer.mlp.down_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds']))
unpack_quip(layer.mlp.down_proj, saved_layer, codebook_id, codesz)
glog.info(f'loading layer {ii} qkv')
saved_layer = torch.load(f'{args.quantized_path}/{ii}_qkv.pt', map_location=cpu)
layer.self_attn.qkv_proj.fuse_scales[0].copy_(saved_layer['W_q_scale'])
layer.self_attn.qkv_proj.fuse_scales[1].copy_(saved_layer['W_k_scale'])
layer.self_attn.qkv_proj.fuse_scales[2].copy_(saved_layer['W_v_scale'])
layer.self_attn.qkv_proj.Wscale.copy_(saved_layer['Wscale'])
unpack_quip(layer.self_attn.qkv_proj, saved_layer, codebook_id, codesz)

glog.info(f'loading layer {ii} up')
saved_layer = torch.load(f'{args.quantized_path}/{ii}_up.pt', map_location=cpu)
layer.mlp.upgate_proj.fuse_scales[0].copy_(saved_layer['W_up_scale'])
layer.mlp.upgate_proj.fuse_scales[1].copy_(saved_layer['W_gate_scale'])
layer.mlp.upgate_proj.Wscale.copy_(saved_layer['Wscale'])
unpack_quip(layer.mlp.upgate_proj, saved_layer, codebook_id, codesz)

glog.info(f'loading layer {ii} o')
saved_layer = torch.load(f'{args.quantized_path}/{ii}_o.pt', map_location=cpu)
layer.self_attn.o_proj.Wscale.copy_(saved_layer['W_o_scale'] * saved_layer['Wscale'])
unpack_quip(layer.self_attn.o_proj, saved_layer, codebook_id, codesz)

glog.info(f'loading layer {ii} down')
saved_layer = torch.load(f'{args.quantized_path}/{ii}_down.pt', map_location=cpu)
layer.mlp.down_proj.Wscale.copy_(saved_layer['W_down_scale'] * saved_layer['Wscale'])
if model_config.quip_params['outlier_channel_split']:
layer.mlp.down_proj.ocs_dupe_inds.copy_(torch.tensor(saved_layer['ocs_dupe_inds']))
unpack_quip(layer.mlp.down_proj, saved_layer, codebook_id, codesz)

glog.info(f'saving model...')
model.save_pretrained(args.hf_output_path, safe_serialization=True)
Expand Down
6 changes: 1 addition & 5 deletions lib/utils/unsafe_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@ def maybe_wrap(use_cuda_graph):
is_quantized = hasattr(bad_config, 'quip_params')
model_type = bad_config.model_type
if is_quantized:
fused = bad_config.quip_params.get('fused', True)
if model_type == 'llama':
model_str = transformers.LlamaConfig.from_pretrained(path)._name_or_path
if fused:
model_cls = llama_fuse
else:
raise Exception
model_cls = llama_fuse
elif model_type == 'mistral':
model_str = transformers.MistralConfig.from_pretrained(path)._name_or_path
model_cls = MistralForCausalLM
Expand Down

0 comments on commit f729d8a

Please sign in to comment.