diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index 4348cbb2cec3fa97c4f6e3077722ae22d2ad18e0..5bb2f3cba9758e246c16397d86250508084bfced 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -1,11 +1,9 @@
 import logging
-import shelve
 from abc import ABCMeta, abstractmethod
 from collections import deque
 from dataclasses import KW_ONLY, dataclass
-from pathlib import Path
 from statistics import mean
-from typing import Iterable, Literal, Union
+from typing import Any, Iterable, Literal
 
 from datasets import Dataset, load_dataset
 from llama_cpp import LlamaGrammar
@@ -265,6 +263,7 @@ class Task(metaclass=ABCMeta):
         evaluation_strategy: EvaluationStrategyKey,
         validation_split: str | None = None,
         use_evolution_demo: bool = False,
+        n_evaluation_demo: int | None = None,
         test_split: str | None = None,
         debug: bool = False,
         **kwargs,
@@ -272,6 +271,7 @@ class Task(metaclass=ABCMeta):
         self.model = model
         self.debug = debug
         self.use_grammar = use_grammar
+        self.n_evaluation_demo = n_evaluation_demo
 
         self.evaluation_strategy = get_evaluation_strategy(evaluation_strategy)(self)
         logger.info(
@@ -287,6 +287,10 @@ class Task(metaclass=ABCMeta):
         self.validation_dataset = self.load_validation_set(
             validation_dataset, validation_split
         )
+
+        # get demonstration samples
+        self.demonstration_samples = self.get_demonstration_samples()
+
         if self.debug and len(self.validation_dataset) > 10:
             self.validation_dataset = self.validation_dataset.shuffle(42).select(
                 range(10)
@@ -309,13 +313,38 @@ class Task(metaclass=ABCMeta):
     def load_test_set(self, test_dataset: str, test_split: str | None):
         return load_dataset(test_dataset, split=test_split)
 
+    def get_demonstration_samples(self) -> list[DatasetDatum]:
+        if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0:
+            return []
+
+        # get demonstration samples from validation set
+        samples_ids = self._get_demonstration_sample_ids(
+            self.validation_dataset, self.n_evaluation_demo
+        )
+        # retrieve demonstration samples from validation set
+        demonstration_samples = self.validation_dataset.filter(
+            lambda _, idx: idx in samples_ids, with_indices=True
+        )
+        # remove demonstration samples from validation set
+        self.validation_dataset = self.validation_dataset.filter(
+            lambda _, idx: idx not in samples_ids, with_indices=True
+        )
+        return demonstration_samples
+
+    @abstractmethod
+    def _get_demonstration_sample_ids(
+        self, dataset: Dataset, n_evaluation_demo: int
+    ) -> list[Any]:
+        pass
+
     def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]:
         # run model for inference using grammar to constrain output
         # TODO grammar also depends on prompt and vice-versa -> what are good labels?
         response, _, usage = self.model(
             system_message=SYSTEM_MESSAGE,
             prompt=prompt,
-            prompt_appendix=self._get_prompt_text_for_datum(datum),
+            prompt_suffix="\n" + self._get_prompt_text_for_datum(datum),
+            prompt_appendix="\nResponse: ",
             # grammar can be applied to constrain the model output
             grammar=self._get_grammar(datum) if self.use_grammar else None,
         )
@@ -344,9 +373,10 @@ class Task(metaclass=ABCMeta):
     def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
         pass
 
+    @abstractmethod
+    # This method is needed for the demonstration examples.
     def _get_prompt_output_for_datum(self, datum: DatasetDatum) -> str:
-        # This method is needed for the demonstration example.
-        return self._get_gold_label_for_datum(datum)
+        pass
 
     @abstractmethod
     def _aggregate_result(self, results: list) -> float:
@@ -366,6 +396,14 @@ class Task(metaclass=ABCMeta):
         evaluation_usage = ModelUsage()
         evaluation_history = []
 
+        # augment prompt with demonstration samples
+        prompt += "".join(
+            [
+                f"\n{self._get_prompt_text_for_datum(datum)}\nResponse: {self._get_prompt_output_for_datum(datum)}"
+                for datum in self.demonstration_samples
+            ]
+        )
+
         for datum in dataset_iterator:
             result, usage = self._evaluate_sample(prompt, datum)
             results.append(result)
diff --git a/evoprompt/task/text_classification.py b/evoprompt/task/text_classification.py
index c2c2cfa3429d17ef2daf13d9d5ae3b2a1665cd0e..d1e1016be0d840f64ab59f0f6912536b3d422eac 100644
--- a/evoprompt/task/text_classification.py
+++ b/evoprompt/task/text_classification.py
@@ -1,8 +1,9 @@
 from abc import abstractmethod
-from functools import lru_cache
 import logging
-from typing import Mapping
+import re
+from typing import Any, Mapping
 
+from datasets import Dataset
 from llama_cpp import LlamaGrammar
 from evoprompt.task import Task
 from evoprompt.task.task import DatasetDatum
@@ -16,24 +17,42 @@ class TextClassification(Task):
         gold_label = self._get_gold_label_for_datum(datum)
         class_mapping = self._get_label_mapping()
         response, usage = self.predict(prompt=prompt, datum=datum)
-        response = response.lower()
         if self.use_grammar:
             # model output is from label space
             answer_label = class_mapping[response]
         else:
-            answer_label = None
-            for label in class_mapping.keys():
-                if label in response:
-                    answer_label = class_mapping[label]
-                    break
+            matches = re.findall(
+                # regex that matches "negative" or "positive" after "Response: "
+                rf"Response: ({'|'.join(class_mapping.keys())})",
+                response,
+                flags=re.IGNORECASE,
+            )
+            if matches:
+                answer_label = class_mapping[matches[-1]]
             else:
+                # TODO in this case we could try other stuff, like checking if a class label is somewhere in the response?
                 logger.warning(f"Invalid answer: {response}")
                 return "failed", usage
 
         classification_result = "incorrect" if answer_label != gold_label else "correct"
         return classification_result, usage
 
-    # @lru_cache
+    def _get_demonstration_sample_ids(
+        self, dataset: Dataset, n_evaluation_demo: int
+    ) -> list[Any]:
+        # we need to return row indices hence we add them first as a new column to keep track of them
+        dataset_with_row_indices = dataset.map(
+            lambda _, idx: {"idx": idx}, with_indices=True
+        ).shuffle(42)
+        sample_ids = []
+        for label in self._get_label_mapping().values():
+            sample_ids_for_label = dataset_with_row_indices.filter(
+                lambda sample: self._get_gold_label_for_datum(sample) == label
+            )[:n_evaluation_demo]["idx"]
+            sample_ids += sample_ids_for_label
+        return sample_ids
+
+    # NOTE cannot be cached since grammar is not picklable
     def _get_grammar(self, datum: DatasetDatum, verbose: bool = False):
         return LlamaGrammar.from_string(
             "root ::= ({})".format(
@@ -43,6 +62,7 @@ class TextClassification(Task):
         )
 
     def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
+        # TODO do we need quotes?
         return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"'
 
     @abstractmethod
@@ -57,6 +77,10 @@ class TextClassification(Task):
     def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
         pass
 
+    def _get_prompt_output_for_datum(self, datum: DatasetDatum) -> str:
+        id_to_label = {v: k for k, v in self._get_label_mapping().items()}
+        return id_to_label[self._get_gold_label_for_datum(datum)]
+
     def _aggregate_result(self, results: list[str]) -> float:
         num_correct_results = sum(1 for result in results if result == "correct")
         accuracy = num_correct_results / len(results)