Skip to content

Commit

Permalink
Retriever based on GCP DocAI Warehouse (langchain-ai#11400)
Browse files Browse the repository at this point in the history
- **Description:** implements a retriever on top of DocAI Warehouse (to
interact with existing enterprise documents)
  https://cloud.google.com/document-ai-warehouse?hl=en
  - **Issue:** new functionality
 
@baskaryan

---------

Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
lkuligin and baskaryan authored Oct 12, 2023
1 parent 629d9b7 commit 2aba9ab
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 0 deletions.
17 changes: 17 additions & 0 deletions docs/docs/integrations/platforms/google.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,23 @@ See a [usage example](/docs/integrations/retrievers/google_vertex_ai_search).
from langchain.retrievers import GoogleVertexAISearchRetriever
```

### Document AI Warehouse
> [Google Cloud Document AI Warehouse](https://cloud.google.com/document-ai-warehouse)
> allows enterprises to search, store, govern, and manage documents and their AI-extracted
> data and metadata in a single platform. Documents should be uploaded outside of Langchain,
>
```python
from langchain.retrievers import GoogleDocumentAIWarehouseRetriever
docai_wh_retriever = GoogleDocumentAIWarehouseRetriever(
project_number=...
)
query = ...
documents = docai_wh_retriever.get_relevant_documents(
query, user_ldap=...
)
```

## Tools
### Google Search

Expand Down
4 changes: 4 additions & 0 deletions libs/langchain/langchain/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from langchain.retrievers.docarray import DocArrayRetriever
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.retrievers.google_cloud_documentai_warehouse import (
GoogleDocumentAIWarehouseRetriever,
)
from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
Expand Down Expand Up @@ -74,6 +77,7 @@
"ContextualCompressionRetriever",
"ChaindeskRetriever",
"ElasticSearchBM25Retriever",
"GoogleDocumentAIWarehouseRetriever",
"GoogleCloudEnterpriseSearchRetriever",
"GoogleVertexAISearchRetriever",
"KayAiRetriever",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Retriever wrapper for Google Cloud Document AI Warehouse."""
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document
from langchain.pydantic_v1 import root_validator
from langchain.schema import BaseRetriever
from langchain.utils import get_from_dict_or_env

if TYPE_CHECKING:
from google.cloud.contentwarehouse_v1 import (
DocumentServiceClient,
RequestMetadata,
SearchDocumentsRequest,
)
from google.cloud.contentwarehouse_v1.services.document_service.pagers import (
SearchDocumentsPager,
)


class GoogleDocumentAIWarehouseRetriever(BaseRetriever):
"""A retriever based on Document AI Warehouse.
Documents should be created and documents should be uploaded
in a separate flow, and this retriever uses only Document AI
schema_id provided to search for revelant documents.
More info: https://cloud.google.com/document-ai-warehouse.
"""

location: str = "us"
"GCP location where DocAI Warehouse is placed."
project_number: str
"GCP project number, should contain digits only."
schema_id: Optional[str] = None
"DocAI Warehouse schema to queary against. If nothing is provided, all documents "
"in the project will be searched."
qa_size_limit: int = 5
"The limit on the number of documents returned."
client: "DocumentServiceClient" = None #: :meta private:

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates the environment."""
try: # noqa: F401
from google.cloud.contentwarehouse_v1 import (
DocumentServiceClient,
)
except ImportError as exc:
raise ImportError(
"google.cloud.contentwarehouse is not installed."
"Please install it with pip install google-cloud-contentwarehouse"
) from exc

values["project_number"] = get_from_dict_or_env(
values, "project_number", "PROJECT_NUMBER"
)
values["client"] = DocumentServiceClient()
return values

def _prepare_request_metadata(self, user_ldap: str) -> "RequestMetadata":
from google.cloud.contentwarehouse_v1 import RequestMetadata, UserInfo

user_info = UserInfo(id=f"user:{user_ldap}")
return RequestMetadata(user_info=user_info)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
request = self._prepare_search_request(query, **kwargs)
response = self.client.search_documents(request=request)
return self._parse_search_response(response=response)

def _prepare_search_request(
self, query: str, **kwargs: Any
) -> "SearchDocumentsRequest":
from google.cloud.contentwarehouse_v1 import (
DocumentQuery,
SearchDocumentsRequest,
)

try:
user_ldap = kwargs["user_ldap"]
except KeyError:
raise ValueError("Argument user_ldap should be provided!")

request_metadata = self._prepare_request_metadata(user_ldap=user_ldap)
schemas = []
if self.schema_id:
schemas.append(
self.client.document_schema_path(
project=self.project_number,
location=self.location,
document_schema=self.schema_id,
)
)
return SearchDocumentsRequest(
parent=self.client.common_location_path(self.project_number, self.location),
request_metadata=request_metadata,
document_query=DocumentQuery(
query=query, is_nl_query=True, document_schema_names=schemas
),
qa_size_limit=self.qa_size_limit,
)

def _parse_search_response(
self, response: "SearchDocumentsPager"
) -> List[Document]:
documents = []
for doc in response.matching_documents:
metadata = {
"title": doc.document.title,
"source": doc.document.raw_document_path,
}
documents.append(
Document(page_content=doc.search_text_snippet, metadata=metadata)
)
return documents
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Test Google Cloud Document AI Warehouse retriever."""
import os

from langchain.retrievers import GoogleDocumentAIWarehouseRetriever
from langchain.schema import Document


def test_google_documentai_warehoure_retriever() -> None:
"""In order to run this test, you should provide a project_id and user_ldap.
Example:
export USER_LDAP=...
export PROJECT_NUMBER=...
"""
project_number = os.environ["PROJECT_NUMBER"]
user_ldap = os.environ["USER_LDAP"]
docai_wh_retriever = GoogleDocumentAIWarehouseRetriever(
project_number=project_number
)
documents = docai_wh_retriever.get_relevant_documents(
"What are Alphabet's Other Bets?", user_ldap=user_ldap
)
assert len(documents) > 0
for doc in documents:
assert isinstance(doc, Document)

0 comments on commit 2aba9ab

Please sign in to comment.