Skip to content

Commit

Permalink
ENH: Allow default headers to be passed to OpenAI API (chroma-core#1397)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Allows users to pass custom headers to OpenAI API, enabling
intermediary proxies with different authentication methods.
 - New functionality
- New optional `default_headers` input at the `OpenAIEmbeddingFunction`
class.

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*

Since this is a relatively specific feature, I believe it won't require
an usage example in the docs.

Co-authored-by: Gustavo Antoniassi <[email protected]>
  • Loading branch information
GusAntoniassi and Gustavo Antoniassi authored Nov 16, 2023
1 parent 7840b51 commit 9e4e838
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import tarfile
import requests
from typing import Any, Dict, List, Union, cast
from typing import Any, Dict, List, Mapping, Union, cast
import numpy as np
import numpy.typing as npt
import importlib
Expand Down Expand Up @@ -86,6 +86,7 @@ def __init__(
api_type: Optional[str] = None,
api_version: Optional[str] = None,
deployment_id: Optional[str] = None,
default_headers: Optional[Mapping[str, str]] = None,
):
"""
Initialize the OpenAIEmbeddingFunction.
Expand All @@ -105,6 +106,7 @@ def __init__(
it will use the api version for the OpenAI API. This can be used to
point to a different deployment, such as an Azure deployment.
deployment_id (str, optional): Deployment ID for Azure OpenAI.
default_headers (Mapping, optional): A mapping of default headers to be sent with each API request.
"""
try:
Expand Down Expand Up @@ -141,12 +143,14 @@ def __init__(
self._client = openai.AzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=api_base
azure_endpoint=api_base,
default_headers=default_headers
).embeddings
else:
self._client = openai.OpenAI(
api_key=api_key,
base_url=api_base
base_url=api_base,
default_headers=default_headers
).embeddings
else:
self._client = openai.Embedding
Expand Down

0 comments on commit 9e4e838

Please sign in to comment.