diff --git a/docs/source/quicktour.mdx b/docs/source/quicktour.mdx index 69eb52a06..7daea73da 100644 --- a/docs/source/quicktour.mdx +++ b/docs/source/quicktour.mdx @@ -30,7 +30,7 @@ lighteval accelerate \ "leaderboard|truthfulqa:mc|0|0" ``` -Here, we first choose a backend (either `accelerate`, `nanotron`, or `vllm`), and then specify the model and task(s) to run. +Here, we first choose a backend (either `accelerate`, `nanotron`, `endpoint`, or `vllm`), and then specify the model and task(s) to run. The syntax for the model arguments is `key1=value1,key2=value2,etc`. Valid key-value pairs correspond with the backend configuration, and are detailed [below](#Model Arguments). @@ -104,13 +104,32 @@ GPUs. ## Backend configuration +#### General information + The `model-args` argument takes a string representing a list of model argument. The arguments allowed vary depending on the backend you use and correspond to the fields of the model configs. -The model config can be found [here](./package_reference/models). +The model configurations can be found [here](./package_reference/models). + +All models allow you to post process your reasoning model predictions, +to remove the thinking tokens from the trace used to compute the metrics, +using `--remove-reasoning-tags`, and `--reasoning-tags` to specify which +reasoning tags to remove (defaults to and ). + +Here's an example with `mistralai/Magistral-Small-2507` which outputs custom +think tokens. + +```bash +lighteval vllm \ + "model_name=mistralai/Magistral-Small-2507,dtype=float16,data_parallel_size=4" \ + "lighteval|aime24|0|0" \ + --remove-reasoning-tags \ + --reasoning-tags="[('[THINK]','[/THINK]')]" +``` + -## Nanotron +#### Nanotron To evaluate a model trained with nanotron on a single gpu. diff --git a/src/lighteval/logging/info_loggers.py b/src/lighteval/logging/info_loggers.py index e48648d23..446006aec 100644 --- a/src/lighteval/logging/info_loggers.py +++ b/src/lighteval/logging/info_loggers.py @@ -170,27 +170,10 @@ class Detail: """Experiment details of one single example of one task. Attributes: - example (str): Current task example query - instruction (str): Instruction prepended to the example and few shots. - For example "In this task, you are given information of type x. You need to predict y." - full_prompt (str): Expanded full prompt (instruction if present, then prompt) - num_effective_few_shots (int): Number of actual few shots used for the example. - This depends on the model context length and few-shots samples size: when using effective few-shots, - only `num_effective_few_shots` few-shot samples are kept, allowing - 1) each of the used few-shot examples and the prompt to not be truncated - 2) this context still allows the model to predict up to the requested max numbers of tokens within its remaining context size. - num_asked_few_shots (int): Initially asked number of few-shot samples. - predictions (list): List of the actual model predictions - input_tokens (list): List of the input tokens given to the model - cont_tokens (list): List of the continuation tokens predicted by the model - truncated (list): Size of the truncations (if it was needed to fit the prompt in the model context length) - padded (list): Size of the padding (if it was needed for the current example) - gold (list): Example gold targets (for generative evaluations) - pred_logits (list): List of the actual model predicted logits - choices (list): List of the possible choices (for multichoice/loglikelihood evaluations) - gold_index (list): Indices of the gold targets among the [`choices`] - metrics (dict): Metric name to current example score - + doc (Doc): The [`Doc`] object containing the current example information. + model_response (ModelResponse): The [`ModelResponse`] object containing the model response for the current example. + metric (dict): The metric scores for the current example. + Example: {"accuracy": 0.5, "f1": 0.7, "exact_match": 0.6} """ doc: Doc diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index b1e28b9ce..a3cd4c1b2 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -60,6 +60,16 @@ def accelerate( # noqa C901 load_responses_from_details_date_id: Annotated[ Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) ] = None, + remove_reasoning_tags: Annotated[ + bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1) + ] = True, + reasoning_tags: Annotated[ + str | None, + Option( + help="List of reasoning tags (as pairs) to remove from responses. Default is [('', '')].", + rich_help_panel=HELP_PANEL_NAME_1, + ), + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -131,6 +141,8 @@ def accelerate( # noqa C901 custom_tasks_directory=custom_tasks, num_fewshot_seeds=num_fewshot_seeds, max_samples=max_samples, + remove_reasoning_tags=remove_reasoning_tags, + reasoning_tags=reasoning_tags, load_responses_from_details_date_id=load_responses_from_details_date_id, ) diff --git a/src/lighteval/main_custom.py b/src/lighteval/main_custom.py index 6cf4f2ae8..3c9e660d2 100644 --- a/src/lighteval/main_custom.py +++ b/src/lighteval/main_custom.py @@ -31,10 +31,10 @@ app = typer.Typer() -HELP_PANNEL_NAME_1 = "Common Parameters" -HELP_PANNEL_NAME_2 = "Logging Parameters" -HELP_PANNEL_NAME_3 = "Debug Parameters" -HELP_PANNEL_NAME_4 = "Modeling Parameters" +HELP_PANEL_NAME_1 = "Common Parameters" +HELP_PANEL_NAME_2 = "Logging Parameters" +HELP_PANEL_NAME_3 = "Debug Parameters" +HELP_PANEL_NAME_4 = "Modeling Parameters" @app.command(rich_help_panel="Evaluation Backends") @@ -45,46 +45,56 @@ def custom( tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")], # === Common parameters === dataset_loading_processes: Annotated[ - int, Option(help="Number of processes to use for dataset loading.", rich_help_panel=HELP_PANNEL_NAME_1) + int, Option(help="Number of processes to use for dataset loading.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, custom_tasks: Annotated[ - Optional[str], Option(help="Path to custom tasks directory.", rich_help_panel=HELP_PANNEL_NAME_1) + Optional[str], Option(help="Path to custom tasks directory.", rich_help_panel=HELP_PANEL_NAME_1) ] = None, num_fewshot_seeds: Annotated[ - int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANNEL_NAME_1) + int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) ] = 1, + remove_reasoning_tags: Annotated[ + bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1) + ] = True, + reasoning_tags: Annotated[ + str | None, + Option( + help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].", + rich_help_panel=HELP_PANEL_NAME_1, + ), + ] = None, # === saving === output_dir: Annotated[ - str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANNEL_NAME_2) + str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) ] = "results", results_path_template: Annotated[ str | None, Option( help="Template path for where to save the results, you have access to 3 variables, `output_dir`, `org` and `model`. for example a template can be `'{output_dir}/1234/{org}+{model}'`", - rich_help_panel=HELP_PANNEL_NAME_2, + rich_help_panel=HELP_PANEL_NAME_2, ), ] = None, push_to_hub: Annotated[ - bool, Option(help="Push results to the huggingface hub.", rich_help_panel=HELP_PANNEL_NAME_2) + bool, Option(help="Push results to the huggingface hub.", rich_help_panel=HELP_PANEL_NAME_2) ] = False, push_to_tensorboard: Annotated[ - bool, Option(help="Push results to tensorboard.", rich_help_panel=HELP_PANNEL_NAME_2) + bool, Option(help="Push results to tensorboard.", rich_help_panel=HELP_PANEL_NAME_2) ] = False, public_run: Annotated[ - bool, Option(help="Push results and details to a public repo.", rich_help_panel=HELP_PANNEL_NAME_2) + bool, Option(help="Push results and details to a public repo.", rich_help_panel=HELP_PANEL_NAME_2) ] = False, results_org: Annotated[ - Optional[str], Option(help="Organization to push results to.", rich_help_panel=HELP_PANNEL_NAME_2) + Optional[str], Option(help="Organization to push results to.", rich_help_panel=HELP_PANEL_NAME_2) ] = None, save_details: Annotated[ - bool, Option(help="Save detailed, sample per sample, results.", rich_help_panel=HELP_PANNEL_NAME_2) + bool, Option(help="Save detailed, sample per sample, results.", rich_help_panel=HELP_PANEL_NAME_2) ] = False, # === debug === max_samples: Annotated[ - Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANNEL_NAME_3) + Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANEL_NAME_3) ] = None, job_id: Annotated[ - int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANNEL_NAME_3) + int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANEL_NAME_3) ] = 0, ): """ @@ -113,6 +123,8 @@ def custom( custom_tasks_directory=custom_tasks, num_fewshot_seeds=num_fewshot_seeds, max_samples=max_samples, + remove_reasoning_tags=remove_reasoning_tags, + reasoning_tags=reasoning_tags, ) pipeline = Pipeline( tasks=tasks, diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index ec5be08c9..3dad8917b 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -62,6 +62,16 @@ def inference_endpoint( load_responses_from_details_date_id: Annotated[ Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) ] = None, + remove_reasoning_tags: Annotated[ + bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1) + ] = True, + reasoning_tags: Annotated[ + str | None, + Option( + help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].", + rich_help_panel=HELP_PANEL_NAME_1, + ), + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -136,6 +146,8 @@ def inference_endpoint( num_fewshot_seeds=num_fewshot_seeds, max_samples=max_samples, load_responses_from_details_date_id=load_responses_from_details_date_id, + remove_reasoning_tags=remove_reasoning_tags, + reasoning_tags=reasoning_tags, ) pipeline = Pipeline( tasks=tasks, @@ -175,6 +187,16 @@ def tgi( load_responses_from_details_date_id: Annotated[ Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) ] = None, + remove_reasoning_tags: Annotated[ + bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1) + ] = True, + reasoning_tags: Annotated[ + str | None, + Option( + help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].", + rich_help_panel=HELP_PANEL_NAME_1, + ), + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -253,6 +275,8 @@ def tgi( num_fewshot_seeds=num_fewshot_seeds, max_samples=max_samples, load_responses_from_details_date_id=load_responses_from_details_date_id, + remove_reasoning_tags=remove_reasoning_tags, + reasoning_tags=reasoning_tags, ) pipeline = Pipeline( tasks=tasks, @@ -295,6 +319,16 @@ def litellm( load_responses_from_details_date_id: Annotated[ Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) ] = None, + remove_reasoning_tags: Annotated[ + bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1) + ] = True, + reasoning_tags: Annotated[ + str | None, + Option( + help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].", + rich_help_panel=HELP_PANEL_NAME_1, + ), + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -376,6 +410,8 @@ def litellm( num_fewshot_seeds=num_fewshot_seeds, max_samples=max_samples, load_responses_from_details_date_id=load_responses_from_details_date_id, + remove_reasoning_tags=remove_reasoning_tags, + reasoning_tags=reasoning_tags, ) pipeline = Pipeline( tasks=tasks, @@ -449,6 +485,16 @@ def inference_providers( rich_help_panel=HELP_PANEL_NAME_2, ), ] = False, + remove_reasoning_tags: Annotated[ + bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1) + ] = True, + reasoning_tags: Annotated[ + str | None, + Option( + help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].", + rich_help_panel=HELP_PANEL_NAME_1, + ), + ] = None, # === debug === max_samples: Annotated[ Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANEL_NAME_3) @@ -493,6 +539,8 @@ def inference_providers( num_fewshot_seeds=num_fewshot_seeds, max_samples=max_samples, load_responses_from_details_date_id=None, + remove_reasoning_tags=remove_reasoning_tags, + reasoning_tags=reasoning_tags, ) pipeline = Pipeline( tasks=tasks, diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index a64bfcdb9..e925862fa 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -43,6 +43,16 @@ def nanotron( str, Option(help="Path to the nanotron checkpoint YAML or python config file, potentially on s3.") ], lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")], + remove_reasoning_tags: Annotated[ + bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1) + ] = True, + reasoning_tags: Annotated[ + str | None, + Option( + help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].", + rich_help_panel=HELP_PANEL_NAME_1, + ), + ] = None, ): """ Evaluate models using nanotron as backend. @@ -101,6 +111,8 @@ def nanotron( custom_tasks_directory=lighteval_config.tasks.custom_tasks, num_fewshot_seeds=1, max_samples=lighteval_config.tasks.max_samples, + remove_reasoning_tags=remove_reasoning_tags, + reasoning_tags=reasoning_tags, ) pipeline = Pipeline( diff --git a/src/lighteval/main_sglang.py b/src/lighteval/main_sglang.py index 13fe647ad..a10964ed5 100644 --- a/src/lighteval/main_sglang.py +++ b/src/lighteval/main_sglang.py @@ -53,6 +53,16 @@ def sglang( load_responses_from_details_date_id: Annotated[ Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) ] = None, + remove_reasoning_tags: Annotated[ + bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1) + ] = True, + reasoning_tags: Annotated[ + str | None, + Option( + help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].", + rich_help_panel=HELP_PANEL_NAME_1, + ), + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -122,6 +132,8 @@ def sglang( num_fewshot_seeds=num_fewshot_seeds, max_samples=max_samples, load_responses_from_details_date_id=load_responses_from_details_date_id, + remove_reasoning_tags=remove_reasoning_tags, + reasoning_tags=reasoning_tags, ) if model_args.endswith(".yaml"): diff --git a/src/lighteval/main_vllm.py b/src/lighteval/main_vllm.py index 907c4ace9..ba30777d2 100644 --- a/src/lighteval/main_vllm.py +++ b/src/lighteval/main_vllm.py @@ -56,6 +56,16 @@ def vllm( load_responses_from_details_date_id: Annotated[ Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1) ] = None, + remove_reasoning_tags: Annotated[ + bool, Option(help="Remove reasoning tags from responses.", rich_help_panel=HELP_PANEL_NAME_1) + ] = False, + reasoning_tags: Annotated[ + str | None, + Option( + help="List of reasoning tags (provided as pairs) to remove from responses. Default is [('', '')].", + rich_help_panel=HELP_PANEL_NAME_1, + ), + ] = None, # === saving === output_dir: Annotated[ str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) @@ -126,6 +136,8 @@ def vllm( max_samples=max_samples, cot_prompt=cot_prompt, load_responses_from_details_date_id=load_responses_from_details_date_id, + remove_reasoning_tags=remove_reasoning_tags, + reasoning_tags=reasoning_tags, ) if model_args.endswith(".yaml"): diff --git a/src/lighteval/metrics/dynamic_metrics.py b/src/lighteval/metrics/dynamic_metrics.py index 29659ae20..745d606ec 100644 --- a/src/lighteval/metrics/dynamic_metrics.py +++ b/src/lighteval/metrics/dynamic_metrics.py @@ -236,7 +236,7 @@ def add_to_specifics_with_timeout( def sample_level_fn(doc: Doc, model_response: ModelResponse) -> float: golds = doc.get_golds() - predictions = model_response.text + predictions = model_response.final_text gold_extraction_regexes = get_extraction_regexes(doc, gold_extraction_target, language) pred_extraction_regexes = get_extraction_regexes(doc, pred_extraction_target, language) diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index 23d97a076..903295240 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -44,6 +44,7 @@ BLEURT, MRR, ROUGE, + AvgAtK, BertScore, ExactMatches, Extractiveness, @@ -349,6 +350,22 @@ class Metrics(Enum): corpus_level_fn=np.mean, higher_is_better=True, ) + math_avg_at_64 = SampleLevelMetric( + metric_name="math_avg@64", + sample_level_fn=AvgAtK( + k=64, + sample_scoring_function=lambda doc, model_response: multilingual_extractive_match_metric( + language=Language.ENGLISH, + gold_extraction_target=[ExprExtractionConfig(), LatexExtractionConfig()], + pred_extraction_target=[ExprExtractionConfig(), LatexExtractionConfig()], + precision=6, + ).sample_level_fn(doc, model_response), + ).compute, + category=SamplingMethod.GENERATIVE, + corpus_level_fn=np.mean, + higher_is_better=True, + ) + math_pass_at_1_1n = SampleLevelMetric( metric_name="math_pass@1:1_samples", sample_level_fn=PassAtK( @@ -475,6 +492,13 @@ class Metrics(Enum): corpus_level_fn=CorpusLevelF1Score(average=None, num_classes=3).compute, higher_is_better=True, ) + avg_at_64 = SampleLevelMetric( + metric_name="avg@64", + sample_level_fn=PassAtK(k=64, n=64, strip_strings=True).compute, + category=SamplingMethod.GENERATIVE, + corpus_level_fn=np.mean, + higher_is_better=True, + ) pass_at_1 = SampleLevelMetric( metric_name="pass@1:32_samples", sample_level_fn=PassAtK(k=1, n=32, strip_strings=True).compute, diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 872897290..706a7664a 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -110,7 +110,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float: # We might need to flatten golds if they are a list of lists golds = doc.get_golds() for gold in golds: - for pred in model_response.text: + for pred in model_response.final_text: results.append(self.compute_one_item(gold=gold, pred=pred)) return self.aggregation_function(results) @@ -186,7 +186,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float: """ results = [] golds = doc.get_golds() - predictions = model_response.text + predictions = model_response.final_text # We might need to flatten golds if they are a list of lists for gold in golds: for pred in predictions: @@ -528,7 +528,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float | from rouge_score import rouge_scorer golds = doc.get_golds() - predictions = model_response.text + predictions = model_response.final_text if self.scorer is None: self.scorer = rouge_scorer.RougeScorer(self.methods, tokenizer=self.tokenizer) @@ -619,7 +619,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str dict: Scores over the current sample's items. """ golds = doc.get_golds() - predictions = model_response.text + predictions = model_response.final_text if self.bert_scorer is None: logger.warning("The first metric computation step might be a bit longer as we need to download the model.") @@ -680,7 +680,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str self.stats_metric = DataStatsMetric() inp = doc.specific[self.input_column] - prediction = model_response.text[0] + prediction = model_response.final_text[0] if self.normalize_input: inp = self.normalize_input(inp) if self.normalize_pred: @@ -734,7 +734,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str granularity="sentence", model_name="vitc", imager_load_cache=False ) # , device=device) inp = doc.specific[self.input_column] - predictions = model_response.text + predictions = model_response.final_text prediction = predictions[0] if self.normalize_input: inp = self.normalize_input(inp) @@ -774,7 +774,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float: Returns: float: Score over the current sample's items. """ - predictions = model_response.text + predictions = model_response.final_text golds = doc.get_golds() if len(predictions) == 1: predictions = predictions * len(golds) @@ -803,7 +803,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs): float: Score over the current sample's items. """ golds = doc.get_golds() - predictions = model_response.text + predictions = model_response.final_text return np.mean([self._bleu_score(golds, p) for p in predictions]) def _bleu_score(self, gold: list[str], pred: str): @@ -852,7 +852,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs): Returns: dict: The different scores computed """ - predictions = model_response.text + predictions = model_response.final_text golds = doc.get_golds() if len(golds) > 1: logger.warning( @@ -1075,6 +1075,72 @@ def compute(self, model_responses: list[ModelResponse], docs: list[Doc], **kwarg return metrics +class AvgAtK: + def __init__( + self, + k: int, + sample_scoring_function: Callable[[Doc, ModelResponse], float] | str | None = None, + ): + """Sample score averages all the individual k predictions scores. + + Args: + normalize_gold (callable, optional): Function to use to normalize the reference strings. + Defaults to None if no normalization is applied. + normalize_pred (callable, optional): Function to use to normalize the predicted strings. + Defaults to None if no normalization is applied. + strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False. + sample_scoring_function (callable | str, optional): Function to use to compute the score for each sample. + If None, uses the default scoring function which is a simple exact match. + """ + self.k = k + # Managed the logic of the per prediction of sample scoring + if callable(sample_scoring_function): + self.compute_score = sample_scoring_function + else: + if isinstance(sample_scoring_function, str): + if sample_scoring_function not in ["prefix", "suffix", "full"]: + raise ValueError( + f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead." + ) + type_exact_match = sample_scoring_function + else: + type_exact_match = "full" + self.compute_score = self.default_sample_scoring(type_exact_match) + + def compute(self, model_response: ModelResponse, doc: Doc, **kwargs): + """Computes the metric over a list of golds and predictions for one single sample. + It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones, + then compares it to the gold. + + Args: + golds (list[str]): Reference targets + predictions (list[str]): k predicted strings + + Returns: + float: Aggregated score over the current sample's items. + """ + all_scores = [] + for i in range(self.k): + all_scores.append(self.compute_score(doc, model_response[i])) + + avg_score = np.mean(all_scores) + return avg_score + + def default_sample_scoring(self, type_exact_match: str) -> callable: + def sample_scoring_function(doc: Doc, model_response: ModelResponse) -> int: + """Default sample scoring function that checks if the prediction is equal to the gold.""" + pred = model_response.final_text[0] + gold = doc.get_golds()[0] + + if type_exact_match == "prefix": + return 1 if pred.startswith(gold) else 0 + if type_exact_match == "suffix": + return 1 if pred.endswith(gold) else 0 + return 1 if gold == pred else 0 + + return sample_scoring_function + + class MajAtK: def __init__( self, @@ -1123,7 +1189,7 @@ def compute(self, model_response: ModelResponse, docs: Doc, **kwargs): float: Aggregated score over the current sample's items. """ golds = docs.get_golds() - predictions = model_response.text + predictions = model_response.final_text if len(golds) > 1: raise Exception("Cannot compute maj@k with several golds") @@ -1224,7 +1290,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float: float: Aggregated score over the current sample's items. """ golds = doc.get_golds() - predictions = model_response.text + predictions = model_response.final_text if len(golds) > 1: raise Exception("Cannot compute pass@k with several golds") @@ -1273,7 +1339,7 @@ def get_processed_pred(self, pred: str) -> str: return pred def default_sample_scoring(self, doc, model_response) -> int: - pred = model_response.text[0] + pred = model_response.final_text[0] gold = doc.get_golds()[0] if self.type_exact_match == "prefix": @@ -1355,7 +1421,7 @@ def compute(self, model_response: ModelResponse, doc: Doc, **kwargs) -> float: float: Aggregated score over the current sample's items. """ golds = doc.get_golds() - predictions = model_response.text + predictions = model_response.final_text if len(golds) > 1: raise Exception("Cannot compute G-Pass@k with several golds") @@ -1408,7 +1474,7 @@ def get_processed_pred(self, pred: str) -> str: def default_sample_scoring(self, doc: Doc, model_response: ModelResponse) -> int: gold = doc.get_golds()[0] - pred = model_response.text[0] + pred = model_response.final_text[0] if self.type_exact_match == "prefix": return 1 if pred.startswith(gold) else 0 if self.type_exact_match == "suffix": diff --git a/src/lighteval/metrics/sample_preparator.py b/src/lighteval/metrics/sample_preparator.py index 830326fc2..2b99483b7 100644 --- a/src/lighteval/metrics/sample_preparator.py +++ b/src/lighteval/metrics/sample_preparator.py @@ -73,7 +73,7 @@ def prepare(doc: Doc, model_response: ModelResponse, **kwargs): GenerativeCorpusMetricInput: Stores the golds and predictions as such """ golds = as_list(doc.get_golds()) - predictions = model_response.text + predictions = model_response.final_text return GenerativeCorpusMetricInput(golds=golds, preds=predictions) diff --git a/src/lighteval/models/model_output.py b/src/lighteval/models/model_output.py index 6f0b9884e..3663d88bd 100644 --- a/src/lighteval/models/model_output.py +++ b/src/lighteval/models/model_output.py @@ -21,7 +21,6 @@ # SOFTWARE. from dataclasses import dataclass, field -from typing import Optional import torch @@ -40,11 +39,23 @@ class ModelResponse: The original input prompt or context that was fed to the model. Used for debugging and analysis purposes. + input_tokens (list[int]): + The tokenized representation of the input prompt. + Useful for understanding how the model processes the input. + text (list[str]): The generated text responses from the model. Each element represents one generation (useful when num_samples > 1). **Required for**: Generative metrics, exact match, llm as a judge, etc. + text_post_processed (Optional[list[str]]): + The generated text responses from the model, but post processed. + Atm, post processing removes thinking/reasoning steps. + + Careful! This is not computed by default, but in a separate step by calling + `post_process` on the ModelResponse object. + **Required for**: Generative metrics that require direct answers. + logprobs (list[float]): Log probabilities of the generated tokens or sequences. **Required for**: loglikelihood and perplexity metrics. @@ -54,6 +65,7 @@ class ModelResponse: Used for accuracy calculations in multiple choice and classification tasks. **Required for**: certain loglikelihood metrics. + unconditioned_logprobs (Optional[list[float]]): Log probabilities from an unconditioned model (e.g., without context). Used for PMI (Pointwise Mutual Information) normalization. @@ -105,26 +117,46 @@ class ModelResponse: - For most evaluation tasks, only a subset of attributes is required - The `text` attribute is the most commonly used for generative tasks - `logprobs` are essential for probability-based metrics like perplexity - - `argmax_logits_eq_gold` is specifically for certain multiple choice/classification tasks - - Token-level attributes (`input_tokens`, `output_tokens`) are useful for debugging - - Truncation and padding counts help understand model behavior with long inputs """ + # Model inputs input: str | list | None = None + input_tokens: list[int] = field(default_factory=list) + + # Model text outputs text: list[str] = field(default_factory=list) # The text of the response + output_tokens: list[list[int]] = field(default_factory=list) # Model generations + text_post_processed: list[str] | None = None # The text of the response postprocessed + + # Model logprob outputs logprobs: list[float] = field(default_factory=list) # Log probabilities of the response argmax_logits_eq_gold: list[bool] = field(default_factory=list) # Whether the argmax logits match the gold text logits: list[list[float]] | None = None # Logits of the response, if applicable + unconditioned_logprobs: list[float] | None = None # Log probabilities of the unconditioned model (if applicable) + # Other metadata truncated_tokens_count: int = 0 # How many tokens truncated padded_tokens_count: int = 0 # How many tokens of padding - input_tokens: list[int] = field(default_factory=list) # model inputs - output_tokens: list[list[int]] = field(default_factory=list) # model generations - - unconditioned_logprobs: Optional[list[float]] = ( - None # Log probabilities of the unconditioned model (if applicable) - ) + @property + def final_text(self) -> list[str]: + if self.text_post_processed is not None: + return self.text_post_processed + return self.text + + def __getitem__(self, index: int) -> "ModelResponse": + return ModelResponse( + input=self.input, + input_tokens=self.input_tokens, + text=[self.text[index]], + output_tokens=[self.output_tokens[index]], + logprobs=[self.logprobs[index]] if self.logprobs else [], + argmax_logits_eq_gold=[self.argmax_logits_eq_gold[index]] if self.argmax_logits_eq_gold else [], + logits=[self.logits[index]] if self.logits else None, + unconditioned_logprobs=[self.unconditioned_logprobs[index]] if self.unconditioned_logprobs else None, + truncated_tokens_count=self.truncated_tokens_count, + padded_tokens_count=self.padded_tokens_count, + ) @dataclass diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 53aba0c37..5f7a7e9bc 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -179,7 +179,9 @@ def __init__( self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False self._tokenizer = self._create_auto_tokenizer(config) - self._max_length = config.max_model_length if config.max_model_length is not None else None + self._max_length = ( + config.max_model_length + ) # will be None if the config is None, then defined in _create_auto_model # If model_parallel is not set we compare the number of processes with the number of GPUs self.model = self._create_auto_model(config) @@ -258,6 +260,15 @@ def _create_auto_model(self, config: VLLMModelConfig) -> Optional[LLM]: if config.data_parallel_size > 1: self.model_args["distributed_executor_backend"] = "ray" self._batch_size = "auto" + + if self._max_length is None: + # Todo: we will want to manage this automatically - atm this arg must be set at least 2 times (in gen params + model args) for + # vllm models, which is an issue. + logger.warning( + "The model max_length was not set in the model arguments. Since the model is using data parallelism, it is created later " + " with `ray`, so we can't infer the max_length automatically atm. It might raise issues later on: if it does, relaunch your " + "run, but set `max_model_length` explicitely in the model args." + ) return None model = LLM(**self.model_args) @@ -328,7 +339,11 @@ def greedy_until( context_size = len(inputs[0]) # left truncate the inputs to the maximum length - if max_new_tokens is not None: + if self.max_length is None: + logger.warning( + "The model max_length was not set in the model arguments, so we cannot check if we need to truncate the context." + ) + elif max_new_tokens is not None: if context_size + max_new_tokens > self.max_length: logger.warning( f"{context_size + max_new_tokens=} which is greater than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens." diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index 49e0e3a5f..508a2a1c7 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import ast import asyncio import collections import os @@ -56,7 +57,7 @@ is_vllm_available, ) from lighteval.utils.parallelism import test_all_gather -from lighteval.utils.utils import make_results_table +from lighteval.utils.utils import make_results_table, remove_reasoning_tags if is_accelerate_available(): @@ -102,10 +103,13 @@ class PipelineParameters: num_fewshot_seeds: int = 1 max_samples: int | None = None cot_prompt: str | None = None + remove_reasoning_tags: bool = True + reasoning_tags: str | list[tuple[str, str]] | None = None load_responses_from_details_date_id: str | None = None bootstrap_iters: int = 1000 def __post_init__(self): # noqa C901 + # Import testing if self.launcher_type == ParallelismManager.ACCELERATE: if not is_accelerate_available(): raise ImportError(NO_ACCELERATE_ERROR_MSG) @@ -124,6 +128,25 @@ def __post_init__(self): # noqa C901 elif self.launcher_type == ParallelismManager.OPENAI: if not is_openai_available(): raise ImportError(NO_OPENAI_ERROR_MSG) + if self.reasoning_tags is None: + self.reasoning_tags = [("", "")] + else: + # Convert reasoning tags to list if needed + if not isinstance(self.reasoning_tags, list): + try: + self.reasoning_tags = ast.literal_eval(self.reasoning_tags) + except ValueError as e: + raise ValueError( + "reasoning_tags must be a list of pair tuples, e.g. [('start_tag', 'end_tag'), ...]. " + f"Got {self.reasoning_tags} instead, which caused parsing error {e}." + ) + + # Make sure format is correct + if not all(isinstance(tag, tuple) and len(tag) == 2 for tag in self.reasoning_tags): + raise ValueError( + "reasoning_tags must be a list of pair tuples, e.g. [('start_tag', 'end_tag'), ...]. " + f"Got {self.reasoning_tags} instead." + ) class Pipeline: @@ -284,9 +307,10 @@ def evaluate(self): else: outputs = self._run_model() - self._compute_metrics(outputs) - if self.is_main_process(): + self._post_process_outputs(outputs) + self._compute_metrics(outputs) + self.evaluation_tracker.general_config_logger.log_end_time() self.evaluation_tracker.metrics_logger.aggregate( task_dict=self.tasks_dict, bootstrap_iters=self.pipeline_parameters.bootstrap_iters @@ -341,6 +365,21 @@ def _run_model(self): return outputs + def _post_process_outputs(self, sampling_method_responses: dict[str, list[ModelResponse]]): + # Removes reasoning tags if needed + logger.info("--- POST-PROCESSING MODEL RESPONSES ---") + + if self.pipeline_parameters.remove_reasoning_tags: + for _, responses in sampling_method_responses.items(): + for response in responses: + response.text_post_processed = [ + remove_reasoning_tags( + text=text, + tag_pairs=self.pipeline_parameters.reasoning_tags, + ) + for text in response.text + ] + def _compute_metrics(self, sampling_method_responses: dict[str, list[ModelResponse]]): # To compute the metrics we first group the samples and task and then by metrics. # This way we can batch the metrics computation for each task and metric category diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py index 815f08289..3d988acf5 100644 --- a/src/lighteval/tasks/default_tasks.py +++ b/src/lighteval/tasks/default_tasks.py @@ -383,6 +383,22 @@ ], version=2, ) +aime24_avg = LightevalTaskConfig( + name="aime24_avg", + suite=["lighteval"], + prompt_function=prompt.aime_prompt_fn, + hf_repo="HuggingFaceH4/aime_2024", + hf_subset="default", + hf_avail_splits=["train"], + evaluation_splits=["train"], + few_shots_split=None, + few_shots_select=None, + generation_size=None, + metrics=[ + Metrics.math_avg_at_64, + ], + version=2, +) aime24_gpassk = LightevalTaskConfig( name="aime24_gpassk", suite=["lighteval"], diff --git a/src/lighteval/tasks/extended/ifeval/main.py b/src/lighteval/tasks/extended/ifeval/main.py index 79b283dbf..78bbe6e3f 100644 --- a/src/lighteval/tasks/extended/ifeval/main.py +++ b/src/lighteval/tasks/extended/ifeval/main.py @@ -32,7 +32,6 @@ from lighteval.models.model_output import ModelResponse from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.utils import remove_reasoning_tags # Very specific task where there are no precise outputs but instead we test if the format obeys rules @@ -60,9 +59,7 @@ def ifeval_prompt(line, task_name: str = ""): def ifeval_metric(doc: Doc, model_response: ModelResponse, **kwargs) -> dict: - response = model_response.text[0] - # Remove the reasoning block to avoid false negatives: https://github.com/huggingface/lighteval/issues/790 - response = remove_reasoning_tags(response, REASONING_TAG_PAIRS) + response = model_response.final_text[0] # Strict instructions instruction_list = doc.specific["instructions_id_list"] diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index f5e63c351..42ba3408e 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -26,7 +26,7 @@ from dataclasses import asdict, dataclass, field from typing import Callable -from datasets import DatasetDict +from datasets import DatasetDict, load_dataset from huggingface_hub import TextGenerationInputGrammarType from multiprocess import Pool from pytablewriter import MarkdownTableWriter @@ -36,7 +36,7 @@ from lighteval.tasks.requests import ( Doc, ) -from lighteval.utils.utils import ListLike, as_list, download_dataset_worker +from lighteval.utils.utils import ListLike, as_list logger = logging.getLogger(__name__) @@ -241,13 +241,7 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]: list[Doc]: List of documents. """ if self.dataset is None: - self.dataset = download_dataset_worker( - self.dataset_path, - self.dataset_config_name, - self.trust_dataset, - self.dataset_filter, - self.dataset_revision, - ) + self.dataset = self.download_dataset_worker(self) assert self.dataset is not None, f"Dataset {self.dataset_path} not found." @@ -356,35 +350,43 @@ def load_datasets(tasks: dict[str, "LightevalTask"], dataset_loading_processes: """ if dataset_loading_processes <= 1: - datasets = [ - download_dataset_worker( - task.dataset_path, - task.dataset_config_name, - task.trust_dataset, - task.dataset_filter, - task.dataset_revision, - ) - for task in tasks.values() - ] + # Useful for the test suite: we can mock loading tasks by overwriting the + # individual download_dataset_worker functions + datasets = [task.download_dataset_worker(task) for task in tasks.values()] else: with Pool(processes=dataset_loading_processes) as pool: datasets = pool.starmap( - download_dataset_worker, - [ - ( - task.dataset_path, - task.dataset_config_name, - task.trust_dataset, - task.dataset_filter, - task.dataset_revision, - ) - for task in tasks.values() - ], + LightevalTask.download_dataset_worker, + [tasks.values()], ) for task, dataset in zip(tasks, datasets): tasks[task].dataset = dataset + @staticmethod + def download_dataset_worker( + task: "LightevalTask", + ) -> DatasetDict: + """ + Worker function to download a dataset from the HuggingFace Hub. + Used for parallel dataset loading. + """ + dataset = load_dataset( + path=task.dataset_path, + name=task.dataset_config_name, + data_dir=None, + cache_dir=None, + download_mode=None, + trust_remote_code=task.trust_dataset, + revision=task.dataset_revision, + ) + + if task.dataset_filter is not None: + dataset = dataset.filter(task.dataset_filter) + + # It returns DatasetDict because we don't specify a split + return dataset # type: ignore + def extract_num_samples(metric_name: str) -> int: """Gets the number of samples to generate from the metric name. diff --git a/src/lighteval/utils/utils.py b/src/lighteval/utils/utils.py index 28e0ac4a4..115987ba7 100644 --- a/src/lighteval/utils/utils.py +++ b/src/lighteval/utils/utils.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import asdict, is_dataclass -from typing import Callable, TypeVar, Union +from typing import TypeVar, Union import numpy as np -from datasets import DatasetDict, load_dataset from pytablewriter import MarkdownTableWriter @@ -199,34 +198,6 @@ def boolstring_to_bool(x: Union[str, bool, int]) -> Union[bool, None]: raise ValueError(f"You tried to convert {x} to a boolean but it's not possible.") -def download_dataset_worker( - dataset_path: str, - dataset_config_name: str, - trust_dataset: bool, - dataset_filter: Callable[[dict], bool] | None = None, - revision: str | None = None, -) -> DatasetDict: - """ - Worker function to download a dataset from the HuggingFace Hub. - Used for parallel dataset loading. - """ - dataset = load_dataset( - path=dataset_path, - name=dataset_config_name, - data_dir=None, - cache_dir=None, - download_mode=None, - trust_remote_code=trust_dataset, - revision=revision, - ) - - if dataset_filter is not None: - dataset = dataset.filter(dataset_filter) - - # It returns DatasetDict because we don't specify a split - return dataset # type: ignore - - def safe_divide(numerator: np.ndarray, denominator: float, default_value: float = 0.0) -> np.ndarray: return np.where(denominator != 0, numerator / denominator, default_value) diff --git a/tests/pipeline/test_reasoning_tags.py b/tests/pipeline/test_reasoning_tags.py new file mode 100644 index 000000000..cadddd0d9 --- /dev/null +++ b/tests/pipeline/test_reasoning_tags.py @@ -0,0 +1,416 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import tempfile +import unittest +from pathlib import Path +from types import ModuleType +from typing import Optional, Union +from unittest.mock import patch + +from lighteval.logging.evaluation_tracker import EvaluationTracker +from lighteval.metrics.metrics import Metrics +from lighteval.models.dummy.dummy_model import DummyModel, DummyModelConfig +from lighteval.models.model_output import ModelResponse +from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters +from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig +from lighteval.tasks.registry import Registry +from lighteval.tasks.requests import Doc, SamplingMethod +from lighteval.utils.imports import is_accelerate_available + + +class TestPipelineReasoningTags(unittest.TestCase): + """Test suite for pipeline reasoning tags functionality using DummyModel.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + + # Create a simple test task + self.task_config = LightevalTaskConfig( + name="test_reasoning_task", + suite=["test"], + prompt_function=lambda x: x, + hf_repo="test_repo", + hf_subset="default", + metrics=[Metrics.exact_match], + hf_avail_splits=["test"], + evaluation_splits=["test"], + few_shots_split=None, + few_shots_select=None, + generation_size=10, + stop_sequence=["\n"], + trust_dataset=True, + num_fewshots=0, + ) + + # Create test documents with reasoning tags in expected responses + self.test_docs = [ + Doc( + task_name="test|test_reasoning_task|0", + query="What is 2+2?", + choices=["4"], + gold_index=[0], + instruction="", + sampling_methods=[SamplingMethod.GENERATIVE], + ), + ] + + # Mock dataset + self.mock_dataset = {"test": self.test_docs} + + def _mock_task_registry(self, task_config, task_docs, responses_with_reasoning_tags): + """Create a fake registry for testing.""" + + class FakeTask(LightevalTask): + def __post_init__(self): + self._docs = task_docs + + def get_docs(self, max_samples=None): + return task_docs + + @staticmethod + def download_dataset_worker(task) -> None: + # Mock dataset loading + return task._docs + + class FakeRegistry(Registry): + def __init__(self, custom_tasks: Optional[Union[str, Path, ModuleType]] = None): + super().__init__(custom_tasks=custom_tasks) + + def get_tasks_configs(self, task: str): + return [task_config] + + def get_tasks_from_configs(self, tasks_configs): + return {f"{task_config.suite[0]}|{task_config.full_name}": FakeTask(task_config)} + + # Create a DummyModel that returns responses with reasoning tags + class TestDummyModel(DummyModel): + def __init__(self, config): + super().__init__(config) + + def greedy_until(self, docs): + # Return responses with reasoning tags + return responses_with_reasoning_tags + + return FakeRegistry, TestDummyModel + + def test_remove_reasoning_tags_enabled(self): + """Test that reasoning tags are removed when remove_reasoning_tags=True.""" + + # Responses with reasoning tags + responses_with_reasoning = [ + ModelResponse(text=["Let me think about this... 2+2=4The answer is 4"]) + ] + + FakeRegistry, TestDummyModel = self._mock_task_registry( + self.task_config, self.test_docs, responses_with_reasoning + ) + + # Initialize accelerator if available + if is_accelerate_available(): + from accelerate import Accelerator + + Accelerator() + + with patch("lighteval.pipeline.Registry", FakeRegistry): + # Create pipeline with reasoning tag removal enabled + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.NONE, + remove_reasoning_tags=True, + reasoning_tags=[("", "")], + max_samples=1, + ) + + evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir) + model = TestDummyModel(DummyModelConfig(seed=42)) + + pipeline = Pipeline( + tasks="test|test_reasoning_task|0|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=model, + ) + + # Run the pipeline + pipeline.evaluate() + + # Check that reasoning tags were removed from post-processed text + details = pipeline.evaluation_tracker.details + self.assertEqual( + details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"], ["The answer is 4"] + ) + + def test_remove_reasoning_tags_enabled_tags_as_string(self): + """Test that reasoning tags are removed when remove_reasoning_tags=True.""" + + # Responses with reasoning tags + responses_with_reasoning = [ + ModelResponse(text=["Let me think about this... 2+2=4The answer is 4"]) + ] + + FakeRegistry, TestDummyModel = self._mock_task_registry( + self.task_config, self.test_docs, responses_with_reasoning + ) + + # Initialize accelerator if available + if is_accelerate_available(): + from accelerate import Accelerator + + Accelerator() + + with patch("lighteval.pipeline.Registry", FakeRegistry): + # Create pipeline with reasoning tag removal enabled + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.NONE, + remove_reasoning_tags=True, + reasoning_tags='[("", "")]', + max_samples=1, + ) + + evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir) + model = TestDummyModel(DummyModelConfig(seed=42)) + + pipeline = Pipeline( + tasks="test|test_reasoning_task|0|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=model, + ) + + # Run the pipeline + pipeline.evaluate() + + # Check that reasoning tags were removed from post-processed text + details = pipeline.evaluation_tracker.details + self.assertEqual( + details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"], ["The answer is 4"] + ) + + def test_remove_reasoning_tags_enabled_default_tags(self): + """Test that reasoning tags are removed when remove_reasoning_tags=True.""" + + # Responses with reasoning tags + responses_with_reasoning = [ + ModelResponse(text=["Let me think about this... 2+2=4The answer is 4"]) + ] + + FakeRegistry, TestDummyModel = self._mock_task_registry( + self.task_config, self.test_docs, responses_with_reasoning + ) + + # Initialize accelerator if available + if is_accelerate_available(): + from accelerate import Accelerator + + Accelerator() + + with patch("lighteval.pipeline.Registry", FakeRegistry): + # Create pipeline with reasoning tag removal enabled + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.NONE, remove_reasoning_tags=True, max_samples=1 + ) + + evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir) + model = TestDummyModel(DummyModelConfig(seed=42)) + + pipeline = Pipeline( + tasks="test|test_reasoning_task|0|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=model, + ) + + # Run the pipeline + pipeline.evaluate() + + # Check that reasoning tags were removed from post-processed text + details = pipeline.evaluation_tracker.details + self.assertEqual( + details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"], ["The answer is 4"] + ) + + def test_remove_reasoning_tags_disabled(self): + """Test that reasoning tags are preserved when remove_reasoning_tags=False.""" + + # Responses with reasoning tags + responses_with_reasoning = [ + ModelResponse(text=["Let me think about this... 2+2=4The answer is 4"]) + ] + + FakeRegistry, TestDummyModel = self._mock_task_registry( + self.task_config, self.test_docs, responses_with_reasoning + ) + + # Initialize accelerator if available + if is_accelerate_available(): + from accelerate import Accelerator + + Accelerator() + + with patch("lighteval.pipeline.Registry", FakeRegistry): + # Create pipeline with reasoning tag removal disabled + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.NONE, + remove_reasoning_tags=False, + reasoning_tags=[("", "")], + max_samples=1, + ) + + evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir) + model = TestDummyModel(DummyModelConfig(seed=42)) + + pipeline = Pipeline( + tasks="test|test_reasoning_task|0|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=model, + ) + + # Run the pipeline + pipeline.evaluate() + + # Check that post-processed text is None (= no post processing happened) + details = pipeline.evaluation_tracker.details + self.assertIsNone( + details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"], + ) + + def test_custom_reasoning_tags(self): + """Test that custom reasoning tags are correctly applied.""" + + # Responses with custom reasoning tags + responses_with_reasoning = [ + ModelResponse(text=["[reasoning]This is my thought process[/reasoning]Final answer: 4"]) + ] + + FakeRegistry, TestDummyModel = self._mock_task_registry( + self.task_config, self.test_docs, responses_with_reasoning + ) + + # Initialize accelerator if available + if is_accelerate_available(): + from accelerate import Accelerator + + Accelerator() + + with patch("lighteval.pipeline.Registry", FakeRegistry): + # Create pipeline with custom reasoning tags + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.NONE, + remove_reasoning_tags=True, + reasoning_tags=[("[reasoning]", "[/reasoning]")], + max_samples=1, + ) + + evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir) + model = TestDummyModel(DummyModelConfig(seed=42)) + + pipeline = Pipeline( + tasks="test|test_reasoning_task|0|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=model, + ) + + # Run the pipeline + pipeline.evaluate() + + # Check that reasoning tags were removed from post-processed text + details = pipeline.evaluation_tracker.details + self.assertEqual( + details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"], ["Final answer: 4"] + ) + + def test_multiple_reasoning_tags(self): + """Test that multiple reasoning tag pairs are correctly handled.""" + + # Responses with multiple reasoning tag types + responses_with_reasoning = [ + ModelResponse(text=["First thoughtSome textSecond thoughtFinal: 4"]) + ] + + FakeRegistry, TestDummyModel = self._mock_task_registry( + self.task_config, self.test_docs, responses_with_reasoning + ) + + # Initialize accelerator if available + if is_accelerate_available(): + from accelerate import Accelerator + + Accelerator() + + with patch("lighteval.pipeline.Registry", FakeRegistry): + # Create pipeline with multiple reasoning tag pairs + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.NONE, + remove_reasoning_tags=True, + reasoning_tags='[("", ""), ("", "")]', + max_samples=1, + ) + + evaluation_tracker = EvaluationTracker(output_dir=self.temp_dir) + model = TestDummyModel(DummyModelConfig(seed=42)) + + pipeline = Pipeline( + tasks="test|test|test_reasoning_task|0|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=model, + ) + + # Run the pipeline + pipeline.evaluate() + + # Check that reasoning tags were removed from post-processed text + details = pipeline.evaluation_tracker.details + self.assertEqual( + details["test|test_reasoning_task|0"][0]["model_response"]["text_post_processed"], + ["Some textFinal: 4"], + ) + + def test_reasoning_tags_validation(self): + """Test that invalid reasoning_tags parameter raises appropriate error.""" + + for test_string in ["['incorrect_format']", "invalid_format"]: + with self.assertRaises(ValueError) as context: + PipelineParameters( + launcher_type=ParallelismManager.NONE, + reasoning_tags=test_string, # Should be a list of tuples + ) + + # Check that the error message mentions the expected format + print(context.__dict__) + self.assertIn("reasoning_tags must be a list of pair tuples", str(context.exception)) + + def test_default_reasoning_tags(self): + """Test that default reasoning tags are correctly set.""" + + pipeline_params = PipelineParameters(launcher_type=ParallelismManager.NONE) + + # Check that default reasoning tags are set + self.assertEqual(pipeline_params.reasoning_tags, [("", "")]) + self.assertTrue(pipeline_params.remove_reasoning_tags) + + +if __name__ == "__main__": + unittest.main()