forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_laboratory.py
82 lines (71 loc) · 3.15 KB
/
model_laboratory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Experiment with different models."""
from __future__ import annotations
from typing import List, Optional, Sequence
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
class ModelLaboratory:
"""Experiment with different models."""
def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None):
"""Initialize with chains to experiment with.
Args:
chains: list of chains to experiment with.
"""
for chain in chains:
if not isinstance(chain, Chain):
raise ValueError(
"ModelLaboratory should now be initialized with Chains. "
"If you want to initialize with LLMs, use the `from_llms` method "
"instead (`ModelLaboratory.from_llms(...)`)"
)
if len(chain.input_keys) != 1:
raise ValueError(
"Currently only support chains with one input variable, "
f"got {chain.input_keys}"
)
if len(chain.output_keys) != 1:
raise ValueError(
"Currently only support chains with one output variable, "
f"got {chain.output_keys}"
)
if names is not None:
if len(names) != len(chains):
raise ValueError("Length of chains does not match length of names.")
self.chains = chains
chain_range = [str(i) for i in range(len(self.chains))]
self.chain_colors = get_color_mapping(chain_range)
self.names = names
@classmethod
def from_llms(
cls, llms: List[BaseLLM], prompt: Optional[PromptTemplate] = None
) -> ModelLaboratory:
"""Initialize with LLMs to experiment with and optional prompt.
Args:
llms: list of LLMs to experiment with
prompt: Optional prompt to use to prompt the LLMs. Defaults to None.
If a prompt was provided, it should only have one input variable.
"""
if prompt is None:
prompt = PromptTemplate(input_variables=["_input"], template="{_input}")
chains = [LLMChain(llm=llm, prompt=prompt) for llm in llms]
names = [str(llm) for llm in llms]
return cls(chains, names=names)
def compare(self, text: str) -> None:
"""Compare model outputs on an input text.
If a prompt was provided with starting the laboratory, then this text will be
fed into the prompt. If no prompt was provided, then the input text is the
entire prompt.
Args:
text: input text to run all models on.
"""
print(f"\033[1mInput:\033[0m\n{text}\n")
for i, chain in enumerate(self.chains):
if self.names is not None:
name = self.names[i]
else:
name = str(chain)
print_text(name, end="\n")
output = chain.run(text)
print_text(output, color=self.chain_colors[str(i)], end="\n\n")