Skip to content

Commit

Permalink
Merge branch 'ClownsharkBatwing:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
drozbay authored Dec 16, 2024
2 parents f87bd82 + 3d950a6 commit f3cee41
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 5 deletions.
2 changes: 1 addition & 1 deletion conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def __call__(self, transformer_options, threshold, threshold_inv=False, dtype=to
if self.mask_type == "gradient":
return self.mask.clone().to(sigma.device).to(dtype)
elif self.mask_type == "differential":
return self.mask.clone().to(sigma.device) >= threshold
return self.mask.clone().to(sigma.device) > threshold
if threshold_inv==False:
mask_tmp = self.mask.to(dtype).clone().to(sigma.device) >= threshold
return mask_tmp
Expand Down
25 changes: 21 additions & 4 deletions flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,27 @@ def forward_blocks(self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Ten
pe = self.pe_embedder(ids)

weight = transformer_options['reg_cond_weight'] if 'reg_cond_weight' in transformer_options else 0.0
threshold_factor_double = transformer_options['reg_cond_diff_threshold_factor_double'] if 'reg_cond_diff_threshold_factor_double' in transformer_options else 1.0
threshold_factor_single = transformer_options['reg_cond_diff_threshold_factor_single'] if 'reg_cond_diff_threshold_factor_single' in transformer_options else 0.5

threshold_factor_double_absolute = transformer_options['reg_cond_diff_threshold_factor_double_absolute'] if 'reg_cond_diff_threshold_factor_double_absolute' in transformer_options else -1
threshold_factor_single_absolute = transformer_options['reg_cond_diff_threshold_factor_single_absolute'] if 'reg_cond_diff_threshold_factor_single_absolute' in transformer_options else -1

for i, block in enumerate(self.double_blocks):
mask = None
mask_obj = transformer_options.get('patches', {}).get('regional_conditioning_mask', None)
threshold = i / 56
threshold = i / 57
if mask_obj is not None and weight >= threshold:
#if self.threshold_inv:
# threshold = 1 - threshold
mask = mask_obj[0](transformer_options, threshold) #, self.threshold_inv)
#threshold = (56-i)/57
if threshold_factor_double_absolute < 0:
mask_threshold = threshold * threshold_factor_double
else:
mask_threshold = threshold_factor_double_absolute
mask = mask_obj[0](transformer_options, mask_threshold) #, self.threshold_inv)
#mask = (mask >= threshold).to(mask.dtype)
#print("i = ", i, threshold, weight)

img, txt = block(img=img, txt=txt, vec=vec, pe=pe, timestep=timesteps, transformer_options=transformer_options, mask=mask) #, mask=mask)

Expand All @@ -120,13 +131,19 @@ def forward_blocks(self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Ten
for i, block in enumerate(self.single_blocks):
mask = None
mask_obj = transformer_options.get('patches', {}).get('regional_conditioning_mask', None)
threshold = (1+18)/56
threshold = (i+18)/57
if mask_obj is not None and weight >= threshold:
#if self.threshold_inv:
# threshold = 1 - threshold
mask = mask_obj[0](transformer_options, threshold) #, self.threshold_inv)
#threshold = (56-(i+18))/57
if threshold_factor_single_absolute < 0:
mask_threshold = threshold * threshold_factor_single
else:
mask_threshold = threshold_factor_single_absolute
mask = mask_obj[0](transformer_options, mask_threshold) #, self.threshold_inv)
#threshold = (i+18)/56
#mask = (mask >= threshold).to(mask.dtype)
#print("i2 = ", i, threshold, weight)

img = block(img, vec=vec, pe=pe, timestep=timesteps, transformer_options=transformer_options, mask=mask)

Expand Down
5 changes: 5 additions & 0 deletions rk_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def post_cfg_function(args):
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)

extra_args['model_options']['transformer_options']['reg_cond_diff_threshold_factor_double'] = float(get_extra_options_kv("reg_cond_diff_threshold_factor_double", "0.5", extra_options))
extra_args['model_options']['transformer_options']['reg_cond_diff_threshold_factor_single'] = float(get_extra_options_kv("reg_cond_diff_threshold_factor_single", "0.5", extra_options))
extra_args['model_options']['transformer_options']['reg_cond_diff_threshold_factor_double_absolute'] = float(get_extra_options_kv("reg_cond_diff_threshold_factor_double_absolute", "-1", extra_options))
extra_args['model_options']['transformer_options']['reg_cond_diff_threshold_factor_single_absolute'] = float(get_extra_options_kv("reg_cond_diff_threshold_factor_single_absolute", "-1", extra_options))
if extra_options_flag("cfg_cw", extra_options):
cfg_cw = float(get_extra_options_kv("cfg_cw", "1.0", extra_options))
extra_args = rk.init_cfg_channelwise(x, cfg_cw, **extra_args)
Expand All @@ -162,6 +166,7 @@ def post_cfg_function(args):
else:
#model.inner_model.model_options['transformer_options']['patches']['regional_conditioning_weight'] = 0.0
extra_args['model_options']['transformer_options']['regional_conditioning_weight'] = 0.0

eta = eta_var = etas[step] if etas is not None else eta
s_noise = s_noises[step] if s_noises is not None else s_noise

Expand Down
1 change: 1 addition & 0 deletions samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def main(self, model, cfg, sampler_mode, scheduler, steps, denoise=1.0, denoise_
model.set_model_patch(regional_mask, 'regional_conditioning_mask')

if "extra_options" in sampler.extra_options:
extra_options += " "
extra_options += sampler.extra_options['extra_options']
sampler.extra_options['extra_options'] = extra_options

Expand Down

0 comments on commit f3cee41

Please sign in to comment.