Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ClownsharkBatwing authored Jul 11, 2024
1 parent 3e4c8a9 commit 59984f6
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 6 deletions.
3 changes: 2 additions & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"LatentPhaseMagnitudeOffset": latents.LatentPhaseMagnitudeOffset,
"LatentPhaseMagnitudePower": latents.LatentPhaseMagnitudePower,

"ClownGuides": samplers.ClownGuides,
"ClownSampler": samplers.ClownSampler,
"SharkSampler": samplers.SharkSampler,
"SamplerDPMPP_DualSDE_Advanced": samplers.SamplerDPMPP_DUALSDE_MOMENTUMIZED_ADVANCED,
Expand Down Expand Up @@ -76,7 +77,7 @@
"Tan Scheduler 2": sigmas.tan_scheduler_2stage,
"Tan Scheduler 2 Simple": sigmas.tan_scheduler_2stage_simple,

"StableCascade_StageB_Conditioning64": samplers.StableCascade_StageB_Conditioning64,
"StableCascade_StageB_Conditioning64": conditioning.StableCascade_StageB_Conditioning64,

}
__all__ = ['NODE_CLASS_MAPPINGS']
56 changes: 56 additions & 0 deletions conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,37 @@
import comfy.sampler_helpers
import node_helpers

import functools
def cast_fp64(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Find the first tensor argument to determine the target device
target_device = None
for arg in args:
if torch.is_tensor(arg):
target_device = arg.device
break
if target_device is None:
for v in kwargs.values():
if torch.is_tensor(v):
target_device = v.device
break

# Recursive function to cast tensors in nested dictionaries
def cast_and_move_to_device(data):
if torch.is_tensor(data):
return data.to(torch.float64).to(target_device)
elif isinstance(data, dict):
return {k: cast_and_move_to_device(v) for k, v in data.items()}
return data

# Cast all tensor arguments to float64 and move them to the target device
new_args = [cast_and_move_to_device(arg) for arg in args]
new_kwargs = {k: cast_and_move_to_device(v) for k, v in kwargs.items()}

return func(*new_args, **new_kwargs)
return wrapper


def initialize_or_scale(tensor, value, steps):
if tensor is None:
Expand Down Expand Up @@ -154,3 +185,28 @@ def main(self, conditioning_0, conditioning_1, ratio):
cond += node_helpers.conditioning_set_values(average, {"start_percent": percents[i]["start_percent"], "end_percent": percents[i]["end_percent"]})

return (cond,)


class StableCascade_StageB_Conditioning64:
@classmethod
def INPUT_TYPES(s):
return {"required": { "conditioning": ("CONDITIONING",),
"stage_c": ("LATENT",),
}}
RETURN_TYPES = ("CONDITIONING",)

FUNCTION = "set_prior"

CATEGORY = "conditioning/stable_cascade"

@cast_fp64
def set_prior(self, conditioning, stage_c):
c = []
for t in conditioning:
d = t[1].copy()
d['stable_cascade_prior'] = stage_c['samples']
n = [t[0], d]
c.append(n)
return (c, )


61 changes: 56 additions & 5 deletions refined_exp_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def sample_refined_exp_s_advanced(

if denoised2_prev is not None:
x_n[0][0] = denoised2_prev
x_next, x_hat, denoised2_prev = branch_mode_proc(x_n, x_h, denoised2, branch_mode, branch_depth, branch_width)
x_next, x_hat, denoised2_prev = branch_mode_proc(x_n, x_h, denoised2, latent_guide_2, branch_mode, branch_depth, branch_width)

d = to_d(x_hat, sigma_hat, x_next)
dt = sigma_next - sigma_hat
Expand Down Expand Up @@ -399,6 +399,7 @@ def sample_refined_exp_s_advanced(
def branch_mode_proc(
x_n, x_h,
denoised2,
latent,
branch_mode,
branch_depth,
branch_width,
Expand All @@ -418,6 +419,55 @@ def branch_mode_proc(
x_next, x_hat, d_next = select_perpendicular_cosine_trajectory(x_n, x_h, branch_depth, branch_width)
if branch_mode == 'cos_perpendicular_d':
x_next, x_hat, d_next = select_perpendicular_cosine_trajectory_d(x_n, x_h, denoised2, branch_depth, branch_width)

if branch_mode == 'latent_match':
distances = [torch.norm(tensor - latent).item() for tensor in x_n[branch_depth]]
closest_index = distances.index(min(distances))
x_next = x_n[branch_depth][closest_index]
x_hat = x_h[branch_depth][closest_index]
d_next = denoised2[branch_depth][closest_index]

if branch_mode == 'latent_match_d':
distances = [torch.norm(tensor - latent).item() for tensor in denoised2[branch_depth]]
closest_index = distances.index(min(distances))
x_next = x_n[branch_depth][closest_index]
x_hat = x_h[branch_depth][closest_index]
d_next = denoised2[branch_depth][closest_index]

if branch_mode == 'latent_match_sdxl_color_d':
relevant_latent = latent[:, 1:3, :, :]
denoised2_relevant = [tensor[:, 1:3, :, :] for tensor in denoised2[branch_depth]]

distances = [torch.norm(tensor - relevant_latent).item() for tensor in denoised2_relevant]
closest_index = distances.index(min(distances))

x_next = x_n[branch_depth][closest_index]
x_hat = x_h[branch_depth][closest_index]
d_next = denoised2[branch_depth][closest_index]

if branch_mode == 'latent_match_sdxl_luminosity_d':
relevant_latent = latent[:, 0:1, :, :]
denoised2_relevant = [tensor[:, 0:1, :, :] for tensor in denoised2[branch_depth]]

distances = [torch.norm(tensor - relevant_latent).item() for tensor in denoised2_relevant]
closest_index = distances.index(min(distances))

x_next = x_n[branch_depth][closest_index]
x_hat = x_h[branch_depth][closest_index]
d_next = denoised2[branch_depth][closest_index]

if branch_mode == 'latent_match_sdxl_pattern_d':
relevant_latent = latent[:, 3:4, :, :]
denoised2_relevant = [tensor[:, 3:4, :, :] for tensor in denoised2[branch_depth]]

distances = [torch.norm(tensor - relevant_latent).item() for tensor in denoised2_relevant]
closest_index = distances.index(min(distances))

x_next = x_n[branch_depth][closest_index]
x_hat = x_h[branch_depth][closest_index]
d_next = denoised2[branch_depth][closest_index]


if branch_mode == 'mean':
x_mean = torch.mean(torch.stack(x_n[branch_depth]), dim=0)
distances = [torch.norm(tensor - x_mean).item() for tensor in x_n[branch_depth]]
Expand All @@ -427,14 +477,15 @@ def branch_mode_proc(
d_next = denoised2[branch_depth][closest_index]

if branch_mode == 'mean_d':
x_mean = torch.mean(torch.stack(denoised2[branch_depth]), dim=0)
distances = [torch.norm(tensor - x_mean).item() for tensor in x_n[branch_depth]]
d_mean = torch.mean(torch.stack(denoised2[branch_depth]), dim=0)
distances = [torch.norm(tensor - d_mean).item() for tensor in denoised2[branch_depth]]
closest_index = distances.index(min(distances))
x_next = x_n[branch_depth][closest_index]
x_hat = x_h[branch_depth][closest_index]
d_next = denoised2[branch_depth][closest_index]

if branch_mode == 'median': #minimum median distance
d_n_3 = [tensor for tensor in denoised2[branch_depth] if tensor is not None]
x_n_3 = [tensor for tensor in x_n[branch_depth] if tensor is not None]
x_h_3 = [tensor for tensor in x_h[branch_depth] if tensor is not None]
num_tensors = len(x_n_3)
Expand Down Expand Up @@ -714,7 +765,7 @@ def select_most_linear_trajectory_d(x_n, x_h, denoised2, branch_depth, branch_wi
return x_next, x_hat, d_next


def select_perpendicular_cosine_trajectory(x_n, x_h, denoised2, branch_width, branch_depth):
def select_perpendicular_cosine_trajectory(x_n, x_h, denoised2, branch_depth, branch_width):
d_n_depth = [tensor for tensor in denoised2[branch_depth] if tensor is not None]
x_n_depth = [tensor for tensor in x_n[branch_depth] if tensor is not None]
x_h_depth = [tensor for tensor in x_h[branch_depth] if tensor is not None]
Expand Down Expand Up @@ -756,7 +807,7 @@ def select_perpendicular_cosine_trajectory(x_n, x_h, denoised2, branch_width, br



def select_perpendicular_cosine_trajectory_d(x_n, x_h, denoised2, branch_width, branch_depth):
def select_perpendicular_cosine_trajectory_d(x_n, x_h, denoised2, branch_depth, branch_width):
d_n_depth = [tensor for tensor in denoised2[branch_depth] if tensor is not None]
x_n_depth = [tensor for tensor in x_n[branch_depth] if tensor is not None]
x_h_depth = [tensor for tensor in x_h[branch_depth] if tensor is not None]
Expand Down

0 comments on commit 59984f6

Please sign in to comment.