Skip to content

Commit

Permalink
Merge pull request Azure#258 from Azure/matt/docintel-custom
Browse files Browse the repository at this point in the history
Add document intelligence custom skill sample
  • Loading branch information
mattgotteiner authored Aug 6, 2024
2 parents b100946 + e57cb4e commit 881de7e
Show file tree
Hide file tree
Showing 21 changed files with 1,836 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.azure
.venv
scripts/.venv
models
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.venv
.vscode
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
models

bin
obj
csx
.vs
edge
Publish

*.user
*.suo
*.cscfg
*.Cache
project.lock.json

/packages
/TestResults

/tools/NuGet.exe
/App_Data
/secrets
/data
.secrets
appsettings.json
local.settings.json

node_modules
dist

# Local python packages
.python_packages/

# Python Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
models

# Azurite artifacts
__blobstorage__
__queuestorage__
__azurite_db*__.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"recommendations": [
"ms-azuretools.vscode-azurefunctions"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
import azure.functions as func
import json
import logging
import os
from azure.ai.documentintelligence.models import AnalyzeDocumentRequest
from azure.ai.documentintelligence.aio import DocumentIntelligenceClient
from azure.identity.aio import DefaultAzureCredential
import base64
import io
from azure.core.exceptions import HttpResponseError
from typing import Any
from langchain_text_splitters.markdown import MarkdownHeaderTextSplitter
from langchain_text_splitters.character import RecursiveCharacterTextSplitter

app = func.FunctionApp()

@app.function_name(name="markdownsplit")
@app.route(route="markdownsplit")
async def SplitMarkdownDocument(req: func.HttpRequest) -> func.HttpResponse:
input = {}
result_content = []
try:
req_body = req.get_json()
# Read input values
# Either file_data or metadata_storage_path and metadata_storage_sas_token
if "values" in req_body:
for value in req_body["values"]:
record_id = value["recordId"]
chunk_size = 512
if "chunkSize" in value["data"]:
try:
chunk_size = int(value["data"]["chunkSize"])
except Exception as e:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": "'chunkSize' must be an int"
}
]
}
)
continue
chunk_overlap = 128
if "chunkOverlap" in value["data"]:
try:
chunk_overlap = int(value["data"]["chunkOverlap"])
except Exception as e:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": "'chunkOverlap' must be an int"
}
]
}
)
continue
encoder_model_name = "text-embedding-3-large"
if "encoderModelName" in value["data"]:
encoder_model_name = value["data"]["encoderModelName"]
if encoder_model_name not in ["text-embedding-ada-002", "text-embedding-3-large", "text-embedding-3-small"]:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": f"Unknown encoder model {encoder_model_name}"
}
]
}
)
continue
if "content" in value["data"]:
input[record_id] = { "content": value["data"]["content"], "chunkSize": chunk_size, "chunkOverlap": chunk_overlap, "encoderModelName": encoder_model_name }
else:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": "Expected 'content'"
}
]
}
)
except Exception as e:
logging.exception("Could not get input data request body")
return func.HttpResponse(
"Invalid input",
status_code=400
)

# Split the document into chunks based on markdown headers.
# Max chunking goes 8 headers deep.
headers_to_split_on = [
("#", "1"),
("##", "2"),
("###", "3"),
("####", "4"),
("#####", "5"),
("######", "6"),
("#######", "7"),
("########", "8")
]
text_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on, return_each_line=False, strip_headers=False)
for record_id, data in input.items():
encoder_model_name = data["encoderModelName"]
character_splitter: RecursiveCharacterTextSplitter
try:
if encoder_model_name:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(model_name=encoder_model_name, chunk_size = data["chunkSize"], chunk_overlap = data["chunkOverlap"])
else:
character_splitter = RecursiveCharacterTextSplitter(chunk_size = data["chunkSize"], chunk_overlap = data["chunkOverlap"])
except Exception as e:
logging.exception("Failed to load text splitter")
result_content.append({
"recordId": record_id,
"data": {},
"errors": [
{
"message": f"Failed to split text: {e}"
}
]
})
continue
# Split markdown content into chunks based on headers
md_chunks = text_splitter.split_text(data["content"])
# Further split the markdown chunks into the desired
char_chunks = character_splitter.split_documents(md_chunks)
# Return chunk content and headers
chunks = [{ "content": document.page_content, "headers": [document.metadata[header] for header in sorted(document.metadata.keys())] } for document in char_chunks]
result_content.append({ "recordId": record_id, "data": { "chunks": chunks } })

response = { "values": result_content }
return func.HttpResponse(body=json.dumps(response), mimetype="application/json", status_code=200)

