-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 903d524
Showing
6 changed files
with
1,110 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .pipeline import prepare_unet, BKSDM |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch | ||
import gc | ||
import diffusers | ||
import transformers | ||
|
||
|
||
class BKSDM: | ||
def __init__( | ||
self, | ||
model_id="runwayml/stable-diffusion-v1-5", | ||
torch_dtype=torch.float16, | ||
model_type="base", | ||
device="cuda", | ||
**kwargs | ||
): | ||
assert model_type in ["base", "tiny", "midless", "small"] | ||
if not torch.cuda.is_available(): | ||
device = "cpu" | ||
torch_dtype = torch.float32 | ||
self.pipe = diffusers.StableDiffusionPipeline.from_pretrained( | ||
model_id, torch_dtype=torch_dtype | ||
) | ||
self.pipe.to(device) | ||
# Set mid block to None if mode is other than base | ||
if model_type != "base": | ||
self.pipe.unet.mid_block = None | ||
# Commence deletion of resnets/attentions inside the U-net | ||
if model_type != "midless": | ||
# Handle Down Blocks | ||
for i in range(3): | ||
delattr(self.pipe.unet.down_blocks[i].resnets, "1") | ||
delattr(self.pipe.unet.down_blocks[i].attentions, "1") | ||
|
||
if model_type == "tiny": | ||
delattr(self.pipe.unet.down_blocks, "3") | ||
self.pipe.unet.down_blocks[2].downsamplers = None | ||
|
||
else: | ||
delattr(self.pipe.unet.down_blocks[3].resnets, "1") | ||
# Handle Up blocks | ||
|
||
self.pipe.unet.up_blocks[0].resnets[1] = self.pipe.unet.up_blocks[ | ||
0 | ||
].resnets[2] | ||
delattr(self.pipe.unet.up_blocks[0].resnets, "2") | ||
for i in range(1, 4): | ||
self.pipe.unet.up_blocks[i].attentions[1] = self.pipe.unet.up_blocks[ | ||
i | ||
].attentions[2] | ||
delattr(self.pipe.unet.up_blocks[i].attentions, "2") | ||
delattr(self.pipe.unet.up_blocks[i].resnets, "1") | ||
if model_type == "tiny": | ||
for i in range(3): | ||
self.pipe.unet.up_blocks[i] = self.pipe.unet.up_blocks[i + 1] | ||
delattr(self.pipe.unet.up_blocks, "3") | ||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
def __call__( | ||
self, prompt, num_inference_steps=50, guidance_scale=7.5, negative_prompt=None | ||
): | ||
return self.pipe( | ||
prompt, | ||
num_inference_steps=num_inference_steps, | ||
guidance_scale=guidance_scale, | ||
negative_prompt=negative_prompt, | ||
) | ||
|
||
|
||
def prepare_unet(unet, model_type): | ||
assert model_type in ["base", "tiny", "midless", "small"] | ||
# Set mid block to None if mode is other than base | ||
if model_type != "base": | ||
unet.mid_block = None | ||
# Commence deletion of resnets/attentions inside the U-net | ||
if model_type != "midless": | ||
# Handle Down Blocks | ||
for i in range(3): | ||
delattr(unet.down_blocks[i].resnets, "1") | ||
delattr(unet.down_blocks[i].attentions, "1") | ||
|
||
if model_type == "tiny": | ||
delattr(unet.down_blocks, "3") | ||
unet.down_blocks[2].downsamplers = None | ||
|
||
else: | ||
delattr(unet.down_blocks[3].resnets, "1") | ||
# Handle Up blocks | ||
|
||
unet.up_blocks[0].resnets[1] = unet.up_blocks[0].resnets[2] | ||
delattr(unet.up_blocks[0].resnets, "2") | ||
for i in range(1, 4): | ||
unet.up_blocks[i].resnets[1] = unet.up_blocks[i].resnets[2] | ||
unet.up_blocks[i].attentions[1] = unet.up_blocks[i].attentions[2] | ||
delattr(unet.up_blocks[i].attentions, "2") | ||
delattr(unet.up_blocks[i].resnets, "2") | ||
if model_type == "tiny": | ||
for i in range(3): | ||
unet.up_blocks[i] = unet.up_blocks[i + 1] | ||
delattr(unet.up_blocks, "3") | ||
torch.cuda.empty_cache() | ||
gc.collect() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from pyspark.sql import SparkSession | ||
import pyspark.sql.functions as F | ||
from pyspark.sql.functions import rand | ||
from img2dataset import download | ||
import shutil | ||
import os | ||
|
||
|
||
def repartition(): | ||
spark = SparkSession.builder.config("spark.driver.memory", "16G") .master("local[16]").appName('spark-repart').getOrCreate() | ||
df = spark.read.parquet("dataset/data.parquet") | ||
# df = df.filter((df.WIDTH >= 1024) & (df.HEIGHT >= 1024)) | ||
# df = df.filter((df.AESTHETIC_SCORE > 7)) | ||
df = df.orderBy(rand(seed = 0)) # this line is important to have a shuffled dataset | ||
print(df.count()) | ||
df.repartition(10).write.parquet("dataset/laion_small") | ||
|
||
|
||
def download_images(output_dir="dataset/laion_small_images", url = "dataset/laion_small/part-00002-195faf27-0776-429e-a03b-a6aba71d4f16-c000.snappy.parquet"): | ||
output_dir = os.path.abspath(output_dir) | ||
|
||
if os.path.exists(output_dir): | ||
shutil.rmtree(output_dir) | ||
|
||
spark = ( | ||
SparkSession.builder.config("spark.driver.memory", "16G").master("local[16]").appName("spark-stats").getOrCreate() | ||
) | ||
|
||
download( | ||
processes_count=1, | ||
thread_count=32, | ||
url_list= url, | ||
image_size=512, | ||
output_folder=output_dir, | ||
output_format="webdataset", | ||
input_format="parquet", | ||
url_col="url", | ||
caption_col="generated_caption", | ||
enable_wandb=True, | ||
number_sample_per_shard=1000, | ||
distributor="pyspark", | ||
) | ||
|
||
if __name__ == "__main__": | ||
repartition() | ||
download_images() | ||
|
||
|
||
""" | ||
Common Error Handling, | ||
1. If an error with connecting to JAVA port, Paste this in terminal `export JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64` | ||
2. Untar Files `for f in dataset/laion_small_images/*.tar; do tar -xvf "$f" -C data/; done` | ||
3. Copy Files `for f in full_images/*.txt; do cp -v "$f" new_images/"${f//\//_}"; done` this is for text do .jpg for images | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import torch | ||
from diffusers import DiffusionPipeline | ||
from diffusers import DPMSolverMultistepScheduler | ||
from torch import Generator | ||
|
||
|
||
path = 'Warlord-K/BKSDM-Base-45K' | ||
prompt = "Faceshot Portrait of pretty young (18-year-old) Caucasian wearing a high neck sweater, (masterpiece, extremely detailed skin, photorealistic, heavy shadow, dramatic and cinematic lighting, key light, fill light), sharp focus, BREAK epicrealism" | ||
negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" | ||
|
||
torch.set_grad_enabled(False) | ||
torch.backends.cudnn.benchmark = True | ||
|
||
with torch.inference_mode(): | ||
gen = Generator("cuda") | ||
gen.manual_seed(1674753452) | ||
pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False) | ||
pipe.to('cuda') | ||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | ||
pipe.unet.to(device='cuda', dtype=torch.float16, memory_format=torch.channels_last) | ||
|
||
for i in range(3): | ||
img = pipe(prompt=prompt,negative_prompt=negative_prompt, width=512, height=512, num_inference_steps=25, guidance_scale = 7, num_images_per_prompt=1, generator = gen).images[0] | ||
img.save(f"{i}.png") |
Oops, something went wrong.