Skip to content

Commit

Permalink
Avoid unsupported init_module and ignore_modules combination
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Oct 24, 2023
1 parent 9a23e73 commit d1c2a9f
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,14 +376,7 @@ def test_lora_merge_with_quantize():
to_projection=True,
)
fabric = Fabric(devices=1, plugins=BitsandbytesPrecision("nf4", dtype=torch.bfloat16, ignore_modules={"lm_head"}))
with fabric.init_module(empty_init=False):
model = GPT(config)
model.apply(model._init_weights)

attn_proj = model.transformer.h[0].attn.proj
assert model.lm_head.linear.weight.dtype is torch.bfloat16
assert attn_proj.linear.weight.dtype is torch.bfloat16

model = GPT(config)
mark_only_lora_as_trainable(model)

from bitsandbytes.optim import PagedAdamW
Expand All @@ -393,10 +386,11 @@ def test_lora_merge_with_quantize():

model.train()

attn_proj = model.transformer.h[0].attn.proj
initial_weight = attn_proj.linear.weight.clone()

# this was skipped
assert model.lm_head.linear.weight.dtype is torch.bfloat16
assert model.lm_head.linear.weight.dtype is torch.floar32
assert attn_proj.linear.weight.dtype is torch.uint8

# perform an update to the LoRA weights
Expand Down

0 comments on commit d1c2a9f

Please sign in to comment.