Skip to content

Commit

Permalink
add layer GPU offloading for hidden/target states
Browse files Browse the repository at this point in the history
  • Loading branch information
kallewoof committed May 22, 2024
1 parent b428e23 commit 0ece2f3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
22 changes: 12 additions & 10 deletions conversion/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def print_status_box(*content_lines):
print('-' * box_width)

@torch.inference_mode()
def measure_quant(job, save_fn, model):
def measure_quant(job, save_fn, model, hidden_state_offload_layers):

# vars for status box
time_spent_list = []
Expand Down Expand Up @@ -418,8 +418,9 @@ def measure_quant(job, save_fn, model):

hidden_states = []
with safe_open(states_filename, framework = "pt", device = "cpu") as f:
for k in sorted(f.keys()):
hidden_states.append(f.get_tensor(k))
for i, k in enumerate(sorted(f.keys())):
t = f.get_tensor(k)
hidden_states.append(t.to("cuda:0") if i < hidden_state_offload_layers else t)

index = job["last_module_idx"]
while True:
Expand Down Expand Up @@ -515,18 +516,19 @@ def measure_quant(job, save_fn, model):

x = hidden_states[i].to("cuda:0")
outputs = module.forward(x, cache, attn_params, intermediates = True)
target_device = "cuda:0" if i < hidden_state_offload_layers else "cpu"

# Hessians

if mode == "self_attn":
quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V
quantizers["o_proj"].add_batch(outputs["attn_output"])
target_states.append(outputs["hidden_states"].to("cpu"))
target_states.append(outputs["hidden_states"].to(target_device))

if mode == "mlp":
quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj
quantizers["down_proj"].add_batch(outputs["pre_down"])
target_states.append(outputs["hidden_states"].to("cpu"))
target_states.append(outputs["hidden_states"].to(target_device))

if mode == "block_sparse_moe":
for j in range(model.config.num_experts):
Expand All @@ -537,19 +539,19 @@ def measure_quant(job, save_fn, model):
uncalibrated_experts[j] += 1
else:
uncalibrated_experts[j] += 1
target_states.append(outputs["hidden_states"].to("cpu"))
target_states.append(outputs["hidden_states"].to(target_device))

if mode == "parallel_decoder":
quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K, V, up_proj and gate_proj
quantizers["o_proj"].add_batch(outputs["attn_output"])
quantizers["down_proj"].add_batch(outputs["pre_down"])
hidden_states[i] = outputs["post_norm"]
target_states_attn.append(outputs["hidden_states_attn"].to("cpu"))
target_states_mlp.append(outputs["hidden_states_mlp"].to("cpu"))
target_states.append(outputs["hidden_states"].to("cpu"))
target_states_attn.append(outputs["hidden_states_attn"].to(target_device))
target_states_mlp.append(outputs["hidden_states_mlp"].to(target_device))
target_states.append(outputs["hidden_states"].to(target_device))

if mode == "pos_emb":
target_states.append(outputs["hidden_states"].to("cpu"))
target_states.append(outputs["hidden_states"].to(target_device))

# For MoE layers, warn if any layer received less than 10% of a calibration batch

Expand Down
3 changes: 2 additions & 1 deletion convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
parser.add_argument("-l", "--length", type = int, default = 2048, help = "Max no. tokens per sample")
parser.add_argument("-ml", "--measurement_length", type = int, default = 2048, help = "Max no. tokens per sample when measuring")
parser.add_argument("-so", "--status_output", action = "store_true", help = "Include machine-parseable status updates in console output")
parser.add_argument("-hsol", "--hidden_state_offload_layers", type = int, default = 0, help = "Number of hidden/target states to keep in VRAM. Speed-up but increases VRAM usage")

args = parser.parse_args()

Expand Down Expand Up @@ -242,7 +243,7 @@ def save_job():
model = ExLlamaV2(config)
model.load(lazy = True)

status = measure_quant(job, save_job, model) # capturing the graceful exits
status = measure_quant(job, save_job, model, args.hidden_state_offload_layers) # capturing the graceful exits
if status == "interrupted":
print("Process interrupted. Exiting gracefully.")
save_job()
Expand Down

0 comments on commit 0ece2f3

Please sign in to comment.