From ad25c5a523a8c0fdcf291b995b57ed19e80aadb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Grie=C3=9Fhaber?= <griesshaber@hdm-stuttgart.de> Date: Mon, 19 Aug 2024 14:19:48 +0200 Subject: [PATCH] remove OpenAI (non-chat) model --- evoprompt/models.py | 42 ++++++------------------------------------ 1 file changed, 6 insertions(+), 36 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index 563979c..e6772a5 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -300,7 +300,7 @@ class LlamaChat(Llama): return {"messages": messages}, messages -class OpenAI(LLMModel): +class OpenAiChat(LLMModel): """Queries an OpenAI model using its API.""" def __init__( @@ -314,41 +314,6 @@ class OpenAI(LLMModel): super().__init__(options, **kwargs) - def _create_completion( - self, - use_cache: bool = False, - **kwargs, - ): - response = self._call_model( - self.openai_client.completions.create, - model=self.model, - use_cache=use_cache, - **kwargs, - ) - response_text = response.choices[0].text - - usage = ModelUsage(**response.usage.__dict__) - return response_text, usage - - def build_model_input( - self, - prompt: str, - system_message: str | None, - messages: list[dict[str, str]], - history: list[dict[str, str]] | None = None, - ): - if system_message is not None: - prompt = system_message + prompt - return {"prompt": prompt}, messages - - @classmethod - def register_arguments(cls, parser: ArgumentParser): - group = parser.add_argument_group("OpenAI model arguments") - group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") - - -class OpenAiChat(OpenAI): - def _create_completion( self, use_cache: bool = False, @@ -382,6 +347,11 @@ class OpenAiChat(OpenAI): return {"messages": messages}, messages + @classmethod + def register_arguments(cls, parser: ArgumentParser): + group = parser.add_argument_group("OpenAI model arguments") + group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") + argument_group = argument_parser.add_argument_group("Model arguments") argument_group.add_argument( -- GitLab