Skip to content

Commit

Permalink
feat: 修改QA对组件生成的结果
Browse files Browse the repository at this point in the history
  • Loading branch information
zgqgit committed Jun 26, 2024
1 parent 7e56844 commit 708e73f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,18 @@ def _call(
[{"text": d.page_content} for d in docs], run_manager=run_manager
)
qa = ''
qa_i = 0
for res in results.generations:
try:
# response = json.loads(parse_json(res[0].text))
qa += res[0].text
qa_i += 1
except Exception as e:
logger.error(f"Failed to parse response: {res[0].text}. Error: {e}")
continue

if self.k is not None:
if len(qa) >= self.k:
if qa_i >= self.k:
break
return {self.output_key: qa}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dataclasses import dataclass
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel

try:
from llama_index.node_parser import SimpleNodeParser
from llama_index.readers.schema import Document as LlamaindexDocument
Expand Down Expand Up @@ -72,7 +73,6 @@ def load_as_json(text):
"conditional": 0.0,
}


DataRow = namedtuple(
"DataRow",
[
Expand Down Expand Up @@ -108,7 +108,6 @@ def to_pandas(self) -> pd.DataFrame:


class TrainsetGenerator:

"""
Ragas Train Set Generator
Expand All @@ -127,13 +126,13 @@ class TrainsetGenerator:
"""

def __init__(
self,
generator_llm: BaseLanguageModel,
critic_llm: BaseLanguageModel,
trainset_distribution: t.Optional[t.Dict[str, float]] = None,
chunk_size: int = 1024,
seed: int = 42,
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
self,
generator_llm: BaseLanguageModel,
critic_llm: BaseLanguageModel,
trainset_distribution: t.Optional[t.Dict[str, float]] = None,
chunk_size: int = 1024,
seed: int = 42,
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
) -> None:
self.generator_llm = generator_llm
self.critic_llm = critic_llm
Expand All @@ -154,11 +153,11 @@ def __init__(

@classmethod
def from_default(
cls,
llm: BaseLanguageModel,
chunk_size: int = 512,
trainset_distribution: dict = DEFAULT_TRAIN_DISTRIBUTION,
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
cls,
llm: BaseLanguageModel,
chunk_size: int = 512,
trainset_distribution: dict = DEFAULT_TRAIN_DISTRIBUTION,
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
):
generator_llm = llm
critic_llm = llm
Expand Down Expand Up @@ -227,14 +226,14 @@ def _generate_answer(self, question: str, context: t.List[str]) -> t.List[str]:
]

def _remove_nodes(
self, available_indices: t.List[BaseNode], node_idx: t.List
self, available_indices: t.List[BaseNode], node_idx: t.List
) -> t.List[BaseNode]:
for idx in node_idx:
available_indices.remove(idx)
return available_indices

def _generate_doc_nodes_map(
self, document_nodes: t.List[BaseNode]
self, document_nodes: t.List[BaseNode]
) -> t.Dict[str, t.List[BaseNode]]:
doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list)
for node in document_nodes:
Expand All @@ -244,7 +243,7 @@ def _generate_doc_nodes_map(
return doc_nodes_map # type: ignore

def _get_neighbour_node(
self, node: BaseNode, related_nodes: t.List[BaseNode]
self, node: BaseNode, related_nodes: t.List[BaseNode]
) -> t.List[BaseNode]:
if len(related_nodes) < 2:
warnings.warn("No neighbors exists")
Expand All @@ -254,9 +253,9 @@ def _get_neighbour_node(
return [related_nodes[idx] for idx in ids]

def generate(
self,
documents: t.List[LlamaindexDocument] | t.List[Document],
train_size: int,
self,
documents: t.List[LlamaindexDocument] | t.List[Document],
train_size: int,
) -> TrainDataset:
if not isinstance(documents[0], (LlamaindexDocument, Document)):
raise ValueError(
Expand Down Expand Up @@ -323,7 +322,7 @@ def generate(
is_conv = len(context) > 1
answer = self._generate_answer(question, context)
for i, (qstn, ctx, ans) in enumerate(
zip(question.split("\n"), context, answer)
zip(question.split("\n"), context, answer)
):
episode_done = False if is_conv and i == 0 else True
samples.append(
Expand All @@ -350,13 +349,13 @@ class QAGenerationChainV2(Chain):

@classmethod
def from_llm(
cls,
documents: List[Document],
llm: BaseLanguageModel,
k: Optional[int] = None,
chunk_size: int = 512,
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
**kwargs: Any,
cls,
documents: List[Document],
llm: BaseLanguageModel,
k: Optional[int] = None,
chunk_size: int = 512,
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
**kwargs: Any,
) -> QAGenerationChainV2:
"""
Create a QAGenerationChain from a language model.
Expand Down Expand Up @@ -385,9 +384,9 @@ def output_keys(self) -> List[str]:
return [self.output_key]

def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, List]:
for doc in self.documents:
doc.metadata = {}
Expand All @@ -396,20 +395,19 @@ def _call(
dataset = self.generator.generate(documents=self.documents, train_size=self.k)
df = dataset.to_pandas()
qa_pairs = df.to_dict("records")
qa = []
qa = ''
for pair in qa_pairs:
qa.append(
qa += json.dumps(
{
"question": pair["question"],
"answer": pair["ground_truth"][0],
}
)
}, ensure_ascii=False)
return {self.output_key: qa}

async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, List]:
output = self._call(inputs, run_manager)
return output
return output

0 comments on commit 708e73f

Please sign in to comment.