diff --git a/eval_prompt.py b/eval_prompt.py index f6f13b4ac397372b316b1210150ae1c8901ece05..dc54636209b15cd8a320090842433c0b1ed9788c 100644 --- a/eval_prompt.py +++ b/eval_prompt.py @@ -14,19 +14,21 @@ logger = logging.getLogger(__name__) def evaluate_prompt(prompt: str, task_name: str, args: Namespace): logger.info(f'Evaluating prompt "{prompt}"') - # evaluation_model = LLMModel.get_model(name="llamachat", options=args) evaluation_model = LLMModel.get_model( name="hfllamachat", + # name="llamachat", options=args, - # model="PKU-Alignment/alpaca-7b-reproduced", - model="chavinlo/alpaca-native", - load_in_8bit=True, - torch_dtype=torch.float16, - device_map="auto", - # use less randomness, i.e., more certain outputs - temperature=0.0, + model="PKU-Alignment/alpaca-7b-reproduced", + # model="chavinlo/alpaca-native", + model_kwargs=dict( + load_in_8bit=True, + torch_dtype=torch.float16, + device_map="auto", + # use less randomness, i.e., more certain outputs + temperature=0.0, + ), # torch_dtype is not JSON serializable therefore we ignore it - ignore_cache_kwargs=["torch_dtype"], + ignore_cache_kwargs=["model_kwargs.torch_dtype"], ) task = get_task(task_name, evaluation_model, **vars(args)) diff --git a/evoprompt/models.py b/evoprompt/models.py index b1f086faada17208538a2de247f3720e150f5eef..cb6e1b3e6595f73b1c397c7204d2d9f1ca45ce9a 100644 --- a/evoprompt/models.py +++ b/evoprompt/models.py @@ -119,16 +119,27 @@ class LLMModel(ABC): def get_options_kwargs_hash(options: Namespace, kwargs): # sometimes we want to ignore certain kwargs from the hash, e.g., when they are not relevant for the model or if they are not serializable kwargs = kwargs.copy() + + def iter_dict(d: dict, prefix: str = ""): + for k, v in d.items(): + k = f"{prefix}{k}" + if isinstance(v, dict): + yield from iter_dict(v, prefix=f"{k}.") + else: + yield k, v + ignore_cache_kwargs: list[str] | None = kwargs.pop("ignore_cache_kwargs", None) if ignore_cache_kwargs is not None: options = Namespace( **{ k: v - for k, v in vars(options).items() + for k, v in iter_dict(vars(options)) if k not in ignore_cache_kwargs } ) - kwargs = {k: v for k, v in kwargs.items() if k not in ignore_cache_kwargs} + kwargs = { + k: v for k, v in iter_dict(kwargs) if k not in ignore_cache_kwargs + } unique_options_key = json.dumps( (vars(options), kwargs), @@ -342,7 +353,7 @@ class HfLlamaChat(ChatModel, Llama): def __init__( self, options: Namespace, - model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", + model="meta-llama/Meta-Llama-3.1-8B-Instruct", **kwargs, ) -> None: import torch @@ -350,12 +361,23 @@ class HfLlamaChat(ChatModel, Llama): super().__init__(options, **kwargs) + # set some default values + model_kwargs = kwargs.pop("model_kwargs", {}) + if "torch_dtype" not in model_kwargs: + model_kwargs["torch_dtype"] = torch.bfloat16 + if "device_map" not in model_kwargs: + model_kwargs["device_map"] = "auto" + # initialize model self.pipeline = transformers.pipeline( "text-generation", - model=model_id, - model_kwargs={"torch_dtype": torch.bfloat16}, - device_map="auto", + model=model, + model_kwargs=model_kwargs, + **kwargs, + ) + # Setting the pad token to the eos token to avoid stdout prints + self.pipeline.model.generation_config.pad_token_id = ( + self.pipeline.model.generation_config.eos_token_id ) def _create_completion(