Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
idonahum1 committed May 21, 2024
1 parent dc09ccf commit 87bd2a8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
30 changes: 13 additions & 17 deletions run_tests_with_prompts.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#!/bin/bash

# Define paths
MODEL1_PATH="photoverse_058000.pt"
MODEL2_PATH="photoverse_arcface_042000.pt"
MODEL3_PATH="photoverse_facenet_074000.pt"
MODEL_PATH="photoverse_facenet_074000.pt"

# Define other common arguments
PRETRAINED_MODEL_NAME="runwayml/stable-diffusion-v1-5"
Expand All @@ -22,17 +20,15 @@ DENOISE_TIMESTEPS=25 # Update if you want to run for multiple timesteps
mkdir -p $OUTPUT_DIR

# Run the script for each model
for MODEL_PATH in "$MODEL1_PATH" "$MODEL2_PATH" "$MODEL3_PATH"; do
python test_prompts.py --pretrained_model_name_or_path $PRETRAINED_MODEL_NAME \
--pretrained_photoverse_path $MODEL_PATH \
--data_root_path $DATA_ROOT_PATH \
--img_subfolder $IMG_SUBFOLDER \
--output_dir $OUTPUT_DIR \
--batch_size $BATCH_SIZE \
--num_workers $NUM_WORKERS \
--denoise_timesteps $DENOISE_TIMESTEPS \
--guidance_scale $GUIDANCE_SCALE \
--device $DEVICE \
--resolution $RESOLUTION \
--max_gen_images $MAX_GEN_IMAGES
done
python test_prompts.py --pretrained_model_name_or_path $PRETRAINED_MODEL_NAME \
--pretrained_photoverse_path $MODEL_PATH \
--data_root_path $DATA_ROOT_PATH \
--img_subfolder $IMG_SUBFOLDER \
--output_dir $OUTPUT_DIR \
--batch_size $BATCH_SIZE \
--num_workers $NUM_WORKERS \
--denoise_timesteps $DENOISE_TIMESTEPS \
--guidance_scale $GUIDANCE_SCALE \
--device $DEVICE \
--resolution $RESOLUTION \
--max_gen_images $MAX_GEN_IMAGES
26 changes: 18 additions & 8 deletions test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,23 @@
from utils.image_utils import to_pil, denormalize, save_images_grid

PROMPTS = ['A photo of {}',
'{} in Ghibli anime style',
'{} in Disney & Pixar style',
'{} in Ghilbi anime style',
'{} in Disney/Pixar style',
'{} wears a red hat',
'{} on the beach',
'Manga drawing of {}',
'{} Funko Pop',
'{} latte art', ]
'{} as a Funko Pop figure',
'Latte art of {}',
'{} flower arrangement',
'Pointillism painting of {}',
'{} stained glass window',
'{} is camping in the mountains',
'{} is a character in a video game',
'Watercolor painting of {}',
'{} as a knight in plate',
'{} as a character in a comic book']

PROMPTS_NAMES = ['photo','ghibli', 'disney_pixar', 'red_hat', 'beach', 'manga', 'funko_pop', 'latte_art']
PROMPTS_NAMES = ['photo','ghibli', 'disney_pixar', 'red_hat', 'beach', 'manga', 'funko_pop', 'latte_art', 'flower_arrangement', 'pointillism', 'stained_glass', 'camping', 'video_game', 'watercolor', 'knight', 'comic_book']

def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
Expand Down Expand Up @@ -144,6 +152,9 @@ def main():
pixel_values = sample["pixel_values"].to(args.device)
input_images = [to_pil(denormalize(img)) for img in pixel_values]
grid_data = [("Input Images", input_images)]
for row_idx, input_image in enumerate(input_images):
os.makedirs(os.path.join(full_output_dir, f"grid_{batch_idx}_row_{row_idx}"), exist_ok=True)
input_image.save(os.path.join(full_output_dir, f"grid_{batch_idx}_row_{row_idx}", "input_image.png"))
for prompt, prompt_name in zip(PROMPTS, PROMPTS_NAMES):
sample_to_update = prepare_prompt(tokenizer, prompt, "*",
num_of_samples=len(pixel_values))
Expand All @@ -154,9 +165,8 @@ def main():
guidance_scale=args.guidance_scale,
timesteps=args.denoise_timesteps, token_index=0)
gen_images = [to_pil(denormalize(gen_tensor)) for gen_tensor in gen_tensors]
for sample_idx, (gen_image, input_image) in enumerate(zip(gen_images, input_images)):
gen_image.save(os.path.join(full_output_dir,
f"generated_{prompt_name}_img_batch_idx{batch_idx}_sample_idx{sample_idx}.png"))
for sample_idx, gen_image in enumerate(gen_images):
gen_image.save(os.path.join(full_output_dir, f"grid_{batch_idx}_row_{sample_idx}", f"{prompt_name}.png"))
grid_data.append((sample['text'][0], gen_images))
torch.cuda.empty_cache()

Expand Down

0 comments on commit 87bd2a8

Please sign in to comment.