Skip to content
Snippets Groups Projects
Commit 9249cf32 authored by Max Kimmich's avatar Max Kimmich
Browse files

Update temperature for more creative generations

parent 2780fdcb
No related branches found
No related tags found
1 merge request!7Refactor tasks and models and fix format for various models
......@@ -249,6 +249,7 @@ class GeneticAlgorithm(EvolutionAlgorithm):
self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
enforce_randomness=True,
)
)
......@@ -351,6 +352,7 @@ class DifferentialEvolution(EvolutionAlgorithm):
self.evolution_model.create_completion(
system_message=SYSTEM_MESSAGE,
prompt=filled_prompt,
enforce_randomness=True,
)
)
......@@ -468,6 +470,7 @@ class DifferentialEvolutionWithCot(DifferentialEvolution):
prompt=filled_prompt,
history=history,
stop="</prompt>" if idx == len(DE_COT_PROMPTS) - 1 else None,
enforce_randomness=True,
)
)
logger.debug(
......
......@@ -275,7 +275,8 @@ class Llama(LLMModel):
"max_tokens": max_tokens,
}
if enforce_randomness:
model_call_kwargs["temperature"] = 2.0
# same temperature as in evoprompt paper reference implementation
model_call_kwargs["temperature"] = 0.5
model_call_kwargs["seed"] = random.randint(0, 2**32 - 1)
else:
model_call_kwargs["temperature"] = 0.0
......@@ -398,7 +399,8 @@ class LlamaChat(ChatModel, Llama):
"max_tokens": max_tokens,
}
if enforce_randomness:
model_call_kwargs["temperature"] = 2.0
# same temperature as in evoprompt paper reference implementation
model_call_kwargs["temperature"] = 0.5
model_call_kwargs["seed"] = random.randint(0, 2**32 - 1)
else:
model_call_kwargs["temperature"] = 0.0
......@@ -478,7 +480,8 @@ class HfChat(ChatModel, LLMModel):
"max_length": max_tokens if max_tokens is not None else 2048,
}
if enforce_randomness:
model_call_kwargs["temperature"] = 2.0
# same temperature as in evoprompt paper reference implementation
model_call_kwargs["temperature"] = 0.5
model_call_kwargs["do_sample"] = True
else:
model_call_kwargs["do_sample"] = False
......@@ -575,7 +578,8 @@ class OpenAIChat(ChatModel, LLMModel):
"max_completion_tokens": max_tokens if max_tokens is not None else 1024,
}
if enforce_randomness:
model_call_kwargs["temperature"] = 2.0
# same temperature as in evoprompt paper reference implementation
model_call_kwargs["temperature"] = 0.5
else:
model_call_kwargs["temperature"] = 0.0
......
......@@ -415,7 +415,7 @@ class Task(metaclass=ABCMeta):
# we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result
use_cache=True,
# use less randomness, i.e., more certain outputs
temperature=0.0,
enforce_randomness=False,
)
if not self.use_grammar:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment