forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Retriever based on GCP DocAI Warehouse (langchain-ai#11400)
- **Description:** implements a retriever on top of DocAI Warehouse (to interact with existing enterprise documents) https://cloud.google.com/document-ai-warehouse?hl=en - **Issue:** new functionality @baskaryan --------- Co-authored-by: Bagatur <[email protected]>
- Loading branch information
Showing
4 changed files
with
164 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
libs/langchain/langchain/retrievers/google_cloud_documentai_warehouse.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
"""Retriever wrapper for Google Cloud Document AI Warehouse.""" | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||
|
||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun | ||
from langchain.docstore.document import Document | ||
from langchain.pydantic_v1 import root_validator | ||
from langchain.schema import BaseRetriever | ||
from langchain.utils import get_from_dict_or_env | ||
|
||
if TYPE_CHECKING: | ||
from google.cloud.contentwarehouse_v1 import ( | ||
DocumentServiceClient, | ||
RequestMetadata, | ||
SearchDocumentsRequest, | ||
) | ||
from google.cloud.contentwarehouse_v1.services.document_service.pagers import ( | ||
SearchDocumentsPager, | ||
) | ||
|
||
|
||
class GoogleDocumentAIWarehouseRetriever(BaseRetriever): | ||
"""A retriever based on Document AI Warehouse. | ||
Documents should be created and documents should be uploaded | ||
in a separate flow, and this retriever uses only Document AI | ||
schema_id provided to search for revelant documents. | ||
More info: https://cloud.google.com/document-ai-warehouse. | ||
""" | ||
|
||
location: str = "us" | ||
"GCP location where DocAI Warehouse is placed." | ||
project_number: str | ||
"GCP project number, should contain digits only." | ||
schema_id: Optional[str] = None | ||
"DocAI Warehouse schema to queary against. If nothing is provided, all documents " | ||
"in the project will be searched." | ||
qa_size_limit: int = 5 | ||
"The limit on the number of documents returned." | ||
client: "DocumentServiceClient" = None #: :meta private: | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validates the environment.""" | ||
try: # noqa: F401 | ||
from google.cloud.contentwarehouse_v1 import ( | ||
DocumentServiceClient, | ||
) | ||
except ImportError as exc: | ||
raise ImportError( | ||
"google.cloud.contentwarehouse is not installed." | ||
"Please install it with pip install google-cloud-contentwarehouse" | ||
) from exc | ||
|
||
values["project_number"] = get_from_dict_or_env( | ||
values, "project_number", "PROJECT_NUMBER" | ||
) | ||
values["client"] = DocumentServiceClient() | ||
return values | ||
|
||
def _prepare_request_metadata(self, user_ldap: str) -> "RequestMetadata": | ||
from google.cloud.contentwarehouse_v1 import RequestMetadata, UserInfo | ||
|
||
user_info = UserInfo(id=f"user:{user_ldap}") | ||
return RequestMetadata(user_info=user_info) | ||
|
||
def _get_relevant_documents( | ||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any | ||
) -> List[Document]: | ||
request = self._prepare_search_request(query, **kwargs) | ||
response = self.client.search_documents(request=request) | ||
return self._parse_search_response(response=response) | ||
|
||
def _prepare_search_request( | ||
self, query: str, **kwargs: Any | ||
) -> "SearchDocumentsRequest": | ||
from google.cloud.contentwarehouse_v1 import ( | ||
DocumentQuery, | ||
SearchDocumentsRequest, | ||
) | ||
|
||
try: | ||
user_ldap = kwargs["user_ldap"] | ||
except KeyError: | ||
raise ValueError("Argument user_ldap should be provided!") | ||
|
||
request_metadata = self._prepare_request_metadata(user_ldap=user_ldap) | ||
schemas = [] | ||
if self.schema_id: | ||
schemas.append( | ||
self.client.document_schema_path( | ||
project=self.project_number, | ||
location=self.location, | ||
document_schema=self.schema_id, | ||
) | ||
) | ||
return SearchDocumentsRequest( | ||
parent=self.client.common_location_path(self.project_number, self.location), | ||
request_metadata=request_metadata, | ||
document_query=DocumentQuery( | ||
query=query, is_nl_query=True, document_schema_names=schemas | ||
), | ||
qa_size_limit=self.qa_size_limit, | ||
) | ||
|
||
def _parse_search_response( | ||
self, response: "SearchDocumentsPager" | ||
) -> List[Document]: | ||
documents = [] | ||
for doc in response.matching_documents: | ||
metadata = { | ||
"title": doc.document.title, | ||
"source": doc.document.raw_document_path, | ||
} | ||
documents.append( | ||
Document(page_content=doc.search_text_snippet, metadata=metadata) | ||
) | ||
return documents |
25 changes: 25 additions & 0 deletions
25
libs/langchain/tests/integration_tests/retrievers/test_google_docai_warehoure_retriever.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
"""Test Google Cloud Document AI Warehouse retriever.""" | ||
import os | ||
|
||
from langchain.retrievers import GoogleDocumentAIWarehouseRetriever | ||
from langchain.schema import Document | ||
|
||
|
||
def test_google_documentai_warehoure_retriever() -> None: | ||
"""In order to run this test, you should provide a project_id and user_ldap. | ||
Example: | ||
export USER_LDAP=... | ||
export PROJECT_NUMBER=... | ||
""" | ||
project_number = os.environ["PROJECT_NUMBER"] | ||
user_ldap = os.environ["USER_LDAP"] | ||
docai_wh_retriever = GoogleDocumentAIWarehouseRetriever( | ||
project_number=project_number | ||
) | ||
documents = docai_wh_retriever.get_relevant_documents( | ||
"What are Alphabet's Other Bets?", user_ldap=user_ldap | ||
) | ||
assert len(documents) > 0 | ||
for doc in documents: | ||
assert isinstance(doc, Document) |