Skip to content
This repository has been archived by the owner on May 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request AUTOMATIC1111#3197 from AUTOMATIC1111/training-hel…
Browse files Browse the repository at this point in the history
…p-text

Training UI Changes
  • Loading branch information
AUTOMATIC1111 authored Oct 21, 2022
2 parents 2273e75 + 0c5522e commit e487772
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 23 deletions.
2 changes: 1 addition & 1 deletion modules/hypernetworks/hypernetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
Loss: {mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
Expand Down
5 changes: 3 additions & 2 deletions modules/hypernetworks/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from modules.hypernetworks import hypernetwork


def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"

if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
Expand Down
40 changes: 28 additions & 12 deletions modules/textual_inversion/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import modules.deepbooru as deepbooru


def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
try:
if process_caption:
shared.interrogator.load()
Expand All @@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)

preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru)

finally:

Expand All @@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_



def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
Expand All @@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)

def save_pic_with_caption(image, index):
def save_pic_with_caption(image, index, existing_caption=None):
caption = ""

if process_caption:
Expand All @@ -66,17 +66,26 @@ def save_pic_with_caption(image, index):
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))

if preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
elif preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
elif preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption

caption = caption.strip()

if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)

subindex[0] += 1

def save_pic(image, index):
save_pic_with_caption(image, index)
def save_pic(image, index, existing_caption=None):
save_pic_with_caption(image, index, existing_caption=existing_caption)

if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index)
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)

for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
Expand All @@ -86,6 +95,13 @@ def save_pic(image, index):
except Exception:
continue

existing_caption = None

try:
existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
except Exception as e:
print(e)

if shared.state.interrupted:
break

Expand All @@ -97,20 +113,20 @@ def save_pic(image, index):
img = img.resize((width, height * img.height // img.width))

top = img.crop((0, 0, width, height))
save_pic(top, index)
save_pic(top, index, existing_caption=existing_caption)

bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
save_pic(bot, index, existing_caption=existing_caption)
elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height))

left = img.crop((0, 0, width, height))
save_pic(left, index)
save_pic(left, index, existing_caption=existing_caption)

right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
save_pic(right, index, existing_caption=existing_caption)
else:
img = images.resize_image(1, img, width, height)
save_pic(img, index)
save_pic(img, index, existing_caption=existing_caption)

shared.state.nextjob()
5 changes: 3 additions & 2 deletions modules/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def find_embedding_at_position(self, tokens, offset):
return None, None


def create_embedding(name, num_vectors_per_token, init_text='*'):
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings

Expand All @@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]

fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"

embedding = Embedding(vec, name)
embedding.step = 0
Expand Down
4 changes: 2 additions & 2 deletions modules/textual_inversion/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from modules import sd_hijack, shared


def create_embedding(name, initialization_text, nvpt):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
def create_embedding(name, initialization_text, nvpt, overwrite_old):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)

sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()

Expand Down
17 changes: 13 additions & 4 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,7 @@ def refresh():
new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")

with gr.Row():
with gr.Column(scale=3):
Expand All @@ -1224,6 +1225,7 @@ def refresh():
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])

with gr.Row():
Expand All @@ -1238,6 +1240,7 @@ def refresh():
process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])

with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
Expand All @@ -1253,14 +1256,17 @@ def refresh():
run_preprocess = gr.Button(value="Preprocess", variant='primary')

with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
with gr.Row():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")

batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
Expand Down Expand Up @@ -1294,6 +1300,7 @@ def refresh():
new_embedding_name,
initialization_text,
nvpt,
overwrite_old_embedding,
],
outputs=[
train_embedding_name,
Expand All @@ -1307,6 +1314,7 @@ def refresh():
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
overwrite_old_hypernetwork,
new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func,
Expand All @@ -1326,6 +1334,7 @@ def refresh():
process_dst,
process_width,
process_height,
preprocess_txt_action,
process_flip,
process_split,
process_caption,
Expand All @@ -1342,7 +1351,7 @@ def refresh():
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
learn_rate,
embedding_learn_rate,
batch_size,
dataset_directory,
log_directory,
Expand All @@ -1367,7 +1376,7 @@ def refresh():
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
learn_rate,
hypernetwork_learn_rate,
batch_size,
dataset_directory,
log_directory,
Expand Down

0 comments on commit e487772

Please sign in to comment.