forked from langchain-ai/chat-langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfiguration.py
94 lines (73 loc) · 2.92 KB
/
configuration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Define the configurable parameters for the agent."""
from __future__ import annotations
from dataclasses import dataclass, field, fields
from typing import Annotated, Any, Literal, Optional, Type, TypeVar
from langchain_core.runnables import RunnableConfig, ensure_config
MODEL_NAME_TO_RESPONSE_MODEL = {
"anthropic_claude_3_5_sonnet": "anthropic/claude-3-5-sonnet-20240620",
}
def _update_configurable_for_backwards_compatibility(
configurable: dict[str, Any],
) -> dict[str, Any]:
update = {}
if "k" in configurable:
update["search_kwargs"] = {"k": configurable["k"]}
if "model_name" in configurable:
update["response_model"] = MODEL_NAME_TO_RESPONSE_MODEL.get(
configurable["model_name"], configurable["model_name"]
)
if update:
return {**configurable, **update}
return configurable
@dataclass(kw_only=True)
class BaseConfiguration:
"""Configuration class for indexing and retrieval operations.
This class defines the parameters needed for configuring the indexing and
retrieval processes, including embedding model selection, retriever provider choice, and search parameters.
"""
embedding_model: Annotated[
str,
{"__template_metadata__": {"kind": "embeddings"}},
] = field(
default="openai/text-embedding-3-small",
metadata={
"description": "Name of the embedding model to use. Must be a valid embedding model name."
},
)
retriever_provider: Annotated[
Literal["weaviate"],
{"__template_metadata__": {"kind": "retriever"}},
] = field(
default="weaviate",
metadata={"description": "The vector store provider to use for retrieval."},
)
search_kwargs: dict[str, Any] = field(
default_factory=dict,
metadata={
"description": "Additional keyword arguments to pass to the search function of the retriever."
},
)
# for backwards compatibility
k: int = field(
default=6,
metadata={
"description": "The number of documents to retrieve. Use search_kwargs instead."
},
)
@classmethod
def from_runnable_config(
cls: Type[T], config: Optional[RunnableConfig] = None
) -> T:
"""Create an IndexConfiguration instance from a RunnableConfig object.
Args:
cls (Type[T]): The class itself.
config (Optional[RunnableConfig]): The configuration object to use.
Returns:
T: An instance of IndexConfiguration with the specified configuration.
"""
config = ensure_config(config)
configurable = config.get("configurable") or {}
configurable = _update_configurable_for_backwards_compatibility(configurable)
_fields = {f.name for f in fields(cls) if f.init}
return cls(**{k: v for k, v in configurable.items() if k in _fields})
T = TypeVar("T", bound=BaseConfiguration)