Skip to content

Commit

Permalink
exposed scale to pti
Browse files Browse the repository at this point in the history
  • Loading branch information
cloneofsimo committed Jan 31, 2023
1 parent 22a8e9b commit 477f29b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def train(
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
lora_clip_target_modules={"CLIPAttention"},
lora_dropout_p: float = 0.0,
lora_scale: float = 1.0,
use_extended_lora: bool = False,
clip_ti_decay: bool = True,
learning_rate_unet: float = 1e-4,
Expand Down Expand Up @@ -729,6 +730,7 @@ def train(
r=lora_rank,
target_replace_module=lora_unet_target_modules,
dropout_p=lora_dropout_p,
scale=lora_scale,
)
else:
print("PTI : USING EXTENDED UNET!!!")
Expand Down
2 changes: 2 additions & 0 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def inject_trainable_lora(
loras=None, # path to lora .pt
verbose: bool = False,
dropout_p: float = 0.0,
scale: float = 1.0,
):
"""
inject lora into model, and returns lora parameter groups.
Expand All @@ -284,6 +285,7 @@ def inject_trainable_lora(
_child_module.bias is not None,
r=r,
dropout_p=dropout_p,
scale=scale,
)
_tmp.linear.weight = weight
if bias is not None:
Expand Down

0 comments on commit 477f29b

Please sign in to comment.