Skip to content
Snippets Groups Projects
Commit ad25c5a5 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

remove OpenAI (non-chat) model

parent 33f58a73
No related branches found
No related tags found
2 merge requests!2remove is_chat argument,!1Refactor models
......@@ -300,7 +300,7 @@ class LlamaChat(Llama):
return {"messages": messages}, messages
class OpenAI(LLMModel):
class OpenAiChat(LLMModel):
"""Queries an OpenAI model using its API."""
def __init__(
......@@ -314,41 +314,6 @@ class OpenAI(LLMModel):
super().__init__(options, **kwargs)
def _create_completion(
self,
use_cache: bool = False,
**kwargs,
):
response = self._call_model(
self.openai_client.completions.create,
model=self.model,
use_cache=use_cache,
**kwargs,
)
response_text = response.choices[0].text
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
if system_message is not None:
prompt = system_message + prompt
return {"prompt": prompt}, messages
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments")
group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
class OpenAiChat(OpenAI):
def _create_completion(
self,
use_cache: bool = False,
......@@ -382,6 +347,11 @@ class OpenAiChat(OpenAI):
return {"messages": messages}, messages
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments")
group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
argument_group = argument_parser.add_argument_group("Model arguments")
argument_group.add_argument(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment