Skip to content
Snippets Groups Projects
Commit b47cdc45 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

fix error for openai model after refactoring

parent 1083a08d
No related branches found
No related tags found
2 merge requests!2remove is_chat argument,!1Refactor models
...@@ -73,9 +73,6 @@ class LLMModel(ABC): ...@@ -73,9 +73,6 @@ class LLMModel(ABC):
history: list[dict[str, str]] | None = None, history: list[dict[str, str]] | None = None,
**kwargs: Any, **kwargs: Any,
) -> tuple[str, ModelUsage]: ) -> tuple[str, ModelUsage]:
if chat is None:
chat = self.chat
# create prompt # create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
messages = [self._get_user_message(prompt)] messages = [self._get_user_message(prompt)]
...@@ -331,13 +328,16 @@ class OpenAI(LLMModel): ...@@ -331,13 +328,16 @@ class OpenAI(LLMModel):
usage = ModelUsage(**response.usage.__dict__) usage = ModelUsage(**response.usage.__dict__)
return response_text, usage return response_text, usage
def build_model_input(self, **kwargs):
return kwargs
@classmethod @classmethod
def register_arguments(cls, parser: ArgumentParser): def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments") group = parser.add_argument_group("OpenAI model arguments")
group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo") group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
def OpenAiChat(OpenAI): class OpenAiChat(OpenAI):
def _create_completion( def _create_completion(
self, self,
...@@ -372,4 +372,3 @@ argument_group.add_argument( ...@@ -372,4 +372,3 @@ argument_group.add_argument(
type=int, type=int,
help="Maximum number of tokens being generated from LLM. ", help="Maximum number of tokens being generated from LLM. ",
) )
argument_group.add_argument("--chat", "-c", action="store_true")
...@@ -61,13 +61,8 @@ if __name__ == "__main__": ...@@ -61,13 +61,8 @@ if __name__ == "__main__":
if debug: if debug:
logger.info("DEBUG mode: Do a quick run") logger.info("DEBUG mode: Do a quick run")
# set up evolution model # # set up evolution model
evolution_model_name = ( evolution_model = LLMModel.get_model(options.evolution_engine, options=options)
(options.evolution_engine + "chat")
if options.chat
else options.evolution_engine
)
evolution_model = LLMModel.get_model(evolution_model_name, options=options)
match options.evolution_engine: match options.evolution_engine:
case "llama": case "llama":
...@@ -81,13 +76,12 @@ if __name__ == "__main__": ...@@ -81,13 +76,12 @@ if __name__ == "__main__":
logger.info("Using Llama as the evaluation engine") logger.info("Using Llama as the evaluation engine")
evaluation_model: LLMModel evaluation_model: LLMModel
match options.evolution_engine: match options.evolution_engine:
case "llama": case "llama" | "llamachat":
evaluation_model = evolution_model evaluation_model = evolution_model
case "openai": case "openai":
if not options.chat: evaluation_model = Llama(options)
evaluation_model = Llama(options) case "openaichat":
else: evaluation_model = LlamaChat(options)
evaluation_model = LlamaChat(options)
task = get_task(options.task, evaluation_model, **options.__dict__) task = get_task(options.task, evaluation_model, **options.__dict__)
logger.info(f"Running with task {task.__class__.__name__}") logger.info(f"Running with task {task.__class__.__name__}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment