Skip to content

Commit

Permalink
Added sharepoint connector (onyx-dot-app#963)
Browse files Browse the repository at this point in the history
  • Loading branch information
hagen6835 authored Jan 25, 2024
1 parent e94fd8b commit d6d83e7
Show file tree
Hide file tree
Showing 12 changed files with 590 additions and 2 deletions.
1 change: 1 addition & 0 deletions backend/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ api_keys.py
.env
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule
1 change: 1 addition & 0 deletions backend/danswer/configs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class DocumentSource(str, Enum):
GOOGLE_SITES = "google_sites"
ZENDESK = "zendesk"
LOOPIO = "loopio"
SHAREPOINT = "sharepoint"


class DocumentIndexType(str, Enum):
Expand Down
16 changes: 16 additions & 0 deletions backend/danswer/connectors/cross_connector_utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,19 @@ def read_file(
file_content_raw += line

return file_content_raw, metadata


def is_text_file_extension(file_name: str) -> bool:
extensions = (
".txt",
".mdx",
".md",
".conf",
".log",
".json",
".xml",
".yaml",
".yml",
".json",
)
return any(file_name.endswith(ext) for ext in extensions)
2 changes: 2 additions & 0 deletions backend/danswer/connectors/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from danswer.connectors.web.connector import WebConnector
from danswer.connectors.zendesk.connector import ZendeskConnector
from danswer.connectors.zulip.connector import ZulipConnector
from danswer.connectors.sharepoint.connector import SharepointConnector


class ConnectorMissingException(Exception):
Expand Down Expand Up @@ -68,6 +69,7 @@ def identify_connector_class(
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
DocumentSource.ZENDESK: ZendeskConnector,
DocumentSource.LOOPIO: LoopioConnector,
DocumentSource.SHAREPOINT: SharepointConnector,
}
connector_by_source = connector_map.get(source, {})

Expand Down
Empty file.
266 changes: 266 additions & 0 deletions backend/danswer/connectors/sharepoint/connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import io
import os
import tempfile
from datetime import datetime
from datetime import timezone
from typing import Any

import docx # type: ignore
import msal # type: ignore
import openpyxl # type: ignore
import pptx # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
from office365.onedrive.sites.site import Site # type: ignore

from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.file_utils import is_text_file_extension
from danswer.connectors.cross_connector_utils.file_utils import read_pdf_file
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger

UNSUPPORTED_FILE_TYPE_CONTENT = "" # idea copied from the google drive side of things


logger = setup_logger()


def get_text_from_xlsx_driveitem(driveitem_object: DriveItem) -> str:
file_content = driveitem_object.get_content().execute_query().value
excel_file = io.BytesIO(file_content)
workbook = openpyxl.load_workbook(excel_file, read_only=True)

full_text = []
for sheet in workbook.worksheets:
sheet_string = "\n".join(
",".join(map(str, row))
for row in sheet.iter_rows(min_row=1, values_only=True)
)
full_text.append(sheet_string)

return "\n".join(full_text)


def get_text_from_docx_driveitem(driveitem_object: DriveItem) -> str:
file_content = driveitem_object.get_content().execute_query().value
full_text = []

with tempfile.TemporaryDirectory() as local_path:
with open(os.path.join(local_path, driveitem_object.name), "wb") as local_file:
local_file.write(file_content)
doc = docx.Document(local_file.name)
for para in doc.paragraphs:
full_text.append(para.text)
return "\n".join(full_text)


def get_text_from_pdf_driveitem(driveitem_object: DriveItem) -> str:
file_content = driveitem_object.get_content().execute_query().value
file_text = read_pdf_file(
file=io.BytesIO(file_content), file_name=driveitem_object.name
)
return file_text


def get_text_from_txt_driveitem(driveitem_object: DriveItem) -> str:
file_content: bytes = driveitem_object.get_content().execute_query().value
text_string = file_content.decode("utf-8")
return text_string


def get_text_from_pptx_driveitem(driveitem_object: DriveItem):
file_content = driveitem_object.get_content().execute_query().value
pptx_stream = io.BytesIO(file_content)
with tempfile.NamedTemporaryFile() as temp:
temp.write(pptx_stream.getvalue())
presentation = pptx.Presentation(temp.name)
extracted_text = ""
for slide_number, slide in enumerate(presentation.slides, start=1):
extracted_text += f"\nSlide {slide_number}:\n"

for shape in slide.shapes:
if hasattr(shape, "text"):
extracted_text += shape.text + "\n"

return extracted_text


class SharepointConnector(LoadConnector, PollConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
sites: list[str] = [],
) -> None:
self.batch_size = batch_size
self.graph_client: GraphClient | None = None
self.requested_site_list: list[str] = sites

def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
aad_client_id = credentials["aad_client_id"]
aad_client_secret = credentials["aad_client_secret"]
aad_directory_id = credentials["aad_directory_id"]

def _acquire_token_func() -> dict[str, Any]:
"""
Acquire token via MSAL
"""
authority_url = f"https://login.microsoftonline.com/{aad_directory_id}"
app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=aad_client_id,
client_credential=aad_client_secret,
)
token = app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
)
return token

self.graph_client = GraphClient(_acquire_token_func)
return None

def get_all_driveitem_objects(
self,
site_object_list: list[Site],
start: datetime | None = None,
end: datetime | None = None,
) -> list[DriveItem]:
filter_str = ""
if start is not None and end is not None:
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"

driveitem_list = []
for site_object in site_object_list:
site_list_objects = site_object.lists.get().execute_query()
for site_list_object in site_list_objects:
try:
query = site_list_object.drive.root.get_files(True)
if filter_str:
query = query.filter(filter_str)
driveitems = query.execute_query()
driveitem_list.extend(driveitems)
except Exception:
# Sites include things that do not contain .drive.root so this fails
# but this is fine, as there are no actually documents in those
pass

return driveitem_list

def get_all_site_objects(self) -> list[Site]:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")

site_object_list: list[Site] = []

sites_object = self.graph_client.sites.get().execute_query()

if len(self.requested_site_list) > 0:
for requested_site in self.requested_site_list:
adjusted_string = "/" + requested_site.replace(" ", "")
for site_object in sites_object:
if site_object.web_url.endswith(adjusted_string):
site_object_list.append(site_object)
else:
site_object_list.extend(sites_object)

return site_object_list

def _fetch_from_sharepoint(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")

site_object_list = self.get_all_site_objects()

driveitem_list = self.get_all_driveitem_objects(
site_object_list=site_object_list,
start=start,
end=end,
)

# goes over all urls, converts them into Document objects and then yjelds them in batches
doc_batch: list[Document] = []
batch_count = 0
for driveitem_object in driveitem_list:
doc_batch.append(
self.convert_driveitem_object_to_document(driveitem_object)
)

batch_count += 1
if batch_count >= self.batch_size:
yield doc_batch
batch_count = 0
doc_batch = []
yield doc_batch

def convert_driveitem_object_to_document(
self,
driveitem_object: DriveItem,
) -> Document:
file_text = self.extract_driveitem_text(driveitem_object)
doc = Document(
id=driveitem_object.id,
sections=[Section(link=driveitem_object.web_url, text=file_text)],
source=DocumentSource.SHAREPOINT,
semantic_identifier=driveitem_object.name,
doc_updated_at=driveitem_object.last_modified_datetime.replace(
tzinfo=timezone.utc
),
primary_owners=[
BasicExpertInfo(
display_name=driveitem_object.last_modified_by.user.displayName,
email=driveitem_object.last_modified_by.user.email,
)
],
metadata={},
)
return doc

def extract_driveitem_text(self, driveitem_object: DriveItem) -> str:
driveitem_name = driveitem_object.name
driveitem_text = UNSUPPORTED_FILE_TYPE_CONTENT

if driveitem_name.endswith(".docx"):
driveitem_text = get_text_from_docx_driveitem(driveitem_object)
elif driveitem_name.endswith(".pdf"):
driveitem_text = get_text_from_pdf_driveitem(driveitem_object)
elif driveitem_name.endswith(".xlsx"):
driveitem_text = get_text_from_xlsx_driveitem(driveitem_object)
elif driveitem_name.endswith(".pptx"):
driveitem_text = get_text_from_xlsx_driveitem(driveitem_object)
elif is_text_file_extension(driveitem_name):
driveitem_text = get_text_from_txt_driveitem(driveitem_object)

return driveitem_text

def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_sharepoint()

def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)


if __name__ == "__main__":
connector = SharepointConnector(sites=os.environ["SITES"].split(","))

connector.load_credentials(
{
"aad_client_id": os.environ["AAD_CLIENT_ID"],
"aad_client_secret": os.environ["AAD_CLIENT_SECRET"],
"aad_directory_id": os.environ["AAD_CLIENT_DIRECTORY_ID"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))
7 changes: 5 additions & 2 deletions backend/requirements/default.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ celery==5.3.4
chardet==5.2.0
dask==2023.8.1
distributed==2023.8.1
python-dateutil==2.8.2
fastapi==0.103.0
fastapi-users==11.0.0
fastapi-users-db-sqlalchemy==5.0.0
Expand All @@ -28,18 +27,22 @@ llama-index==0.9.8
Mako==1.2.4
nltk==3.8.1
docx2txt==0.8
openai==1.3.5
oauthlib==3.2.2
openai==1.3.5
openpyxl==3.1.2
playwright==1.40.0
psutil==5.9.5
psycopg2-binary==2.9.9
pycryptodome==3.19.1
pydantic==1.10.7
PyGithub==1.58.2
python-dateutil==2.8.2
python-gitlab==3.9.0
python-pptx==0.6.23
pypdf==3.17.0
pytest-mock==3.12.0
pytest-playwright==0.3.2
python-docx==1.1.0
python-dotenv==1.0.0
python-multipart==0.0.6
requests==2.31.0
Expand Down
Binary file added web/public/Sharepoint.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit d6d83e7

Please sign in to comment.