Skip to content

Commit

Permalink
framework/v18.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
slrelease committed Jan 10, 2025
1 parent d37943b commit 101b464
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 69 deletions.
15 changes: 9 additions & 6 deletions framework/src/framework/dsl/executor/query/query_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,15 @@ def add_input(
if vector := self.__get_looks_like_vector(index_node_id, looks_like_clause):
add_input(index_node_id, vector, looks_like_clause.get_weight(), True)
for similar_clause in query_descriptor.get_clauses_by_type(SimilarFilterClause):
value = similar_clause.get_value()
weight = similar_clause.get_weight()
if value is None or not weight:
continue
node_id = similar_clause.space._get_embedding_node(query_descriptor.schema).node_id
add_input(node_id, similar_clause.field_set._generate_space_input(value), weight, False)
if (result := similar_clause.evaluate()) is not None:
node_id = similar_clause.space._get_embedding_node(query_descriptor.schema).node_id
_, weighted_value = result
add_input(
node_id,
similar_clause.field_set._generate_space_input(weighted_value.item),
weighted_value.weight,
False,
)

return inputs

Expand Down
17 changes: 6 additions & 11 deletions framework/src/framework/dsl/query/query_clause.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from superlinked.framework.common.interface.evaluated import Evaluated
from superlinked.framework.common.interface.has_annotation import HasAnnotation
from superlinked.framework.common.interface.weighted import Weighted
from superlinked.framework.common.nlq.open_ai import OpenAIClientConfig
from superlinked.framework.common.schema.schema_object import SchemaField
from superlinked.framework.common.util.generic_class_util import GenericClassUtil
Expand All @@ -43,7 +44,6 @@
from superlinked.framework.dsl.query.predicate.binary_predicate import (
EvaluatedBinaryPredicate,
LooksLikePredicate,
SimilarPredicate,
)
from superlinked.framework.dsl.query.query_filter_validator import QueryFilterValidator
from superlinked.framework.dsl.space.categorical_similarity_space import (
Expand Down Expand Up @@ -270,11 +270,10 @@ def value_accepted_type(self) -> type:

@dataclass(frozen=True)
class SimilarFilterClause(
WeightedQueryClause[tuple[Space, EvaluatedBinaryPredicate[SimilarPredicate]] | None],
WeightedQueryClause[tuple[Space, Weighted[PythonTypes]] | None],
HasAnnotation,
):
field_set: SpaceFieldSet
schema_field: SchemaField

@property
def space(self) -> Space:
Expand All @@ -283,24 +282,20 @@ def space(self) -> Space:
@override
def evaluate(
self,
) -> tuple[Space, EvaluatedBinaryPredicate[SimilarPredicate]] | None:
) -> tuple[Space, Weighted[PythonTypes]] | None:
value = self.get_value()
weight = self.get_weight()
if value is None or weight == constants.DEFAULT_NOT_AFFECTING_WEIGHT:
return None
node = self.space._get_embedding_node(self.schema_field.schema_obj)
similar_filter = EvaluatedBinaryPredicate(
SimilarPredicate(self.schema_field, cast(ParamInputType, value), weight, node)
)
return self.space, similar_filter
return (self.space, Weighted(value, weight))

@override
def get_default_value_param_name(self) -> str:
return f"similar_filter_{self.space}_{self.schema_field.name}_value_param__"
return f"similar_filter_{self.space}_{self.field_set.fields_id}_value_param__"

@override
def get_default_weight_param_name(self) -> str:
return f"similar_filter_{self.space}_{self.schema_field.name}_weight_param__"
return f"similar_filter_{self.space}_{self.field_set.fields_id}_weight_param__"

@property
@override
Expand Down
27 changes: 10 additions & 17 deletions framework/src/framework/dsl/query/query_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations

from collections import defaultdict
from collections.abc import Mapping

import structlog
Expand Down Expand Up @@ -55,7 +54,6 @@
from superlinked.framework.dsl.query.predicate.binary_predicate import (
EvaluatedBinaryPredicate,
LooksLikePredicate,
SimilarPredicate,
)
from superlinked.framework.dsl.query.query_clause import (
HardFilterClause,
Expand Down Expand Up @@ -152,15 +150,17 @@ def similar(
field_set = (
space_field_set.space_field_set if isinstance(space_field_set, HasSpaceFieldSet) else space_field_set
)
schema_field = field_set.get_field_for_schema(self.schema)
if not schema_field:
raise InvalidSchemaException(f"'find' ({type(self.schema)}) is not in similarity field's schema types.")
self.__validate_schema(field_set)
value_param = self.__to_param(param)
weight_param = self.__to_param(weight)
clause = SimilarFilterClause(value_param, weight_param, field_set, schema_field)
clause = SimilarFilterClause(value_param, weight_param, field_set)
altered_query_descriptor = self.__append_clause(clause)
return altered_query_descriptor

def __validate_schema(self, field_set: SpaceFieldSet) -> None:
if self.schema not in field_set.space._embedding_node_by_schema:
raise InvalidSchemaException(f"'find' ({type(self.schema)}) is not in similarity field's schema types.")

def limit(self, limit: IntParamType | None) -> QueryDescriptor:
"""
Set a limit to the number of results returned by the query.
Expand Down Expand Up @@ -407,16 +407,9 @@ def get_looks_like_filter(
looks_like_filter = looks_like_clause.evaluate() if looks_like_clause is not None else None
return looks_like_filter

def get_similar_filters(
self,
) -> dict[Space, list[EvaluatedBinaryPredicate[SimilarPredicate]]]:
similar_filters_by_space = defaultdict(list)
for clause in self.get_clauses_by_type(SimilarFilterClause):
space_and_similar_filter = clause.evaluate()
if space_and_similar_filter is not None:
space, similar = space_and_similar_filter
similar_filters_by_space[space].append(similar)
return dict(similar_filters_by_space)
def get_similar_filters_spaces(self) -> list[Space]:
evaluation_results = [clause.evaluate() for clause in self.get_clauses_by_type(SimilarFilterClause)]
return [result[0] for result in evaluation_results if result is not None]

def get_context_time(self, default: int | Any) -> int:
if (overridden_now_clause := self.get_clause_by_type(OverriddenNowClause)) is not None:
Expand Down Expand Up @@ -451,7 +444,7 @@ def get_param_value_to_set_for_unset_space_weight_clauses(self) -> dict[str, flo
}
if self.get_looks_like_filter() is not None:
return {param_name: constants.DEFAULT_WEIGHT for param_name in unset_space_by_param_name.keys()}
similar_filter_spaces = self.get_similar_filters().keys()
similar_filter_spaces = self.get_similar_filters_spaces()
return {
param_name: (
constants.DEFAULT_WEIGHT if space in similar_filter_spaces else constants.DEFAULT_NOT_AFFECTING_WEIGHT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def __init__(
InvalidSchemaException: If a schema object does not have a corresponding node in the similarity space,
indicating a configuration or implementation error.
"""
TypeValidator.validate_list_item_type(categories, str, "categories")
# TODO FAI-2843 this type ignore is not needed but linting is flaky in CI
super().__init__(
category_input,
String | StringList, # type: ignore[misc] # interface supports only one type
)
TypeValidator.validate_list_item_type(categories, str, "categories")
self.__category = SpaceFieldSet[list[str]](
self,
set(category_input if isinstance(category_input, list) else [category_input]),
Expand Down
69 changes: 44 additions & 25 deletions framework/src/framework/dsl/space/image_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from superlinked.framework.common.dag.image_embedding_node import ImageEmbeddingNode
from superlinked.framework.common.dag.schema_field_node import SchemaFieldNode
from superlinked.framework.common.data_types import Vector
from superlinked.framework.common.schema.blob_information import BlobInformation
from superlinked.framework.common.schema.image_data import ImageData
from superlinked.framework.common.schema.schema_object import (
Blob,
DescribedBlob,
SchemaField,
SchemaObject,
String,
)
Expand Down Expand Up @@ -99,37 +99,19 @@ def __init__(
InvalidSpaceParamException: If the image and description fields are not
from the same schema.
"""
described_blobs = [self._get_described_blob(img) for img in (image if isinstance(image, Sequence) else [image])]
image_fields = [described.blob for described in described_blobs]
self.__validate_field_schemas(image)
image_fields, description_fields = self._split_images_from_descriptions(image)
super().__init__(image_fields, Blob)
length = ImageEmbedding.init_manager(model_handler, model, model_cache_dir).calculate_length()
self.image = ImageSpaceFieldSet(self, set(image_fields))
self.description = ImageDescriptionSpaceFieldSet(
self, set(described.description for described in described_blobs)
self, set(description for description in description_fields if description is not None)
)
self._all_fields = self.image.fields | self.description.fields
self._transformation_config = self._init_transformation_config(model, length, model_handler)
self._schema_field_nodes_by_schema: dict[
SchemaObject, tuple[SchemaFieldNode[BlobInformation], SchemaFieldNode[str]]
] = {
described_blob.blob.schema_obj: (
SchemaFieldNode(described_blob.blob),
SchemaFieldNode(described_blob.description),
)
for described_blob in described_blobs
}
self.__embedding_node_by_schema: dict[SchemaObject, EmbeddingNode[Vector, ImageData]] = {
schema: ImageEmbeddingNode(
image_blob_node=image_blob_node,
description_node=description_node,
transformation_config=self.transformation_config,
fields_for_identification=self._all_fields,
)
for schema, (
image_blob_node,
description_node,
) in self._schema_field_nodes_by_schema.items()
}
self.__embedding_node_by_schema = self._init_embedding_node_by_schema(
image_fields, description_fields, self._all_fields, self.transformation_config
)
self._model = model

def _get_described_blob(self, image: Blob | DescribedBlob) -> DescribedBlob:
Expand All @@ -140,6 +122,26 @@ def _get_described_blob(self, image: Blob | DescribedBlob) -> DescribedBlob:
description = String(DEFAULT_DESCRIPTION_FIELD_PREFIX + image.name, image.schema_obj)
return DescribedBlob(image, description)

def __validate_field_schemas(self, images: Blob | DescribedBlob | Sequence[Blob | DescribedBlob]) -> None:
if any(
image.description.schema_obj != image.blob.schema_obj
for image in (images if isinstance(images, Sequence) else [images])
if isinstance(image, DescribedBlob)
):
raise InvalidSpaceParamException("ImageSpace image and description field must be in the same schema.")

def _split_images_from_descriptions(
self, images: Blob | DescribedBlob | Sequence[Blob | DescribedBlob]
) -> tuple[list[Blob], list[String | None]]:
images = images if isinstance(images, Sequence) else [images]
blobs, descriptions = zip(
*[
(image.blob, image.description) if isinstance(image, DescribedBlob) else (image, None)
for image in images
]
)
return list(blobs), list(descriptions)

@property
@override
def transformation_config(self) -> TransformationConfig[Vector, ImageData]:
Expand Down Expand Up @@ -185,3 +187,20 @@ def _init_transformation_config(
aggregation_config = VectorAggregationConfig(Vector)
normalization_config = L2NormConfig()
return TransformationConfig(normalization_config, aggregation_config, embedding_config)

def _init_embedding_node_by_schema(
self,
image_fields: Sequence[Blob],
description_fields: Sequence[String | None],
all_fields: set[SchemaField],
transformation_config: TransformationConfig[Vector, ImageData],
) -> dict[SchemaObject, EmbeddingNode[Vector, ImageData]]:
return {
image_field.schema_obj: ImageEmbeddingNode(
image_blob_node=SchemaFieldNode(image_field),
description_node=SchemaFieldNode(description_field) if description_field is not None else None,
transformation_config=transformation_config,
fields_for_identification=all_fields,
)
for image_field, description_field in zip(image_fields, description_fields)
}
1 change: 0 additions & 1 deletion framework/src/framework/dsl/space/image_space_field_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

@dataclass
class ImageSpaceFieldSet(SpaceFieldSet[ImageData]):

@property
@override
def input_type(self) -> type:
Expand Down
3 changes: 1 addition & 2 deletions framework/src/framework/dsl/space/number_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def __init__( # pylint: disable=too-many-arguments
negative_filter (float): This is a value that will be set for everything that is equal or
lower than the min_value. It can be a float. It defaults to 0 (No effect)
"""
self._aggregation_mode = aggregation_mode # this must be set before super init for _handle_node_not_present
self._embedding_config = NumberEmbeddingConfig(
float,
float(min_value),
Expand All @@ -119,7 +118,7 @@ def __init__( # pylint: disable=too-many-arguments
number_fields = number if isinstance(number, list) else [number]
self.number = SpaceFieldSet[float](self, set(number_fields))
self._aggregation_config_type_by_mode = self.__init_aggregation_config_type_by_mode()
self._transformation_config = self._init_transformation_config(self._embedding_config, self._aggregation_mode)
self._transformation_config = self._init_transformation_config(self._embedding_config, aggregation_mode)
self.__schema_node_map: dict[SchemaObject, EmbeddingNode[float, float]] = {
number_field.schema_obj: NumberEmbeddingNode(
parent=SchemaFieldNode(number_field),
Expand Down
8 changes: 4 additions & 4 deletions framework/src/framework/dsl/space/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from abc import abstractmethod

from beartype.typing import Generic, TypeAlias, TypeVar
from beartype.typing import Generic, Sequence, TypeAlias, TypeVar
from typing_extensions import override

from superlinked.framework.common.dag.embedding_node import EmbeddingNode
Expand Down Expand Up @@ -54,16 +54,16 @@ class Space(

def __init__(
self,
fields: SpaceSchemaFieldT | list[SpaceSchemaFieldT],
fields: SpaceSchemaFieldT | Sequence[SpaceSchemaFieldT],
type_: type | TypeAlias,
) -> None:
super().__init__()
field_list: list[SpaceSchemaFieldT] = fields if isinstance(fields, list) else [fields]
field_list = fields if isinstance(fields, Sequence) else [fields]
TypeValidator.validate_list_item_type(field_list, type_, "field_list")
self.__validate_fields(field_list)
self._field_set = set(field_list)

def __validate_fields(self, field_list: list[SpaceSchemaFieldT]) -> None:
def __validate_fields(self, field_list: Sequence[SpaceSchemaFieldT]) -> None:
if not self._allow_empty_fields and not field_list:
raise InvalidSpaceParamException(f"{self.__class__.__name__} field input must not be empty.")
schema_list = [field.schema_obj for field in field_list]
Expand Down
14 changes: 12 additions & 2 deletions framework/src/framework/dsl/space/space_field_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,18 @@ class SpaceFieldSet(Generic[SIT]):
fields: set[SchemaField]

def __post_init__(self) -> None:
self.__schema_field_map = {field.schema_obj: field for field in self.fields}
self._schema_field_map = {field.schema_obj: field for field in self.fields}
self._input_type: type[SIT] = GenericClassUtil.get_generic_types(self.space)[1]
self._fields_id = self.__generate_fields_id(self.fields)

@property
def input_type(self) -> type[SIT]:
return self._input_type

@property
def fields_id(self) -> str:
return self._fields_id

@property
def field_names_text(self) -> Sequence[str]:
return ",".join([f"{field.schema_obj._schema_name}.{field.name}" for field in self.fields])
Expand All @@ -50,4 +55,9 @@ def _generate_space_input(self, value: PythonTypes) -> SIT:
return cast(SIT, value)

def get_field_for_schema(self, schema_: Any) -> SchemaField | None:
return self.__schema_field_map.get(schema_)
return self._schema_field_map.get(schema_)

def __generate_fields_id(self, fields: set[SchemaField]) -> str:
field_ids = [f"{field.schema_obj._schema_name}_{field.name}" for field in fields]
field_ids.sort()
return "_".join(field_ids)

0 comments on commit 101b464

Please sign in to comment.