Skip to content
This repository has been archived by the owner on Jun 12, 2024. It is now read-only.

Commit

Permalink
feat(image.py): add cookies in init
Browse files Browse the repository at this point in the history
  • Loading branch information
dsdanielpark committed Apr 28, 2024
1 parent 9e8d581 commit f40ea36
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions gemini/src/model/image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Will be refactored.
import os
import random
import httpx
Expand All @@ -19,23 +20,25 @@ class GeminiImage(BaseModel):
alt (str): The alt text of the image. Defaults to "".
Methods:
validate_images(cls, images): Validates the input images list.
save(cls, images: List["GeminiImage"], save_path: str = "cached", cookies: Optional[dict] = None) -> Optional[Path]:
validate_images(self, images): Validates the input images list.
save(self, images: List["GeminiImage"], save_path: str = "cached", cookies: Optional[dict] = None) -> Optional[Path]:
Downloads and saves images asynchronously.
fetch_bytes(url: HttpUrl, cookies: Optional[dict] = None) -> Optional[bytes]:
Fetches bytes of an image asynchronously.
fetch_images_dict(cls, images: List["GeminiImage"], cookies: Optional[dict] = None) -> Dict[str, bytes]:
fetch_images_dict(self, images: List["GeminiImage"], cookies: Optional[dict] = None) -> Dict[str, bytes]:
Fetches images asynchronously and returns a dictionary of image data.
save_images(cls, image_data: Dict[str, bytes], save_path: str = "cached"):
save_images(self, image_data: Dict[str, bytes], save_path: str = "cached"):
Saves images locally.
"""

url: HttpUrl
title: str = "[Image]"
alt: str = ""

@classmethod
def validate_images(cls, images):
def __init__(self, cookies=None):
self.cookies = cookies

def validate_images(self, images):
"""
Validates the input images list.
Expand All @@ -51,12 +54,10 @@ def validate_images(cls, images):
)

# Async downloader
@classmethod
async def save(
cls,
self,
images: List["GeminiImage"],
save_path: str = "cached",
cookies: Optional[dict] = None,
) -> Optional[Path]:
"""
Downloads and saves images asynchronously.
Expand All @@ -69,9 +70,9 @@ async def save(
Returns:
Optional[Path]: The path to the directory where the images are saved, or None if saving fails.
"""
cls.validate_images(images)
image_data = await cls.fetch_images_dict(images, cookies)
await cls.save_images(image_data, save_path)
self.validate_images(images)
image_data = await self.fetch_images_dict(images, self.cookies)
await self.save_images(image_data, save_path)

@staticmethod
async def fetch_bytes(
Expand All @@ -96,9 +97,8 @@ async def fetch_bytes(
print(f"Failed to download {url}: {str(e)}")
return None

@classmethod
async def fetch_images_dict(
cls, images: List["GeminiImage"], cookies: Optional[dict] = None
self, images: List["GeminiImage"], cookies: Optional[dict] = None
) -> Dict[str, bytes]:
"""
Fetches images asynchronously and returns a dictionary of image data.
Expand All @@ -110,8 +110,8 @@ async def fetch_images_dict(
Returns:
Dict[str, bytes]: A dictionary containing image titles as keys and image bytes as values.
"""
cls.validate_images(images)
tasks = [cls.fetch_bytes(image.url, cookies=cookies) for image in images]
self.validate_images(images)
tasks = [self.fetch_bytes(image.url, cookies=cookies) for image in images]
results = await asyncio.gather(*tasks)
return {image.title: result for image, result in zip(images, results) if result}

Expand All @@ -137,10 +137,9 @@ async def save_images(image_data: Dict[str, bytes], save_path: str = "cached"):
print(f"Error saving {title}: {str(e)}")

# Sync downloader
@staticmethod
def save_sync(
self,
images: List["GeminiImage"],
cookies: Optional[dict] = None,
save_path: str = "cached",
) -> Optional[Path]:
"""Synchronously saves the image to the specified path.
Expand All @@ -154,7 +153,7 @@ def save_sync(
Returns:
Optional[Path]: The path where the image is saved, or None if saving fails.
"""
image_data = GeminiImage.fetch_images_dict_sync(images, cookies)
image_data = GeminiImage.fetch_images_dict_sync(images, self.cookies)
GeminiImage.validate_images(image_data)
GeminiImage.save_images_sync(image_data, save_path)

Expand Down

0 comments on commit f40ea36

Please sign in to comment.