Skip to content

Commit

Permalink
pathlib and fstrings Signed-off-by: Steven Zimmerman <[email protected]>
Browse files Browse the repository at this point in the history
  • Loading branch information
SZim92 committed Aug 12, 2024
1 parent a233228 commit 279d649
Showing 1 changed file with 98 additions and 105 deletions.
203 changes: 98 additions & 105 deletions service/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import base64
import math
from typing import IO
from typing import IO, Optional, Union, Callable
from pathlib import Path
from PIL import Image
import io
import os
import hashlib
import torch
from typing import Optional, Union, Callable
import numpy as np

# Import model_config here for better organization and to avoid repeated imports within functions
import model_config
import realesrgan
import realesrgan


def image_to_base64(image: Image.Image) -> str:
"""
Expand All @@ -24,9 +24,7 @@ def image_to_base64(image: Image.Image) -> str:
"""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return "data:image/png;base64,{}".format(
base64.b64encode(buffered.getvalue()).decode("utf-8")
)
return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"


def generate_mask_image(mask_flag_bytes: bytes, width: int, height: int) -> Image.Image:
Expand All @@ -41,12 +39,8 @@ def generate_mask_image(mask_flag_bytes: bytes, width: int, height: int) -> Imag
Returns:
Image.Image: The generated mask image as a PIL Image.
"""
from PIL import Image
import numpy as np

np_data = np.frombuffer(mask_flag_bytes, dtype=np.uint8)
image = Image.fromarray(np_data.reshape((height, width)), mode="L").convert("RGB")

return image


Expand Down Expand Up @@ -82,7 +76,7 @@ def get_image_shape_ceil(image: Image.Image) -> float:
return get_shape_ceil(H, W)


def check_model_exist(type: int, repo_id: str) -> bool: # Corrected typo in function name: "mmodel" to "model".
def check_model_exist(type: int, repo_id: str) -> bool:
"""
Check if a model of a specified type exists in the configured model directory.
Expand All @@ -104,121 +98,122 @@ def check_model_exist(type: int, repo_id: str) -> bool: # Corrected typo in fun
folder_name = repo_id.replace("/", "---") # Replace '/' with '---' for folder names.

if type == 0: # Check for LLM models
model_path = model_config.config.get("llm")
return os.path.exists(os.path.join(model_path, folder_name, "config.json"))
if type == 1: # Check for Stable Diffusion models
model_path = model_config.config.get("stableDiffusion")
model_path = Path(model_config.config.get("llm"))
return (model_path / folder_name / "config.json").exists()
elif type == 1: # Check for Stable Diffusion models
model_path = Path(model_config.config.get("stableDiffusion"))
if is_single_file(repo_id):
return os.path.exists(os.path.join(model_path, repo_id))
return os.path.exists(os.path.join(model_path, folder_name, "model_index.json"))
if type == 2: # Check for LORA models
model_path = model_config.config.get("lora")
return (model_path / repo_id).exists()
return (model_path / folder_name / "model_index.json").exists()
elif type == 2: # Check for LORA models
model_path = Path(model_config.config.get("lora"))
if is_single_file(repo_id):
return os.path.exists(os.path.join(model_path, repo_id))
return (model_path / repo_id).exists()
return (
os.path.exists(os.path.join(model_path, folder_name, "pytorch_lora_weights.safetensors"))
or os.path.exists(os.path.join(model_path, folder_name, "pytorch_lora_weights.bin"))
(model_path / folder_name / "pytorch_lora_weights.safetensors").exists()
or (model_path / folder_name / "pytorch_lora_weights.bin").exists()
)
if type == 3: # Check for VAE models
model_path = model_config.config.get("vae")
return os.path.exists(os.path.join(model_path, folder_name))
if type == 4: # Check for ESRGAN models
model_path = model_config.config.get("ESRGAN")
return os.path.exists(os.path.join(model_path, realesrgan.ESRGAN_MODEL_URL.split("/")[-1]))
if type == 5: # Check for Embedding models
model_path = model_config.config.get("embedding")
return os.path.exists(os.path.join(model_path, folder_name))
if type == 6: # Check for Inpaint models
model_path = model_config.config.get("inpaint")
elif type == 3: # Check for VAE models
model_path = Path(model_config.config.get("vae"))
return (model_path / folder_name).exists()
elif type == 4: # Check for ESRGAN models
model_path = Path(model_config.config.get("ESRGAN"))
return (model_path / Path(realesrgan.ESRGAN_MODEL_URL).name).exists()
elif type == 5: # Check for Embedding models
model_path = Path(model_config.config.get("embedding"))
return (model_path / folder_name).exists()
elif type == 6: # Check for Inpaint models
model_path = Path(model_config.config.get("inpaint"))
if is_single_file(repo_id):
return os.path.exists(os.path.join(model_path, repo_id))
return os.path.exists(os.path.join(model_path, folder_name, "model_index.json"))
if type == 7: # Check for Preview models
model_path = model_config.config.get("preview")
return (model_path / repo_id).exists()
return (model_path / folder_name / "model_index.json").exists()
elif type == 7: # Check for Preview models
model_path = Path(model_config.config.get("preview"))
return (
os.path.exists(os.path.join(model_path, folder_name, "config.json"))
or os.path.exists(os.path.join(model_path, f"{repo_id}.safetensors"))
or os.path.exists(os.path.join(model_path, f"{repo_id}.bin"))
(model_path / folder_name / "config.json").exists()
or (model_path / f"{repo_id}.safetensors").exists()
or (model_path / f"{repo_id}.bin").exists()
)
raise Exception(f"Unknown model type value: {type}") # Corrected typo: "uwnkown" to "unknown".
else:
raise Exception(f"Unknown model type value: {type}") # Corrected typo: "uwnkown" to "unknown".


