From a5d48a9fc0f041a25ecc2de20ceb587783f6046f Mon Sep 17 00:00:00 2001
From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de>
Date: Thu, 17 Oct 2024 11:57:45 +0200
Subject: [PATCH] Allow to directly control temperature

---
 evoprompt/models.py | 39 ++++++++++++++++++++++++++++++---------
 1 file changed, 30 insertions(+), 9 deletions(-)

diff --git a/evoprompt/models.py b/evoprompt/models.py
index 4b0b4ab..f80a582 100644
--- a/evoprompt/models.py
+++ b/evoprompt/models.py
@@ -101,6 +101,7 @@ class LLMModel(ABC):
         grammar: llama_cpp.LlamaGrammar | None = None,
         stop: str | None = None,
         use_randomness: bool = False,
+        temperature: float | None = None,
         **kwargs,
     ) -> tuple[str, ChatMessages | None, ChatMessages, ModelUsage]:
         raise NotImplementedError("Non-chat models are currently not supported")
@@ -168,6 +169,7 @@ class LLMModel(ABC):
         stop: str | None = None,
         max_tokens: int | None = None,
         use_randomness: bool = False,
+        temperature: float | None = None,
     ): ...
 
     def _get_user_message(self, content: Any) -> ChatMessage:
@@ -318,6 +320,7 @@ class Llama(LLMModel):
         stop: str | None = None,
         max_tokens: int | None = None,
         use_randomness: bool = False,
+        temperature: float | None = None,
     ):
         # setup kwargs for model call
         model_call_kwargs = {
@@ -327,8 +330,11 @@ class Llama(LLMModel):
             "max_tokens": max_tokens,
         }
         if use_randomness:
-            # same temperature as in evoprompt paper reference implementation
-            model_call_kwargs["temperature"] = 0.5
+            if temperature is None:
+                # same temperature as in evoprompt paper reference implementation
+                model_call_kwargs["temperature"] = 0.5
+            else:
+                model_call_kwargs["temperature"] = temperature
             model_call_kwargs["seed"] = random.randint(0, 2**32 - 1)
         else:
             model_call_kwargs["temperature"] = 0.0
@@ -400,6 +406,7 @@ class ChatModel:
         stop: str | None = None,
         history: ChatMessages | None = None,
         use_randomness: bool = False,
+        temperature: float | None = None,
         **kwargs,
     ) -> tuple[str, ChatMessages | None, ChatMessages, ModelUsage]:
         if messages is None:
@@ -425,6 +432,7 @@ class ChatModel:
             use_cache=use_cache,
             max_tokens=self.max_tokens,
             use_randomness=use_randomness,
+            temperature=temperature,
         )
 
         return (
@@ -447,6 +455,7 @@ class LlamaChat(ChatModel, Llama):
         stop: str | None,
         max_tokens: int | None,
         use_randomness: bool,
+        temperature: float | None = None,
     ):
         # input(
         #     f"The input for a Llama3.x model will look like this:\n{format_llama3(messages).prompt}"
@@ -459,8 +468,11 @@ class LlamaChat(ChatModel, Llama):
             "max_tokens": max_tokens,
         }
         if use_randomness:
-            # same temperature as in evoprompt paper reference implementation
-            model_call_kwargs["temperature"] = 0.5
+            if temperature is None:
+                # same temperature as in evoprompt paper reference implementation
+                model_call_kwargs["temperature"] = 0.5
+            else:
+                model_call_kwargs["temperature"] = temperature
             model_call_kwargs["seed"] = random.randint(0, 2**32 - 1)
         else:
             model_call_kwargs["temperature"] = 0.0
@@ -533,17 +545,22 @@ class HfChat(ChatModel, LLMModel):
         stop: str | None,
         max_tokens: int | None,
         use_randomness: bool,
+        temperature: float | None = None,
         **kwargs,
     ):
         # setup kwargs for model call
         model_call_kwargs = {
             "text_inputs": messages,
             "stop": stop,
-            "max_length": max_tokens if max_tokens is not None else 16384,
+            # "max_length": max_tokens if max_tokens is not None else 16384,
+            "max_new_tokens": 1000,
         }
         if use_randomness:
-            # same temperature as in evoprompt paper reference implementation
-            model_call_kwargs["temperature"] = 0.5
+            if temperature is None:
+                # same temperature as in evoprompt paper reference implementation
+                model_call_kwargs["temperature"] = 0.5
+            else:
+                model_call_kwargs["temperature"] = temperature
             model_call_kwargs["do_sample"] = True
         else:
             model_call_kwargs["do_sample"] = False
@@ -633,6 +650,7 @@ class OpenAIChat(ChatModel, LLMModel):
         stop: str | None,
         max_tokens: int | None,
         use_randomness: bool,
+        temperature: float | None = None,
     ):
         # setup kwargs for model call
         model_call_kwargs = {
@@ -642,8 +660,11 @@ class OpenAIChat(ChatModel, LLMModel):
             "max_completion_tokens": max_tokens if max_tokens is not None else 1024,
         }
         if use_randomness:
-            # same temperature as in evoprompt paper reference implementation
-            model_call_kwargs["temperature"] = 0.5
+            if temperature is None:
+                # same temperature as in evoprompt paper reference implementation
+                model_call_kwargs["temperature"] = 0.5
+            else:
+                model_call_kwargs["temperature"] = temperature
         else:
             model_call_kwargs["temperature"] = 0.0
 
-- 
GitLab