Skip to content

Commit

Permalink
Add Training and Inference Code
Browse files Browse the repository at this point in the history
  • Loading branch information
Warlord-K committed Jul 21, 2023
0 parents commit 903d524
Show file tree
Hide file tree
Showing 6 changed files with 1,110 additions and 0 deletions.
1 change: 1 addition & 0 deletions BKSDM/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pipeline import prepare_unet, BKSDM
102 changes: 102 additions & 0 deletions BKSDM/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch
import gc
import diffusers
import transformers


class BKSDM:
def __init__(
self,
model_id="runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
model_type="base",
device="cuda",
**kwargs
):
assert model_type in ["base", "tiny", "midless", "small"]
if not torch.cuda.is_available():
device = "cpu"
torch_dtype = torch.float32
self.pipe = diffusers.StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch_dtype
)
self.pipe.to(device)
# Set mid block to None if mode is other than base
if model_type != "base":
self.pipe.unet.mid_block = None
# Commence deletion of resnets/attentions inside the U-net
if model_type != "midless":
# Handle Down Blocks
for i in range(3):
delattr(self.pipe.unet.down_blocks[i].resnets, "1")
delattr(self.pipe.unet.down_blocks[i].attentions, "1")

if model_type == "tiny":
delattr(self.pipe.unet.down_blocks, "3")
self.pipe.unet.down_blocks[2].downsamplers = None

else:
delattr(self.pipe.unet.down_blocks[3].resnets, "1")
# Handle Up blocks

self.pipe.unet.up_blocks[0].resnets[1] = self.pipe.unet.up_blocks[
0
].resnets[2]
delattr(self.pipe.unet.up_blocks[0].resnets, "2")
for i in range(1, 4):
self.pipe.unet.up_blocks[i].attentions[1] = self.pipe.unet.up_blocks[
i
].attentions[2]
delattr(self.pipe.unet.up_blocks[i].attentions, "2")
delattr(self.pipe.unet.up_blocks[i].resnets, "1")
if model_type == "tiny":
for i in range(3):
self.pipe.unet.up_blocks[i] = self.pipe.unet.up_blocks[i + 1]
delattr(self.pipe.unet.up_blocks, "3")
torch.cuda.empty_cache()
gc.collect()

def __call__(
self, prompt, num_inference_steps=50, guidance_scale=7.5, negative_prompt=None
):
return self.pipe(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
)


def prepare_unet(unet, model_type):
assert model_type in ["base", "tiny", "midless", "small"]
# Set mid block to None if mode is other than base
if model_type != "base":
unet.mid_block = None
# Commence deletion of resnets/attentions inside the U-net
if model_type != "midless":
# Handle Down Blocks
for i in range(3):
delattr(unet.down_blocks[i].resnets, "1")
delattr(unet.down_blocks[i].attentions, "1")

if model_type == "tiny":
delattr(unet.down_blocks, "3")
unet.down_blocks[2].downsamplers = None

else:
delattr(unet.down_blocks[3].resnets, "1")
# Handle Up blocks

unet.up_blocks[0].resnets[1] = unet.up_blocks[0].resnets[2]
delattr(unet.up_blocks[0].resnets, "2")
for i in range(1, 4):
unet.up_blocks[i].resnets[1] = unet.up_blocks[i].resnets[2]
unet.up_blocks[i].attentions[1] = unet.up_blocks[i].attentions[2]
delattr(unet.up_blocks[i].attentions, "2")
delattr(unet.up_blocks[i].resnets, "2")
if model_type == "tiny":
for i in range(3):
unet.up_blocks[i] = unet.up_blocks[i + 1]
delattr(unet.up_blocks, "3")
torch.cuda.empty_cache()
gc.collect()
56 changes: 56 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.functions import rand
from img2dataset import download
import shutil
import os


def repartition():
spark = SparkSession.builder.config("spark.driver.memory", "16G") .master("local[16]").appName('spark-repart').getOrCreate()
df = spark.read.parquet("dataset/data.parquet")
# df = df.filter((df.WIDTH >= 1024) & (df.HEIGHT >= 1024))
# df = df.filter((df.AESTHETIC_SCORE > 7))
df = df.orderBy(rand(seed = 0)) # this line is important to have a shuffled dataset
print(df.count())
df.repartition(10).write.parquet("dataset/laion_small")


def download_images(output_dir="dataset/laion_small_images", url = "dataset/laion_small/part-00002-195faf27-0776-429e-a03b-a6aba71d4f16-c000.snappy.parquet"):
output_dir = os.path.abspath(output_dir)

if os.path.exists(output_dir):
shutil.rmtree(output_dir)

spark = (
SparkSession.builder.config("spark.driver.memory", "16G").master("local[16]").appName("spark-stats").getOrCreate()
)

download(
processes_count=1,
thread_count=32,
url_list= url,
image_size=512,
output_folder=output_dir,
output_format="webdataset",
input_format="parquet",
url_col="url",
caption_col="generated_caption",
enable_wandb=True,
number_sample_per_shard=1000,
distributor="pyspark",
)

if __name__ == "__main__":
repartition()
download_images()


"""
Common Error Handling,
1. If an error with connecting to JAVA port, Paste this in terminal `export JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64`
2. Untar Files `for f in dataset/laion_small_images/*.tar; do tar -xvf "$f" -C data/; done`
3. Copy Files `for f in full_images/*.txt; do cp -v "$f" new_images/"${f//\//_}"; done` this is for text do .jpg for images
"""
24 changes: 24 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
from diffusers import DiffusionPipeline
from diffusers import DPMSolverMultistepScheduler
from torch import Generator


path = 'Warlord-K/BKSDM-Base-45K'
prompt = "Faceshot Portrait of pretty young (18-year-old) Caucasian wearing a high neck sweater, (masterpiece, extremely detailed skin, photorealistic, heavy shadow, dramatic and cinematic lighting, key light, fill light), sharp focus, BREAK epicrealism"
negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"

torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True

with torch.inference_mode():
gen = Generator("cuda")
gen.manual_seed(1674753452)
pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False)
pipe.to('cuda')
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.unet.to(device='cuda', dtype=torch.float16, memory_format=torch.channels_last)

for i in range(3):
img = pipe(prompt=prompt,negative_prompt=negative_prompt, width=512, height=512, num_inference_steps=25, guidance_scale = 7, num_images_per_prompt=1, generator = gen).images[0]
img.save(f"{i}.png")
Loading

0 comments on commit 903d524

Please sign in to comment.