Skip to content
Snippets Groups Projects
Commit b47cdc45 authored by Grießhaber Daniel's avatar Grießhaber Daniel :squid:
Browse files

fix error for openai model after refactoring

parent 1083a08d
No related branches found
No related tags found
2 merge requests!2remove is_chat argument,!1Refactor models
......@@ -73,9 +73,6 @@ class LLMModel(ABC):
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
if chat is None:
chat = self.chat
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
messages = [self._get_user_message(prompt)]
......@@ -331,13 +328,16 @@ class OpenAI(LLMModel):
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
def build_model_input(self, **kwargs):
return kwargs
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments")
group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
def OpenAiChat(OpenAI):
class OpenAiChat(OpenAI):
def _create_completion(
self,
......@@ -372,4 +372,3 @@ argument_group.add_argument(
type=int,
help="Maximum number of tokens being generated from LLM. ",
)
argument_group.add_argument("--chat", "-c", action="store_true")
......@@ -61,13 +61,8 @@ if __name__ == "__main__":
if debug:
logger.info("DEBUG mode: Do a quick run")
# set up evolution model
evolution_model_name = (
(options.evolution_engine + "chat")
if options.chat
else options.evolution_engine
)
evolution_model = LLMModel.get_model(evolution_model_name, options=options)
# # set up evolution model
evolution_model = LLMModel.get_model(options.evolution_engine, options=options)
match options.evolution_engine:
case "llama":
......@@ -81,13 +76,12 @@ if __name__ == "__main__":
logger.info("Using Llama as the evaluation engine")
evaluation_model: LLMModel
match options.evolution_engine:
case "llama":
case "llama" | "llamachat":
evaluation_model = evolution_model
case "openai":
if not options.chat:
evaluation_model = Llama(options)
else:
evaluation_model = LlamaChat(options)
evaluation_model = Llama(options)
case "openaichat":
evaluation_model = LlamaChat(options)
task = get_task(options.task, evaluation_model, **options.__dict__)
logger.info(f"Running with task {task.__class__.__name__}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment