Skip to content

Commit

Permalink
Add a way for nodes to validate their own inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Apr 23, 2023
1 parent f7a8218 commit ccad603
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 23 deletions.
21 changes: 11 additions & 10 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import nodes

import comfy.model_management
import folder_paths

def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
Expand Down Expand Up @@ -250,14 +249,15 @@ def validate_inputs(prompt, item):
if "max" in info[1] and val > info[1]["max"]:
return (False, "Value bigger than max. {}, {}".format(class_type, x))

if isinstance(type_input, list):
is_annotated_path = val.endswith("[temp]") or val.endswith("[input]") or val.endswith("[output]")
if is_annotated_path:
if not folder_paths.exists_annotated_filepath(val):
return (False, "Invalid file path. {}, {}: {}".format(class_type, x, val))

elif val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
if ret != True:
return (False, "{}, {}".format(class_type, ret))
else:
if isinstance(type_input, list):
if val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
return (True, "")

def validate_prompt(prompt):
Expand All @@ -279,7 +279,8 @@ def validate_prompt(prompt):
m = validate_inputs(prompt, o)
valid = m[0]
reason = m[1]
except:
except Exception as e:
print(traceback.format_exc())
valid = False
reason = "Parsing error"

Expand Down
6 changes: 3 additions & 3 deletions folder_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_directory_by_type(type_name):

# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir
def touch_annotated_filepath(name):
def annotated_filepath(name):
if name.endswith("[output]"):
base_dir = get_output_directory()
name = name[:-9]
Expand All @@ -88,7 +88,7 @@ def touch_annotated_filepath(name):


def get_annotated_filepath(name, default_dir=None):
name, base_dir = touch_annotated_filepath(name)
name, base_dir = annotated_filepath(name)

if base_dir is None:
if default_dir is not None:
Expand All @@ -100,7 +100,7 @@ def get_annotated_filepath(name, default_dir=None):


def exists_annotated_filepath(name):
name, base_dir = touch_annotated_filepath(name)
name, base_dir = annotated_filepath(name)

if base_dir is None:
base_dir = get_input_directory() # fallback path
Expand Down
32 changes: 23 additions & 9 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,8 +974,7 @@ def INPUT_TYPES(s):
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image):
input_dir = folder_paths.get_input_directory()
image_path = folder_paths.get_annotated_filepath(image, input_dir)
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
Expand All @@ -989,29 +988,35 @@ def load_image(self, image):

@classmethod
def IS_CHANGED(s, image):
input_dir = folder_paths.get_input_directory()
image_path = folder_paths.get_annotated_filepath(image, input_dir)
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()

@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)

return True

class LoadImageMask:
_color_channels = ["alpha", "red", "green", "blue"]
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
return {"required":
{"image": (sorted(os.listdir(input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),}
"channel": (s._color_channels, ),}
}

CATEGORY = "mask"

RETURN_TYPES = ("MASK",)
FUNCTION = "load_image"
def load_image(self, image, channel):
input_dir = folder_paths.get_input_directory()
image_path = folder_paths.get_annotated_filepath(image, input_dir)
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
if i.getbands() != ("R", "G", "B", "A"):
i = i.convert("RGBA")
Expand All @@ -1028,13 +1033,22 @@ def load_image(self, image, channel):

@classmethod
def IS_CHANGED(s, image, channel):
input_dir = folder_paths.get_input_directory()
image_path = folder_paths.get_annotated_filepath(image, input_dir)
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()

@classmethod
def VALIDATE_INPUTS(s, image, channel):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)

if channel not in s._color_channels:
return "Invalid color channel: {}".format(channel)

return True

class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
crop_methods = ["disabled", "center"]
Expand Down
2 changes: 1 addition & 1 deletion web/scripts/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ export class ComfyApp {
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
if (prop) {
prop.value = value;
prop.callback(value);
}
});
}
Expand Down

0 comments on commit ccad603

Please sign in to comment.