diff --git a/evoprompt/task/__init__.py b/evoprompt/task/__init__.py
index 0bfbdf273c1cb6a0a5fe9232efc949b637a4f73e..c83a698b57a52a9f05c5ba7d2ba0dfd8dd43233a 100644
--- a/evoprompt/task/__init__.py
+++ b/evoprompt/task/__init__.py
@@ -11,6 +11,7 @@ from evoprompt.task.text_classification import TextClassification
 from evoprompt.task.sentiment_analysis import SentimentAnalysis
 from evoprompt.task.topic_classification import AGNews, TREC
 from evoprompt.task.subjectivity_classification import Subj
+from evoprompt.task.text_generation import TextGeneration
 from evoprompt.task.summarization import Summarization, SAMSum
 from evoprompt.task.simplification import Simplification, ASSET
 
@@ -33,7 +34,7 @@ def get_task(name: str, evaluation_model: LLMModel, **options):
 argument_parser.add_argument("--debug", "-d", action="store_true", default=None)
 argument_group = argument_parser.add_argument_group("Task arguments")
 argument_group.add_argument(
-    "--task", "-t", type=str, required=True, choices=tasks.keys()
+    "--task", "-t", type=str, required=True, choices=sorted(tasks.keys())
 )
 argument_group.add_argument("--use-grammar", "-g", action="store_true")
 argument_group.add_argument(
diff --git a/evoprompt/task/question_answering.py b/evoprompt/task/question_answering.py
index f85d441957cdaf6281e3541e5bfb3bc431364d90..fbd2d60867e89ee95bb532a7ee2a8d3283de8097 100644
--- a/evoprompt/task/question_answering.py
+++ b/evoprompt/task/question_answering.py
@@ -1,7 +1,8 @@
 import logging
 import re
 from abc import abstractmethod
-from functools import lru_cache
+from functools import cache, lru_cache
+from typing import Iterable
 
 from datasets import Dataset
 from evaluate import load as load_metric
@@ -9,6 +10,7 @@ from llama_cpp import LlamaGrammar
 
 from evoprompt.opt_types import ModelUsage
 from evoprompt.task.task import DatasetDatum, Task
+from evoprompt.utils import get_rng
 
 logger = logging.getLogger(__name__)
 
@@ -53,7 +55,7 @@ class QuestionAnswering(Task):
 
         self.metric = load_metric("squad")
 
-    @lru_cache
+    @cache
     def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar:
         # context-sensitive grammar
         context = self._get_context_from_datum(datum)
@@ -74,6 +76,12 @@ class QuestionAnswering(Task):
             "\nContext: " + '"' + context + '"' + "\nQuestion: " + '"' + question + '"'
         )
 
+    def _get_demonstration_sample_ids(
+        self, dataset: Dataset, n_evaluation_demo: int
+    ) -> Iterable[int]:
+        # select demonstration samples uniformly at random
+        return get_rng().choice(len(dataset), n_evaluation_demo, replace=False)
+
     @abstractmethod
     def _get_id_from_datum(self, datum: DatasetDatum):
         pass
@@ -89,8 +97,23 @@ class QuestionAnswering(Task):
     def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
         _id = self._get_id_from_datum(datum)
         gold_answers = self._get_gold_label_for_datum(datum)
