Skip to content

Commit

Permalink
fixes related to merge
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Oct 11, 2022
1 parent 5de8061 commit 530103b
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 160 deletions.
103 changes: 0 additions & 103 deletions modules/hypernetwork.py

This file was deleted.

74 changes: 45 additions & 29 deletions modules/hypernetwork/hypernetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ def __init__(self, dim, state_dict=None):
if state_dict is not None:
self.load_state_dict(state_dict, strict=True)
else:
self.linear1.weight.data.fill_(0.0001)
self.linear1.bias.data.fill_(0.0001)
self.linear2.weight.data.fill_(0.0001)
self.linear2.bias.data.fill_(0.0001)

self.linear1.weight.data.normal_(mean=0.0, std=0.01)
self.linear1.bias.data.zero_()
self.linear2.weight.data.normal_(mean=0.0, std=0.01)
self.linear2.bias.data.zero_()

self.to(devices.device)

Expand Down Expand Up @@ -92,41 +93,54 @@ def load(self, filename):
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)


def load_hypernetworks(path):
def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
res[name] = filename
return res

for filename in glob.iglob(path + '**/*.pt', recursive=True):

def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
if path is not None:
print(f"Loading hypernetwork {filename}")
try:
hn = Hypernetwork()
hn.load(filename)
res[hn.name] = hn
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)

except Exception:
print(f"Error loading hypernetwork {filename}", file=sys.stderr)
print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")

return res
shared.loaded_hypernetwork = None


def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
def apply_hypernetwork(hypernetwork, context, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)

q = self.to_q(x)
context = default(context, x)
if hypernetwork_layers is None:
return context, context

hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None)
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]

if hypernetwork_layers is not None:
hypernetwork_k, hypernetwork_v = hypernetwork_layers
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
return context_k, context_v

self.hypernetwork_k = hypernetwork_k
self.hypernetwork_v = hypernetwork_v

context_k = hypernetwork_k(context)
context_v = hypernetwork_v(context)
else:
context_k = context
context_v = context
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads

q = self.to_q(x)
context = default(context, x)

context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
k = self.to_k(context_k)
v = self.to_v(context_v)

Expand All @@ -151,7 +165,9 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
assert hypernetwork_name, 'embedding not selected'

shared.hypernetwork = shared.hypernetworks[hypernetwork_name]
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)

shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
Expand All @@ -176,9 +192,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,

shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)

hypernetwork = shared.hypernetworks[hypernetwork_name]
hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
Expand All @@ -194,7 +210,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps:
return hypernetwork, filename

pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text) in pbar:
hypernetwork.step = i + ititial_step

Expand Down
10 changes: 5 additions & 5 deletions modules/hypernetwork/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,24 @@
import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
from modules.hypernetwork import hypernetwork


def create_hypernetwork(name):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"

hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
hypernetwork.save(fn)
hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
hypernet.save(fn)

shared.reload_hypernetworks()
shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)

return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""


def train_hypernetwork(*args):

initial_hypernetwork = shared.hypernetwork
initial_hypernetwork = shared.loaded_hypernetwork

try:
sd_hijack.undo_optimizations()
Expand All @@ -38,6 +38,6 @@ def train_hypernetwork(*args):
except Exception:
raise
finally:
shared.hypernetwork = initial_hypernetwork
shared.loaded_hypernetwork = initial_hypernetwork
sd_hijack.apply_optimizations()

3 changes: 2 additions & 1 deletion modules/sd_hijack_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from ldm.util import default
from einops import rearrange

from modules import shared, hypernetwork
from modules import shared
from modules.hypernetwork import hypernetwork


if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
Expand Down
13 changes: 11 additions & 2 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import modules.sd_models
import modules.styles
import modules.devices as devices
from modules import sd_samplers, hypernetwork
from modules import sd_samplers
from modules.hypernetwork import hypernetwork
from modules.paths import models_path, script_path, sd_path

sd_model_file = os.path.join(script_path, 'model.ckpt')
Expand All @@ -29,6 +30,7 @@
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
Expand Down Expand Up @@ -82,10 +84,17 @@
xformers_available = False
config_filename = cmd_opts.ui_settings_file

hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks'))
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None


def reload_hypernetworks():
global hypernetworks

hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)


class State:
skipped = False
interrupted = False
Expand Down
12 changes: 7 additions & 5 deletions modules/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn


def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
assert embedding_name, 'embedding not selected'

shared.state.textinfo = "Initializing textual inversion training..."
Expand Down Expand Up @@ -238,12 +238,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')

preview_text = text if preview_image_prompt == "" else preview_image_prompt

p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=text,
prompt=preview_text,
steps=20,
height=training_height,
width=training_width,
height=training_height,
width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
Expand All @@ -254,7 +256,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.current_image = image
image.save(last_saved_image)

last_saved_image += f", prompt: {text}"
last_saved_image += f", prompt: {preview_text}"

shared.state.job_no = embedding.step

Expand Down
5 changes: 3 additions & 2 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="")

with gr.Column():
create_embedding = gr.Button(value="Create", variant='primary')
create_embedding = gr.Button(value="Create embedding", variant='primary')

with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
Expand All @@ -1035,7 +1035,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="")

with gr.Column():
create_hypernetwork = gr.Button(value="Create", variant='primary')
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')

with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
Expand Down Expand Up @@ -1147,6 +1147,7 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every,
save_embedding_every,
template_file,
preview_image_prompt,
],
outputs=[
ti_output,
Expand Down
3 changes: 2 additions & 1 deletion scripts/xy_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import modules.scripts as scripts
import gradio as gr

from modules import images, hypernetwork
from modules import images
from modules.hypernetwork import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
Expand Down
Loading

0 comments on commit 530103b

Please sign in to comment.