From fbeef45aa2db8e1fe793521c5dc5863088284aee Mon Sep 17 00:00:00 2001 From: ClownsharkBatwing Date: Sun, 15 Dec 2024 21:51:29 -0500 Subject: [PATCH 1/2] Add files via upload Bugfix for regional differential conditioning. --- conditioning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conditioning.py b/conditioning.py index 2bb5909..d226ac4 100644 --- a/conditioning.py +++ b/conditioning.py @@ -446,7 +446,7 @@ def __call__(self, transformer_options, threshold, threshold_inv=False, dtype=to if self.mask_type == "gradient": return self.mask.clone().to(sigma.device).to(dtype) elif self.mask_type == "differential": - return self.mask.clone().to(sigma.device) >= threshold + return self.mask.clone().to(sigma.device) > threshold if threshold_inv==False: mask_tmp = self.mask.to(dtype).clone().to(sigma.device) >= threshold return mask_tmp From 3d950a666428c8d2c73081a2381d9d1c79f716c6 Mon Sep 17 00:00:00 2001 From: ClownsharkBatwing Date: Sun, 15 Dec 2024 23:56:26 -0500 Subject: [PATCH 2/2] Add files via upload --- flux/model.py | 25 +++++++++++++++++++++---- rk_sampler.py | 5 +++++ samplers.py | 1 + 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/flux/model.py b/flux/model.py index 6aece28..219d16a 100644 --- a/flux/model.py +++ b/flux/model.py @@ -96,16 +96,27 @@ def forward_blocks(self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Ten pe = self.pe_embedder(ids) weight = transformer_options['reg_cond_weight'] if 'reg_cond_weight' in transformer_options else 0.0 + threshold_factor_double = transformer_options['reg_cond_diff_threshold_factor_double'] if 'reg_cond_diff_threshold_factor_double' in transformer_options else 1.0 + threshold_factor_single = transformer_options['reg_cond_diff_threshold_factor_single'] if 'reg_cond_diff_threshold_factor_single' in transformer_options else 0.5 + + threshold_factor_double_absolute = transformer_options['reg_cond_diff_threshold_factor_double_absolute'] if 'reg_cond_diff_threshold_factor_double_absolute' in transformer_options else -1 + threshold_factor_single_absolute = transformer_options['reg_cond_diff_threshold_factor_single_absolute'] if 'reg_cond_diff_threshold_factor_single_absolute' in transformer_options else -1 for i, block in enumerate(self.double_blocks): mask = None mask_obj = transformer_options.get('patches', {}).get('regional_conditioning_mask', None) - threshold = i / 56 + threshold = i / 57 if mask_obj is not None and weight >= threshold: #if self.threshold_inv: # threshold = 1 - threshold - mask = mask_obj[0](transformer_options, threshold) #, self.threshold_inv) + #threshold = (56-i)/57 + if threshold_factor_double_absolute < 0: + mask_threshold = threshold * threshold_factor_double + else: + mask_threshold = threshold_factor_double_absolute + mask = mask_obj[0](transformer_options, mask_threshold) #, self.threshold_inv) #mask = (mask >= threshold).to(mask.dtype) + #print("i = ", i, threshold, weight) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, timestep=timesteps, transformer_options=transformer_options, mask=mask) #, mask=mask) @@ -120,13 +131,19 @@ def forward_blocks(self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Ten for i, block in enumerate(self.single_blocks): mask = None mask_obj = transformer_options.get('patches', {}).get('regional_conditioning_mask', None) - threshold = (1+18)/56 + threshold = (i+18)/57 if mask_obj is not None and weight >= threshold: #if self.threshold_inv: # threshold = 1 - threshold - mask = mask_obj[0](transformer_options, threshold) #, self.threshold_inv) + #threshold = (56-(i+18))/57 + if threshold_factor_single_absolute < 0: + mask_threshold = threshold * threshold_factor_single + else: + mask_threshold = threshold_factor_single_absolute + mask = mask_obj[0](transformer_options, mask_threshold) #, self.threshold_inv) #threshold = (i+18)/56 #mask = (mask >= threshold).to(mask.dtype) + #print("i2 = ", i, threshold, weight) img = block(img, vec=vec, pe=pe, timestep=timesteps, transformer_options=transformer_options, mask=mask) diff --git a/rk_sampler.py b/rk_sampler.py index 7b5d38b..5638d84 100644 --- a/rk_sampler.py +++ b/rk_sampler.py @@ -145,6 +145,10 @@ def post_cfg_function(args): model_options = extra_args.get("model_options", {}).copy() extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args['model_options']['transformer_options']['reg_cond_diff_threshold_factor_double'] = float(get_extra_options_kv("reg_cond_diff_threshold_factor_double", "0.5", extra_options)) + extra_args['model_options']['transformer_options']['reg_cond_diff_threshold_factor_single'] = float(get_extra_options_kv("reg_cond_diff_threshold_factor_single", "0.5", extra_options)) + extra_args['model_options']['transformer_options']['reg_cond_diff_threshold_factor_double_absolute'] = float(get_extra_options_kv("reg_cond_diff_threshold_factor_double_absolute", "-1", extra_options)) + extra_args['model_options']['transformer_options']['reg_cond_diff_threshold_factor_single_absolute'] = float(get_extra_options_kv("reg_cond_diff_threshold_factor_single_absolute", "-1", extra_options)) if extra_options_flag("cfg_cw", extra_options): cfg_cw = float(get_extra_options_kv("cfg_cw", "1.0", extra_options)) extra_args = rk.init_cfg_channelwise(x, cfg_cw, **extra_args) @@ -162,6 +166,7 @@ def post_cfg_function(args): else: #model.inner_model.model_options['transformer_options']['patches']['regional_conditioning_weight'] = 0.0 extra_args['model_options']['transformer_options']['regional_conditioning_weight'] = 0.0 + eta = eta_var = etas[step] if etas is not None else eta s_noise = s_noises[step] if s_noises is not None else s_noise diff --git a/samplers.py b/samplers.py index 067fc43..430b241 100644 --- a/samplers.py +++ b/samplers.py @@ -103,6 +103,7 @@ def main(self, model, cfg, sampler_mode, scheduler, steps, denoise=1.0, denoise_ model.set_model_patch(regional_mask, 'regional_conditioning_mask') if "extra_options" in sampler.extra_options: + extra_options += " " extra_options += sampler.extra_options['extra_options'] sampler.extra_options['extra_options'] = extra_options