Skip to content

Commit b1d8293

Browse files
authored
Merge pull request cloneofsimo#164 from cloneofsimo/develop
v0.1.4
2 parents 437cb62 + 82ba343 commit b1d8293

File tree

4 files changed

+74
-7
lines changed

4 files changed

+74
-7
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,19 @@
5050

5151
# UPDATES & Notes
5252

53-
### 2022/02/01
53+
### 2023/02/01
5454

5555
- LoRA Joining is now available with `--mode=ljl` flag. Only three parameters are required : `path_to_lora1`, `path_to_lora2`, and `path_to_save`.
5656

57-
### 2022/01/29
57+
### 2023/01/29
5858

5959
- Dataset pipelines
6060
- LoRA Applied to Resnet as well, use `--use_extended_lora` to use it.
6161
- SVD distillation now supports resnet-lora as well.
6262
- Compvis format Conversion script now works with safetensors, and will for PTI it will return Textual inversion format as well, so you can use it in embeddings folder.
6363
- 🥳🥳, LoRA is now officially integrated into the amazing Huggingface 🤗 `diffusers` library! Check out the [Blog](https://huggingface.co/blog/lora) and [examples](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora)! (NOTE : It is CURRENTLY DIFFERENT FILE FORMAT)
6464

65-
### 2022/01/09
65+
### 2023/01/09
6666

6767
- Pivotal Tuning Inversion with extended latent
6868
- Better textual inversion with Norm prior

lora_diffusion/cli_lora_pti.py

+69-3
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,12 @@ def perform_tuning(
423423
lora_unet_target_modules,
424424
lora_clip_target_modules,
425425
mask_temperature,
426+
out_name: str,
427+
tokenizer,
428+
test_image_path: str,
429+
log_wandb: bool = False,
430+
wandb_log_prompt_cnt: int = 10,
431+
class_token: str = "person",
426432
):
427433

428434
progress_bar = tqdm(range(num_steps))
@@ -434,6 +440,11 @@ def perform_tuning(
434440
unet.train()
435441
text_encoder.train()
436442

443+
if log_wandb:
444+
preped_clip = prepare_clip_model_sets()
445+
446+
loss_sum = 0.0
447+
437448
for epoch in range(math.ceil(num_steps / len(dataloader))):
438449
for batch in dataloader:
439450
lr_scheduler_lora.step()
@@ -450,6 +461,8 @@ def perform_tuning(
450461
mixed_precision=True,
451462
mask_temperature=mask_temperature,
452463
)
464+
loss_sum += loss.detach().item()
465+
453466
loss.backward()
454467
torch.nn.utils.clip_grad_norm_(
455468
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
@@ -493,15 +506,59 @@ def perform_tuning(
493506

494507
print("LORA CLIP Moved", moved)
495508

509+
if log_wandb:
510+
with torch.no_grad():
511+
pipe = StableDiffusionPipeline(
512+
vae=vae,
513+
text_encoder=text_encoder,
514+
tokenizer=tokenizer,
515+
unet=unet,
516+
scheduler=scheduler,
517+
safety_checker=None,
518+
feature_extractor=None,
519+
)
520+
521+
# open all images in test_image_path
522+
images = []
523+
for file in os.listdir(test_image_path):
524+
if file.endswith(".png") or file.endswith(".jpg"):
525+
images.append(
526+
Image.open(os.path.join(test_image_path, file))
527+
)
528+
529+
wandb.log({"loss": loss_sum / save_steps})
530+
loss_sum = 0.0
531+
wandb.log(
532+
evaluate_pipe(
533+
pipe,
534+
target_images=images,
535+
class_token=class_token,
536+
learnt_token="".join(placeholder_tokens),
537+
n_test=wandb_log_prompt_cnt,
538+
n_step=50,
539+
clip_model_sets=preped_clip,
540+
)
541+
)
542+
496543
if global_step >= num_steps:
497-
return
544+
break
545+
546+
save_all(
547+
unet,
548+
text_encoder,
549+
placeholder_token_ids=placeholder_token_ids,
550+
placeholder_tokens=placeholder_tokens,
551+
save_path=os.path.join(save_path, f"{out_name}.safetensors"),
552+
target_replace_module_text=lora_clip_target_modules,
553+
target_replace_module_unet=lora_unet_target_modules,
554+
)
498555

499556

500557
def train(
501558
instance_data_dir: str,
502559
pretrained_model_name_or_path: str,
503560
output_dir: str,
504-
train_text_encoder: bool = False,
561+
train_text_encoder: bool = True,
505562
pretrained_vae_name_or_path: str = None,
506563
revision: Optional[str] = None,
507564
class_data_dir: Optional[str] = None,
@@ -555,7 +612,9 @@ def train(
555612
wandb_log_prompt_cnt: int = 10,
556613
wandb_project_name: str = "new_pti_project",
557614
wandb_entity: str = "new_pti_entity",
615+
proxy_token: str = "person",
558616
enable_xformers_memory_efficient_attention: bool = False,
617+
out_name: str = "final_lora",
559618
):
560619
torch.manual_seed(seed)
561620

@@ -566,7 +625,6 @@ def train(
566625
name=f"steps_{max_train_steps_ti}_lr_{learning_rate_ti}_{instance_data_dir.split('/')[-1]}",
567626
reinit=True,
568627
config={
569-
"lr": learning_rate_ti,
570628
**(extra_args if extra_args is not None else {}),
571629
},
572630
)
@@ -594,6 +652,8 @@ def train(
594652
placeholder_tokens
595653
), "Unequal Initializer token for Placeholder tokens."
596654

655+
if proxy_token is not None:
656+
class_token = proxy_token
597657
class_token = "".join(initializer_tokens)
598658

599659
if placeholder_token_at_data is not None:
@@ -817,6 +877,12 @@ def train(
817877
lora_unet_target_modules=lora_unet_target_modules,
818878
lora_clip_target_modules=lora_clip_target_modules,
819879
mask_temperature=mask_temperature,
880+
tokenizer=tokenizer,
881+
out_name=out_name,
882+
test_image_path=instance_data_dir,
883+
log_wandb=log_wandb,
884+
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
885+
class_token=class_token,
820886
)
821887

822888

lora_diffusion/dataset.py

+1
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def __init__(
223223
transforms.ColorJitter(0.1, 0.1)
224224
if color_jitter
225225
else transforms.Lambda(lambda x: x),
226+
transforms.CenterCrop(size),
226227
transforms.ToTensor(),
227228
transforms.Normalize([0.5], [0.5]),
228229
]

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name="lora_diffusion",
88
py_modules=["lora_diffusion"],
9-
version="0.1.3",
9+
version="0.1.4",
1010
description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
1111
author="Simo Ryu",
1212
packages=find_packages(),

0 commit comments

Comments
 (0)