diff --git a/docs/source/quicktour.mdx b/docs/source/quicktour.mdx
index 69eb52a06..7daea73da 100644
--- a/docs/source/quicktour.mdx
+++ b/docs/source/quicktour.mdx
@@ -30,7 +30,7 @@ lighteval accelerate \
"leaderboard|truthfulqa:mc|0|0"
```
-Here, we first choose a backend (either `accelerate`, `nanotron`, or `vllm`), and then specify the model and task(s) to run.
+Here, we first choose a backend (either `accelerate`, `nanotron`, `endpoint`, or `vllm`), and then specify the model and task(s) to run.
The syntax for the model arguments is `key1=value1,key2=value2,etc`.
Valid key-value pairs correspond with the backend configuration, and are detailed [below](#Model Arguments).
@@ -104,13 +104,32 @@ GPUs.
## Backend configuration
+#### General information
+
The `model-args` argument takes a string representing a list of model
argument. The arguments allowed vary depending on the backend you use and
correspond to the fields of the model configs.
-The model config can be found [here](./package_reference/models).
+The model configurations can be found [here](./package_reference/models).
+
+All models allow you to post process your reasoning model predictions,
+to remove the thinking tokens from the trace used to compute the metrics,
+using `--remove-reasoning-tags`, and `--reasoning-tags` to specify which
+reasoning tags to remove (defaults to and ).
+
+Here's an example with `mistralai/Magistral-Small-2507` which outputs custom
+think tokens.
+
+```bash
+lighteval vllm \
+ "model_name=mistralai/Magistral-Small-2507,dtype=float16,data_parallel_size=4" \
+ "lighteval|aime24|0|0" \
+ --remove-reasoning-tags \
+ --reasoning-tags="[('[THINK]','[/THINK]')]"
+```
+
-## Nanotron
+#### Nanotron
To evaluate a model trained with nanotron on a single gpu.
diff --git a/src/lighteval/logging/info_loggers.py b/src/lighteval/logging/info_loggers.py
index e48648d23..446006aec 100644
--- a/src/lighteval/logging/info_loggers.py
+++ b/src/lighteval/logging/info_loggers.py
@@ -170,27 +170,10 @@ class Detail:
"""Experiment details of one single example of one task.
Attributes:
- example (str): Current task example query
- instruction (str): Instruction prepended to the example and few shots.
- For example "In this task, you are given information of type x. You need to predict y."
- full_prompt (str): Expanded full prompt (instruction if present, then prompt)
- num_effective_few_shots (int): Number of actual few shots used for the example.
- This depends on the model context length and few-shots samples size: when using effective few-shots,
- only `num_effective_few_shots` few-shot samples are kept, allowing
- 1) each of the used few-shot examples and the prompt to not be truncated
- 2) this context still allows the model to predict up to the requested max numbers of tokens within its remaining context size.
- num_asked_few_shots (int): Initially asked number of few-shot samples.
- predictions (list): List of the actual model predictions
- input_tokens (list): List of the input tokens given to the model
- cont_tokens (list): List of the continuation tokens predicted by the model
- truncated (list): Size of the truncations (if it was needed to fit the prompt in the model context length)
- padded (list): Size of the padding (if it was needed for the current example)
- gold (list): Example gold targets (for generative evaluations)
- pred_logits (list): List of the actual model predicted logits
- choices (list): List of the possible choices (for multichoice/loglikelihood evaluations)
- gold_index (list): Indices of the gold targets among the [`choices`]
- metrics (dict): Metric name to current example score
-
+ doc (Doc): The [`Doc`] object containing the current example information.
+ model_response (ModelResponse): The [`ModelResponse`] object containing the model response for the current example.
+ metric (dict): The metric scores for the current example.
+ Example: {"accuracy": 0.5, "f1": 0.7, "exact_match": 0.6}
"""
doc: Doc
diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py
index b1e28b9ce..a3cd4c1b2 100644
--- a/src/lighteval/main_accelerate.py
+++ b/src/lighteval/main_accelerate.py
@@ -60,6 +60,16 @@ def accelerate( # noqa C901
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
+ remove_reasoning_tags: Annotated[
+ bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1)
+ ] = True,
+ reasoning_tags: Annotated[
+ str | None,
+ Option(
+ help="List of reasoning tags (as pairs) to remove from responses. Default is [('', '')].",
+ rich_help_panel=HELP_PANEL_NAME_1,
+ ),
+ ] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -131,6 +141,8 @@ def accelerate( # noqa C901
custom_tasks_directory=custom_tasks,
num_fewshot_seeds=num_fewshot_seeds,
max_samples=max_samples,
+ remove_reasoning_tags=remove_reasoning_tags,
+ reasoning_tags=reasoning_tags,
load_responses_from_details_date_id=load_responses_from_details_date_id,
)
diff --git a/src/lighteval/main_custom.py b/src/lighteval/main_custom.py
index 6cf4f2ae8..3c9e660d2 100644
--- a/src/lighteval/main_custom.py
+++ b/src/lighteval/main_custom.py
@@ -31,10 +31,10 @@
app = typer.Typer()
-HELP_PANNEL_NAME_1 = "Common Parameters"
-HELP_PANNEL_NAME_2 = "Logging Parameters"
-HELP_PANNEL_NAME_3 = "Debug Parameters"
-HELP_PANNEL_NAME_4 = "Modeling Parameters"
+HELP_PANEL_NAME_1 = "Common Parameters"
+HELP_PANEL_NAME_2 = "Logging Parameters"
+HELP_PANEL_NAME_3 = "Debug Parameters"
+HELP_PANEL_NAME_4 = "Modeling Parameters"
@app.command(rich_help_panel="Evaluation Backends")
@@ -45,46 +45,56 @@ def custom(
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
# === Common parameters ===
dataset_loading_processes: Annotated[
- int, Option(help="Number of processes to use for dataset loading.", rich_help_panel=HELP_PANNEL_NAME_1)
+ int, Option(help="Number of processes to use for dataset loading.", rich_help_panel=HELP_PANEL_NAME_1)
] = 1,
custom_tasks: Annotated[
- Optional[str], Option(help="Path to custom tasks directory.", rich_help_panel=HELP_PANNEL_NAME_1)
+ Optional[str], Option(help="Path to custom tasks directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
num_fewshot_seeds: Annotated[
- int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANNEL_NAME_1)
+ int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
] = 1,
+ remove_reasoning_tags: Annotated[
+ bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1)
+ ] = True,
+ reasoning_tags: Annotated[
+ str | None,
+ Option(
+ help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].",
+ rich_help_panel=HELP_PANEL_NAME_1,
+ ),
+ ] = None,
# === saving ===
output_dir: Annotated[
- str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANNEL_NAME_2)
+ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
] = "results",
results_path_template: Annotated[
str | None,
Option(
help="Template path for where to save the results, you have access to 3 variables, `output_dir`, `org` and `model`. for example a template can be `'{output_dir}/1234/{org}+{model}'`",
- rich_help_panel=HELP_PANNEL_NAME_2,
+ rich_help_panel=HELP_PANEL_NAME_2,
),
] = None,
push_to_hub: Annotated[
- bool, Option(help="Push results to the huggingface hub.", rich_help_panel=HELP_PANNEL_NAME_2)
+ bool, Option(help="Push results to the huggingface hub.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
push_to_tensorboard: Annotated[
- bool, Option(help="Push results to tensorboard.", rich_help_panel=HELP_PANNEL_NAME_2)
+ bool, Option(help="Push results to tensorboard.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
public_run: Annotated[
- bool, Option(help="Push results and details to a public repo.", rich_help_panel=HELP_PANNEL_NAME_2)
+ bool, Option(help="Push results and details to a public repo.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
results_org: Annotated[
- Optional[str], Option(help="Organization to push results to.", rich_help_panel=HELP_PANNEL_NAME_2)
+ Optional[str], Option(help="Organization to push results to.", rich_help_panel=HELP_PANEL_NAME_2)
] = None,
save_details: Annotated[
- bool, Option(help="Save detailed, sample per sample, results.", rich_help_panel=HELP_PANNEL_NAME_2)
+ bool, Option(help="Save detailed, sample per sample, results.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
# === debug ===
max_samples: Annotated[
- Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANNEL_NAME_3)
+ Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANEL_NAME_3)
] = None,
job_id: Annotated[
- int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANNEL_NAME_3)
+ int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANEL_NAME_3)
] = 0,
):
"""
@@ -113,6 +123,8 @@ def custom(
custom_tasks_directory=custom_tasks,
num_fewshot_seeds=num_fewshot_seeds,
max_samples=max_samples,
+ remove_reasoning_tags=remove_reasoning_tags,
+ reasoning_tags=reasoning_tags,
)
pipeline = Pipeline(
tasks=tasks,
diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py
index ec5be08c9..3dad8917b 100644
--- a/src/lighteval/main_endpoint.py
+++ b/src/lighteval/main_endpoint.py
@@ -62,6 +62,16 @@ def inference_endpoint(
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
+ remove_reasoning_tags: Annotated[
+ bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1)
+ ] = True,
+ reasoning_tags: Annotated[
+ str | None,
+ Option(
+ help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].",
+ rich_help_panel=HELP_PANEL_NAME_1,
+ ),
+ ] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -136,6 +146,8 @@ def inference_endpoint(
num_fewshot_seeds=num_fewshot_seeds,
max_samples=max_samples,
load_responses_from_details_date_id=load_responses_from_details_date_id,
+ remove_reasoning_tags=remove_reasoning_tags,
+ reasoning_tags=reasoning_tags,
)
pipeline = Pipeline(
tasks=tasks,
@@ -175,6 +187,16 @@ def tgi(
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
+ remove_reasoning_tags: Annotated[
+ bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1)
+ ] = True,
+ reasoning_tags: Annotated[
+ str | None,
+ Option(
+ help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].",
+ rich_help_panel=HELP_PANEL_NAME_1,
+ ),
+ ] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -253,6 +275,8 @@ def tgi(
num_fewshot_seeds=num_fewshot_seeds,
max_samples=max_samples,
load_responses_from_details_date_id=load_responses_from_details_date_id,
+ remove_reasoning_tags=remove_reasoning_tags,
+ reasoning_tags=reasoning_tags,
)
pipeline = Pipeline(
tasks=tasks,
@@ -295,6 +319,16 @@ def litellm(
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
+ remove_reasoning_tags: Annotated[
+ bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1)
+ ] = True,
+ reasoning_tags: Annotated[
+ str | None,
+ Option(
+ help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].",
+ rich_help_panel=HELP_PANEL_NAME_1,
+ ),
+ ] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -376,6 +410,8 @@ def litellm(
num_fewshot_seeds=num_fewshot_seeds,
max_samples=max_samples,
load_responses_from_details_date_id=load_responses_from_details_date_id,
+ remove_reasoning_tags=remove_reasoning_tags,
+ reasoning_tags=reasoning_tags,
)
pipeline = Pipeline(
tasks=tasks,
@@ -449,6 +485,16 @@ def inference_providers(
rich_help_panel=HELP_PANEL_NAME_2,
),
] = False,
+ remove_reasoning_tags: Annotated[
+ bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1)
+ ] = True,
+ reasoning_tags: Annotated[
+ str | None,
+ Option(
+ help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].",
+ rich_help_panel=HELP_PANEL_NAME_1,
+ ),
+ ] = None,
# === debug ===
max_samples: Annotated[
Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANEL_NAME_3)
@@ -493,6 +539,8 @@ def inference_providers(
num_fewshot_seeds=num_fewshot_seeds,
max_samples=max_samples,
load_responses_from_details_date_id=None,
+ remove_reasoning_tags=remove_reasoning_tags,
+ reasoning_tags=reasoning_tags,
)
pipeline = Pipeline(
tasks=tasks,
diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py
index a64bfcdb9..e925862fa 100644
--- a/src/lighteval/main_nanotron.py
+++ b/src/lighteval/main_nanotron.py
@@ -43,6 +43,16 @@ def nanotron(
str, Option(help="Path to the nanotron checkpoint YAML or python config file, potentially on s3.")
],
lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")],
+ remove_reasoning_tags: Annotated[
+ bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1)
+ ] = True,
+ reasoning_tags: Annotated[
+ str | None,
+ Option(
+ help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].",
+ rich_help_panel=HELP_PANEL_NAME_1,
+ ),
+ ] = None,
):
"""
Evaluate models using nanotron as backend.
@@ -101,6 +111,8 @@ def nanotron(
custom_tasks_directory=lighteval_config.tasks.custom_tasks,
num_fewshot_seeds=1,
max_samples=lighteval_config.tasks.max_samples,
+ remove_reasoning_tags=remove_reasoning_tags,
+ reasoning_tags=reasoning_tags,
)
pipeline = Pipeline(
diff --git a/src/lighteval/main_sglang.py b/src/lighteval/main_sglang.py
index 13fe647ad..a10964ed5 100644
--- a/src/lighteval/main_sglang.py
+++ b/src/lighteval/main_sglang.py
@@ -53,6 +53,16 @@ def sglang(
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
+ remove_reasoning_tags: Annotated[
+ bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1)
+ ] = True,
+ reasoning_tags: Annotated[
+ str | None,
+ Option(
+ help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].",
+ rich_help_panel=HELP_PANEL_NAME_1,
+ ),
+ ] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -122,6 +132,8 @@ def sglang(
num_fewshot_seeds=num_fewshot_seeds,
max_samples=max_samples,
load_responses_from_details_date_id=load_responses_from_details_date_id,
+ remove_reasoning_tags=remove_reasoning_tags,
+ reasoning_tags=reasoning_tags,
)
if model_args.endswith(".yaml"):
diff --git a/src/lighteval/main_vllm.py b/src/lighteval/main_vllm.py
index 907c4ace9..ba30777d2 100644
--- a/src/lighteval/main_vllm.py
+++ b/src/lighteval/main_vllm.py
@@ -56,6 +56,16 @@ def vllm(
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
+ remove_reasoning_tags: Annotated[
+ bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1)
+ ] = False,
+ reasoning_tags: Annotated[
+ str | None,
+ Option(
+ help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].",
+ rich_help_panel=HELP_PANEL_NAME_1,
+ ),
+ ] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
@@ -126,6 +136,8 @@ def vllm(
max_samples=max_samples,
cot_prompt=cot_prompt,
load_responses_from_details_date_id=load_responses_from_details_date_id,
+ remove_reasoning_tags=remove_reasoning_tags,
+ reasoning_tags=reasoning_tags,
)
if model_args.endswith(".yaml"):
diff --git a/src/lighteval/metrics/dynamic_metrics.py b/src/lighteval/metrics/dynamic_metrics.py
index 29659ae20..745d606ec 100644
--- a/src/lighteval/metrics/dynamic_metrics.py
+++ b/src/lighteval/metrics/dynamic_metrics.py
@@ -236,7 +236,7 @@ def add_to_specifics_with_timeout(
def sample_level_fn(doc: Doc, model_response: ModelResponse) -> float:
golds = doc.get_golds()
- predictions = model_response.text
+ predictions = model_response.final_text
gold_extraction_regexes = get_extraction_regexes(doc, gold_extraction_target, language)
pred_extraction_regexes = get_extraction_regexes(doc, pred_extraction_target, language)
diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py
index 23d97a076..903295240 100644
--- a/src/lighteval/metrics/metrics.py
+++ b/src/lighteval/metrics/metrics.py
@@ -44,6 +44,7 @@
BLEURT,
MRR,
ROUGE,
+ AvgAtK,
BertScore,
ExactMatches,
Extractiveness,
@@ -349,6 +350,22 @@ class Metrics(Enum):
corpus_level_fn=np.mean,
higher_is_better=True,
)
+ math_avg_at_64 = SampleLevelMetric(
+ metric_name="math_avg@64",
+ sample_level_fn=AvgAtK(
+ k=64,
+ sample_scoring_function=lambda doc, model_response: multilingual_extractive_match_metric(
+ language=Language.ENGLISH,
+ gold_extraction_target=[ExprExtractionConfig(), LatexExtractionConfig()],
+ pred_extraction_target=[ExprExtractionConfig(), LatexExtractionConfig()],
+ precision=6,
+ ).sample_level_fn(doc, model_response),
+ ).compute,
+ category=SamplingMethod.GENERATIVE,
+ corpus_level_fn=np.mean,
+ higher_is_better=True,
+ )
+
math_pass_at_1_1n = SampleLevelMetric(
metric_name="math_pass@1:1_samples",
sample_level_fn=PassAtK(
@@ -475,6 +492,13 @@ class Metrics(Enum):
corpus_level_fn=CorpusLevelF1Score(average=None, num_classes=3).compute,
higher_is_better=True,
)
+ avg_at_64 = SampleLevelMetric(
+ metric_name="avg@64",
+ sample_level_fn=PassAtK(k=64, n=64, strip_strings=True).compute,
+ category=SamplingMethod.GENERATIVE,
+ corpus_level_fn=np.mean,
+ higher_is_better=True,
+ )
pass_at_1 = SampleLevelMetric(
metric_name="pass@1:32_samples",
sample_level_fn=PassAtK(k=1, n=32, strip_strings=True).compute,
diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py
index 872897290..706a7664a 100644
--- a/src/lighteval/metrics/metrics_sample.py
+++ b/src/lighteval/metrics/metrics_sample.py
@@ -110,7 +110,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
# We might need to flatten golds if they are a list of lists
golds = doc.get_golds()
for gold in golds:
- for pred in model_response.text:
+ for pred in model_response.final_text:
results.append(self.compute_one_item(gold=gold, pred=pred))
return self.aggregation_function(results)
@@ -186,7 +186,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
"""
results = []
golds = doc.get_golds()
- predictions = model_response.text
+ predictions = model_response.final_text
# We might need to flatten golds if they are a list of lists
for gold in golds:
for pred in predictions:
@@ -528,7 +528,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float |
from rouge_score import rouge_scorer
golds = doc.get_golds()
- predictions = model_response.text
+ predictions = model_response.final_text
if self.scorer is None:
self.scorer = rouge_scorer.RougeScorer(self.methods, tokenizer=self.tokenizer)
@@ -619,7 +619,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str
dict: Scores over the current sample's items.
"""
golds = doc.get_golds()
- predictions = model_response.text
+ predictions = model_response.final_text
if self.bert_scorer is None:
logger.warning("The first metric computation step might be a bit longer as we need to download the model.")
@@ -680,7 +680,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str
self.stats_metric = DataStatsMetric()
inp = doc.specific[self.input_column]
- prediction = model_response.text[0]
+ prediction = model_response.final_text[0]
if self.normalize_input:
inp = self.normalize_input(inp)
if self.normalize_pred:
@@ -734,7 +734,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str
granularity="sentence", model_name="vitc", imager_load_cache=False
) # , device=device)
inp = doc.specific[self.input_column]
- predictions = model_response.text
+ predictions = model_response.final_text
prediction = predictions[0]
if self.normalize_input:
inp = self.normalize_input(inp)
@@ -774,7 +774,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
Returns:
float: Score over the current sample's items.
"""
- predictions = model_response.text
+ predictions = model_response.final_text
golds = doc.get_golds()
if len(predictions) == 1:
predictions = predictions * len(golds)
@@ -803,7 +803,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
float: Score over the current sample's items.
"""
golds = doc.get_golds()
- predictions = model_response.text
+ predictions = model_response.final_text
return np.mean([self._bleu_score(golds, p) for p in predictions])
def _bleu_score(self, gold: list[str], pred: str):
@@ -852,7 +852,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
Returns:
dict: The different scores computed
"""
- predictions = model_response.text
+ predictions = model_response.final_text
golds = doc.get_golds()
if len(golds) > 1:
logger.warning(
@@ -1075,6 +1075,72 @@ def compute(self, model_responses: list[ModelResponse], docs: list[Doc], **kwarg
return metrics
+class AvgAtK:
+ def __init__(
+ self,
+ k: int,
+ sample_scoring_function: Callable[[Doc, ModelResponse], float] | str | None = None,
+ ):
+ """Sample score averages all the individual k predictions scores.
+
+ Args:
+ normalize_gold (callable, optional): Function to use to normalize the reference strings.
+ Defaults to None if no normalization is applied.
+ normalize_pred (callable, optional): Function to use to normalize the predicted strings.
+ Defaults to None if no normalization is applied.
+ strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False.
+ sample_scoring_function (callable | str, optional): Function to use to compute the score for each sample.
+ If None, uses the default scoring function which is a simple exact match.
+ """
+ self.k = k
+ # Managed the logic of the per prediction of sample scoring
+ if callable(sample_scoring_function):
+ self.compute_score = sample_scoring_function
+ else:
+ if isinstance(sample_scoring_function, str):
+ if sample_scoring_function not in ["prefix", "suffix", "full"]:
+ raise ValueError(
+ f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead."
+ )
+ type_exact_match = sample_scoring_function
+ else:
+ type_exact_match = "full"
+ self.compute_score = self.default_sample_scoring(type_exact_match)
+
+ def compute(self, model_response: ModelResponse, doc: Doc, **kwargs):
+ """Computes the metric over a list of golds and predictions for one single sample.
+ It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones,
+ then compares it to the gold.
+
+ Args:
+ golds (list[str]): Reference targets
+ predictions (list[str]): k predicted strings
+
+ Returns:
+ float: Aggregated score over the current sample's items.
+ """
+ all_scores = []
+ for i in range(self.k):
+ all_scores.append(self.compute_score(doc, model_response[i]))
+
+ avg_score = np.mean(all_scores)
+ return avg_score
+
+ def default_sample_scoring(self, type_exact_match: str) -> callable:
+ def sample_scoring_function(doc: Doc, model_response: ModelResponse) -> int:
+ """Default sample scoring function that checks if the prediction is equal to the gold."""
+ pred = model_response.final_text[0]
+ gold = doc.get_golds()[0]
+
+ if type_exact_match == "prefix":
+ return 1 if pred.startswith(gold) else 0
+ if type_exact_match == "suffix":
+ return 1 if pred.endswith(gold) else 0
+ return 1 if gold == pred else 0
+
+ return sample_scoring_function
+
+
class MajAtK:
def __init__(
self,
@@ -1123,7 +1189,7 @@ def compute(self, model_response: ModelResponse, docs: Doc, **kwargs):
float: Aggregated score over the current sample's items.
"""
golds = docs.get_golds()
- predictions = model_response.text
+ predictions = model_response.final_text
if len(golds) > 1:
raise Exception("Cannot compute maj@k with several golds")
@@ -1224,7 +1290,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
float: Aggregated score over the current sample's items.
"""
golds = doc.get_golds()
- predictions = model_response.text
+ predictions = model_response.final_text
if len(golds) > 1:
raise Exception("Cannot compute pass@k with several golds")
@@ -1273,7 +1339,7 @@ def get_processed_pred(self, pred: str) -> str:
return pred
def default_sample_scoring(self, doc, model_response) -> int:
- pred = model_response.text[0]
+ pred = model_response.final_text[0]
gold = doc.get_golds()[0]
if self.type_exact_match == "prefix":
@@ -1355,7 +1421,7 @@ def compute(self, model_response: ModelResponse, doc: Doc, **kwargs) -> float:
float: Aggregated score over the current sample's items.
"""
golds = doc.get_golds()
- predictions = model_response.text
+ predictions = model_response.final_text
if len(golds) > 1:
raise Exception("Cannot compute G-Pass@k with several golds")
@@ -1408,7 +1474,7 @@ def get_processed_pred(self, pred: str) -> str:
def default_sample_scoring(self, doc: Doc, model_response: ModelResponse) -> int:
gold = doc.get_golds()[0]
- pred = model_response.text[0]
+ pred = model_response.final_text[0]
if self.type_exact_match == "prefix":
return 1 if pred.startswith(gold) else 0
if self.type_exact_match == "suffix":
diff --git a/src/lighteval/metrics/sample_preparator.py b/src/lighteval/metrics/sample_preparator.py
index 830326fc2..2b99483b7 100644
--- a/src/lighteval/metrics/sample_preparator.py
+++ b/src/lighteval/metrics/sample_preparator.py
@@ -73,7 +73,7 @@ def prepare(doc: Doc, model_response: ModelResponse, **kwargs):
GenerativeCorpusMetricInput: Stores the golds and predictions as such
"""
golds = as_list(doc.get_golds())
- predictions = model_response.text
+ predictions = model_response.final_text
return GenerativeCorpusMetricInput(golds=golds, preds=predictions)
diff --git a/src/lighteval/models/model_output.py b/src/lighteval/models/model_output.py
index 6f0b9884e..3663d88bd 100644
--- a/src/lighteval/models/model_output.py
+++ b/src/lighteval/models/model_output.py
@@ -21,7 +21,6 @@
# SOFTWARE.
from dataclasses import dataclass, field
-from typing import Optional
import torch
@@ -40,11 +39,23 @@ class ModelResponse:
The original input prompt or context that was fed to the model.
Used for debugging and analysis purposes.
+ input_tokens (list[int]):
+ The tokenized representation of the input prompt.
+ Useful for understanding how the model processes the input.
+
text (list[str]):
The generated text responses from the model. Each element represents
one generation (useful when num_samples > 1).
**Required for**: Generative metrics, exact match, llm as a judge, etc.
+ text_post_processed (Optional[list[str]]):
+ The generated text responses from the model, but post processed.
+ Atm, post processing removes thinking/reasoning steps.
+
+ Careful! This is not computed by default, but in a separate step by calling
+ `post_process` on the ModelResponse object.
+ **Required for**: Generative metrics that require direct answers.
+
logprobs (list[float]):
Log probabilities of the generated tokens or sequences.
**Required for**: loglikelihood and perplexity metrics.
@@ -54,6 +65,7 @@ class ModelResponse:
Used for accuracy calculations in multiple choice and classification tasks.
**Required for**: certain loglikelihood metrics.
+
unconditioned_logprobs (Optional[list[float]]):
Log probabilities from an unconditioned model (e.g., without context).
Used for PMI (Pointwise Mutual Information) normalization.
@@ -105,26 +117,46 @@ class ModelResponse:
- For most evaluation tasks, only a subset of attributes is required
- The `text` attribute is the most commonly used for generative tasks
- `logprobs` are essential for probability-based metrics like perplexity
- - `argmax_logits_eq_gold` is specifically for certain multiple choice/classification tasks
- - Token-level attributes (`input_tokens`, `output_tokens`) are useful for debugging
- - Truncation and padding counts help understand model behavior with long inputs
"""
+ # Model inputs
input: str | list | None = None
+ input_tokens: list[int] = field(default_factory=list)
+
+ # Model text outputs
text: list[str] = field(default_factory=list) # The text of the response
+ output_tokens: list[list[int]] = field(default_factory=list) # Model generations
+ text_post_processed: list[str] | None = None # The text of the response postprocessed
+
+ # Model logprob outputs
logprobs: list[float] = field(default_factory=list) # Log probabilities of the response
argmax_logits_eq_gold: list[bool] = field(default_factory=list) # Whether the argmax logits match the gold text
logits: list[list[float]] | None = None # Logits of the response, if applicable
+ unconditioned_logprobs: list[float] | None = None # Log probabilities of the unconditioned model (if applicable)
+ # Other metadata
truncated_tokens_count: int = 0 # How many tokens truncated
padded_tokens_count: int = 0 # How many tokens of padding
- input_tokens: list[int] = field(default_factory=list) # model inputs
- output_tokens: list[list[int]] = field(default_factory=list) # model generations
-
- unconditioned_logprobs: Optional[list[float]] = (
- None # Log probabilities of the unconditioned model (if applicable)
- )
+ @property
+ def final_text(self) -> list[str]:
+ if self.text_post_processed is not None:
+ return self.text_post_processed
+ return self.text
+
+ def __getitem__(self, index: int) -> "ModelResponse":
+ return ModelResponse(
+ input=self.input,
+ input_tokens=self.input_tokens,
+ text=[self.text[index]],
+ output_tokens=[self.output_tokens[index]],
+ logprobs=[self.logprobs[index]] if self.logprobs else [],
+ argmax_logits_eq_gold=[self.argmax_logits_eq_gold[index]] if self.argmax_logits_eq_gold else [],
+ logits=[self.logits[index]] if self.logits else None,
+ unconditioned_logprobs=[self.unconditioned_logprobs[index]] if self.unconditioned_logprobs else None,
+ truncated_tokens_count=self.truncated_tokens_count,
+ padded_tokens_count=self.padded_tokens_count,
+ )
@dataclass
diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py
index 53aba0c37..5f7a7e9bc 100644
--- a/src/lighteval/models/vllm/vllm_model.py
+++ b/src/lighteval/models/vllm/vllm_model.py
@@ -179,7 +179,9 @@ def __init__(
self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False
self._tokenizer = self._create_auto_tokenizer(config)
- self._max_length = config.max_model_length if config.max_model_length is not None else None
+ self._max_length = (
+ config.max_model_length
+ ) # will be None if the config is None, then defined in _create_auto_model
# If model_parallel is not set we compare the number of processes with the number of GPUs
self.model = self._create_auto_model(config)
@@ -258,6 +260,15 @@ def _create_auto_model(self, config: VLLMModelConfig) -> Optional[LLM]:
if config.data_parallel_size > 1:
self.model_args["distributed_executor_backend"] = "ray"
self._batch_size = "auto"
+
+ if self._max_length is None:
+ # Todo: we will want to manage this automatically - atm this arg must be set at least 2 times (in gen params + model args) for
+ # vllm models, which is an issue.
+ logger.warning(
+ "The model max_length was not set in the model arguments. Since the model is using data parallelism, it is created later "
+ " with `ray`, so we can't infer the max_length automatically atm. It might raise issues later on: if it does, relaunch your "
+ "run, but set `max_model_length` explicitely in the model args."
+ )
return None
model = LLM(**self.model_args)
@@ -328,7 +339,11 @@ def greedy_until(
context_size = len(inputs[0])
# left truncate the inputs to the maximum length
- if max_new_tokens is not None:
+ if self.max_length is None:
+ logger.warning(
+ "The model max_length was not set in the model arguments, so we cannot check if we need to truncate the context."
+ )
+ elif max_new_tokens is not None:
if context_size + max_new_tokens > self.max_length:
logger.warning(
f"{context_size + max_new_tokens=} which is greater than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens."
diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py
index 49e0e3a5f..508a2a1c7 100644
--- a/src/lighteval/pipeline.py
+++ b/src/lighteval/pipeline.py
@@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
+import ast
import asyncio
import collections
import os
@@ -56,7 +57,7 @@
is_vllm_available,
)
from lighteval.utils.parallelism import test_all_gather
-from lighteval.utils.utils import make_results_table
+from lighteval.utils.utils import make_results_table, remove_reasoning_tags
if is_accelerate_available():
@@ -102,10 +103,13 @@ class PipelineParameters:
num_fewshot_seeds: int = 1
max_samples: int | None = None
cot_prompt: str | None = None
+ remove_reasoning_tags: bool = True
+ reasoning_tags: str | list[tuple[str, str]] | None = None
load_responses_from_details_date_id: str | None = None
bootstrap_iters: int = 1000
def __post_init__(self): # noqa C901
+ # Import testing
if self.launcher_type == ParallelismManager.ACCELERATE:
if not is_accelerate_available():
raise ImportError(NO_ACCELERATE_ERROR_MSG)
@@ -124,6 +128,25 @@ def __post_init__(self): # noqa C901
elif self.launcher_type == ParallelismManager.OPENAI:
if not is_openai_available():
raise ImportError(NO_OPENAI_ERROR_MSG)
+ if self.reasoning_tags is None:
+ self.reasoning_tags = [("", "")]
+ else:
+ # Convert reasoning tags to list if needed
+ if not isinstance(self.reasoning_tags, list):
+ try:
+ self.reasoning_tags = ast.literal_eval(self.reasoning_tags)
+ except ValueError as e:
+ raise ValueError(
+ "reasoning_tags must be a list of pair tuples, e.g. [('start_tag', 'end_tag'), ...]. "
+ f"Got {self.reasoning_tags} instead, which caused parsing error {e}."
+ )
+
+ # Make sure format is correct
+ if not all(isinstance(tag, tuple) and len(tag) == 2 for tag in self.reasoning_tags):
+ raise ValueError(
+ "reasoning_tags must be a list of pair tuples, e.g. [('start_tag', 'end_tag'), ...]. "
+ f"Got {self.reasoning_tags} instead."
+ )
class Pipeline:
@@ -284,9 +307,10 @@ def evaluate(self):
else:
outputs = self._run_model()
- self._compute_metrics(outputs)
-
if self.is_main_process():
+ self._post_process_outputs(outputs)
+ self._compute_metrics(outputs)
+
self.evaluation_tracker.general_config_logger.log_end_time()
self.evaluation_tracker.metrics_logger.aggregate(
task_dict=self.tasks_dict, bootstrap_iters=self.pipeline_parameters.bootstrap_iters
@@ -341,6 +365,21 @@ def _run_model(self):
return outputs
+ def _post_process_outputs(self, sampling_method_responses: dict[str, list[ModelResponse]]):
+ # Removes reasoning tags if needed
+ logger.info("--- POST-PROCESSING MODEL RESPONSES ---")
+
+ if self.pipeline_parameters.remove_reasoning_tags:
+ for _, responses in sampling_method_responses.items():
+ for response in responses:
+ response.text_post_processed = [
+ remove_reasoning_tags(
+ text=text,
+ tag_pairs=self.pipeline_parameters.reasoning_tags,
+ )
+ for text in response.text
+ ]
+
def _compute_metrics(self, sampling_method_responses: dict[str, list[ModelResponse]]):
# To compute the metrics we first group the samples and task and then by metrics.
# This way we can batch the metrics computation for each task and metric category
diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py
index 815f08289..3d988acf5 100644
--- a/src/lighteval/tasks/default_tasks.py
+++ b/src/lighteval/tasks/default_tasks.py
@@ -383,6 +383,22 @@
],
version=2,
)
+aime24_avg = LightevalTaskConfig(
+ name="aime24_avg",
+ suite=["lighteval"],
+ prompt_function=prompt.aime_prompt_fn,
+ hf_repo="HuggingFaceH4/aime_2024",
+ hf_subset="default",
+ hf_avail_splits=["train"],
+ evaluation_splits=["train"],
+ few_shots_split=None,
+ few_shots_select=None,
+ generation_size=None,
+ metrics=[
+ Metrics.math_avg_at_64,
+ ],
+ version=2,
+)
aime24_gpassk = LightevalTaskConfig(
name="aime24_gpassk",
suite=["lighteval"],
diff --git a/src/lighteval/tasks/extended/ifeval/main.py b/src/lighteval/tasks/extended/ifeval/main.py
index 79b283dbf..78bbe6e3f 100644
--- a/src/lighteval/tasks/extended/ifeval/main.py
+++ b/src/lighteval/tasks/extended/ifeval/main.py
@@ -32,7 +32,6 @@
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc, SamplingMethod
-from lighteval.utils.utils import remove_reasoning_tags
# Very specific task where there are no precise outputs but instead we test if the format obeys rules
@@ -60,9 +59,7 @@ def ifeval_prompt(line, task_name: str = ""):
def ifeval_metric(doc: Doc, model_response: ModelResponse, **kwargs) -> dict:
- response = model_response.text[0]
- # Remove the reasoning block to avoid false negatives: https://github.com/huggingface/lighteval/issues/790
- response = remove_reasoning_tags(response, REASONING_TAG_PAIRS)
+ response = model_response.final_text[0]
# Strict instructions
instruction_list = doc.specific["instructions_id_list"]
diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py
index f5e63c351..42ba3408e 100644
--- a/src/lighteval/tasks/lighteval_task.py
+++ b/src/lighteval/tasks/lighteval_task.py
@@ -26,7 +26,7 @@
from dataclasses import asdict, dataclass, field
from typing import Callable
-from datasets import DatasetDict
+from datasets import DatasetDict, load_dataset
from huggingface_hub import TextGenerationInputGrammarType
from multiprocess import Pool
from pytablewriter import MarkdownTableWriter
@@ -36,7 +36,7 @@
from lighteval.tasks.requests import (
Doc,
)
-from lighteval.utils.utils import ListLike, as_list, download_dataset_worker
+from lighteval.utils.utils import ListLike, as_list
logger = logging.getLogger(__name__)
@@ -241,13 +241,7 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
list[Doc]: List of documents.
"""
if self.dataset is None:
- self.dataset = download_dataset_worker(
- self.dataset_path,
- self.dataset_config_name,
- self.trust_dataset,
- self.dataset_filter,
- self.dataset_revision,
- )
+ self.dataset = self.download_dataset_worker(self)
assert self.dataset is not None, f"Dataset {self.dataset_path} not found."
@@ -356,35 +350,43 @@ def load_datasets(tasks: dict[str, "LightevalTask"], dataset_loading_processes:
"""
if dataset_loading_processes <= 1:
- datasets = [
- download_dataset_worker(
- task.dataset_path,
- task.dataset_config_name,
- task.trust_dataset,
- task.dataset_filter,
- task.dataset_revision,
- )
- for task in tasks.values()
- ]
+ # Useful for the test suite: we can mock loading tasks by overwriting the
+ # individual download_dataset_worker functions
+ datasets = [task.download_dataset_worker(task) for task in tasks.values()]
else:
with Pool(processes=dataset_loading_processes) as pool:
datasets = pool.starmap(
- download_dataset_worker,
- [
- (
- task.dataset_path,
- task.dataset_config_name,
- task.trust_dataset,
- task.dataset_filter,
- task.dataset_revision,
- )
- for task in tasks.values()
- ],
+ LightevalTask.download_dataset_worker,
+ [tasks.values()],
)
for task, dataset in zip(tasks, datasets):
tasks[task].dataset = dataset
+ @staticmethod
+ def download_dataset_worker(
+ task: "LightevalTask",
+ ) -> DatasetDict:
+ """
+ Worker function to download a dataset from the HuggingFace Hub.
+ Used for parallel dataset loading.
+ """
+ dataset = load_dataset(
+ path=task.dataset_path,
+ name=task.dataset_config_name,
+ data_dir=None,
+ cache_dir=None,
+ download_mode=None,
+ trust_remote_code=task.trust_dataset,
+ revision=task.dataset_revision,
+ )
+
+ if task.dataset_filter is not None:
+ dataset = dataset.filter(task.dataset_filter)
+
+ # It returns DatasetDict because we don't specify a split
+ return dataset # type: ignore
+
def extract_num_samples(metric_name: str) -> int:
"""Gets the number of samples to generate from the metric name.
diff --git a/src/lighteval/utils/utils.py b/src/lighteval/utils/utils.py
index 28e0ac4a4..115987ba7 100644
--- a/src/lighteval/utils/utils.py
+++ b/src/lighteval/utils/utils.py
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, is_dataclass
-from typing import Callable, TypeVar, Union
+from typing import TypeVar, Union
import numpy as np
-from datasets import DatasetDict, load_dataset
from pytablewriter import MarkdownTableWriter
@@ -199,34 +198,6 @@ def boolstring_to_bool(x: Union[str, bool, int]) -> Union[bool, None]:
raise ValueError(f"You tried to convert {x} to a boolean but it's not possible.")
-def download_dataset_worker(
- dataset_path: str,
- dataset_config_name: str,
- trust_dataset: bool,
- dataset_filter: Callable[[dict], bool] | None = None,
- revision: str | None = None,
-) -> DatasetDict:
- """
- Worker function to download a dataset from the HuggingFace Hub.
- Used for parallel dataset loading.
- """
- dataset = load_dataset(
- path=dataset_path,
- name=dataset_config_name,
- data_dir=None,
- cache_dir=None,
- download_mode=None,
- trust_remote_code=trust_dataset,
- revision=revision,
- )
-
- if dataset_filter is not None:
- dataset = dataset.filter(dataset_filter)
-
- # It returns DatasetDict because we don't specify a split
- return dataset # type: ignore
-
-
def safe_divide(numerator: np.ndarray, denominator: float, default_value: float = 0.0) -> np.ndarray:
return np.where(denominator != 0, numerator / denominator, default_value)
diff --git a/tests/pipeline/test_reasoning_tags.py b/tests/pipeline/test_reasoning_tags.py
new file mode 100644
index 000000000..cadddd0d9
--- /dev/null
+++ b/tests/pipeline/test_reasoning_tags.py
@@ -0,0 +1,416 @@
+# MIT License
+
+# Copyright (c) 2024 The HuggingFace Team
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import tempfile
+import unittest
+from pathlib import Path
+from types import ModuleType
+from typing import Optional, Union
+from unittest.mock import patch
+
+from lighteval.logging.evaluation_tracker import EvaluationTracker
+from lighteval.metrics.metrics import Metrics
+from lighteval.models.dummy.dummy_model import DummyModel, DummyModelConfig
+from lighteval.models.model_output import ModelResponse
+from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
+from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig
+from lighteval.tasks.registry import Registry
+from lighteval.tasks.requests import Doc, SamplingMethod
+from lighteval.utils.imports import is_accelerate_available
+
+
+class TestPipelineReasoningTags(unittest.TestCase):
+ """Test suite for pipeline reasoning tags functionality using DummyModel."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ self.temp_dir = tempfile.mkdtemp()
+
+ # Create a simple test task
+ self.task_config = LightevalTaskConfig(
+ name="test_reasoning_task",
+ suite=["test"],
+ prompt_function=lambda x: x,
+ hf_repo="test_repo",
+ hf_subset="default",
+ metrics=[Metrics.exact_match],
+ hf_avail_splits=["test"],
+ evaluation_splits=["test"],
+ few_shots_split=None,
+ few_shots_select=None,
+ generation_size=10,
+ stop_sequence=["\n"],
+ trust_dataset=True,
+ num_fewshots=0,
+ )
+
+ # Create test documents with reasoning tags in expected responses
+ self.test_docs = [
+ Doc(
+ task_name="test|test_reasoning_task|0",
+ query="What is 2+2?",
+ choices=["4"],
+ gold_index=[0],
+ instruction="",
+ sampling_methods=[SamplingMethod.GENERATIVE],
+ ),
+ ]
+
+ # Mock dataset
+ self.mock_dataset = {"test": self.test_docs}
+
+ def _mock_task_registry(self, task_config, task_docs, responses_with_reasoning_tags):
+ """Create a fake registry for testing."""
+
+ class FakeTask(LightevalTask):
+ def __post_init__(self):
+ self._docs = task_docs
+
+ def get_docs(self, max_samples=None):
+ return task_docs
+
+ @staticmethod
+ def download_dataset_worker(task) -> None:
+ # Mock dataset loading
+ return task._docs
+
+ class FakeRegistry(Registry):
+ def __init__(self, custom_tasks: Optional[Union[str, Path, ModuleType]] = None):
+ super().__init__(custom_tasks=custom_tasks)
+
+ def get_tasks_configs(self, task: str):
+ return [task_config]
+
+ def get_tasks_from_configs(self, tasks_configs):
+ return {f"{task_config.suite[0]}|{task_config.full_name}": FakeTask(task_config)}
+
+ # Create a DummyModel that returns responses with reasoning tags
+ class TestDummyModel(DummyModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ def greedy_until(self, docs):
+ # Return responses with reasoning tags
+ return responses_with_reasoning_tags
+
+ return FakeRegistry, TestDummyModel
+
+ def test_remove_reasoning_tags_enabled(self):
+ """Test that reasoning tags are removed when remove_reasoning_tags=True."""
+
+ # Responses with reasoning tags
+ responses_with_reasoning = [
+ ModelResponse(text=["Let me think about this... 2+2=4The answer is 4"])
+ ]
+
+ FakeRegistry, TestDummyModel = self._mock_task_registry(
+ self.task_config, self.test_docs, responses_with_reasoning
+ )
+
+ # Initialize accelerator if available
+ if is_accelerate_available():
+ from accelerate import Accelerator
+
+ Accelerator()
+
+ with patch("lighteval.pipeline.Registry", FakeRegistry):
+ # Create pipeline with reasoning tag removal enabled
+ pipeline_params = PipelineParameters(
+ launcher_type=ParallelismManager.NONE,
+ remove_reasoning_tags=True,
+ reasoning_tags=[("", "")],
+ max_samples=1,
+ )
+
+ evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir)
+ model = TestDummyModel(DummyModelConfig(seed=42))
+
+ pipeline = Pipeline(
+ tasks="test|test_reasoning_task|0|0",
+ pipeline_parameters=pipeline_params,
+ evaluation_tracker=evaluation_tracker,
+ model=model,
+ )
+
+ # Run the pipeline
+ pipeline.evaluate()
+
+ # Check that reasoning tags were removed from post-processed text
+ details = pipeline.evaluation_tracker.details
+ self.assertEqual(
+ details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"], ["The answer is 4"]
+ )
+
+ def test_remove_reasoning_tags_enabled_tags_as_string(self):
+ """Test that reasoning tags are removed when remove_reasoning_tags=True."""
+
+ # Responses with reasoning tags
+ responses_with_reasoning = [
+ ModelResponse(text=["Let me think about this... 2+2=4The answer is 4"])
+ ]
+
+ FakeRegistry, TestDummyModel = self._mock_task_registry(
+ self.task_config, self.test_docs, responses_with_reasoning
+ )
+
+ # Initialize accelerator if available
+ if is_accelerate_available():
+ from accelerate import Accelerator
+
+ Accelerator()
+
+ with patch("lighteval.pipeline.Registry", FakeRegistry):
+ # Create pipeline with reasoning tag removal enabled
+ pipeline_params = PipelineParameters(
+ launcher_type=ParallelismManager.NONE,
+ remove_reasoning_tags=True,
+ reasoning_tags='[("", "")]',
+ max_samples=1,
+ )
+
+ evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir)
+ model = TestDummyModel(DummyModelConfig(seed=42))
+
+ pipeline = Pipeline(
+ tasks="test|test_reasoning_task|0|0",
+ pipeline_parameters=pipeline_params,
+ evaluation_tracker=evaluation_tracker,
+ model=model,
+ )
+
+ # Run the pipeline
+ pipeline.evaluate()
+
+ # Check that reasoning tags were removed from post-processed text
+ details = pipeline.evaluation_tracker.details
+ self.assertEqual(
+ details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"], ["The answer is 4"]
+ )
+
+ def test_remove_reasoning_tags_enabled_default_tags(self):
+ """Test that reasoning tags are removed when remove_reasoning_tags=True."""
+
+ # Responses with reasoning tags
+ responses_with_reasoning = [
+ ModelResponse(text=["Let me think about this... 2+2=4The answer is 4"])
+ ]
+
+ FakeRegistry, TestDummyModel = self._mock_task_registry(
+ self.task_config, self.test_docs, responses_with_reasoning
+ )
+
+ # Initialize accelerator if available
+ if is_accelerate_available():
+ from accelerate import Accelerator
+
+ Accelerator()
+
+ with patch("lighteval.pipeline.Registry", FakeRegistry):
+ # Create pipeline with reasoning tag removal enabled
+ pipeline_params = PipelineParameters(
+ launcher_type=ParallelismManager.NONE, remove_reasoning_tags=True, max_samples=1
+ )
+
+ evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir)
+ model = TestDummyModel(DummyModelConfig(seed=42))
+
+ pipeline = Pipeline(
+ tasks="test|test_reasoning_task|0|0",
+ pipeline_parameters=pipeline_params,
+ evaluation_tracker=evaluation_tracker,
+ model=model,
+ )
+
+ # Run the pipeline
+ pipeline.evaluate()
+
+ # Check that reasoning tags were removed from post-processed text
+ details = pipeline.evaluation_tracker.details
+ self.assertEqual(
+ details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"], ["The answer is 4"]
+ )
+
+ def test_remove_reasoning_tags_disabled(self):
+ """Test that reasoning tags are preserved when remove_reasoning_tags=False."""
+
+ # Responses with reasoning tags
+ responses_with_reasoning = [
+ ModelResponse(text=["Let me think about this... 2+2=4The answer is 4"])
+ ]
+
+ FakeRegistry, TestDummyModel = self._mock_task_registry(
+ self.task_config, self.test_docs, responses_with_reasoning
+ )
+
+ # Initialize accelerator if available
+ if is_accelerate_available():
+ from accelerate import Accelerator
+
+ Accelerator()
+
+ with patch("lighteval.pipeline.Registry", FakeRegistry):
+ # Create pipeline with reasoning tag removal disabled
+ pipeline_params = PipelineParameters(
+ launcher_type=ParallelismManager.NONE,
+ remove_reasoning_tags=False,
+ reasoning_tags=[("", "")],
+ max_samples=1,
+ )
+
+ evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir)
+ model = TestDummyModel(DummyModelConfig(seed=42))
+
+ pipeline = Pipeline(
+ tasks="test|test_reasoning_task|0|0",
+ pipeline_parameters=pipeline_params,
+ evaluation_tracker=evaluation_tracker,
+ model=model,
+ )
+
+ # Run the pipeline
+ pipeline.evaluate()
+
+ # Check that post-processed text is None (= no post processing happened)
+ details = pipeline.evaluation_tracker.details
+ self.assertIsNone(
+ details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"],
+ )
+
+ def test_custom_reasoning_tags(self):
+ """Test that custom reasoning tags are correctly applied."""
+
+ # Responses with custom reasoning tags
+ responses_with_reasoning = [
+ ModelResponse(text=["[reasoning]This is my thought process[/reasoning]Final answer: 4"])
+ ]
+
+ FakeRegistry, TestDummyModel = self._mock_task_registry(
+ self.task_config, self.test_docs, responses_with_reasoning
+ )
+
+ # Initialize accelerator if available
+ if is_accelerate_available():
+ from accelerate import Accelerator
+
+ Accelerator()
+
+ with patch("lighteval.pipeline.Registry", FakeRegistry):
+ # Create pipeline with custom reasoning tags
+ pipeline_params = PipelineParameters(
+ launcher_type=ParallelismManager.NONE,
+ remove_reasoning_tags=True,
+ reasoning_tags=[("[reasoning]", "[/reasoning]")],
+ max_samples=1,
+ )
+
+ evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir)
+ model = TestDummyModel(DummyModelConfig(seed=42))
+
+ pipeline = Pipeline(
+ tasks="test|test_reasoning_task|0|0",
+ pipeline_parameters=pipeline_params,
+ evaluation_tracker=evaluation_tracker,
+ model=model,
+ )
+
+ # Run the pipeline
+ pipeline.evaluate()
+
+ # Check that reasoning tags were removed from post-processed text
+ details = pipeline.evaluation_tracker.details
+ self.assertEqual(
+ details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"], ["Final answer: 4"]
+ )
+
+ def test_multiple_reasoning_tags(self):
+ """Test that multiple reasoning tag pairs are correctly handled."""
+
+ # Responses with multiple reasoning tag types
+ responses_with_reasoning = [
+ ModelResponse(text=["First thoughtSome textSecond thoughtFinal: 4"])
+ ]
+
+ FakeRegistry, TestDummyModel = self._mock_task_registry(
+ self.task_config, self.test_docs, responses_with_reasoning
+ )
+
+ # Initialize accelerator if available
+ if is_accelerate_available():
+ from accelerate import Accelerator
+
+ Accelerator()
+
+ with patch("lighteval.pipeline.Registry", FakeRegistry):
+ # Create pipeline with multiple reasoning tag pairs
+ pipeline_params = PipelineParameters(
+ launcher_type=ParallelismManager.NONE,
+ remove_reasoning_tags=True,
+ reasoning_tags='[("", ""), ("", "")]',
+ max_samples=1,
+ )
+
+ evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir)
+ model = TestDummyModel(DummyModelConfig(seed=42))
+
+ pipeline = Pipeline(
+ tasks="test|test|test_reasoning_task|0|0",
+ pipeline_parameters=pipeline_params,
+ evaluation_tracker=evaluation_tracker,
+ model=model,
+ )
+
+ # Run the pipeline
+ pipeline.evaluate()
+
+ # Check that reasoning tags were removed from post-processed text
+ details = pipeline.evaluation_tracker.details
+ self.assertEqual(
+ details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"],
+ ["Some textFinal: 4"],
+ )
+
+ def test_reasoning_tags_validation(self):
+ """Test that invalid reasoning_tags parameter raises appropriate error."""
+
+ for test_string in ["['incorrect_format']", "invalid_format"]:
+ with self.assertRaises(ValueError) as context:
+ PipelineParameters(
+ launcher_type=ParallelismManager.NONE,
+ reasoning_tags=test_string, # Should be a list of tuples
+ )
+
+ # Check that the error message mentions the expected format
+ print(context.__dict__)
+ self.assertIn("reasoning_tags must be a list of pair tuples", str(context.exception))
+
+ def test_default_reasoning_tags(self):
+ """Test that default reasoning tags are correctly set."""
+
+ pipeline_params = PipelineParameters(launcher_type=ParallelismManager.NONE)
+
+ # Check that default reasoning tags are set
+ self.assertEqual(pipeline_params.reasoning_tags, [("", "")])
+ self.assertTrue(pipeline_params.remove_reasoning_tags)
+
+
+if __name__ == "__main__":
+ unittest.main()