diff --git a/framework/src/framework/dsl/executor/query/query_executor.py b/framework/src/framework/dsl/executor/query/query_executor.py index 4e8e5c09..949746ef 100644 --- a/framework/src/framework/dsl/executor/query/query_executor.py +++ b/framework/src/framework/dsl/executor/query/query_executor.py @@ -162,12 +162,15 @@ def add_input( if vector := self.__get_looks_like_vector(index_node_id, looks_like_clause): add_input(index_node_id, vector, looks_like_clause.get_weight(), True) for similar_clause in query_descriptor.get_clauses_by_type(SimilarFilterClause): - value = similar_clause.get_value() - weight = similar_clause.get_weight() - if value is None or not weight: - continue - node_id = similar_clause.space._get_embedding_node(query_descriptor.schema).node_id - add_input(node_id, similar_clause.field_set._generate_space_input(value), weight, False) + if (result := similar_clause.evaluate()) is not None: + node_id = similar_clause.space._get_embedding_node(query_descriptor.schema).node_id + _, weighted_value = result + add_input( + node_id, + similar_clause.field_set._generate_space_input(weighted_value.item), + weighted_value.weight, + False, + ) return inputs diff --git a/framework/src/framework/dsl/query/query_clause.py b/framework/src/framework/dsl/query/query_clause.py index e7f8119f..e49baedf 100644 --- a/framework/src/framework/dsl/query/query_clause.py +++ b/framework/src/framework/dsl/query/query_clause.py @@ -31,6 +31,7 @@ ) from superlinked.framework.common.interface.evaluated import Evaluated from superlinked.framework.common.interface.has_annotation import HasAnnotation +from superlinked.framework.common.interface.weighted import Weighted from superlinked.framework.common.nlq.open_ai import OpenAIClientConfig from superlinked.framework.common.schema.schema_object import SchemaField from superlinked.framework.common.util.generic_class_util import GenericClassUtil @@ -43,7 +44,6 @@ from superlinked.framework.dsl.query.predicate.binary_predicate import ( EvaluatedBinaryPredicate, LooksLikePredicate, - SimilarPredicate, ) from superlinked.framework.dsl.query.query_filter_validator import QueryFilterValidator from superlinked.framework.dsl.space.categorical_similarity_space import ( @@ -270,11 +270,10 @@ def value_accepted_type(self) -> type: @dataclass(frozen=True) class SimilarFilterClause( - WeightedQueryClause[tuple[Space, EvaluatedBinaryPredicate[SimilarPredicate]] | None], + WeightedQueryClause[tuple[Space, Weighted[PythonTypes]] | None], HasAnnotation, ): field_set: SpaceFieldSet - schema_field: SchemaField @property def space(self) -> Space: @@ -283,24 +282,20 @@ def space(self) -> Space: @override def evaluate( self, - ) -> tuple[Space, EvaluatedBinaryPredicate[SimilarPredicate]] | None: + ) -> tuple[Space, Weighted[PythonTypes]] | None: value = self.get_value() weight = self.get_weight() if value is None or weight == constants.DEFAULT_NOT_AFFECTING_WEIGHT: return None - node = self.space._get_embedding_node(self.schema_field.schema_obj) - similar_filter = EvaluatedBinaryPredicate( - SimilarPredicate(self.schema_field, cast(ParamInputType, value), weight, node) - ) - return self.space, similar_filter + return (self.space, Weighted(value, weight)) @override def get_default_value_param_name(self) -> str: - return f"similar_filter_{self.space}_{self.schema_field.name}_value_param__" + return f"similar_filter_{self.space}_{self.field_set.fields_id}_value_param__" @override def get_default_weight_param_name(self) -> str: - return f"similar_filter_{self.space}_{self.schema_field.name}_weight_param__" + return f"similar_filter_{self.space}_{self.field_set.fields_id}_weight_param__" @property @override diff --git a/framework/src/framework/dsl/query/query_descriptor.py b/framework/src/framework/dsl/query/query_descriptor.py index 9f5fa539..0a5c4862 100644 --- a/framework/src/framework/dsl/query/query_descriptor.py +++ b/framework/src/framework/dsl/query/query_descriptor.py @@ -14,7 +14,6 @@ from __future__ import annotations -from collections import defaultdict from collections.abc import Mapping import structlog @@ -55,7 +54,6 @@ from superlinked.framework.dsl.query.predicate.binary_predicate import ( EvaluatedBinaryPredicate, LooksLikePredicate, - SimilarPredicate, ) from superlinked.framework.dsl.query.query_clause import ( HardFilterClause, @@ -152,15 +150,17 @@ def similar( field_set = ( space_field_set.space_field_set if isinstance(space_field_set, HasSpaceFieldSet) else space_field_set ) - schema_field = field_set.get_field_for_schema(self.schema) - if not schema_field: - raise InvalidSchemaException(f"'find' ({type(self.schema)}) is not in similarity field's schema types.") + self.__validate_schema(field_set) value_param = self.__to_param(param) weight_param = self.__to_param(weight) - clause = SimilarFilterClause(value_param, weight_param, field_set, schema_field) + clause = SimilarFilterClause(value_param, weight_param, field_set) altered_query_descriptor = self.__append_clause(clause) return altered_query_descriptor + def __validate_schema(self, field_set: SpaceFieldSet) -> None: + if self.schema not in field_set.space._embedding_node_by_schema: + raise InvalidSchemaException(f"'find' ({type(self.schema)}) is not in similarity field's schema types.") + def limit(self, limit: IntParamType | None) -> QueryDescriptor: """ Set a limit to the number of results returned by the query. @@ -407,16 +407,9 @@ def get_looks_like_filter( looks_like_filter = looks_like_clause.evaluate() if looks_like_clause is not None else None return looks_like_filter - def get_similar_filters( - self, - ) -> dict[Space, list[EvaluatedBinaryPredicate[SimilarPredicate]]]: - similar_filters_by_space = defaultdict(list) - for clause in self.get_clauses_by_type(SimilarFilterClause): - space_and_similar_filter = clause.evaluate() - if space_and_similar_filter is not None: - space, similar = space_and_similar_filter - similar_filters_by_space[space].append(similar) - return dict(similar_filters_by_space) + def get_similar_filters_spaces(self) -> list[Space]: + evaluation_results = [clause.evaluate() for clause in self.get_clauses_by_type(SimilarFilterClause)] + return [result[0] for result in evaluation_results if result is not None] def get_context_time(self, default: int | Any) -> int: if (overridden_now_clause := self.get_clause_by_type(OverriddenNowClause)) is not None: @@ -451,7 +444,7 @@ def get_param_value_to_set_for_unset_space_weight_clauses(self) -> dict[str, flo } if self.get_looks_like_filter() is not None: return {param_name: constants.DEFAULT_WEIGHT for param_name in unset_space_by_param_name.keys()} - similar_filter_spaces = self.get_similar_filters().keys() + similar_filter_spaces = self.get_similar_filters_spaces() return { param_name: ( constants.DEFAULT_WEIGHT if space in similar_filter_spaces else constants.DEFAULT_NOT_AFFECTING_WEIGHT diff --git a/framework/src/framework/dsl/space/categorical_similarity_space.py b/framework/src/framework/dsl/space/categorical_similarity_space.py index fe52a7e3..e4ae2505 100644 --- a/framework/src/framework/dsl/space/categorical_similarity_space.py +++ b/framework/src/framework/dsl/space/categorical_similarity_space.py @@ -113,12 +113,12 @@ def __init__( InvalidSchemaException: If a schema object does not have a corresponding node in the similarity space, indicating a configuration or implementation error. """ - TypeValidator.validate_list_item_type(categories, str, "categories") # TODO FAI-2843 this type ignore is not needed but linting is flaky in CI super().__init__( category_input, String | StringList, # type: ignore[misc] # interface supports only one type ) + TypeValidator.validate_list_item_type(categories, str, "categories") self.__category = SpaceFieldSet[list[str]]( self, set(category_input if isinstance(category_input, list) else [category_input]), diff --git a/framework/src/framework/dsl/space/image_space.py b/framework/src/framework/dsl/space/image_space.py index c23af76e..1feb69c1 100644 --- a/framework/src/framework/dsl/space/image_space.py +++ b/framework/src/framework/dsl/space/image_space.py @@ -22,11 +22,11 @@ from superlinked.framework.common.dag.image_embedding_node import ImageEmbeddingNode from superlinked.framework.common.dag.schema_field_node import SchemaFieldNode from superlinked.framework.common.data_types import Vector -from superlinked.framework.common.schema.blob_information import BlobInformation from superlinked.framework.common.schema.image_data import ImageData from superlinked.framework.common.schema.schema_object import ( Blob, DescribedBlob, + SchemaField, SchemaObject, String, ) @@ -99,37 +99,19 @@ def __init__( InvalidSpaceParamException: If the image and description fields are not from the same schema. """ - described_blobs = [self._get_described_blob(img) for img in (image if isinstance(image, Sequence) else [image])] - image_fields = [described.blob for described in described_blobs] + self.__validate_field_schemas(image) + image_fields, description_fields = self._split_images_from_descriptions(image) super().__init__(image_fields, Blob) length = ImageEmbedding.init_manager(model_handler, model, model_cache_dir).calculate_length() self.image = ImageSpaceFieldSet(self, set(image_fields)) self.description = ImageDescriptionSpaceFieldSet( - self, set(described.description for described in described_blobs) + self, set(description for description in description_fields if description is not None) ) self._all_fields = self.image.fields | self.description.fields self._transformation_config = self._init_transformation_config(model, length, model_handler) - self._schema_field_nodes_by_schema: dict[ - SchemaObject, tuple[SchemaFieldNode[BlobInformation], SchemaFieldNode[str]] - ] = { - described_blob.blob.schema_obj: ( - SchemaFieldNode(described_blob.blob), - SchemaFieldNode(described_blob.description), - ) - for described_blob in described_blobs - } - self.__embedding_node_by_schema: dict[SchemaObject, EmbeddingNode[Vector, ImageData]] = { - schema: ImageEmbeddingNode( - image_blob_node=image_blob_node, - description_node=description_node, - transformation_config=self.transformation_config, - fields_for_identification=self._all_fields, - ) - for schema, ( - image_blob_node, - description_node, - ) in self._schema_field_nodes_by_schema.items() - } + self.__embedding_node_by_schema = self._init_embedding_node_by_schema( + image_fields, description_fields, self._all_fields, self.transformation_config + ) self._model = model def _get_described_blob(self, image: Blob | DescribedBlob) -> DescribedBlob: @@ -140,6 +122,26 @@ def _get_described_blob(self, image: Blob | DescribedBlob) -> DescribedBlob: description = String(DEFAULT_DESCRIPTION_FIELD_PREFIX + image.name, image.schema_obj) return DescribedBlob(image, description) + def __validate_field_schemas(self, images: Blob | DescribedBlob | Sequence[Blob | DescribedBlob]) -> None: + if any( + image.description.schema_obj != image.blob.schema_obj + for image in (images if isinstance(images, Sequence) else [images]) + if isinstance(image, DescribedBlob) + ): + raise InvalidSpaceParamException("ImageSpace image and description field must be in the same schema.") + + def _split_images_from_descriptions( + self, images: Blob | DescribedBlob | Sequence[Blob | DescribedBlob] + ) -> tuple[list[Blob], list[String | None]]: + images = images if isinstance(images, Sequence) else [images] + blobs, descriptions = zip( + *[ + (image.blob, image.description) if isinstance(image, DescribedBlob) else (image, None) + for image in images + ] + ) + return list(blobs), list(descriptions) + @property @override def transformation_config(self) -> TransformationConfig[Vector, ImageData]: @@ -185,3 +187,20 @@ def _init_transformation_config( aggregation_config = VectorAggregationConfig(Vector) normalization_config = L2NormConfig() return TransformationConfig(normalization_config, aggregation_config, embedding_config) + + def _init_embedding_node_by_schema( + self, + image_fields: Sequence[Blob], + description_fields: Sequence[String | None], + all_fields: set[SchemaField], + transformation_config: TransformationConfig[Vector, ImageData], + ) -> dict[SchemaObject, EmbeddingNode[Vector, ImageData]]: + return { + image_field.schema_obj: ImageEmbeddingNode( + image_blob_node=SchemaFieldNode(image_field), + description_node=SchemaFieldNode(description_field) if description_field is not None else None, + transformation_config=transformation_config, + fields_for_identification=all_fields, + ) + for image_field, description_field in zip(image_fields, description_fields) + } diff --git a/framework/src/framework/dsl/space/image_space_field_set.py b/framework/src/framework/dsl/space/image_space_field_set.py index 0ca3f256..d24dae15 100644 --- a/framework/src/framework/dsl/space/image_space_field_set.py +++ b/framework/src/framework/dsl/space/image_space_field_set.py @@ -29,7 +29,6 @@ @dataclass class ImageSpaceFieldSet(SpaceFieldSet[ImageData]): - @property @override def input_type(self) -> type: diff --git a/framework/src/framework/dsl/space/number_space.py b/framework/src/framework/dsl/space/number_space.py index c0a8fcdb..71efac9a 100644 --- a/framework/src/framework/dsl/space/number_space.py +++ b/framework/src/framework/dsl/space/number_space.py @@ -106,7 +106,6 @@ def __init__( # pylint: disable=too-many-arguments negative_filter (float): This is a value that will be set for everything that is equal or lower than the min_value. It can be a float. It defaults to 0 (No effect) """ - self._aggregation_mode = aggregation_mode # this must be set before super init for _handle_node_not_present self._embedding_config = NumberEmbeddingConfig( float, float(min_value), @@ -119,7 +118,7 @@ def __init__( # pylint: disable=too-many-arguments number_fields = number if isinstance(number, list) else [number] self.number = SpaceFieldSet[float](self, set(number_fields)) self._aggregation_config_type_by_mode = self.__init_aggregation_config_type_by_mode() - self._transformation_config = self._init_transformation_config(self._embedding_config, self._aggregation_mode) + self._transformation_config = self._init_transformation_config(self._embedding_config, aggregation_mode) self.__schema_node_map: dict[SchemaObject, EmbeddingNode[float, float]] = { number_field.schema_obj: NumberEmbeddingNode( parent=SchemaFieldNode(number_field), diff --git a/framework/src/framework/dsl/space/space.py b/framework/src/framework/dsl/space/space.py index b20ab395..7a8a578c 100644 --- a/framework/src/framework/dsl/space/space.py +++ b/framework/src/framework/dsl/space/space.py @@ -16,7 +16,7 @@ from abc import abstractmethod -from beartype.typing import Generic, TypeAlias, TypeVar +from beartype.typing import Generic, Sequence, TypeAlias, TypeVar from typing_extensions import override from superlinked.framework.common.dag.embedding_node import EmbeddingNode @@ -54,16 +54,16 @@ class Space( def __init__( self, - fields: SpaceSchemaFieldT | list[SpaceSchemaFieldT], + fields: SpaceSchemaFieldT | Sequence[SpaceSchemaFieldT], type_: type | TypeAlias, ) -> None: super().__init__() - field_list: list[SpaceSchemaFieldT] = fields if isinstance(fields, list) else [fields] + field_list = fields if isinstance(fields, Sequence) else [fields] TypeValidator.validate_list_item_type(field_list, type_, "field_list") self.__validate_fields(field_list) self._field_set = set(field_list) - def __validate_fields(self, field_list: list[SpaceSchemaFieldT]) -> None: + def __validate_fields(self, field_list: Sequence[SpaceSchemaFieldT]) -> None: if not self._allow_empty_fields and not field_list: raise InvalidSpaceParamException(f"{self.__class__.__name__} field input must not be empty.") schema_list = [field.schema_obj for field in field_list] diff --git a/framework/src/framework/dsl/space/space_field_set.py b/framework/src/framework/dsl/space/space_field_set.py index 8e9c0c1e..3d6aab83 100644 --- a/framework/src/framework/dsl/space/space_field_set.py +++ b/framework/src/framework/dsl/space/space_field_set.py @@ -35,13 +35,18 @@ class SpaceFieldSet(Generic[SIT]): fields: set[SchemaField] def __post_init__(self) -> None: - self.__schema_field_map = {field.schema_obj: field for field in self.fields} + self._schema_field_map = {field.schema_obj: field for field in self.fields} self._input_type: type[SIT] = GenericClassUtil.get_generic_types(self.space)[1] + self._fields_id = self.__generate_fields_id(self.fields) @property def input_type(self) -> type[SIT]: return self._input_type + @property + def fields_id(self) -> str: + return self._fields_id + @property def field_names_text(self) -> Sequence[str]: return ",".join([f"{field.schema_obj._schema_name}.{field.name}" for field in self.fields]) @@ -50,4 +55,9 @@ def _generate_space_input(self, value: PythonTypes) -> SIT: return cast(SIT, value) def get_field_for_schema(self, schema_: Any) -> SchemaField | None: - return self.__schema_field_map.get(schema_) + return self._schema_field_map.get(schema_) + + def __generate_fields_id(self, fields: set[SchemaField]) -> str: + field_ids = [f"{field.schema_obj._schema_name}_{field.name}" for field in fields] + field_ids.sort() + return "_".join(field_ids)