From 9615a3c8551c9858d3f44f24b56ff2c382eed79b Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 22 Jul 2023 07:05:37 -0400 Subject: [PATCH] fix AND linebreaks without replacing newlines with normal space (#33) * fix and linebreaks * fix token counter * comment * renamee --------- Co-authored-by: ljleb --- lib_neutral_prompt/prompt_parser_hijack.py | 16 ++++++++++------ lib_neutral_prompt/ui.py | 1 + scripts/main.py | 7 ------- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/lib_neutral_prompt/prompt_parser_hijack.py b/lib_neutral_prompt/prompt_parser_hijack.py index 0ec0d5e..a7a5352 100644 --- a/lib_neutral_prompt/prompt_parser_hijack.py +++ b/lib_neutral_prompt/prompt_parser_hijack.py @@ -12,13 +12,18 @@ ) -@prompt_parser_hijacker.hijack('get_multicond_learned_conditioning') -def get_multicond_learned_conditioning_hijack(model, prompts, steps, original_function): +# the only difference with the original `re_weight` is the prefix `\s*` in `^(\s*.*?)` (originally `^(.*?)`) +# this makes it possible to line break AND prompts without replacing all newlines with normal spaces +prompt_parser.re_weight = re.compile(r"^(\s*.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") + + +@prompt_parser_hijacker.hijack('get_multicond_prompt_list') +def get_multicond_prompt_list_hijack(prompts, original_function): if not global_state.is_enabled: - return original_function(model, prompts, steps) + return original_function(prompts) global_state.prompt_exprs = parse_prompts(prompts) - return original_function(model, transpile_exprs(global_state.prompt_exprs), steps) + return original_function(transpile_exprs(global_state.prompt_exprs)) def parse_prompts(prompts: List[str]) -> neutral_prompt_parser.PromptExpr: @@ -40,8 +45,7 @@ def transpile_exprs(exprs: neutral_prompt_parser.PromptExpr): class WebuiPromptVisitor: def visit_leaf_prompt(self, that: neutral_prompt_parser.LeafPrompt) -> str: - prompt = re.sub(r'\s+', ' ', that.prompt).strip() - return f'{prompt} :{that.weight}' + return f'{that.prompt} :{that.weight}' def visit_composite_prompt(self, that: neutral_prompt_parser.CompositePrompt) -> str: return ' AND '.join(child.accept(self) for child in that.children) diff --git a/lib_neutral_prompt/ui.py b/lib_neutral_prompt/ui.py index ea755a4..a7fb288 100644 --- a/lib_neutral_prompt/ui.py +++ b/lib_neutral_prompt/ui.py @@ -85,6 +85,7 @@ def unpack_processing_args( def on_ui_settings(): section = ('neutral_prompt', 'Neutral Prompt') shared.opts.add_option('neutral_prompt_enabled', shared.OptionInfo(True, 'Enable neutral-prompt extension', section=section)) + global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True) shared.opts.add_option('neutral_prompt_verbose', shared.OptionInfo(False, 'Enable verbose debugging for neutral-prompt', section=section)) shared.opts.onchange('neutral_prompt_verbose', update_verbose) diff --git a/scripts/main.py b/scripts/main.py index 2c3789f..8aeb697 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -1,11 +1,4 @@ from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui -import importlib -importlib.reload(global_state) -importlib.reload(hijacker) -importlib.reload(neutral_prompt_parser) -importlib.reload(prompt_parser_hijack) -importlib.reload(cfg_denoiser_hijack) -importlib.reload(ui) from modules import scripts, processing, shared from typing import Dict import functools