-        answer, usage = self.predict(prompt, datum)
-        # TODO check if answer is lower-cased in metric computation
+        response, usage = self.predict(prompt, datum)
+        response = response.lower()
+
+        if not self.use_grammar:
+            # if we do not use a grammar, we need to extract the answer from the response
+            # otherwise the answer is from the context as enforced by the grammar
+            matches = re.findall(
+                # regex that matches class labels after "Response: "
+                rf"(?:Response:\s?)?(.+)",
+                response.splitlines()[-1],
+                flags=re.IGNORECASE,
+            )
+            # look for an answer in the response, if not found, use whole response
+            if matches:
+                answer = matches[-1]
+            else:
+                answer = response
 
         result = self.metric.compute(
             predictions=[{"prediction_text": answer, "id": _id}],
@@ -140,7 +163,9 @@ class QuestionAnswering(Task):
     @property
     def base_prompts(self):
         # TODO find good base prompts
-        return ["""In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""]
+        return [
+            """In this task, you are given a context and a question. The task is to answer the question given the context. Return only the answer without any other text. Make sure that the answer is taken directly from the context."""
+        ]
 
 
 class SQuAD(QuestionAnswering):
@@ -165,3 +190,6 @@ class SQuAD(QuestionAnswering):
 
     def _get_gold_label_for_datum(self, datum: DatasetDatum):
         return datum["answers"]
+
+    def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str:
+        return self._get_gold_label_for_datum(datum)["text"][0]
diff --git a/evoprompt/task/sentiment_analysis.py b/evoprompt/task/sentiment_analysis.py
index d3697c20ac2508531afcb39a69ec9da90ca0c71a..6886ed18ce89032b4f19b56f99155abab7f222b0 100644
--- a/evoprompt/task/sentiment_analysis.py
+++ b/evoprompt/task/sentiment_analysis.py
@@ -1,16 +1,11 @@
-import json
 import logging
-from abc import abstractmethod
-from argparse import Namespace
-from functools import lru_cache
-from pathlib import Path
+from functools import cache
 from typing import Mapping
 
 from datasets import load_dataset
-from llama_cpp import LlamaGrammar
 
 from evoprompt.helpers.prompts import BasePromptsFromJsonMixin
-from evoprompt.task import Task, TextClassification
+from evoprompt.task import TextClassification
 from evoprompt.task.task import DatasetDatum
 
 logger = logging.getLogger(__name__)
@@ -22,13 +17,14 @@ logger = logging.getLogger(__name__)
 
 
 class SentimentAnalysis(TextClassification):
-    @lru_cache
-    def _get_label_mapping(self) -> Mapping:
+    @staticmethod
+    @cache
+    def _get_label_mapping() -> Mapping:
         return {"negative": 0, "positive": 1}
 
 
 class HfSST2(BasePromptsFromJsonMixin, SentimentAnalysis):
-    shorthand = "hf-sst2"
+    shorthand = "sst2-hf"
     base_prompts_file = "evoprompt/initial_prompts/sst-2/prompts.json"
 
     def load_validation_set(
@@ -111,14 +107,15 @@ class SST5(BasePromptsFromJsonMixin, SentimentAnalysis):
     def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
         return datum["label"]
 
-    @lru_cache
-    def _get_label_mapping(self) -> Mapping:
+    @staticmethod
+    @cache
+    def _get_label_mapping() -> Mapping:
         classes = ["terrible", "bad", "okay", "good", "great"]
         return dict(zip(classes, range(len(classes))))
 
 
 class HfMovieReviews(BasePromptsFromJsonMixin, SentimentAnalysis):
-    shorthand = "hf-mr"
+    shorthand = "mr-hf"
     base_prompts_file = "evoprompt/initial_prompts/mr/prompts.json"
 
     def load_validation_set(
diff --git a/evoprompt/task/simplification.py b/evoprompt/task/simplification.py
index 8295db713986895859f400051f268cb023a36331..4d58aebf61d4d9d167a6d623839b674b1c0a95e3 100644
--- a/evoprompt/task/simplification.py
+++ b/evoprompt/task/simplification.py
@@ -1,52 +1,27 @@
-import json
 import logging
-from abc import abstractmethod
-from functools import lru_cache
-from pathlib import Path
-from typing import Mapping
 
 from evaluate import load as load_metric
-from llama_cpp import LlamaGrammar
 
 from evoprompt.helpers.prompts import BasePromptsFromJsonMixin
-from evoprompt.models import LLMModel
-from evoprompt.task import Task
+from evoprompt.task import TextGeneration
 from evoprompt.task.task import DatasetDatum
 
 logger = logging.getLogger(__name__)
 
 
-class Simplification(Task):
+class Simplification(TextGeneration):
     def __init__(self, *args, **kwargs) -> None:
         super().__init__(*args, **kwargs)
 
         self.metric = load_metric("evaluate-metric/sari")
 
-    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+    def compute_metric(self, datum: DatasetDatum, prediction: str) -> float:
         gold_label = self._get_gold_label_for_datum(datum)
-        response, usage = self.predict(prompt=prompt, datum=datum)
-        response = response.lower()
-
-        scores = self.metric.compute(
+        return self.metric.compute(
             sources=[self._get_text_for_datum(datum)],
-            predictions=[response],
+            predictions=[prediction],
             references=[gold_label],
-        )
-        return scores["sari"], usage
-
-    @lru_cache
-    def _get_grammar(self, datum: DatasetDatum, verbose: bool = False):
-        return None
-
-    def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
-        return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"'
-
-    @abstractmethod
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
-        pass
-
-    def _aggregate_result(self, results: list[str]) -> float:
-        return sum(results) / len(results)
+        )["sari"]
 
     @property
     def metric_name(self):
diff --git a/evoprompt/task/subjectivity_classification.py b/evoprompt/task/subjectivity_classification.py
index 63988c6aafde7edc293149cedfb9e882e8e2c4c4..5ded8dada78f412afa15f7362a7813b9f32fb487 100644
--- a/evoprompt/task/subjectivity_classification.py
+++ b/evoprompt/task/subjectivity_classification.py
@@ -1,6 +1,4 @@
-import json
-from functools import lru_cache
-from pathlib import Path
+from functools import cache
 from typing import Mapping
 
 from datasets import load_dataset
@@ -39,7 +37,8 @@ class Subj(BasePromptsFromJsonMixin, TextClassification):
     def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
         return datum["label"]
 
-    @lru_cache
-    def _get_label_mapping(self) -> Mapping:
+    @staticmethod
+    @cache
+    def _get_label_mapping() -> Mapping:
         classes = ["subjective", "objective"]
         return dict(zip(classes, range(len(classes))))
diff --git a/evoprompt/task/summarization.py b/evoprompt/task/summarization.py
index d199937f76f4799d77d6a8a32095b99cc674d6dd..3d45aefdca87d03d30025276cc5a391b596dee74 100644
--- a/evoprompt/task/summarization.py
+++ b/evoprompt/task/summarization.py
@@ -1,53 +1,25 @@
-import json
 import logging
-from abc import abstractmethod
-from functools import lru_cache
-from pathlib import Path
-from typing import Mapping
 
 from evaluate import load as load_metric
-from llama_cpp import LlamaGrammar
 
 from evoprompt.helpers.prompts import BasePromptsFromJsonMixin
-from evoprompt.models import LLMModel
-from evoprompt.task import Task
+from evoprompt.task import TextGeneration
 from evoprompt.task.task import DatasetDatum
 
 logger = logging.getLogger(__name__)
 
 
-class Summarization(Task):
+class Summarization(TextGeneration):
     def __init__(self, *args, **kwargs) -> None:
         super().__init__(*args, **kwargs)
 
         self.metric = load_metric("evaluate-metric/rouge")
 
-    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+    def compute_metric(self, datum: DatasetDatum, prediction: str) -> float:
         gold_label = self._get_gold_label_for_datum(datum)
-        response, usage = self.predict(prompt=prompt, datum=datum)
-        response = response.lower()
-
-        scores = self.metric.compute(predictions=[response], references=[gold_label])
-
-        return scores["rougeL"], usage
-
-    @lru_cache
-    def _get_grammar(self, datum: DatasetDatum, verbose: bool = False):
-        return None
-
-    def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
-        return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"'
-
-    @abstractmethod
-    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
-        pass
-
-    @abstractmethod
-    def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
-        pass
-
-    def _aggregate_result(self, results: list[str]) -> float:
-        return sum(results) / len(results)
+        return self.metric.compute(predictions=[prediction], references=[gold_label])[
+            "rougeL"
+        ]
 
     @property
     def metric_name(self):
diff --git a/evoprompt/task/task.py b/evoprompt/task/task.py
index 8b27328b2139fc50716f132f8b26acbabeed8139..cfc79a208161243ce89ecbf9ad97d95d11091d6a 100644
--- a/evoprompt/task/task.py
+++ b/evoprompt/task/task.py
@@ -1,11 +1,9 @@
 import logging
-import shelve
 from abc import ABCMeta, abstractmethod
 from collections import deque
 from dataclasses import KW_ONLY, dataclass
-from pathlib import Path
 from statistics import mean
-from typing import Iterable, Literal, Union
+from typing import Iterable, Literal
 
 from datasets import Dataset, load_dataset
 from llama_cpp import LlamaGrammar
@@ -265,6 +263,7 @@ class Task(metaclass=ABCMeta):
         evaluation_strategy: EvaluationStrategyKey,
         validation_split: str | None = None,
         use_evolution_demo: bool = False,
+        n_evaluation_demo: int | None = None,
         test_split: str | None = None,
         debug: bool = False,
         **kwargs,
@@ -272,6 +271,7 @@ class Task(metaclass=ABCMeta):
         self.model = model
         self.debug = debug
         self.use_grammar = use_grammar
+        self.n_evaluation_demo = n_evaluation_demo
 
         self.evaluation_strategy = get_evaluation_strategy(evaluation_strategy)(self)
         logger.info(
@@ -287,6 +287,12 @@ class Task(metaclass=ABCMeta):
         self.validation_dataset = self.load_validation_set(
             validation_dataset, validation_split
         )
+
+        # get demonstration samples
+        self.demonstration_samples, self.validation_dataset = (
+            self.get_demonstration_samples(self.validation_dataset)
+        )
+
         if self.debug and len(self.validation_dataset) > 10:
             self.validation_dataset = self.validation_dataset.shuffle(42).select(
                 range(10)
@@ -309,13 +315,39 @@ class Task(metaclass=ABCMeta):
     def load_test_set(self, test_dataset: str, test_split: str | None):
         return load_dataset(test_dataset, split=test_split)
 
+    def get_demonstration_samples(self, dataset: Dataset) -> list[DatasetDatum]:
+        if self.n_evaluation_demo is None or self.n_evaluation_demo <= 0:
+            return []
+
+        # get demonstration samples from validation set
+        samples_ids = self._get_demonstration_sample_ids(
+            dataset, self.n_evaluation_demo
+        )
+        # retrieve demonstration samples from validation set
+        demonstration_samples = dataset.filter(
+            lambda _, idx: idx in samples_ids, with_indices=True
+        )
+        # remove demonstration samples from validation set
+        remaining_dataset = self.dataset.filter(
+            lambda _, idx: idx not in samples_ids, with_indices=True
+        )
+        return demonstration_samples, remaining_dataset
+
+    @abstractmethod
+    def _get_demonstration_sample_ids(
+        self, dataset: Dataset, n_evaluation_demo: int
+    ) -> Iterable[int]:
+        pass
+
     def predict(self, prompt: str, datum: DatasetDatum) -> tuple[str, ModelUsage]:
         # run model for inference using grammar to constrain output
         # TODO grammar also depends on prompt and vice-versa -> what are good labels?
         response, _, usage = self.model.create_completion(
             system_message=SYSTEM_MESSAGE,
+            # TODO Allow to modify prompt construction in subclasses
             prompt=prompt,
-            prompt_appendix=self._get_prompt_text_for_datum(datum),
+            prompt_suffix="\n" + self._get_prompt_text_for_datum(datum),
+            prompt_appendix="\nResponse: ",
             # grammar can be applied to constrain the model output
             grammar=self._get_grammar(datum) if self.use_grammar else None,
             # we use cached completions to speed up the process although we loose the non-deterministic behavior of LMs, but we're ok with a single result
@@ -346,9 +378,10 @@ class Task(metaclass=ABCMeta):
     def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
         pass
 
-    def _get_prompt_output_for_datum(self, datum: DatasetDatum) -> str:
-        # This method is needed for the demonstration example.
-        return self._get_gold_label_for_datum(datum)
+    @abstractmethod
+    # This method is needed for the demonstration examples.
+    def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str:
+        pass
 
     @abstractmethod
     def _aggregate_result(self, results: list) -> float:
@@ -368,6 +401,14 @@ class Task(metaclass=ABCMeta):
         evaluation_usage = ModelUsage()
         evaluation_history = []
 
+        # augment prompt with demonstration samples
+        prompt += "".join(
+            [
+                f"\n{self._get_prompt_text_for_datum(datum)}\nResponse: {self._get_gold_label_generation_for_datum(datum)}"
+                for datum in self.demonstration_samples
+            ]
+        )
+
         for datum in dataset_iterator:
             result, usage = self._evaluate_sample(prompt, datum)
             results.append(result)
diff --git a/evoprompt/task/text_classification.py b/evoprompt/task/text_classification.py
index c2c2cfa3429d17ef2daf13d9d5ae3b2a1665cd0e..ea7f3821dd62b1a1190a63a5e22ee7ed70388a82 100644
--- a/evoprompt/task/text_classification.py
+++ b/evoprompt/task/text_classification.py
@@ -1,13 +1,15 @@
+import logging
+import re
 from abc import abstractmethod
 from functools import lru_cache
-import logging
-from typing import Mapping
+from typing import Any, Mapping
 
+from datasets import Dataset
 from llama_cpp import LlamaGrammar
+
 from evoprompt.task import Task
 from evoprompt.task.task import DatasetDatum
 
-
 logger = logging.getLogger(__name__)
 
 
@@ -21,19 +23,47 @@ class TextClassification(Task):
             # model output is from label space
             answer_label = class_mapping[response]
         else:
-            answer_label = None
-            for label in class_mapping.keys():
-                if label in response:
-                    answer_label = class_mapping[label]
-                    break
+            matches = re.findall(
+                # regex that matches class labels after "Response: "
+                rf"Response: ({'|'.join(class_mapping.keys())})",
+                response,
+                flags=re.IGNORECASE,
+            )
+            if matches:
+                answer_label = class_mapping[matches[-1]]
             else:
-                logger.warning(f"Invalid answer: {response}")
-                return "failed", usage
+                # look for a label in the response, if not found, return failed
+                matches = re.findall(
+                    # regex that matches class labels anywhere in the response
+                    rf"({'|'.join(class_mapping.keys())})",
+                    response,
+                    flags=re.IGNORECASE,
+                )
+                if matches:
+                    answer_label = class_mapping[matches[-1]]
+                else:
+                    logger.warning(f"Invalid answer: {response}")
+                    return "failed", usage
 
         classification_result = "incorrect" if answer_label != gold_label else "correct"
         return classification_result, usage
 
-    # @lru_cache
+    def _get_demonstration_sample_ids(
+        self, dataset: Dataset, n_evaluation_demo: int
+    ) -> list[Any]:
+        # we need to return row indices hence we add them first as a new column to keep track of them
+        dataset_with_row_indices = dataset.map(
+            lambda _, idx: {"idx": idx}, with_indices=True
+        ).shuffle(42)
+        sample_ids = []
+        for label in self._get_label_mapping().values():
+            sample_ids_for_label = dataset_with_row_indices.filter(
+                lambda sample: self._get_gold_label_for_datum(sample) == label
+            )[:n_evaluation_demo]["idx"]
+            sample_ids += sample_ids_for_label
+        return sample_ids
+
+    # NOTE cannot be cached since grammar is not picklable
     def _get_grammar(self, datum: DatasetDatum, verbose: bool = False):
         return LlamaGrammar.from_string(
             "root ::= ({})".format(
@@ -43,6 +73,7 @@ class TextClassification(Task):
         )
 
     def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
+        # TODO do we need quotes?
         return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"'
 
     @abstractmethod
@@ -57,6 +88,14 @@ class TextClassification(Task):
     def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
         pass
 
+    @lru_cache
+    def _get_inverse_label_mapping(self):
+        return {v: k for k, v in self._get_label_mapping().items()}
+
+    def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str:
+        id_to_label = self._get_inverse_label_mapping()
+        return id_to_label[self._get_gold_label_for_datum(datum)]
+
     def _aggregate_result(self, results: list[str]) -> float:
         num_correct_results = sum(1 for result in results if result == "correct")
         accuracy = num_correct_results / len(results)
diff --git a/evoprompt/task/text_generation.py b/evoprompt/task/text_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e20cd67f5f1ade917b5180cb952d1e379720048
--- /dev/null
+++ b/evoprompt/task/text_generation.py
@@ -0,0 +1,43 @@
+import logging
+from abc import abstractmethod
+from typing import Iterable
+
+from datasets import Dataset
+from llama_cpp import LlamaGrammar
+
+from evoprompt.task import Task
+from evoprompt.task.task import DatasetDatum
+from evoprompt.utils import get_rng
+
+logger = logging.getLogger(__name__)
+
+
+class TextGeneration(Task):
+    def _evaluate_sample(self, prompt: str, datum: DatasetDatum):
+        response, usage = self.predict(prompt=prompt, datum=datum)
+        response = response.lower()
+        return self.compute_metric(datum, response), usage
+
+    def _get_grammar(self, datum: DatasetDatum) -> LlamaGrammar:
+        # there is no grammar for open text generation
+        return None
+
+    def _get_demonstration_sample_ids(
+        self, dataset: Dataset, n_evaluation_demo: int
+    ) -> Iterable[int]:
+        # select demonstration samples uniformly at random
+        return get_rng().choice(len(dataset), n_evaluation_demo, replace=False)
+
+    def _get_prompt_text_for_datum(self, datum: DatasetDatum) -> str:
+        # TODO do we need quotes?
+        return "\nInput: " + '"' + self._get_text_for_datum(datum) + '"'
+
+    def _get_gold_label_generation_for_datum(self, datum: DatasetDatum) -> str:
+        return self._get_gold_label_for_datum(datum)
+
+    @abstractmethod
+    def _get_text_for_datum(self, datum: DatasetDatum) -> str:
+        pass
+
+    def _aggregate_result(self, results: list[str]) -> float:
+        return sum(results) / len(results)
diff --git a/evoprompt/task/topic_classification.py b/evoprompt/task/topic_classification.py
index aecec3ee9b544b17b7d990d8feecc010a8aef747..4103a43242c0f9adbf83adb75bf73cac16e19d34 100644
--- a/evoprompt/task/topic_classification.py
+++ b/evoprompt/task/topic_classification.py
@@ -1,12 +1,9 @@
-import json
-from functools import lru_cache
-from pathlib import Path
+from functools import cache, lru_cache
 from typing import Mapping
 
 from datasets import load_dataset
 
 from evoprompt.helpers.prompts import BasePromptsFromJsonMixin
-from evoprompt.models import LLMModel
 from evoprompt.task import TextClassification
 from evoprompt.task.task import DatasetDatum
 
@@ -31,8 +28,9 @@ class AGNews(BasePromptsFromJsonMixin, TextClassification):
     def _get_gold_label_for_datum(self, datum: DatasetDatum) -> str:
         return datum["label"]
 
-    @lru_cache
-    def _get_label_mapping(self) -> Mapping:
+    @staticmethod
+    @cache
+    def _get_label_mapping() -> Mapping:
         classes = ["world", "sports", "business", "tech"]
         return dict(zip(classes, range(len(classes))))