diff --git a/evoprompt/models.py b/evoprompt/models.py index 563979c34098093e1f519f8814ca755868ce4dcd..e6772a544dc6d0f999a62f5c74c19a43d33cda2b 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -300,7 +300,7 @@ class LlamaChat(Llama): return {"messages": messages}, messages -class OpenAI(LLMModel): +class OpenAiChat(LLMModel): """Queries an OpenAI model using its API.""" def __init__( @@ -314,41 +314,6 @@ class OpenAI(LLMModel): super().__init__(options, **kwargs) - def _create_completion( - self, - use_cache: bool = False, - **kwargs, - ): - response = self._call_model( - self.openai_client.completions.create, - model=self.model, - use_cache=use_cache, - **kwargs, - ) - response_text = response.choices[0].text - - usage = ModelUsage(**response.usage.__dict__) - return response_text, usage - - def build_model_input( - self, - prompt: str, - system_message: str | None, - messages: list[dict[str, str]], - history: list[dict[str, str]] | None = None, - ): - if system_message is not None: - prompt = system_message + prompt - return {"prompt": prompt}, messages - - @classmethod - def register_arguments(cls, parser: ArgumentParser): - group = parser.add_argument_group("OpenAI model arguments") - group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") - - -class OpenAiChat(OpenAI): - def _create_completion( self, use_cache: bool = False, @@ -382,6 +347,11 @@ class OpenAiChat(OpenAI): return {"messages": messages}, messages + @classmethod + def register_arguments(cls, parser: ArgumentParser): + group = parser.add_argument_group("OpenAI model arguments") + group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") + argument_group = argument_parser.add_argument_group("Model arguments") argument_group.add_argument(