diff --git a/README.md b/README.md index dd3ed5c8..7cf8125f 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,10 @@ strategies. | Anthropic | ✅ | ✅ | ✅ |✅| | MidJourney | | ✅ | |✅| +To support different models in custom func(e.g. Model Comparison) , [follow our example](https://github.com/YiVal/YiVal/blob/litellm_complete/demo/configs/model_compare.yml) + +To support different models in evaluators and generators , [check our config](https://github.com/YiVal/YiVal/blob/litellm_complete/demo/configs/headline_generation.yml) + ## Installation ```sh diff --git a/demo/configs/headline_generation.yml b/demo/configs/headline_generation.yml index 95f11fa7..48d92bf0 100644 --- a/demo/configs/headline_generation.yml +++ b/demo/configs/headline_generation.yml @@ -5,6 +5,7 @@ dataset: openai_prompt_data_generator: chunk_size: 100000 diversify: true + # model_name specify the llm model , e.g. a16z-infra/llama-2-13b-chat:9dff94b1bed5af738655d4a7cbcdcde2bd503aa85c94334fe1f42af7f3dd5ee3 model_name: gpt-4 prompt: "Please provide a concrete and realistic test case as a dictionary for function invocation using the ** operator. @@ -18,7 +19,7 @@ dataset: name: headline_generation_for_business parameters: tech_startup_business: str - number_of_examples: 1 + number_of_examples: 2 output_path: null source_type: machine_generated @@ -31,8 +32,9 @@ variations: variation_id: null generator_name: openai_prompt_based_variation_generator generator_config: - openai_model_name: gpt-4 - number_of_variations: 1 + #model_name specify the llm model , e.g. a16z-infra/llama-2-13b-chat:9dff94b1bed5af738655d4a7cbcdcde2bd503aa85c94334fe1f42af7f3dd5ee3 + model_name: gpt-4 + number_of_variations: 2 diversify: true variables: null prompt: @@ -65,6 +67,7 @@ evaluators: D It meets the criterion very well. E It meets the criterion exceptionally well, with little to no room for improvement. choices: ["A", "B", "C", "D", "E"] + # model_name specify the llm model , e.g. a16z-infra/llama-2-13b-chat:9dff94b1bed5af738655d4a7cbcdcde2bd503aa85c94334fe1f42af7f3dd5ee3 model_name: gpt-4 description: "evaluate the quality of the landing page headline" scale_description: "0-4" diff --git a/src/yival/schemas/varation_generator_configs.py b/src/yival/schemas/varation_generator_configs.py index 562b2734..f841dbe9 100644 --- a/src/yival/schemas/varation_generator_configs.py +++ b/src/yival/schemas/varation_generator_configs.py @@ -20,7 +20,7 @@ class OpenAIPromptBasedVariationGeneratorConfig(BaseVariationGeneratorConfig): """ Generate variation using chatgpt. Currently only support openai models. """ - openai_model_name: str = "gpt-4" + model_name: str = "gpt-4" prompt: Union[str, List[Dict[str, str]]] = "" diversify: bool = False variables: Optional[List[str]] = None diff --git a/src/yival/variation_generators/openai_prompt_based_variation_generator.py b/src/yival/variation_generators/openai_prompt_based_variation_generator.py index bbe85071..eb658668 100644 --- a/src/yival/variation_generators/openai_prompt_based_variation_generator.py +++ b/src/yival/variation_generators/openai_prompt_based_variation_generator.py @@ -3,11 +3,12 @@ import pickle from typing import Any, Dict, Iterator, List -import openai from tqdm import tqdm from ..common import utils +from ..common.model_utils import llm_completion from ..schemas.experiment_config import WrapperVariation +from ..schemas.model_configs import Request from ..schemas.varation_generator_configs import ( OpenAIPromptBasedVariationGeneratorConfig, ) @@ -109,7 +110,7 @@ def generate_variations(self) -> Iterator[List[WrapperVariation]]: responses = asyncio.run( utils.parallel_completions( message_batches, - self.config.openai_model_name, + self.config.model_name, self.config.max_tokens, pbar=pbar ) @@ -131,13 +132,17 @@ def generate_variations(self) -> Iterator[List[WrapperVariation]]: desc="Generating Variations", unit="variation" ) as pbar: - output = openai.ChatCompletion.create( - model=self.config.openai_model_name, - messages=messages, - temperature=1, - presence_penalty=1, - max_tokens=self.config.max_tokens, - ) + output = llm_completion( + Request( + model_name=self.config.model_name, + prompt=messages, + params={ + "temperature": 1, + "presence_penalty": 1, + "max_tokens": self.config.max_tokens, + } + ) + ).output if self.config.variables and not validate_output( output.choices[0].message.content, self.config.variables