def convert_model_type(type: int) -> str:
"""Converts a model type code (int) to its corresponding string representation.
"""
Converts a model type code (int) to its corresponding string representation.
Args:
type (int): An integer representing the model type.
Returns:
str: The string representation of the model type.
"""
if type == 0:
return "llm"
if type == 1:
return "stableDiffusion"
if type == 2:
return "lora"
if type == 3:
return "vae"
if type == 4:
return "ESRGAN"
if type == 5:
return "embedding"
if type == 6:
return "inpaint"
if type == 7:
return "preview"
raise Exception(f"Unknown model type value: {type}") # Corrected typo: "uwnkown" to "unknown".


def get_model_path(type: int) -> str:
"""Gets the file path associated with a given model type.
model_type_map = {
0: "llm",
1: "stableDiffusion",
2: "lora",
3: "vae",
4: "ESRGAN",
5: "embedding",
6: "inpaint",
7: "preview"
}

try:
return model_type_map[type]
except KeyError:
raise Exception(f"Unknown model type value: {type}") # Corrected typo: "uwnkown" to "unknown".


def get_model_path(type: int) -> Path:
"""
Gets the file path associated with a given model type.
Args:
type (int): The integer code representing the model type.
Returns:
str: The file path associated with the specified model type.
Path: The file path associated with the specified model type.
"""
return model_config.config.get(convert_model_type(type))
return Path(model_config.config.get(convert_model_type(type)))


def calculate_md5(file_path: str) -> str:
"""Calculates the MD5 hash of a file.
def calculate_md5(file_path: Path) -> str:
"""
Calculates the MD5 hash of a file.
Args:
file_path (str): The path to the file.
file_path (Path): The path to the file.
Returns:
str: The MD5 hash of the file.
"""
with open(file_path, "rb") as f:
with file_path.open("rb") as f:
file_hash = hashlib.md5()
while chunk := f.read(8192):
file_hash.update(chunk)
return file_hash.hexdigest()


def create_cache_path(md5: str, file_size: int) -> str:
"""Creates a path for caching a file based on its MD5 hash and size.
def create_cache_path(md5: str, file_size: int) -> Path:
"""
Creates a path for caching a file based on its MD5 hash and size.
Args:
md5 (str): The MD5 hash of the file.
md5 (str): The MD5 hash of the file.
file_size (int): The size of the file in bytes.
Returns:
str: The constructed cache path for the file.
Path: The constructed cache path for the file.
"""
cache_dir = "./cache"
sub_dirs = [md5[i : i + 4] for i in range(0, len(md5), 4)]
cache_path = os.path.abspath(
os.path.join(cache_dir, *sub_dirs, f"{md5}_{file_size}")
)
return cache_path
cache_dir = Path("./cache")
sub_dirs = [md5[i:i + 4] for i in range(0, len(md5), 4)]
return cache_dir.joinpath(*sub_dirs, f"{md5}_{file_size}").resolve()


