diff --git a/evoprompt/models.py b/evoprompt/models.py index 13493599ff6f8e47c2d9078544cd62f161ec0a86..2d651475d3caf0e95759bf5ab31c0b734d123f2d 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -73,9 +73,6 @@ class LLMModel(ABC): history: list[dict[str, str]] | None = None, **kwargs: Any, ) -> tuple[str, ModelUsage]: - if chat is None: - chat = self.chat - # create prompt prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix messages = [self._get_user_message(prompt)] @@ -331,13 +328,16 @@ class OpenAI(LLMModel): usage = ModelUsage(**response.usage.__dict__) return response_text, usage + def build_model_input(self, **kwargs): + return kwargs + @classmethod def register_arguments(cls, parser: ArgumentParser): group = parser.add_argument_group("OpenAI model arguments") group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") -def OpenAiChat(OpenAI): +class OpenAiChat(OpenAI): def _create_completion( self, @@ -372,4 +372,3 @@ argument_group.add_argument( type=int, help="Maximum number of tokens being generated from LLM. ", ) -argument_group.add_argument("--chat", "-c", action="store_true") diff --git a/main.py b/main.py index 90bac4b39e4171a9e3de79abaf636c9eb6a2b7ae..e0b8faedc0485ebce7a65ac35969faa971580a74 100644 --- a/main.py +++ b/main.py @@ -61,13 +61,8 @@ if __name__ == "__main__": if debug: logger.info("DEBUG mode: Do a quick run") - # set up evolution model - evolution_model_name = ( - (options.evolution_engine + "chat") - if options.chat - else options.evolution_engine - ) - evolution_model = LLMModel.get_model(evolution_model_name, options=options) + # # set up evolution model + evolution_model = LLMModel.get_model(options.evolution_engine, options=options) match options.evolution_engine: case "llama": @@ -81,13 +76,12 @@ if __name__ == "__main__": logger.info("Using Llama as the evaluation engine") evaluation_model: LLMModel match options.evolution_engine: - case "llama": + case "llama" | "llamachat": evaluation_model = evolution_model case "openai": - if not options.chat: - evaluation_model = Llama(options) - else: - evaluation_model = LlamaChat(options) + evaluation_model = Llama(options) + case "openaichat": + evaluation_model = LlamaChat(options) task = get_task(options.task, evaluation_model, **options.__dict__) logger.info(f"Running with task {task.__class__.__name__}")