Skip to content

Commit

Permalink
refactored save_grid function to save any kind of grid with text
Browse files Browse the repository at this point in the history
  • Loading branch information
idonahum1 committed May 7, 2024
1 parent 8650797 commit 0fdc4f6
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions utils/image_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
import numpy as np

Expand Down Expand Up @@ -28,12 +28,42 @@ def to_pil(image):
return Image.fromarray(image)


def save_images_grid(gen_images, input_images, clip_images, img_grid_file):
def save_images_grid(grid_data, img_grid_file):
img_list = []
for gen_img, input_img, clip_img in zip(gen_images, input_images, clip_images):
text_data = [text_image_pair[0] for text_image_pair in grid_data]
image_data = [text_image_pair[1] for text_image_pair in grid_data]
for zipped_images in zip(*image_data):
# Concatenate images horizontally
img_list.append(
np.concatenate((np.array(gen_img), np.array(input_img), np.array(clip_img)), axis=1))
img_list = np.concatenate(img_list, axis=0)
img_grid = Image.fromarray(img_list)
np.concatenate([np.array(image) for image in zipped_images], axis=1))

# Concatenate all rows vertically
img_array = np.concatenate(img_list, axis=0)
padded_img_array = np.pad(img_array, ((50, 0), (0, 0), (0, 0)), mode='constant', constant_values=255)
img_grid = Image.fromarray(padded_img_array.astype('uint8'), 'RGB')

# Create a drawing context
draw = ImageDraw.Draw(img_grid)

# You may need to install a font or specify a path to one that supports the size you need
# For basic purposes, you can use a default PIL font:
try:
font = ImageFont.truetype("arial.ttf", 36) # Specify the font and size you need
except IOError:
font = ImageFont.load_default(size=36)

num_of_images_in_row = len(image_data)
image_width = image_data[0][0].width
for i, text in enumerate(text_data):
# Calculate the position for the text to be at the center above the i row
text = text.format("S*")
text_x1, text_y1, text_x2, text_y2 = font.getbbox(text)
text_width = text_x2 - text_x1
text_height = text_y2 - text_y1
text_x = (image_width - text_width) // 2 + i * image_width
text_y = (50 - text_height) // 2

# Draw the text
draw.text((text_x, text_y), text, font=font, fill="black")

img_grid.save(img_grid_file)

0 comments on commit 0fdc4f6

Please sign in to comment.