Skip to content

Commit

Permalink
Allow credential to be passed in
Browse files Browse the repository at this point in the history
  • Loading branch information
pamelafox committed Oct 3, 2024
1 parent c3bcb54 commit ea9c3d1
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/evaltools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
logging.basicConfig(
level=logging.WARNING, format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
)
logger = logging.getLogger("scripts")
logger = logging.getLogger("evaltools")
# We only set the level to INFO for our logger,
# to avoid seeing the noisy INFO level logs from the Azure SDKs
logger.setLevel(logging.INFO)
Expand Down
2 changes: 1 addition & 1 deletion src/evaltools/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .evaluate_metrics import metrics_by_name

logger = logging.getLogger("scripts")
logger = logging.getLogger("evaltools")


def send_question_to_target(
Expand Down
2 changes: 1 addition & 1 deletion src/evaltools/eval/evaluate_metrics/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pandas as pd

logger = logging.getLogger("scripts")
logger = logging.getLogger("evaltools")


class BaseMetric(ABC):
Expand Down
2 changes: 1 addition & 1 deletion src/evaltools/eval/evaluate_metrics/code_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .base_metric import BaseMetric

logger = logging.getLogger("scripts")
logger = logging.getLogger("evaltools")


class AnswerLengthMetric(BaseMetric):
Expand Down
2 changes: 1 addition & 1 deletion src/evaltools/eval/evaluate_metrics/prompt_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

PROMPT_TEMPLATE_DIR = Path(__file__).resolve().parent / "prompts"

logger = logging.getLogger("scripts")
logger = logging.getLogger("evaltools")


class PromptBasedEvaluator:
Expand Down
2 changes: 1 addition & 1 deletion src/evaltools/gen/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from evaltools import service_setup

logger = logging.getLogger("scripts")
logger = logging.getLogger("evaltools")


def generate_test_qa_data(
Expand Down
12 changes: 10 additions & 2 deletions src/evaltools/service_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from azure.identity import AzureDeveloperCliCredential, get_bearer_token_provider
from azure.search.documents import SearchClient

logger = logging.getLogger("scripts")
logger = logging.getLogger("evaltools")


def get_azd_credential(tenant_id: Union[str, None]) -> AzureDeveloperCliCredential:
if tenant_id:
logger.info("Using Azure Developer CLI Credential for tenant %s", tenant_id)
return AzureDeveloperCliCredential(tenant_id=tenant_id, process_timeout=60)
logger.info("Using Azure Developer CLI Credential for home tenant")
return AzureDeveloperCliCredential(process_timeout=60)


Expand Down Expand Up @@ -98,8 +100,14 @@ def get_search_client():
def get_openai_client(oai_config: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration]):
if "azure_deployment" in oai_config:
azure_token_provider = None
if not os.environ.get("AZURE_OPENAI_KEY"):
azure_credential = None
if "credential" in oai_config:
logger.info("Using Azure OpenAI Service with provided credential")
azure_credential = oai_config["credential"]
elif not os.environ.get("AZURE_OPENAI_KEY"):
logger.info("Using Azure OpenAI Service with Azure Developer CLI Credential")
azure_credential = get_azd_credential(os.environ.get("AZURE_OPENAI_TENANT_ID"))
if azure_credential is not None:
azure_token_provider = get_bearer_token_provider(
azure_credential, "https://cognitiveservices.azure.com/.default"
)
Expand Down

0 comments on commit ea9c3d1

Please sign in to comment.