Skip to content

Commit

Permalink
Support prefix text for blip captioning
Browse files Browse the repository at this point in the history
  • Loading branch information
levi authored Feb 2, 2023
1 parent b1d8293 commit 872a192
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions lora_diffusion/preprocess_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def clipseg_mask_generator(
@torch.no_grad()
def blip_captioning_dataset(
images: List[Image.Image],
text: Optional[str] = None,
model_id: Literal[
"Salesforce/blip-image-captioning-large",
"Salesforce/blip-image-captioning-base",
Expand All @@ -139,7 +140,7 @@ def blip_captioning_dataset(
captions = []

for image in tqdm(images):
inputs = processor(image, return_tensors="pt").to("cuda")
inputs = processor(image, text=text, return_tensors="pt").to("cuda")
out = model.generate(
**inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7
)
Expand Down Expand Up @@ -243,6 +244,7 @@ def _center_of_mass(mask: Image.Image):
def load_and_save_masks_and_captions(
files: Union[str, List[str]],
output_dir: str,
caption_text: Optional[str] = None,
target_prompts: Optional[Union[List[str], str]] = None,
target_size: int = 512,
crop_based_on_salience: bool = True,
Expand Down Expand Up @@ -277,7 +279,7 @@ def load_and_save_masks_and_captions(

# captions
print(f"Generating {len(images)} captions...")
captions = blip_captioning_dataset(images)
captions = blip_captioning_dataset(images, text=caption_text)

if target_prompts is None:
target_prompts = captions
Expand Down

0 comments on commit 872a192

Please sign in to comment.