From b47cdc459f6f64b338324121281032311f55678e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Mon, 19 Aug 2024 11:05:29 +0200 Subject: [PATCH] fix error for openai model after refactoring --- evoprompt/models.py | 9 ++++----- main.py | 18 ++++++------------ 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index 1349359..2d65147 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -73,9 +73,6 @@ class LLMModel(ABC): history: list[dict[str, str]] | None = None, **kwargs: Any, ) -> tuple[str, ModelUsage]: - if chat is None: - chat = self.chat - # create prompt prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix messages = [self._get_user_message(prompt)] @@ -331,13 +328,16 @@ class OpenAI(LLMModel): usage = ModelUsage(**response.usage.__dict__) return response_text, usage + def build_model_input(self, **kwargs): + return kwargs + @classmethod def register_arguments(cls, parser: ArgumentParser): group = parser.add_argument_group("OpenAI model arguments") group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") -def OpenAiChat(OpenAI): +class OpenAiChat(OpenAI): def _create_completion( self, @@ -372,4 +372,3 @@ argument_group.add_argument( type=int, help="Maximum number of tokens being generated from LLM. ", ) -argument_group.add_argument("--chat", "-c", action="store_true") diff --git a/main.py b/main.py index 90bac4b..e0b8fae 100644 --- a/main.py +++ b/main.py @@ -61,13 +61,8 @@ if __name__ == "__main__": if debug: logger.info("DEBUG mode: Do a quick run") - # set up evolution model - evolution_model_name = ( - (options.evolution_engine + "chat") - if options.chat - else options.evolution_engine - ) - evolution_model = LLMModel.get_model(evolution_model_name, options=options) + # # set up evolution model + evolution_model = LLMModel.get_model(options.evolution_engine, options=options) match options.evolution_engine: case "llama": @@ -81,13 +76,12 @@ if __name__ == "__main__": logger.info("Using Llama as the evaluation engine") evaluation_model: LLMModel match options.evolution_engine: - case "llama": + case "llama" | "llamachat": evaluation_model = evolution_model case "openai": - if not options.chat: - evaluation_model = Llama(options) - else: - evaluation_model = LlamaChat(options) + evaluation_model = Llama(options) + case "openaichat": + evaluation_model = LlamaChat(options) task = get_task(options.task, evaluation_model, **options.__dict__) logger.info(f"Running with task {task.__class__.__name__}") -- GitLab