Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/drozbay/RES4LYF
Browse files Browse the repository at this point in the history
  • Loading branch information
drozbay committed Dec 17, 2024
2 parents d5ed478 + 9ec9c6f commit 132aff4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ def __call__(self, transformer_options, dtype=torch.bfloat16, *args, **kwargs):
sigma = transformer_options['sigmas'][0]
if self.start_percent <= 1 - sigma < self.end_percent:
if self.mask_type == "gradient":
return self.mask.clone().to(sigma.device).to(torch.bool)
return self.mask.clone().to(sigma.device).to(dtype)

return None
Expand Down
4 changes: 2 additions & 2 deletions flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward_blocks(self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Ten
for i, block in enumerate(self.double_blocks):
mask = None
mask_obj = transformer_options.get('patches', {}).get('regional_conditioning_mask', None)
if mask_obj is not None and weight > 0:
if mask_obj is not None and weight >= 0:
mask = mask_obj[0](transformer_options)
text_len = mask_obj[0].text_len
mask[text_len:,text_len:] = torch.clamp(mask[text_len:,text_len:], min=1-weight.to(mask.device))
Expand All @@ -118,7 +118,7 @@ def forward_blocks(self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Ten
for i, block in enumerate(self.single_blocks):
mask = None
mask_obj = transformer_options.get('patches', {}).get('regional_conditioning_mask', None)
if mask_obj is not None and weight > 0:
if mask_obj is not None and weight >= 0:
mask = mask_obj[0](transformer_options)
text_len = mask_obj[0].text_len
mask[text_len:,text_len:] = torch.clamp(mask[text_len:,text_len:], min=1-weight.to(mask.device))
Expand Down

0 comments on commit 132aff4

Please sign in to comment.