Skip to content
Snippets Groups Projects

remove is_chat argument

Merged Grießhaber Daniel requested to merge remove-is-chat into refactor-models
1 unresolved thread
Compare and Show latest version
2 files
+ 40
34
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 34
22
@@ -44,11 +44,15 @@ class LLMModel(ABC):
self.kwargs = kwargs
# set up caching for model calls
self._call_model_cached = None
if not options.disable_cache:
cache = Cache(Path(".cache_dir", self.model_cache_key))
self._call_model_cached = cache.memoize(typed=True, ignore=[0, "func"])(
self._call_model_cached
)
@cache.memoize(typed=True, ignore=["func"])
def _call_function(func, *args, **kwargs):
return func(*args, **kwargs)
self._call_model_cached = _call_function
@abstractmethod
def build_model_input(
@@ -73,13 +77,12 @@ class LLMModel(ABC):
history: list[dict[str, str]] | None = None,
**kwargs: Any,
) -> tuple[str, ModelUsage]:
if chat is None:
chat = self.chat
# create prompt
prompt = prompt_prefix + prompt + prompt_suffix + prompt_appendix
messages = [self._get_user_message(prompt)]
model_input = self.build_model_input(prompt, system_message, messages, history)
model_input, messages = self.build_model_input(
prompt, system_message, messages, history
)
reponse, usage = self._create_completion(
**model_input,
@@ -117,16 +120,11 @@ class LLMModel(ABC):
warnings.warn("Caching is disabled when a grammar is provided.")
use_cache = False
if use_cache:
# use cached function call
if use_cache and self._call_model_cached is not None:
return self._call_model_cached(model_completion_fn, **kwargs)
else:
return model_completion_fn(**kwargs)
def _call_model_cached(self, func, *args, **kwargs):
# `cache_key` is added to the cache key (e.g., to distinguish between different models), but it is not used in the function
return func(*args, **kwargs)
@property
def model_cache_key(self):
unique_options_key = json.dumps(
@@ -137,7 +135,9 @@ class LLMModel(ABC):
sort_keys=True,
)
cache_key = (
str(self.model_name) + hashlib.sha1(unique_options_key.encode()).hexdigest()
str(self.model_name).replace("/", "_")
+ "/"
+ hashlib.sha1(unique_options_key.encode()).hexdigest()
)
return cache_key
@@ -208,7 +208,7 @@ class Llama(LLMModel):
):
if system_message is not None:
prompt = system_message + prompt
return {"prompt": prompt}
return {"prompt": prompt}, messages
def _create_completion(
self,
@@ -292,13 +292,12 @@ class LlamaChat(Llama):
# a history is prepended to the messages, and we assume that it also includes a system message, i.e., we never add a system message in this case
# TODO is it better to check for a system message in the history?
if history is not None:
messages = history + messages
[messages.insert(index, entry) for index, entry in enumerate(history)]
elif system_message:
messages.insert(
0,
self._get_system_message(system_message),
)
return {"messages": messages}
messages = [self._get_system_message(system_message)] + messages
return {"messages": messages}, messages
class OpenAI(LLMModel):
@@ -331,13 +330,27 @@ class OpenAI(LLMModel):
usage = ModelUsage(**response.usage.__dict__)
return response_text, usage
def build_model_input(
self,
prompt: str,
system_message: str | None,
messages: list[dict[str, str]],
history: list[dict[str, str]] | None = None,
):
return {
"prompt": prompt,
"system_message": system_message,
"messages": messages,
"history": history,
}, messages
@classmethod
def register_arguments(cls, parser: ArgumentParser):
group = parser.add_argument_group("OpenAI model arguments")
group.add_argument("--openai-model", "-m", type=str, default="gpt-3.5-turbo")
def OpenAiChat(OpenAI):
class OpenAiChat(OpenAI):
def _create_completion(
self,
@@ -372,4 +385,3 @@ argument_group.add_argument(
type=int,
help="Maximum number of tokens being generated from LLM. ",
)
argument_group.add_argument("--chat", "-c", action="store_true")
Loading