Skip to content

Commit

Permalink
finetunes: add mapping for status events to be displayed to user (coh…
Browse files Browse the repository at this point in the history
…ere-ai#345)

* add mapping
  • Loading branch information
innainu authored Dec 13, 2023
1 parent 709ac28 commit fbb426e
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 72 deletions.
8 changes: 6 additions & 2 deletions cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from cohere.responses.cluster import ClusterJobResult
from cohere.responses.connector import Connector
from cohere.responses.custom_model import (
CUSTOM_MODEL_INTERNAL_STATUS_MAPPING,
CUSTOM_MODEL_PRODUCT_MAPPING,
CUSTOM_MODEL_STATUS,
CUSTOM_MODEL_TYPE,
Expand Down Expand Up @@ -1348,10 +1349,13 @@ def list_custom_models(
before = before.replace(tzinfo=before.tzinfo or timezone.utc)
if after:
after = after.replace(tzinfo=after.tzinfo or timezone.utc)

internal_statuses = []
if statuses:
for status in statuses:
internal_statuses.append(CUSTOM_MODEL_INTERNAL_STATUS_MAPPING[status])
json = {
"query": {
"statuses": statuses,
"statuses": internal_statuses,
"before": before.isoformat(timespec="seconds") if before else None,
"after": after.isoformat(timespec="seconds") if after else None,
"orderBy": order_by,
Expand Down
10 changes: 7 additions & 3 deletions cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from cohere.responses.classify import Example as ClassifyExample
from cohere.responses.cluster import AsyncClusterJobResult
from cohere.responses.custom_model import (
CUSTOM_MODEL_INTERNAL_STATUS_MAPPING,
CUSTOM_MODEL_PRODUCT_MAPPING,
CUSTOM_MODEL_STATUS,
CUSTOM_MODEL_TYPE,
Expand Down Expand Up @@ -1024,7 +1025,7 @@ async def list_custom_models(
"""List custom models of your organization.
Args:
statuses (CUSTOM_MODEL_STATUS, optional): search for fintunes which are in one of these states
statuses (CUSTOM_MODEL_STATUS, optional): search for finetunes which are in one of these states
before (datetime, optional): search for custom models that were created before this timestamp
after (datetime, optional): search for custom models that were created after this timestamp
order_by (Literal["asc", "desc"], optional): sort custom models by created at, either asc or desc
Expand All @@ -1035,10 +1036,13 @@ async def list_custom_models(
before = before.replace(tzinfo=before.tzinfo or timezone.utc)
if after:
after = after.replace(tzinfo=after.tzinfo or timezone.utc)

internal_statuses = []
if statuses:
for status in statuses:
internal_statuses.append(CUSTOM_MODEL_INTERNAL_STATUS_MAPPING[status])
json = {
"query": {
"statuses": statuses,
"statuses": internal_statuses,
"before": before.isoformat(timespec="seconds") if before else None,
"after": after.isoformat(timespec="seconds") if after else None,
"orderBy": order_by,
Expand Down
31 changes: 26 additions & 5 deletions cohere/responses/custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,40 @@
CUSTOM_MODEL_STATUS = Literal[
"UNKNOWN",
"CREATED",
"DATA_PROCESSING",
"TRAINING",
"DEPLOYING",
"READY",
"FAILED",
"DELETED",
"TEMPORARILY_OFFLINE",
"PAUSED",
"QUEUED",
]
INTERNAL_CUSTOM_MODEL_STATUS = Literal[
"UNKNOWN",
"CREATED",
"FINETUNING",
"EXPORTING_MODEL",
"DEPLOYING_API",
"READY",
"FAILED",
"DELETED",
"DELETE_FAILED",
"CANCELLED",
"TEMPORARILY_OFFLINE",
"PAUSED",
"QUEUED",
]
CUSTOM_MODEL_INTERNAL_STATUS_MAPPING: Dict[CUSTOM_MODEL_STATUS, INTERNAL_CUSTOM_MODEL_STATUS] = {
"UNKNOWN": "UNKNOWN",
"CREATED": "CREATED",
"TRAINING": "FINETUNING",
"DEPLOYING": "DEPLOYING_API",
"READY": "READY",
"FAILED": "FAILED",
"DELETED": "DELETED",
"TEMPORARILY_OFFLINE": "TEMPORARILY_OFFLINE",
"PAUSED": "PAUSED",
"QUEUED": "QUEUED",
}
REVERSE_CUSTOM_MODEL_INTERNAL_STATUS_MAPPING = {v: k for k, v in CUSTOM_MODEL_INTERNAL_STATUS_MAPPING.items()}

INTERNAL_CUSTOM_MODEL_TYPE = Literal["GENERATIVE", "CLASSIFICATION", "RERANK", "CHAT"]
CUSTOM_MODEL_TYPE = Literal["GENERATIVE", "CLASSIFY", "RERANK", "CHAT"]
Expand Down Expand Up @@ -131,7 +152,7 @@ def from_dict(cls, data: Dict[str, Any], wait_fn) -> "BaseCustomModel":
wait_fn=wait_fn,
id=data["id"],
name=data["name"],
status=data["status"],
status=REVERSE_CUSTOM_MODEL_INTERNAL_STATUS_MAPPING[data["status"]],
model_type=REVERSE_CUSTOM_MODEL_PRODUCT_MAPPING[data["settings"]["finetuneType"]],
created_at=_parse_date(data["created_at"]),
completed_at=_parse_date(data["completed_at"]) if "completed_at" in data else None,
Expand Down
19 changes: 10 additions & 9 deletions tests/sync/test_finetune.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
import os
import unittest

import pytest
from utils import get_api_key

import cohere

API_KEY = get_api_key()
client = cohere.Client(API_KEY)

IN_CI = os.getenv("CI", "").lower() in ["true", "1"]


class TestFinetuneClient(unittest.TestCase):
def test_list(self):
self.assertTrue(len(client.list_custom_models()) > 0)
models = client.list_custom_models()
# there should always be a model, but make sure tests don't randomly break
if models:
self.assertTrue(len(client.list_custom_models()) > 0)

def test_get(self):
first = client.list_custom_models()[0]
by_id = client.get_custom_model(first.id)
self.assertEqual(first.id, by_id.id)
models = client.list_custom_models()
# there should always be a model, but make sure tests don't randomly break
if models:
first = models[0]
by_id = client.get_custom_model(first.id)
self.assertEqual(first.id, by_id.id)

@pytest.mark.skipif(IN_CI, reason="flaky in CI for some reason")
def test_metrics(self):
models = client.list_custom_models(statuses=["PAUSED", "READY"])
# there should always be a model, but make sure tests don't randomly break
Expand Down
118 changes: 65 additions & 53 deletions tests/test_custom_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from datetime import datetime, timezone

from cohere.responses.custom_model import (
Expand All @@ -8,6 +9,58 @@
_parse_date_with_variable_seconds,
)

sample_finetune_dict = {
"id": "dd942318-dac4-44c2-866e-ce95396e3b00",
"name": "test-response-2",
"creator_id": "91d102b7-b2b9-464a-aa5e-85569a49aa6d",
"organization_id": "7e17242a-8489-4650-9631-f9bcb49319bd",
"organization_name": "",
"status": "QUEUED",
"created_at": "2023-11-17T17:24:36.769824Z",
"updated_at": "2023-11-17T17:24:36.769824Z",
"settings": {
"datasetID": "",
"trainFiles": [
{
"path": "gs://cohere-dev/blobheart-uploads/staging/91d102b7-b2b9-464a-aa5e-85569a49aa6d/wnvz7i/GENERATIVE/train.csv",
"separator": "",
"switchColumns": False,
"hasHeader": False,
"delimiter": ",",
}
],
"evalFiles": [],
"baseModel": "medium",
"finetuneType": "GENERATIVE",
"faxOverride": None,
"finetuneStrategy": "TFEW",
"hyperparameters": {
"earlyStoppingPatience": 6,
"earlyStoppingThreshold": 0.01,
"trainBatchSize": 16,
"trainSteps": 2,
"trainEpochs": 1,
"learningRate": 0.01,
},
"baseVersion": "14.2.0",
"multiLabel": False,
},
"model": {
"name": "mike-test-response-2",
"route": "dd942318-dac4-44c2-866e-ce95396e3b00-ft",
"endpoints": ["generate"],
"isFinetune": True,
"isProtected": False,
"languages": None,
},
"data_metrics": {
"train_files": [{"name": "train.csv", "totalExamples": 32, "size": 140}],
"total_examples": 32,
"trainable_token_count": 192,
},
"billing": {"numTrainingTokens": 192, "epochs": 1, "unitPrice": 0.000001, "totalCost": 0.000192},
}


def test_custom_model_from_dict_with_all_fields_set():
as_dict = {
Expand Down Expand Up @@ -70,59 +123,7 @@ def test_finetune_billing():


def test_finetune_response_with_billing():
response_dict = {
"finetune": {
"id": "dd942318-dac4-44c2-866e-ce95396e3b00",
"name": "test-response-2",
"creator_id": "91d102b7-b2b9-464a-aa5e-85569a49aa6d",
"organization_id": "7e17242a-8489-4650-9631-f9bcb49319bd",
"organization_name": "",
"status": "QUEUED",
"created_at": "2023-11-17T17:24:36.769824Z",
"updated_at": "2023-11-17T17:24:36.769824Z",
"settings": {
"datasetID": "",
"trainFiles": [
{
"path": "gs://cohere-dev/blobheart-uploads/staging/91d102b7-b2b9-464a-aa5e-85569a49aa6d/wnvz7i/GENERATIVE/train.csv",
"separator": "",
"switchColumns": False,
"hasHeader": False,
"delimiter": ",",
}
],
"evalFiles": [],
"baseModel": "medium",
"finetuneType": "GENERATIVE",
"faxOverride": None,
"finetuneStrategy": "TFEW",
"hyperparameters": {
"earlyStoppingPatience": 6,
"earlyStoppingThreshold": 0.01,
"trainBatchSize": 16,
"trainSteps": 2,
"trainEpochs": 1,
"learningRate": 0.01,
},
"baseVersion": "14.2.0",
"multiLabel": False,
},
"model": {
"name": "mike-test-response-2",
"route": "dd942318-dac4-44c2-866e-ce95396e3b00-ft",
"endpoints": ["generate"],
"isFinetune": True,
"isProtected": False,
"languages": None,
},
"data_metrics": {
"train_files": [{"name": "train.csv", "totalExamples": 32, "size": 140}],
"total_examples": 32,
"trainable_token_count": 192,
},
"billing": {"numTrainingTokens": 192, "epochs": 1, "unitPrice": 0.000001, "totalCost": 0.000192},
}
}
response_dict = {"finetune": sample_finetune_dict}
actual = CustomModel.from_dict(response_dict["finetune"], None)
expect = CustomModel(
id="dd942318-dac4-44c2-866e-ce95396e3b00",
Expand All @@ -146,3 +147,14 @@ def test_finetune_response_with_billing():
wait_fn=None,
)
assert actual.__dict__ == expect.__dict__


def test_statuses_mapping():
statuses_finetune_dict = copy.deepcopy(sample_finetune_dict)
statuses_finetune_dict["status"] = "FINETUNING"
response_dict_training = {"finetune": statuses_finetune_dict}
actual_training = CustomModel.from_dict(response_dict_training["finetune"], None)
statuses_finetune_dict["status"] = "DEPLOYING_API"
response_dict_deploying = {"finetune": statuses_finetune_dict}
actual_training = CustomModel.from_dict(response_dict_deploying["finetune"], None)
assert actual_training.__dict__.get("status") == "DEPLOYING"

0 comments on commit fbb426e

Please sign in to comment.