Skip to content

Commit

Permalink
Merge branch 'ClownsharkBatwing:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
drozbay authored Dec 17, 2024
2 parents 132aff4 + 8355357 commit f4b4583
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
10 changes: 7 additions & 3 deletions conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,16 @@ def __init__(self, mask: torch.Tensor, conditioning: torch.Tensor, conditioning_
self.img_len = img_len
self.text_len = text_len

def __call__(self, transformer_options, dtype=torch.bfloat16, *args, **kwargs):
def __call__(self, transformer_options, weight=0, dtype=torch.bfloat16, *args, **kwargs):
sigma = transformer_options['sigmas'][0]
if self.start_percent <= 1 - sigma < self.end_percent:
if self.mask_type == "gradient":
return self.mask.clone().to(sigma.device).to(torch.bool)
return self.mask.clone().to(sigma.device).to(dtype)
mask = self.mask.clone().to(sigma.device)
mask[self.text_len:,self.text_len:] = mask[self.text_len:,self.text_len:] > 1-weight
#mask[self.text_len:,self.text_len:] = torch.clamp(mask[self.text_len:,self.text_len:], min=1-weight)

return mask.to(torch.bool)


return None

Expand Down
10 changes: 6 additions & 4 deletions flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ def forward_blocks(self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Ten
mask = None
mask_obj = transformer_options.get('patches', {}).get('regional_conditioning_mask', None)
if mask_obj is not None and weight >= 0:
mask = mask_obj[0](transformer_options)
mask = mask_obj[0](transformer_options, weight.item())
"""mask = mask_obj[0](transformer_options)
text_len = mask_obj[0].text_len
mask[text_len:,text_len:] = torch.clamp(mask[text_len:,text_len:], min=1-weight.to(mask.device))
mask[text_len:,text_len:] = torch.clamp(mask[text_len:,text_len:], min=1-weight.to(mask.device))"""

img, txt = block(img=img, txt=txt, vec=vec, pe=pe, timestep=timesteps, transformer_options=transformer_options, mask=mask) #, mask=mask)

Expand All @@ -119,9 +120,10 @@ def forward_blocks(self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Ten
mask = None
mask_obj = transformer_options.get('patches', {}).get('regional_conditioning_mask', None)
if mask_obj is not None and weight >= 0:
mask = mask_obj[0](transformer_options)
mask = mask_obj[0](transformer_options, weight.item())
"""mask = mask_obj[0](transformer_options)
text_len = mask_obj[0].text_len
mask[text_len:,text_len:] = torch.clamp(mask[text_len:,text_len:], min=1-weight.to(mask.device))
mask[text_len:,text_len:] = torch.clamp(mask[text_len:,text_len:], min=1-weight.to(mask.device))"""

img = block(img, vec=vec, pe=pe, timestep=timesteps, transformer_options=transformer_options, mask=mask)

Expand Down

0 comments on commit f4b4583

Please sign in to comment.