From a5d48a9fc0f041a25ecc2de20ceb587783f6046f Mon Sep 17 00:00:00 2001 From: Maximilian Kimmich <maximilian.kimmich@ims.uni-stuttgart.de> Date: Thu, 17 Oct 2024 11:57:45 +0200 Subject: [PATCH] Allow to directly control temperature --- evoprompt/models.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/evoprompt/models.py b/evoprompt/models.py index 4b0b4ab..f80a582 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -101,6 +101,7 @@ class LLMModel(ABC): grammar: llama_cpp.LlamaGrammar | None = None, stop: str | None = None, use_randomness: bool = False, + temperature: float | None = None, **kwargs, ) -> tuple[str, ChatMessages | None, ChatMessages, ModelUsage]: raise NotImplementedError("Non-chat models are currently not supported") @@ -168,6 +169,7 @@ class LLMModel(ABC): stop: str | None = None, max_tokens: int | None = None, use_randomness: bool = False, + temperature: float | None = None, ): ... def _get_user_message(self, content: Any) -> ChatMessage: @@ -318,6 +320,7 @@ class Llama(LLMModel): stop: str | None = None, max_tokens: int | None = None, use_randomness: bool = False, + temperature: float | None = None, ): # setup kwargs for model call model_call_kwargs = { @@ -327,8 +330,11 @@ class Llama(LLMModel): "max_tokens": max_tokens, } if use_randomness: - # same temperature as in evoprompt paper reference implementation - model_call_kwargs["temperature"] = 0.5 + if temperature is None: + # same temperature as in evoprompt paper reference implementation + model_call_kwargs["temperature"] = 0.5 + else: + model_call_kwargs["temperature"] = temperature model_call_kwargs["seed"] = random.randint(0, 2**32 - 1) else: model_call_kwargs["temperature"] = 0.0 @@ -400,6 +406,7 @@ class ChatModel: stop: str | None = None, history: ChatMessages | None = None, use_randomness: bool = False, + temperature: float | None = None, **kwargs, ) -> tuple[str, ChatMessages | None, ChatMessages, ModelUsage]: if messages is None: @@ -425,6 +432,7 @@ class ChatModel: use_cache=use_cache, max_tokens=self.max_tokens, use_randomness=use_randomness, + temperature=temperature, ) return ( @@ -447,6 +455,7 @@ class LlamaChat(ChatModel, Llama): stop: str | None, max_tokens: int | None, use_randomness: bool, + temperature: float | None = None, ): # input( # f"The input for a Llama3.x model will look like this:\n{format_llama3(messages).prompt}" @@ -459,8 +468,11 @@ class LlamaChat(ChatModel, Llama): "max_tokens": max_tokens, } if use_randomness: - # same temperature as in evoprompt paper reference implementation - model_call_kwargs["temperature"] = 0.5 + if temperature is None: + # same temperature as in evoprompt paper reference implementation + model_call_kwargs["temperature"] = 0.5 + else: + model_call_kwargs["temperature"] = temperature model_call_kwargs["seed"] = random.randint(0, 2**32 - 1) else: model_call_kwargs["temperature"] = 0.0 @@ -533,17 +545,22 @@ class HfChat(ChatModel, LLMModel): stop: str | None, max_tokens: int | None, use_randomness: bool, + temperature: float | None = None, **kwargs, ): # setup kwargs for model call model_call_kwargs = { "text_inputs": messages, "stop": stop, - "max_length": max_tokens if max_tokens is not None else 16384, + # "max_length": max_tokens if max_tokens is not None else 16384, + "max_new_tokens": 1000, } if use_randomness: - # same temperature as in evoprompt paper reference implementation - model_call_kwargs["temperature"] = 0.5 + if temperature is None: + # same temperature as in evoprompt paper reference implementation + model_call_kwargs["temperature"] = 0.5 + else: + model_call_kwargs["temperature"] = temperature model_call_kwargs["do_sample"] = True else: model_call_kwargs["do_sample"] = False @@ -633,6 +650,7 @@ class OpenAIChat(ChatModel, LLMModel): stop: str | None, max_tokens: int | None, use_randomness: bool, + temperature: float | None = None, ): # setup kwargs for model call model_call_kwargs = { @@ -642,8 +660,11 @@ class OpenAIChat(ChatModel, LLMModel): "max_completion_tokens": max_tokens if max_tokens is not None else 1024, } if use_randomness: - # same temperature as in evoprompt paper reference implementation - model_call_kwargs["temperature"] = 0.5 + if temperature is None: + # same temperature as in evoprompt paper reference implementation + model_call_kwargs["temperature"] = 0.5 + else: + model_call_kwargs["temperature"] = temperature else: model_call_kwargs["temperature"] = 0.0 -- GitLab