def calculate_md5_from_stream(file_stream: IO[bytes]):
"""Calculates the MD5 hash from a file stream.
def calculate_md5_from_stream(file_stream: IO[bytes]) -> str:
"""
Calculates the MD5 hash from a file stream.
Args:
file_stream (IO[bytes]): The file stream to read from.
Expand All @@ -232,47 +227,45 @@ def calculate_md5_from_stream(file_stream: IO[bytes]):
return file_hash.hexdigest()


def cache_file(file_path: Union[IO[bytes], str], file_size: int):
"""Caches a file based on its MD5 hash and size.
def cache_file(file_path: Union[IO[bytes], Path], file_size: int):
"""
Caches a file based on its MD5 hash and size.
The function calculates the MD5 hash of the provided file, creates a directory
structure for caching based on the hash, and moves the file to the cache location.
It then creates a hard link from the original file path to the cached file.
Args:
file_path (Union[IO[bytes], str]): The path to the file, either a file-like object or a string.
file_path (Union[IO[bytes], Path]): The path to the file, either a file-like object or a Path.
file_size (int): The size of the file in bytes.
"""
if isinstance(file_path, io.IOBase):
# If file_path is a stream, save it to a temporary file.
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
shutil.copyfileobj(file_path, temp_file)
temp_file_path = temp_file.name
temp_file_path = Path(temp_file.name)
md5 = calculate_md5(temp_file_path)
cache_path = create_cache_path(md5, file_size)

if not os.path.exists(cache_path):
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
# Since we already wrote to temp file, just rename
os.rename(temp_file_path, cache_path)
else:
# Calculate the MD5 checksum of the file.
if not cache_path.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True)
temp_file_path.rename(cache_path)
else:
md5 = calculate_md5(file_path)
cache_path = create_cache_path(md5, file_size)

cache_path = create_cache_path(md5, file_size) # Create the cache path based on the checksum and size.
if not cache_path.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True)
file_path.rename(cache_path)

if not os.path.exists(cache_path): # Check if the cache path already exists.
os.makedirs(os.path.dirname(cache_path), exist_ok=True) # Create necessary directories.
os.rename(file_path, cache_path) # Move the file to the cache location.
if file_path.exists():
file_path.unlink()

if os.path.exists(file_path): # If the original file still exists, remove it.
os.remove(file_path)
try:
os.link(cache_path, file_path) # Create a hard link from the cache path to the original location.
cache_path.link_to(file_path)
except OSError:
# If making a hardlink fails, fallback to copying the file
shutil.copyfile(cache_path, file_path)


def is_single_file(filename: str) -> bool:
"""
Checks if a filename corresponds to a single model file.
Expand All @@ -283,7 +276,7 @@ def is_single_file(filename: str) -> bool:
Returns:
bool: True if the filename corresponds to a single file (.safetensors or .bin), False otherwise.
"""
return filename.endswith('.safetensors') or filename.endswith('.bin')
return filename.endswith(('.safetensors', '.bin'))


def get_ESRGAN_size() -> int:
Expand All @@ -309,9 +302,9 @@ def get_support_graphics(env_type: str):
"""
device_count = torch.xpu.device_count() # Get the number of XPU devices available.
model_config.env_type = env_type # Set the environment type in the model configuration.
graphics = list() # Initialize the list to store supported devices.
graphics = []
for i in range(device_count): # Iterate over each device.
device_name = torch.xpu.get_device_name(i) # Get the name of the device.
if device_name == "Intel(R) Arc(TM) Graphics" or re.search("Intel\(R\) Arc\(TM\) [^ ]+ Graphics", device_name) is not None:
if device_name == "Intel(R) Arc(TM) Graphics" or re.search("Intel\\(R\\) Arc\\(TM\\) [^ ]+ Graphics", device_name) is not None:
graphics.append({"index": i, "name": device_name}) # If the device is supported, add it to the list.
return graphics # Return the list of supported devices.

0 comments on commit 279d649

Please sign in to comment.