Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(nodes): invocation registration logic #7826

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 105 additions & 98 deletions invokeai/app/invocations/baseinvocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from functools import lru_cache
from inspect import signature
from typing import (
TYPE_CHECKING,
Expand All @@ -27,7 +28,6 @@
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import TypeAliasType

from invokeai.app.invocations.fields import (
FieldKind,
Expand Down Expand Up @@ -100,37 +100,6 @@ class BaseInvocationOutput(BaseModel):
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
"""

_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False

@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls._typeadapter_needs_update = True

@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
"""Gets all invocation outputs."""
return cls._output_classes

@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocationOutput = TypeAliasType(
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
return cls._typeadapter

@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""
return (i.get_type() for i in BaseInvocationOutput.get_outputs())

@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
Expand Down Expand Up @@ -173,76 +142,16 @@ class BaseInvocation(ABC, BaseModel):
All invocations must use the `@invocation` decorator to provide their unique type.
"""

_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False

@classmethod
def get_type(cls) -> str:
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
return cls.model_fields["type"].default

@classmethod
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True

@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocation = TypeAliasType(
"AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
return cls._typeadapter

@classmethod
def invalidate_typeadapter(cls) -> None:
"""Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or
denylist is changed, this should be called to ensure the typeadapter is updated and validation respects
the updated allowlist and denylist."""
cls._typeadapter_needs_update = True

@classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = get_config()
allowed_invocations: set[BaseInvocation] = set()
for sc in cls._invocation_classes:
invocation_type = sc.get_type()
is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
)
is_in_denylist = (
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.add(sc)
return allowed_invocations

@classmethod
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
"""Gets a map of all invocation types to their invocation classes."""
return {i.get_type(): i for i in BaseInvocation.get_invocations()}

@classmethod
def get_invocation_types(cls) -> Iterable[str]:
"""Gets all invocation types."""
return (i.get_type() for i in BaseInvocation.get_invocations())

@classmethod
def get_output_annotation(cls) -> BaseInvocationOutput:
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
return signature(cls.invoke).return_annotation

@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)

@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
Expand Down Expand Up @@ -340,6 +249,105 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)


class InvocationRegistry:
_invocation_classes: ClassVar[set[type[BaseInvocation]]] = set()
_output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set()

@classmethod
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls.invalidate_invocation_typeadapter()

@classmethod
@lru_cache(maxsize=1)
def get_invocation_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation types.

This is used to parse serialized invocations into the correct invocation class.

This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.

@see https://docs.pydantic.dev/latest/concepts/type_adapter/
"""
return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")])

@classmethod
def invalidate_invocation_typeadapter(cls) -> None:
"""Invalidates the cached invocation type adapter."""
cls.get_invocation_typeadapter.cache_clear()

@classmethod
def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = get_config()
allowed_invocations: set[type[BaseInvocation]] = set()
for sc in cls._invocation_classes:
invocation_type = sc.get_type()
is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
)
is_in_denylist = (
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.add(sc)
return allowed_invocations

@classmethod
def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]:
"""Gets a map of all invocation types to their invocation classes."""
return {i.get_type(): i for i in cls.get_invocation_classes()}

@classmethod
def get_invocation_types(cls) -> Iterable[str]:
"""Gets all invocation types."""
return (i.get_type() for i in cls.get_invocation_classes())

@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)

@classmethod
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls.invalidate_output_typeadapter()

@classmethod
def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
"""Gets all invocation outputs."""
return cls._output_classes

@classmethod
@lru_cache(maxsize=1)
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation output types.

This is used to parse serialized invocation outputs into the correct invocation output class.

This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.

@see https://docs.pydantic.dev/latest/concepts/type_adapter/
"""
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])

@classmethod
def invalidate_output_typeadapter(cls) -> None:
"""Invalidates the cached invocation output type adapter."""
cls.get_output_typeadapter.cache_clear()

@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""
return (i.get_type() for i in cls.get_output_classes())


RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"id",
"is_intermediate",
Expand Down Expand Up @@ -453,8 +461,8 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
node_pack = cls.__module__.split(".")[0]

# Handle the case where an existing node is being clobbered by the one we are registering
if invocation_type in BaseInvocation.get_invocation_types():
clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type)
if invocation_type in InvocationRegistry.get_invocation_types():
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
# This should always be true - we just checked if the invocation type was in the set
assert clobbered_invocation is not None

Expand Down Expand Up @@ -539,8 +547,7 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
)
cls.__doc__ = docstring

# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
BaseInvocation.register_invocation(cls) # type: ignore
InvocationRegistry.register_invocation(cls)

return cls

Expand All @@ -565,7 +572,7 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
if re.compile(r"^\S+$").match(output_type) is None:
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')

if output_type in BaseInvocationOutput.get_output_types():
if output_type in InvocationRegistry.get_output_types():
raise ValueError(f'Invocation type "{output_type}" already exists')

validate_fields(cls.model_fields, output_type)
Expand All @@ -586,7 +593,7 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
)
cls.__doc__ = docstring

BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
InvocationRegistry.register_output(cls)

return cls

Expand Down
9 changes: 5 additions & 4 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationRegistry,
invocation,
invocation_output,
)
Expand Down Expand Up @@ -283,7 +284,7 @@ class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return InvocationRegistry.get_invocation_typeadapter().validate_python(v)

return core_schema.no_info_plain_validator_function(validate_invocation)

Expand All @@ -294,7 +295,7 @@ def __get_pydantic_json_schema__(
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
Expand All @@ -304,7 +305,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return InvocationRegistry.get_output_typeadapter().validate_python(v)

return core_schema.no_info_plain_validator_function(validate_invocation_output)

Expand All @@ -316,7 +317,7 @@ def __get_pydantic_json_schema__(
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually

oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
names = [i.__name__ for i in InvocationRegistry.get_output_classes()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
Expand Down
9 changes: 6 additions & 3 deletions invokeai/app/util/custom_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.baseinvocation import (
InvocationRegistry,
UIConfigBase,
)
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
Expand Down Expand Up @@ -56,14 +59,14 @@ def openapi() -> dict[str, Any]:
invocation_output_map_required: list[str] = []

# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
for output in InvocationRegistry.get_output_classes():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema

# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
for invocation in InvocationRegistry.get_invocation_classes():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from pydantic import ValidationError

from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.baseinvocation import InvocationRegistry
from invokeai.app.services.config.config_default import (
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
Expand Down Expand Up @@ -274,7 +274,7 @@ def test_deny_nodes(patch_rootdir):

# We've changed the config, we need to invalidate the typeadapter cache so that the new config is used for
# subsequent graph validations
BaseInvocation.invalidate_typeadapter()
InvocationRegistry.invalidate_invocation_typeadapter()

# confirm graph validation fails when using denied node
Graph.model_validate({"nodes": {"1": {"id": "1", "type": "integer"}}})
Expand All @@ -284,7 +284,7 @@ def test_deny_nodes(patch_rootdir):
Graph.model_validate({"nodes": {"1": {"id": "1", "type": "float"}}})

# confirm invocations union will not have denied nodes
all_invocations = BaseInvocation.get_invocations()
all_invocations = InvocationRegistry.get_invocation_classes()

has_integer = len([i for i in all_invocations if i.get_type() == "integer"]) == 1
has_string = len([i for i in all_invocations if i.get_type() == "string"]) == 1
Expand All @@ -296,4 +296,4 @@ def test_deny_nodes(patch_rootdir):

# Reset the config so that it doesn't affect other tests
get_config.cache_clear()
BaseInvocation.invalidate_typeadapter()
InvocationRegistry.invalidate_invocation_typeadapter()