Skip to content

Commit

Permalink
Feat/ibm memory (crewAIInc#1549)
Browse files Browse the repository at this point in the history
* Everything looks like its working. Waiting for lorenze review.

* Update docs as well.

* clean up for PR
  • Loading branch information
bhancockio authored Nov 1, 2024
1 parent 34954e6 commit 3878daf
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 10 deletions.
25 changes: 25 additions & 0 deletions docs/concepts/memory.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,31 @@ my_crew = Crew(
)
```

### Using Watson embeddings

```python Code
from crewai import Crew, Agent, Task, Process

# Note: Ensure you have installed and imported `ibm_watsonx_ai` for Watson embeddings to work.

my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder={
"provider": "watson",
"config": {
"model": "<model_name>",
"api_url": "<api_url>",
"api_key": "<YOUR_API_KEY>",
"project_id": "<YOUR_PROJECT_ID>",
}
}
)
```

### Resetting Memory

```shell
Expand Down
4 changes: 4 additions & 0 deletions src/crewai/memory/contextual/contextual_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _fetch_stm_context(self, query) -> str:
formatted_results = "\n".join(
[f"- {result['context']}" for result in stm_results]
)
print("formatted_results stm", formatted_results)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""

def _fetch_ltm_context(self, task) -> Optional[str]:
Expand All @@ -53,6 +54,8 @@ def _fetch_ltm_context(self, task) -> Optional[str]:
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")

print("formatted_results ltm", formatted_results)

return f"Historical Data:\n{formatted_results}" if ltm_results else ""

def _fetch_entity_context(self, query) -> str:
Expand All @@ -64,4 +67,5 @@ def _fetch_entity_context(self, query) -> str:
formatted_results = "\n".join(
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
)
print("formatted_results em", formatted_results)
return f"Entities:\n{formatted_results}" if em_results else ""
62 changes: 52 additions & 10 deletions src/crewai/memory/storage/rag_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from typing import Any, Dict, List, Optional, cast

from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function
from chromadb import Documents, EmbeddingFunction, Embeddings
from typing import cast
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path


@contextlib.contextmanager
Expand All @@ -21,9 +21,11 @@ def suppress_logging(
logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel()
logger.setLevel(level)
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
io.StringIO()
), contextlib.suppress(UserWarning):
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
contextlib.suppress(UserWarning),
):
yield
logger.setLevel(original_level)

Expand Down Expand Up @@ -113,12 +115,52 @@ def __call__(self, input: Documents) -> Embeddings:
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
url=config.get("api_url"),
)
elif provider == "watson":
try:
import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.metanames import (
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e

class WatsonEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str):
input = [input]

embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}

embedding = watson_models.Embeddings(
model_id=config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=config.get("api_key"), url=config.get("api_url")
),
project_id=config.get("project_id"),
)

try:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)

except Exception as e:
print("Error during Watson embedding:", e)
raise e

self.embedder_config = WatsonEmbeddingFunction()
else:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]"
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface, watson]"
)
else:
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
validate_embedding_function(self.embedder_config)
self.embedder_config = self.embedder_config

def _initialize_app(self):
Expand Down

0 comments on commit 3878daf

Please sign in to comment.