@app.function_name(name="read")
@app.route(route="read")
async def ReadDocument(req: func.HttpRequest) -> func.HttpResponse:
input = {}
result_content = []
try:
req_body = req.get_json()
# Read input values
# Either file_data or metadata_storage_path and metadata_storage_sas_token
if "values" in req_body:
for value in req_body["values"]:
record_id = value["recordId"]

# Check if using markdown or text mode
if "mode" in value["data"]:
mode = value["data"]["mode"]
if mode not in ["markdown", "text"]:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": "'mode' must be either 'text' or 'markdown'"
}
]
}
)
continue
else:
mode = "text"

if "file_data" in value["data"]:
input[record_id] = { "type": "file", "file_data": value["data"]["file_data"], "mode": mode }
elif "metadata_storage_path" in value["data"] and "metadata_storage_sas_token" in value["data"]:
input[record_id] = { "type": "sas", "sas_uri": f"{value['data']['metadata_storage_path']}{value['data']['metadata_storage_sas_token']}", "mode": mode }
else:
result_content.append(
{
"recordId": record_id,
"data": {},
"errors": [
{
"message": "Expected either 'file_data' or 'metadata_storage_path' and 'metadata_storage_sas_token'"
}
]
}
)
except Exception as e:
logging.exception("Could not get input data request body")
return func.HttpResponse(
"Invalid input",
status_code=400
)

async with DocumentIntelligenceClient(endpoint=os.environ["AZURE_DOCUMENTINTELLIGENCE_ENDPOINT"], credential=DefaultAzureCredential()) as client:
for record_id, data in input.items():
if "file_data" in data:
result_content.append(await process_file(client, record_id, data["file_data"], data["mode"])) # type: ignore
else:
result_content.append(await process_sas_uri(client, record_id, data["sas_uri"], data["mode"])) # type: ignore

response = { "values": result_content }
return func.HttpResponse(body=json.dumps(response), mimetype="application/json", status_code=200)

async def process_file(client: DocumentIntelligenceClient, record_id: str, file: Any, mode: str):
if not isinstance(file, dict) or \
file["$type"] != "file" or \
"data" not in file:
return {
"recordId": record_id,
"data": {},
"errors": [
{
"message": "file_data is not in correct format"
}
]
}

try:
file_bytes = base64.b64decode(file["data"])
file_data = io.BytesIO(file_bytes)
except:
return {
"recordId": record_id,
"data": {},
"errors": [
{
"message": "Failed to decode file content"
}
]
}

try:
poller = await client.begin_analyze_document(
"prebuilt-layout", analyze_request=file_data, content_type="application/octet-stream", output_content_format=mode, features=["ocrHighResolution"]
)
result = await poller.result()
return {
"recordId": record_id, "data": { "content": result.content }
}
except HttpResponseError as e:
logging.exception("Failed to read document")
return {
"recordId": record_id,
"data": {},
"errors": [
{
"message": f"Failed to process file content: {e.message}"
}
]
}

async def process_sas_uri(client: DocumentIntelligenceClient, record_id: str, sas_uri: str, mode: str):
try:
poller = await client.begin_analyze_document(
"prebuilt-layout", analyze_request=AnalyzeDocumentRequest(url_source=sas_uri), output_content_format=mode, features=["ocrHighResolution"]
)
result = await poller.result()
return {
"recordId": record_id, "data": { "content": result.content }
}
except HttpResponseError as e:
logging.exception("Failed to read document")
return {
"recordId": record_id,
"data": {},
"errors": [
{
"message": f"Failed to process file content: {e.message}"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"version": "2.0",
"logging": {
"applicationInsights": {
"samplingSettings": {
"isEnabled": true,
"excludedTypes": "Request"
}
}
},
"extensionBundle": {
"id": "Microsoft.Azure.Functions.ExtensionBundle",
"version": "[4.*, 5.0.0)"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Do not include azure-functions-worker in this file
# The Python Worker is managed by the Azure Functions platform
# Manually managing azure-functions-worker may cause unexpected issues

azure-functions
azure-identity
azure-ai-documentintelligence
aiohttp
langchain-text-splitters==0.2.2
tiktoken
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# yaml-language-server: $schema=https://raw.githubusercontent.com/Azure/azure-dev/main/schemas/v1.0/azure.yaml.json

name: custom-embeddings
services:
api:
project: ./api/functions
language: python
host: function
hooks:
postdeploy:
shell: pwsh
run: ./scripts/setup_search_service.ps1
interactive: true
continueOnError: false
Loading

0 comments on commit 881de7e

Please sign in to comment.