From d9de61d96beeb83b4f2ba0eed1fe553104bbeea7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Fri, 25 Jul 2025 10:10:24 +0200 Subject: [PATCH 01/21] init --- src/lighteval/tasks/default_prompts.py | 24 +++++++ src/lighteval/tasks/default_tasks.py | 96 ++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) diff --git a/src/lighteval/tasks/default_prompts.py b/src/lighteval/tasks/default_prompts.py index 7f1544e98..261bcd668 100644 --- a/src/lighteval/tasks/default_prompts.py +++ b/src/lighteval/tasks/default_prompts.py @@ -1862,6 +1862,30 @@ def mmlu_helm(line, task_name: str = None): ) +def mmlu_redux_2(line, topic, task_name: str = None): + """ + MMLU-Redux-2 prompt function. + The dataset uses integer indices for answers and has additional metadata fields. + """ + query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" + query += line["question"] + "\n" + query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) + query += "Answer: " + + # Handle answer format - MMLU-Redux-2 uses integer indices directly + gold_ix = line["answer"] if isinstance(line["answer"], int) else int(line["answer"]) + is_few_shots = line.get("__few_shots", False) + + return Doc( + task_name=task_name, + query=query, + choices=LETTER_INDICES[:len(line["choices"])], + gold_index=gold_ix, + instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", + target_for_fewshot_sorting=LETTER_INDICES[gold_ix] if not is_few_shots else None, + ) + + def mmlu_qa_abstract_algebra(line, task_name: str = None): return mmlu_qa(line, "abstract_algebra", task_name) diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py index 815f08289..ef15bd140 100644 --- a/src/lighteval/tasks/default_tasks.py +++ b/src/lighteval/tasks/default_tasks.py @@ -21905,3 +21905,99 @@ trust_dataset=True, version=0, ) + +# MMLU-Redux-2 Tasks +_MMLU_REDUX_2_SUBSETS = [ + "abstract_algebra", "anatomy", "astronomy", "business_ethics", "clinical_knowledge", + "college_biology", "college_chemistry", "college_computer_science", "college_mathematics", + "college_medicine", "college_physics", "computer_security", "conceptual_physics", + "econometrics", "electrical_engineering", "elementary_mathematics", "formal_logic", + "global_facts", "high_school_biology", "high_school_chemistry", "high_school_computer_science", + "high_school_european_history", "high_school_geography", "high_school_government_and_politics", + "high_school_macroeconomics", "high_school_mathematics", "high_school_microeconomics", + "high_school_physics", "high_school_psychology", "high_school_statistics", + "high_school_us_history", "high_school_world_history", "human_aging", "human_sexuality", + "international_law", "jurisprudence", "logical_fallacies", "machine_learning", + "management", "marketing", "medical_genetics", "miscellaneous", "moral_disputes", + "moral_scenarios", "nutrition", "philosophy", "prehistory", "professional_accounting", + "professional_law", "professional_medicine", "professional_psychology", "public_relations", + "security_studies", "sociology", "us_foreign_policy", "virology", "world_religions" +] + +_mmlu_redux_2_tasks = { + subset: LightevalTaskConfig( + name=f"mmlu_redux_2:{subset}", + suite=["lighteval"], + prompt_function=lambda line, task_name=None, s=subset: prompt.mmlu_redux_2(line, s, task_name), + hf_repo="edinburgh-dawg/mmlu-redux-2.0", + hf_subset=subset, + hf_avail_splits=["test"], + evaluation_splits=["test"], + few_shots_split=None, + few_shots_select=None, + generation_size=1, + metrics=[Metrics.loglikelihood_acc], + stop_sequence=["\n"], + trust_dataset=True, + version=0, + ) + for subset in _MMLU_REDUX_2_SUBSETS +} + +mmlu_redux_2_abstract_algebra = _mmlu_redux_2_tasks["abstract_algebra"] +mmlu_redux_2_anatomy = _mmlu_redux_2_tasks["anatomy"] +mmlu_redux_2_astronomy = _mmlu_redux_2_tasks["astronomy"] +mmlu_redux_2_business_ethics = _mmlu_redux_2_tasks["business_ethics"] +mmlu_redux_2_clinical_knowledge = _mmlu_redux_2_tasks["clinical_knowledge"] +mmlu_redux_2_college_biology = _mmlu_redux_2_tasks["college_biology"] +mmlu_redux_2_college_chemistry = _mmlu_redux_2_tasks["college_chemistry"] +mmlu_redux_2_college_computer_science = _mmlu_redux_2_tasks["college_computer_science"] +mmlu_redux_2_college_mathematics = _mmlu_redux_2_tasks["college_mathematics"] +mmlu_redux_2_college_medicine = _mmlu_redux_2_tasks["college_medicine"] +mmlu_redux_2_college_physics = _mmlu_redux_2_tasks["college_physics"] +mmlu_redux_2_computer_security = _mmlu_redux_2_tasks["computer_security"] +mmlu_redux_2_conceptual_physics = _mmlu_redux_2_tasks["conceptual_physics"] +mmlu_redux_2_econometrics = _mmlu_redux_2_tasks["econometrics"] +mmlu_redux_2_electrical_engineering = _mmlu_redux_2_tasks["electrical_engineering"] +mmlu_redux_2_elementary_mathematics = _mmlu_redux_2_tasks["elementary_mathematics"] +mmlu_redux_2_formal_logic = _mmlu_redux_2_tasks["formal_logic"] +mmlu_redux_2_global_facts = _mmlu_redux_2_tasks["global_facts"] +mmlu_redux_2_high_school_biology = _mmlu_redux_2_tasks["high_school_biology"] +mmlu_redux_2_high_school_chemistry = _mmlu_redux_2_tasks["high_school_chemistry"] +mmlu_redux_2_high_school_computer_science = _mmlu_redux_2_tasks["high_school_computer_science"] +mmlu_redux_2_high_school_european_history = _mmlu_redux_2_tasks["high_school_european_history"] +mmlu_redux_2_high_school_geography = _mmlu_redux_2_tasks["high_school_geography"] +mmlu_redux_2_high_school_government_and_politics = _mmlu_redux_2_tasks["high_school_government_and_politics"] +mmlu_redux_2_high_school_macroeconomics = _mmlu_redux_2_tasks["high_school_macroeconomics"] +mmlu_redux_2_high_school_mathematics = _mmlu_redux_2_tasks["high_school_mathematics"] +mmlu_redux_2_high_school_microeconomics = _mmlu_redux_2_tasks["high_school_microeconomics"] +mmlu_redux_2_high_school_physics = _mmlu_redux_2_tasks["high_school_physics"] +mmlu_redux_2_high_school_psychology = _mmlu_redux_2_tasks["high_school_psychology"] +mmlu_redux_2_high_school_statistics = _mmlu_redux_2_tasks["high_school_statistics"] +mmlu_redux_2_high_school_us_history = _mmlu_redux_2_tasks["high_school_us_history"] +mmlu_redux_2_high_school_world_history = _mmlu_redux_2_tasks["high_school_world_history"] +mmlu_redux_2_human_aging = _mmlu_redux_2_tasks["human_aging"] +mmlu_redux_2_human_sexuality = _mmlu_redux_2_tasks["human_sexuality"] +mmlu_redux_2_international_law = _mmlu_redux_2_tasks["international_law"] +mmlu_redux_2_jurisprudence = _mmlu_redux_2_tasks["jurisprudence"] +mmlu_redux_2_logical_fallacies = _mmlu_redux_2_tasks["logical_fallacies"] +mmlu_redux_2_machine_learning = _mmlu_redux_2_tasks["machine_learning"] +mmlu_redux_2_management = _mmlu_redux_2_tasks["management"] +mmlu_redux_2_marketing = _mmlu_redux_2_tasks["marketing"] +mmlu_redux_2_medical_genetics = _mmlu_redux_2_tasks["medical_genetics"] +mmlu_redux_2_miscellaneous = _mmlu_redux_2_tasks["miscellaneous"] +mmlu_redux_2_moral_disputes = _mmlu_redux_2_tasks["moral_disputes"] +mmlu_redux_2_moral_scenarios = _mmlu_redux_2_tasks["moral_scenarios"] +mmlu_redux_2_nutrition = _mmlu_redux_2_tasks["nutrition"] +mmlu_redux_2_philosophy = _mmlu_redux_2_tasks["philosophy"] +mmlu_redux_2_prehistory = _mmlu_redux_2_tasks["prehistory"] +mmlu_redux_2_professional_accounting = _mmlu_redux_2_tasks["professional_accounting"] +mmlu_redux_2_professional_law = _mmlu_redux_2_tasks["professional_law"] +mmlu_redux_2_professional_medicine = _mmlu_redux_2_tasks["professional_medicine"] +mmlu_redux_2_professional_psychology = _mmlu_redux_2_tasks["professional_psychology"] +mmlu_redux_2_public_relations = _mmlu_redux_2_tasks["public_relations"] +mmlu_redux_2_security_studies = _mmlu_redux_2_tasks["security_studies"] +mmlu_redux_2_sociology = _mmlu_redux_2_tasks["sociology"] +mmlu_redux_2_us_foreign_policy = _mmlu_redux_2_tasks["us_foreign_policy"] +mmlu_redux_2_virology = _mmlu_redux_2_tasks["virology"] +mmlu_redux_2_world_religions = _mmlu_redux_2_tasks["world_religions"] From a2b1bad263f05727a4f0432edbd336ddea1ae0a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Fri, 25 Jul 2025 11:29:05 +0200 Subject: [PATCH 02/21] Update src/lighteval/tasks/default_prompts.py --- src/lighteval/tasks/default_prompts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lighteval/tasks/default_prompts.py b/src/lighteval/tasks/default_prompts.py index 261bcd668..e0bbbd6f1 100644 --- a/src/lighteval/tasks/default_prompts.py +++ b/src/lighteval/tasks/default_prompts.py @@ -1864,8 +1864,7 @@ def mmlu_helm(line, task_name: str = None): def mmlu_redux_2(line, topic, task_name: str = None): """ - MMLU-Redux-2 prompt function. - The dataset uses integer indices for answers and has additional metadata fields. + Ref: https://arxiv.org/abs/2406.04127 """ query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" query += line["question"] + "\n" From 8429f5fbf7332388f41bf68b995535ee9a9f11de Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Fri, 1 Aug 2025 09:55:15 +0000 Subject: [PATCH 03/21] small fixes --- src/lighteval/models/vllm/vllm_model.py | 3 +- src/lighteval/tasks/default_prompts.py | 4 +- src/lighteval/tasks/default_tasks.py | 71 ++++++++++++++++++++----- 3 files changed, 60 insertions(+), 18 deletions(-) diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index f35eff8d9..ff9fc9001 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -461,7 +461,8 @@ def _loglikelihood_tokens( tokenized_contexts_batch.append(tokenized_context) # Left truncate the inputs to the maximum length - inputs = [input[-self.max_length :] for input in inputs] + if self.max_length: # can be None if the model is initialized with ray + inputs = [input[-self.max_length :] for input in inputs] outputs = self._generate(inputs, generate=False) flat_index = 0 diff --git a/src/lighteval/tasks/default_prompts.py b/src/lighteval/tasks/default_prompts.py index e0bbbd6f1..385a5a407 100644 --- a/src/lighteval/tasks/default_prompts.py +++ b/src/lighteval/tasks/default_prompts.py @@ -1873,15 +1873,13 @@ def mmlu_redux_2(line, topic, task_name: str = None): # Handle answer format - MMLU-Redux-2 uses integer indices directly gold_ix = line["answer"] if isinstance(line["answer"], int) else int(line["answer"]) - is_few_shots = line.get("__few_shots", False) return Doc( task_name=task_name, query=query, - choices=LETTER_INDICES[:len(line["choices"])], + choices=LETTER_INDICES[: len(line["choices"])], gold_index=gold_ix, instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", - target_for_fewshot_sorting=LETTER_INDICES[gold_ix] if not is_few_shots else None, ) diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py index ef15bd140..42d868538 100644 --- a/src/lighteval/tasks/default_tasks.py +++ b/src/lighteval/tasks/default_tasks.py @@ -21908,20 +21908,63 @@ # MMLU-Redux-2 Tasks _MMLU_REDUX_2_SUBSETS = [ - "abstract_algebra", "anatomy", "astronomy", "business_ethics", "clinical_knowledge", - "college_biology", "college_chemistry", "college_computer_science", "college_mathematics", - "college_medicine", "college_physics", "computer_security", "conceptual_physics", - "econometrics", "electrical_engineering", "elementary_mathematics", "formal_logic", - "global_facts", "high_school_biology", "high_school_chemistry", "high_school_computer_science", - "high_school_european_history", "high_school_geography", "high_school_government_and_politics", - "high_school_macroeconomics", "high_school_mathematics", "high_school_microeconomics", - "high_school_physics", "high_school_psychology", "high_school_statistics", - "high_school_us_history", "high_school_world_history", "human_aging", "human_sexuality", - "international_law", "jurisprudence", "logical_fallacies", "machine_learning", - "management", "marketing", "medical_genetics", "miscellaneous", "moral_disputes", - "moral_scenarios", "nutrition", "philosophy", "prehistory", "professional_accounting", - "professional_law", "professional_medicine", "professional_psychology", "public_relations", - "security_studies", "sociology", "us_foreign_policy", "virology", "world_religions" + "abstract_algebra", + "anatomy", + "astronomy", + "business_ethics", + "clinical_knowledge", + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + "college_medicine", + "college_physics", + "computer_security", + "conceptual_physics", + "econometrics", + "electrical_engineering", + "elementary_mathematics", + "formal_logic", + "global_facts", + "high_school_biology", + "high_school_chemistry", + "high_school_computer_science", + "high_school_european_history", + "high_school_geography", + "high_school_government_and_politics", + "high_school_macroeconomics", + "high_school_mathematics", + "high_school_microeconomics", + "high_school_physics", + "high_school_psychology", + "high_school_statistics", + "high_school_us_history", + "high_school_world_history", + "human_aging", + "human_sexuality", + "international_law", + "jurisprudence", + "logical_fallacies", + "machine_learning", + "management", + "marketing", + "medical_genetics", + "miscellaneous", + "moral_disputes", + "moral_scenarios", + "nutrition", + "philosophy", + "prehistory", + "professional_accounting", + "professional_law", + "professional_medicine", + "professional_psychology", + "public_relations", + "security_studies", + "sociology", + "us_foreign_policy", + "virology", + "world_religions", ] _mmlu_redux_2_tasks = { From 1e139ab9f580b1b4ba829f73af76c8b44136def7 Mon Sep 17 00:00:00 2001 From: Nathan Habib <30601243+NathanHB@users.noreply.github.com> Date: Mon, 25 Aug 2025 14:29:40 +0200 Subject: [PATCH 04/21] Apply suggestion from @NathanHB --- src/lighteval/tasks/default_tasks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py index 1b28963ec..43bb30e21 100644 --- a/src/lighteval/tasks/default_tasks.py +++ b/src/lighteval/tasks/default_tasks.py @@ -22803,7 +22803,6 @@ generation_size=1, metrics=[Metrics.loglikelihood_acc], stop_sequence=["\n"], - trust_dataset=True, version=0, ) for subset in _MMLU_REDUX_2_SUBSETS From b5975462c52e517eeb011f7fd7035d01f2183cde Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Thu, 4 Sep 2025 12:42:32 +0000 Subject: [PATCH 05/21] fix metrics kwargs passing --- src/lighteval/metrics/metrics_sample.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index ce2005c1b..56d7dd696 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -1161,7 +1161,7 @@ def __init__(self, k: int | None = None, **kwargs): sample_scoring_function (callable | str, optional): Function to use to compute the score for each sample. If None, uses the default scoring function which is a simple exact match. """ - super().__init__(kwargs) + super().__init__(**kwargs) self.k = k self.attribute_must_be_set = ["k"] @@ -1191,7 +1191,7 @@ def num_samples(self): class MajAtK(SamplingMetric, SampleLevelComputation): def __init__(self, k: int = None, **kwargs): """An exact match class.""" - super().__init__(kwargs) + super().__init__(**kwargs) self.k = k self.attribute_must_be_set = ["k"] @@ -1241,7 +1241,7 @@ def __init__(self, k: int | None = None, n: int | None = None, **kwargs): k (int): Threshold for the number of successful attempts. n (int): Number of samples to generate """ - super().__init__(kwargs) + super().__init__(**kwargs) self.k = k self.n = n self.attribute_must_be_set = ["k"] @@ -1269,7 +1269,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float: elif len(predictions) < self.n: logger.warning(f"Number of predictions is less than {self.n} for pass@k.") - processed_choices = [self.preprocess(gold=g) for g in doc.choices] + processed_choices = [self.preprocess(g) for g in doc.choices] new_doc = Doc( choices=processed_choices, query=doc.query, @@ -1278,7 +1278,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float: all_scores = [] for pred in predictions[: self.n]: - cur_pred = self.preprocess(pred=pred) + cur_pred = self.preprocess(pred) new_model_response = ModelResponse( text=[cur_pred], ) @@ -1314,7 +1314,7 @@ def __init__( n (int): Number of samples to generate. thresholds (list): Thresholds to control successful attempts in k generate. """ - super().__init__(kwargs) + super().__init__(**kwargs) self._k = k self.n = n self.attribute_must_be_set = ["k"] From b0e55846358908767fd1f6bc6cce2f05eb9109c1 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Thu, 4 Sep 2025 12:42:56 +0000 Subject: [PATCH 06/21] add default metric for mmlu_redux --- src/lighteval/tasks/default_tasks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py index 43bb30e21..7092264ad 100644 --- a/src/lighteval/tasks/default_tasks.py +++ b/src/lighteval/tasks/default_tasks.py @@ -22789,6 +22789,7 @@ "world_religions", ] + _mmlu_redux_2_tasks = { subset: LightevalTaskConfig( name=f"mmlu_redux_2:{subset}", @@ -22801,7 +22802,10 @@ few_shots_split=None, few_shots_select=None, generation_size=1, - metrics=[Metrics.loglikelihood_acc], + metrics=[ + Metrics.loglikelihood_acc, + Metrics.pass_at_k_letters(sample_params={"k": 1}), + ], stop_sequence=["\n"], version=0, ) From 951cbc0cc794f13097d9e9712b67799cbf84a414 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Thu, 4 Sep 2025 13:05:20 +0000 Subject: [PATCH 07/21] fix --- src/lighteval/metrics/metrics_sample.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 56d7dd696..79748bf50 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -1110,18 +1110,20 @@ def __init__( self.strip_strings = strip_strings if callable(sample_scoring_function): - self.score_sample = sample_scoring_function + self.compute_score = sample_scoring_function self.type_exact_match = None - else: - if isinstance(sample_scoring_function, str): - if sample_scoring_function not in ["prefix", "suffix", "full"]: - raise ValueError( - f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead." - ) - self.type_exact_match = sample_scoring_function - else: - self.type_exact_match = "full" + elif isinstance(sample_scoring_function, str) or sample_scoring_function is None: + if sample_scoring_function is None: + sample_scoring_function = "full" + if sample_scoring_function not in ["prefix", "suffix", "full"]: + raise ValueError( + f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead." + ) + self.type_exact_match = sample_scoring_function self.compute_score = self.default_sample_scoring + else: # class + self.type_exact_match = None + self.compute_score = sample_scoring_function.compute def preprocess(self, text: str) -> str: if not text: From 96df0e7f6bd062af31d1eff6b00a18d0d2cfa601 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Mon, 8 Sep 2025 14:59:34 +0000 Subject: [PATCH 08/21] update caching" --- docs/source/evaluating-a-custom-model.mdx | 10 +- src/lighteval/models/dummy/dummy_model.py | 8 +- .../models/endpoints/endpoint_model.py | 8 +- .../endpoints/inference_providers_model.py | 8 +- .../models/endpoints/litellm_model.py | 8 +- .../models/nanotron/nanotron_model.py | 7 +- src/lighteval/models/sglang/sglang_model.py | 8 +- .../models/transformers/transformers_model.py | 8 +- .../transformers/vlm_transformers_model.py | 8 +- src/lighteval/models/vllm/vllm_model.py | 13 +- src/lighteval/pipeline.py | 8 +- src/lighteval/utils/cache_management.py | 112 ++++++++++++------ 12 files changed, 127 insertions(+), 79 deletions(-) diff --git a/docs/source/evaluating-a-custom-model.mdx b/docs/source/evaluating-a-custom-model.mdx index 1b055dedd..10e8d855a 100644 --- a/docs/source/evaluating-a-custom-model.mdx +++ b/docs/source/evaluating-a-custom-model.mdx @@ -14,7 +14,7 @@ Here's a basic example: from typing import List from lighteval.models.abstract_model import LightevalModel from lighteval.models.model_output import ModelResponse -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached class MyCustomModel(LightevalModel): @@ -25,17 +25,17 @@ class MyCustomModel(LightevalModel): # Enable caching (recommended) self._cache = SampleCache(config) - @cached("predictions") # Enable caching for better performance + @cached("predictions", SamplingMethod.GENERATIVE) # Enable caching for better performance def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]: # Implement generation logic pass - @cached("predictions") # Enable caching for better performance + @cached("loglikelihood", SamplingMethod.LOGPROBS) # Enable caching for better performance def loglikelihood(self, docs: List[Doc]) -> List[ModelResponse]: # Implement loglikelihood computation pass - @cached("predictions") # Enable caching for better performance + @cached("loglikelihood", SamplingMethod.LOGPROBS) # Enable caching for better performance def loglikelihood_rolling(self, docs: List[Doc]) -> List[ModelResponse]: # Implement rolling loglikelihood computation pass @@ -130,7 +130,7 @@ To enable caching in your custom model: 3. **Add cache decorators** to your prediction methods: ```python - @cached("predictions") + @cached("predictions", SamplingMethod.GENERATIVE) def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]: # Your implementation... ``` diff --git a/src/lighteval/models/dummy/dummy_model.py b/src/lighteval/models/dummy/dummy_model.py index 6971b6daa..8856cdcbd 100644 --- a/src/lighteval/models/dummy/dummy_model.py +++ b/src/lighteval/models/dummy/dummy_model.py @@ -28,7 +28,7 @@ from lighteval.models.abstract_model import LightevalModel, ModelConfig from lighteval.models.model_output import ModelResponse -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached @@ -88,11 +88,11 @@ def add_special_tokens(self): def max_length(self) -> int: return 2048 - @cached("predictions") + @cached("predictions", SamplingMethod.GENERATIVE) def greedy_until(self, docs: list[Doc]) -> list[ModelResponse]: return [ModelResponse(text=["random baseline"]) for _ in range(len(docs))] - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: model_responses = [] for doc in docs: @@ -105,7 +105,7 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return model_responses - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: model_responses = [] for doc in docs: diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index f2dc2c03b..e783424d5 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -48,7 +48,7 @@ from lighteval.models.abstract_model import LightevalModel, ModelConfig from lighteval.models.model_output import ModelResponse from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached @@ -545,7 +545,7 @@ def _process_batch_logprob(self, docs: list[Doc], rolling: bool = False) -> list for context, doc in zip(contexts, docs) ] - @cached("predictions") + @cached("predictions", SamplingMethod.GENERATIVE) def greedy_until( self, docs: List[Doc], @@ -589,11 +589,11 @@ def _greedy_until(self, docs: List[Doc]) -> list[ModelResponse]: return dataset.get_original_order(results) - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return self._loglikelihood(docs, rolling=False) - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc], override_bs=None) -> list[ModelResponse]: return self._loglikelihood(docs, rolling=True) diff --git a/src/lighteval/models/endpoints/inference_providers_model.py b/src/lighteval/models/endpoints/inference_providers_model.py index 3c0d025fe..5c8a90029 100644 --- a/src/lighteval/models/endpoints/inference_providers_model.py +++ b/src/lighteval/models/endpoints/inference_providers_model.py @@ -35,7 +35,7 @@ from lighteval.models.abstract_model import LightevalModel, ModelConfig from lighteval.models.model_output import ModelResponse from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached @@ -191,7 +191,7 @@ async def bounded_api_call(prompt, num_samples): return results - @cached("predictions") + @cached("predictions", SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -250,14 +250,14 @@ def max_length(self) -> int: logger.warning("Tokenizer was not correctly loaded. Max model context length is assumed to be 30K tokens") return 30000 - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" raise NotImplementedError diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index 20fb87d04..f20bbc269 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -30,7 +30,7 @@ from lighteval.models.abstract_model import LightevalModel, ModelConfig from lighteval.models.model_output import ModelResponse from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import is_litellm_available @@ -254,7 +254,7 @@ def __call_api_parallel( return results - @cached("predictions") + @cached("predictions", SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -321,14 +321,14 @@ def max_length(self) -> int: """Return the maximum sequence length of the model.""" return 4096 - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" raise NotImplementedError diff --git a/src/lighteval/models/nanotron/nanotron_model.py b/src/lighteval/models/nanotron/nanotron_model.py index 7c80aebf9..1ac9c26e2 100644 --- a/src/lighteval/models/nanotron/nanotron_model.py +++ b/src/lighteval/models/nanotron/nanotron_model.py @@ -48,6 +48,7 @@ from lighteval.models.transformers.transformers_model import LightevalModel from lighteval.tasks.requests import ( Doc, + SamplingMethod, ) from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import is_nanotron_available @@ -473,7 +474,7 @@ def _check_continuations_start_space(self, continuation: str) -> str: continuation = continuation.lstrip() return continuation - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood(self, requests: List[Doc]) -> List[ModelResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. @@ -496,7 +497,7 @@ def loglikelihood(self, requests: List[Doc]) -> List[ModelResponse]: disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0), ) - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, requests: List[Doc]) -> List[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" for request in tqdm( @@ -931,7 +932,7 @@ def _loglikelihood_tokens( return dataset.get_original_order(res) @torch.inference_mode() - @cached("predictions") + @cached("predictions", SamplingMethod.GENERATIVE) def greedy_until( self, requests: List[Doc], diff --git a/src/lighteval/models/sglang/sglang_model.py b/src/lighteval/models/sglang/sglang_model.py index 227200de8..d96bdf90d 100644 --- a/src/lighteval/models/sglang/sglang_model.py +++ b/src/lighteval/models/sglang/sglang_model.py @@ -33,7 +33,7 @@ from lighteval.models.model_output import ModelResponse from lighteval.models.utils import _simplify_name, uses_chat_template from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import is_sglang_available @@ -216,7 +216,7 @@ def _create_auto_tokenizer(self, config: SGLangModelConfig): tokenizer.pad_token = tokenizer.eos_token return tokenizer - @cached("predictions") + @cached("predictions", SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -345,7 +345,7 @@ def _generate( ) return outputs - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return self._loglikelihood_tokens(docs) @@ -414,6 +414,6 @@ def _loglikelihood_tokens( res.append(answer) return dataset.get_original_order(res) - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: raise NotImplementedError() diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index db2b68bd1..5ce0dcba0 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -52,7 +52,7 @@ ) from lighteval.models.utils import _get_dtype, _get_model_sha, _simplify_name, uses_chat_template from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import ( is_accelerate_available, @@ -740,7 +740,7 @@ def _padded_greedy_until( return dataset.get_original_order(results) - @cached("predictions") + @cached("predictions", SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -867,7 +867,7 @@ def _generate( else: return self._generate_padded(**kwargs) - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood( self, docs: list[Doc], @@ -883,7 +883,7 @@ def loglikelihood( """ return self._loglikelihood_tokens(docs) - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood_rolling( self, docs: list[Doc], diff --git a/src/lighteval/models/transformers/vlm_transformers_model.py b/src/lighteval/models/transformers/vlm_transformers_model.py index f640d15d9..9fb48767b 100644 --- a/src/lighteval/models/transformers/vlm_transformers_model.py +++ b/src/lighteval/models/transformers/vlm_transformers_model.py @@ -44,7 +44,7 @@ from lighteval.models.model_output import ModelResponse from lighteval.models.utils import _get_dtype, _get_model_sha, _simplify_name from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import ( is_accelerate_available, @@ -333,7 +333,7 @@ def _init_max_length(self) -> int: return 2048 - @cached("predictions") + @cached("predictions", SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -424,14 +424,14 @@ def _greedy_until( return dataset.get_original_order(results) - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood( self, docs: list[Doc], ) -> list[ModelResponse]: raise NotImplementedError() - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood_rolling( self, docs: list[Doc], diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 9ab1e5c76..a35167630 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -36,7 +36,7 @@ from lighteval.models.model_output import ModelResponse from lighteval.models.utils import _simplify_name, uses_chat_template from lighteval.tasks.prompt_manager import PromptManager -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import is_vllm_available @@ -259,6 +259,7 @@ def _create_auto_model(self, config: VLLMModelConfig) -> Optional[LLM]: "seed": int(config.seed), "max_num_seqs": int(config.max_num_seqs), "max_num_batched_tokens": int(config.max_num_batched_tokens), + "enforce_eager": True, } if config.quantization is not None: @@ -300,7 +301,7 @@ def _create_auto_tokenizer(self, config: VLLMModelConfig): tokenizer.pad_token = tokenizer.eos_token return tokenizer - @cached("predictions") + @cached("predictions", SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -459,7 +460,7 @@ def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, r return outputs - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return self._loglikelihood_tokens(docs) @@ -528,7 +529,7 @@ def _loglikelihood_tokens( return dataset.get_original_order(res) - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: raise NotImplementedError() @@ -624,7 +625,7 @@ async def _async_batch(self, docs: list[Doc], generative: bool) -> list: results = await asyncio.gather(*processed_requests) return results - @cached("predictions") + @cached("predictions", SamplingMethod.GENERATIVE) async def greedy_until( self, docs: list[Doc], @@ -659,7 +660,7 @@ async def greedy_until( return results - @cached("predictions") + @cached("predictions", SamplingMethod.LOGPROBS) async def loglikelihood( self, docs: list[Doc], diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index 91c1b590d..43aef341c 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -178,6 +178,8 @@ def __init__( self.model_config = model_config self.accelerator, self.parallel_context = self._init_parallelism_manager() self.model = self._init_model(model_config, model) + # Must occur after model and task init + self.model._cache._init_registry(self.registry) # Must occur after model init self._init_accelerator_seeds() @@ -243,13 +245,13 @@ def _init_tasks_and_requests(self, tasks: str): logger.info("--- LOADING TASKS ---") # The registry contains all the potential tasks - registry = Registry( + self.registry = Registry( custom_tasks=self.pipeline_parameters.custom_tasks_directory, ) # load the tasks fro the configs and their datasets - task_configs: list[LightevalTaskConfig] = registry.get_tasks_configs(tasks) - self.tasks_dict: dict[str, LightevalTask] = registry.get_tasks_from_configs(task_configs) + task_configs: list[LightevalTaskConfig] = self.registry.get_tasks_configs(tasks) + self.tasks_dict: dict[str, LightevalTask] = self.registry.get_tasks_from_configs(task_configs) LightevalTask.load_datasets(self.tasks_dict, self.pipeline_parameters.dataset_loading_processes) self.documents_dict = { task.full_name: task.get_docs(self.pipeline_parameters.max_samples) for _, task in self.tasks_dict.items() diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index c1c144cdf..bb6ce6f15 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -35,7 +35,9 @@ from lighteval.models.abstract_model import ModelConfig from lighteval.models.model_output import ModelResponse -from lighteval.tasks.requests import Doc +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.registry import Registry +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.utils import as_list @@ -58,7 +60,8 @@ class SampleCache: - {sample_type}/ - {model_name}/ - {model_hash}/ - - {task_name}.parquet + - {task_name}/ + - {task_hash}.parquet """ def __init__(self, model_config: ModelConfig): @@ -73,6 +76,8 @@ def __init__(self, model_config: ModelConfig): self.model_config = model_config self.model_hash = self.get_model_hash(model_config) + self.registry = None + # Create cache directory structure and load cached indices if present self.all_cache_dirs = {} self.existing_indices = {} @@ -81,8 +86,12 @@ def __init__(self, model_config: ModelConfig): self.cache_dir / sample_type.name.lower() / self.model_config.model_name / self.model_hash ) self.all_cache_dirs[sample_type].mkdir(parents=True, exist_ok=True) + # sample type, (task_name, task_hash), sampling_method self.existing_indices[sample_type] = self._get_cached_indices(sample_type) + def _init_registry(self, registry: Registry): + self.registry = registry + def _get_cached_indices(self, sample_type: SampleType) -> dict: """Loads all indices for samples which are properly cached @@ -95,20 +104,23 @@ def _get_cached_indices(self, sample_type: SampleType) -> dict: if not cache_dir.exists(): return cached_indices - for cache_file in cache_dir.glob("*.parquet"): - task_name = cache_file.stem + for cache_file in cache_dir.rglob("*.parquet"): try: + task_name = cache_file.parent.split("/")[-1] + task_hash = cache_file.stem dataset = load_dataset("parquet", data_files=str(cache_file), split="train") - sample_ids = [] + sample_ids = {SamplingMethod.GENERATIVE: [], SamplingMethod.LOGPROBS: []} for row in dataset: try: # We only save indices of correctly formatted samples, though this means we need to load each at least once self._load_sample(row, sample_type=sample_type) - sample_ids.append(row["sample_id"]) + cur_sample = row["sample_id"] + sampling_method = self.get_sampling_method(cur_sample) + sample_ids[sampling_method].append(cur_sample) except Exception: continue - cached_indices[task_name] = sample_ids + cached_indices[(task_name, task_hash)] = sample_ids logger.debug(f"Loaded {len(sample_ids)} cached indices for task '{task_name}' from {cache_file}") except Exception as e: logger.warning(f"Error loading cached indices for task '{task_name}' from {cache_file}: {e}") @@ -122,7 +134,18 @@ def get_model_hash(self, model_config: ModelConfig) -> str: config_str = json.dumps(config_dict, sort_keys=True, default=str) return hashlib.sha256(config_str.encode()).hexdigest()[:16] - def get_cache_path(self, task_name: str, sample_type: SampleType) -> Path: + def get_task_hash(self, task_name: str) -> str: + if self.registry is None: + logger.warning( + "The task registry was not provided to the cache config. We can't test if the current task has the same hash as the saved tasks." + ) + return "NO_HASH" + task_config: LightevalTaskConfig = self.registry.get_tasks_configs(task_name) + config_dict = task_config.model_dump() + config_str = json.dumps(config_dict, sort_keys=True, default=str) + return hashlib.sha256(config_str.encode()).hexdigest()[:16] + + def get_cache_path(self, task_name: str, task_hash: str, sample_type: SampleType) -> Path: """Get the file path for a specific task's cache file. Args: @@ -132,7 +155,14 @@ def get_cache_path(self, task_name: str, sample_type: SampleType) -> Path: Returns: Path: Path to the cache file for the given task and sample type """ - return self.all_cache_dirs[sample_type] / f"{task_name}.parquet" + return self.all_cache_dirs[sample_type] / task_name / f"{task_hash}.parquet" + + def get_sampling_method(self, sample: dict) -> str: + if "logprobs" in sample: + return SamplingMethod.LOGPROBS + if "text" in sample: + return SamplingMethod.GENERATIVE + return None def _load_sample( self, sample: pd.core.series.Series | dict, sample_type: SampleType @@ -169,7 +199,9 @@ def _dump_sample(self, result: Union[dict, ModelResponse], sample_type: SampleTy elif sample_type == SampleType.PREDICTIONS: return asdict(result) - def get_notcached_samples(self, docs: List[Doc], sample_type: SampleType) -> Tuple[List[Doc], Set]: + def get_notcached_samples( + self, docs: List[Doc], sample_type: SampleType, sampling_method: SamplingMethod + ) -> Tuple[List[Doc], Set]: """ Identify which docs need processing based on cached indices. @@ -185,15 +217,17 @@ def get_notcached_samples(self, docs: List[Doc], sample_type: SampleType) -> Tup for doc in docs: task_name = doc.task_name - if task_name in cached_indices and doc.id in cached_indices[task_name]: - tasks_with_cached_samples.add(task_name) + task_hash = self.get_task_hash(task_name) + task_id = (task_name, task_hash) + if task_id in cached_indices and doc.id in cached_indices[task_id][sampling_method]: + tasks_with_cached_samples.add((task_name, task_hash)) else: docs_not_cached.append(doc) return docs_not_cached, set(tasks_with_cached_samples) def get_samples_from_cache( - self, docs: List[Doc], task_names: list | set, sample_type: SampleType + self, docs: List[Doc], task_ids: list | set, sample_type: SampleType ) -> List[dict | ModelResponse]: """ Get cached samples for the given docs. @@ -205,12 +239,12 @@ def get_samples_from_cache( # Load datasets for tasks that have cached docs task_datasets = {} - for task_name in task_names: - cache_file = self.get_cache_path(task_name=task_name, sample_type=sample_type) + for task_name, task_hash in task_ids: + cache_file = self.get_cache_path(task_name=task_name, task_hash=task_hash, sample_type=sample_type) try: dataset = load_dataset("parquet", data_files=str(cache_file), split="train") dataset_df = dataset.to_pandas().set_index("sample_id") - task_datasets[task_name] = dataset_df + task_datasets[(task_name, task_hash)] = dataset_df except Exception as e: logger.warning(f"Error loading {sample_type.name.lower()} cache for {task_name}: {e}") @@ -218,7 +252,9 @@ def get_samples_from_cache( results = [] for doc in docs: - row = task_datasets[doc.task_name].loc[doc.id] + task_name = doc.task_name + task_hash = self.get_task_hash(task_name) + row = task_datasets[(task_name, task_hash)].loc[doc.id] results.append(self._load_sample(row, sample_type)) return results @@ -227,7 +263,7 @@ def store_samples( self, docs: List[Doc], results: List[dict] | List[ModelResponse], - task_names: list[str], + task_ids: list[tuple[str, str]], sample_type: SampleType, ): """Store new results for samples in docs""" @@ -235,16 +271,17 @@ def store_samples( return # Prepare newly processed data for dataset - processed_data = {task_name: [] for task_name in task_names} + processed_data = {task_id: [] for task_id in task_ids} for doc, result in zip(docs, results): - processed_data[doc.task_name].append( - {"sample_id": doc.id, "sample": self._dump_sample(result, sample_type)} - ) - processed_data = {task_name: task_data for task_name, task_data in processed_data.items() if task_data} + task_name = doc.task_name + task_hash = self.get_task_hash(task_name) + task_id = (task_name, task_hash) + processed_data[task_id].append({"sample_id": doc.id, "sample": self._dump_sample(result, sample_type)}) + processed_data = {task_id: task_data for task_id, task_data in processed_data.items() if task_data} # Concatenate it with existing data and save to file - for task_name, task_data in processed_data.items(): - cache_file = self.get_cache_path(task_name=task_name, sample_type=sample_type) + for (task_name, task_hash), task_data in processed_data.items(): + cache_file = self.get_cache_path(task_name=task_name, task_hash=task_hash, sample_type=sample_type) # Load existing data if present existing_data = [] @@ -253,7 +290,9 @@ def store_samples( existing_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") existing_data = existing_dataset.to_list() except Exception as e: - logger.error(f"Error loading existing {sample_type.name.lower()} cache for {task_name}: {e}") + logger.error( + f"Error loading existing {sample_type.name.lower()} cache for {task_name} ({task_hash}): {e}" + ) # Merge with new data (new data overwrites existing) existing_ids = {row["sample_id"] for row in existing_data} @@ -273,10 +312,12 @@ def store_samples( ) # Refresh cached indices after storing new samples - self.existing_indices[sample_type][task_name] = [sample["sample_id"] for sample in all_samples] + self.existing_indices[sample_type][(task_name, task_hash)] = [ + sample["sample_id"] for sample in all_samples + ] -def cached(cache_type_name: str): # noqa C901 +def cached(cache_type_name: str, sampling_method: SamplingMethod = None): # noqa C901 """ Decorator to cache method results based on Doc inputs. @@ -288,7 +329,7 @@ def cached(cache_type_name: str): # noqa C901 def tok_encode_pair(self, docs: List[Doc], ...): # method implementation - @cached("predictions") + @cached("predictions", "greedy") def greedy_until(self, docs: List[Doc], ...): # method implementation """ @@ -306,10 +347,10 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 cache: SampleCache = self._cache # Extract task names - task_names = {doc.task_name for doc in docs} + task_ids = {(doc.task_name, cache.get_task_hash(doc.task_name)) for doc in docs} # 1) Identify which samples must be processed because they are not cached - docs_not_cached, tasks_with_cached_samples = cache.get_notcached_samples(docs, cache_type) + docs_not_cached, tasks_with_cached_samples = cache.get_notcached_samples(docs, cache_type, sampling_method) # Log cache statistics cached_count = len(docs) - len(docs_not_cached) @@ -321,7 +362,7 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 # 2) Process not cached docs and save to file new_results = [] if docs_not_cached: - notcached_task_names = {doc.task_name for doc in docs_not_cached} + notcached_task_names = {(doc.task_name, cache.get_task_hash(doc.task_name)) for doc in docs_not_cached} logger.info( f"Cache: Processing {len(docs_not_cached)}/{len(docs)} {cache_type.name.lower()} samples for tasks {', '.join(notcached_task_names)}" ) @@ -329,11 +370,14 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 # Store new results in file cache cache.store_samples( - docs=docs_not_cached, results=new_results, task_names=task_names, sample_type=cache_type + docs=docs_not_cached, results=new_results, task_ids=task_ids, sample_type=cache_type ) # 3) Create final results by pulling from newly saved file cache - final_results = cache.get_samples_from_cache(docs, task_names, cache_type) + final_cached_results = cache.get_samples_from_cache(docs, task_ids, cache_type) + + # 4) We only keep samples with the correct sampling method + final_results = [s for s in final_cached_results if cache.get_sampling_method(s) == sampling_method] if any(r is None for r in final_results): raise ValueError("Problem while loading and aggregating items from cache.") From 8b28aba136018b5b686ef2cde88f61eb54dc187c Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Mon, 8 Sep 2025 17:59:28 +0000 Subject: [PATCH 09/21] better str for classes, which allows correct hashing --- src/lighteval/metrics/metrics_corpus.py | 11 +++++ src/lighteval/metrics/metrics_sample.py | 11 +++++ src/lighteval/metrics/sample_preparator.py | 11 +++++ src/lighteval/metrics/utils/metric_utils.py | 4 ++ src/lighteval/tasks/lighteval_task.py | 24 ++++++--- src/lighteval/utils/cache_management.py | 54 ++++++++++++++------- 6 files changed, 91 insertions(+), 24 deletions(-) diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py index 28a4b90c2..b87a83a9f 100644 --- a/src/lighteval/metrics/metrics_corpus.py +++ b/src/lighteval/metrics/metrics_corpus.py @@ -50,6 +50,17 @@ class CorpusLevelComputation(ABC): def compute_corpus(self): raise NotImplementedError + def __str__(self): + attrs = vars(self) + attr_strs = [] + for k, v in attrs.items(): + if callable(v): + val_str = v.__name__ + else: + val_str = str(v) + attr_strs.append(f"{k}={val_str}") + return f"{self.__class__.__name__}({', '.join(attr_strs)})" + # General aggregations class MatthewsCorrCoef(CorpusLevelComputation): diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index f26785ba9..38a0e2f52 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -66,6 +66,17 @@ class SampleLevelComputation(ABC): def compute(self, model_response: ModelResponse, doc: Doc, **kwargs): raise NotImplementedError + def __str__(self): + attrs = vars(self) + attr_strs = [] + for k, v in attrs.items(): + if callable(v): + val_str = v.__name__ + else: + val_str = str(v) + attr_strs.append(f"{k}={val_str}") + return f"{self.__class__.__name__}({', '.join(attr_strs)})" + class ExactMatches(SampleLevelComputation): def __init__( diff --git a/src/lighteval/metrics/sample_preparator.py b/src/lighteval/metrics/sample_preparator.py index 725ba6fc6..8d79cb448 100644 --- a/src/lighteval/metrics/sample_preparator.py +++ b/src/lighteval/metrics/sample_preparator.py @@ -81,6 +81,17 @@ def prepare(doc: Doc, model_response: ModelResponse, **kwargs): predictions = model_response.final_text return GenerativeCorpusMetricInput(golds=golds, preds=predictions) + def __str__(self): + attrs = vars(self) + attr_strs = [] + for k, v in attrs.items(): + if callable(v): + val_str = v.__name__ + else: + val_str = str(v) + attr_strs.append(f"{k}={val_str}") + return f"{self.__class__.__name__}({', '.join(attr_strs)})" + class LoglikelihoodPreparator(Preparator): def __init__(self, is_single_token: bool = False): diff --git a/src/lighteval/metrics/utils/metric_utils.py b/src/lighteval/metrics/utils/metric_utils.py index c23a7854c..e57e56724 100644 --- a/src/lighteval/metrics/utils/metric_utils.py +++ b/src/lighteval/metrics/utils/metric_utils.py @@ -95,6 +95,10 @@ def __call__(self, sample_params: dict | None): self.metric_name = f"{self.metric_name}_with_{sample_params_name}" return self + @staticmethod + def get_allowed_types_for_metrics(): + return (SampleLevelComputation, Preparator, CorpusLevelComputation, Callable) + @dataclass class MetricGrouping(Metric): diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 73956c835..c54afb5fe 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -20,7 +20,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import inspect import logging import random from dataclasses import asdict, dataclass, field @@ -154,20 +153,28 @@ def __post_init__(self): self.stop_sequence = self.stop_sequence if self.stop_sequence is not None else () self.full_name = f"{self.name}|{self.num_fewshots}" # todo clefourrier: this is likely incorrect - def print(self): + def __str__(self, lite: bool = False): md_writer = MarkdownTableWriter() md_writer.headers = ["Key", "Value"] + # These keys change through time + to_ignore = ["original_num_docs", "effective_num_docs"] + values = [] for k, v in asdict(self).items(): - if k == "metric": + if lite and k in to_ignore: + continue + if k == "metrics": for ix, metrics in enumerate(v): for metric_k, metric_v in metrics.items(): - if inspect.ismethod(metric_v): - values.append([f"{k} {ix}: {metric_k}", metric_v.__qualname__]) + if isinstance(metric_v, Callable): + repr_v = metric_v.__name__ + elif isinstance(metric_v, Metric.get_allowed_types_for_metrics()): + repr_v = str(metric_v) else: - values.append([f"{k} {ix}: {metric_k}", repr(metric_v)]) + repr_v = repr(metric_v) + values.append([f"{k} {ix}: {metric_k}", repr_v]) else: if isinstance(v, Callable): @@ -177,7 +184,10 @@ def print(self): md_writer.value_matrix = values - print(md_writer.dumps()) + return md_writer.dumps() + + def print(self, lite: bool = False): + print(str(self, lite)) class LightevalTask: diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 0cb70d72b..149a8a032 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -103,9 +103,9 @@ def _get_cached_indices(self, sample_type: SampleType) -> dict: return cached_indices for cache_file in cache_dir.rglob("*.parquet"): + task_name = str(cache_file.parent).split("/")[-1] + task_hash = cache_file.stem try: - task_name = cache_file.parent.split("/")[-1] - task_hash = cache_file.stem dataset = load_dataset("parquet", data_files=str(cache_file), split="train") sample_ids = {SamplingMethod.GENERATIVE: [], SamplingMethod.LOGPROBS: []} for row in dataset: @@ -113,7 +113,7 @@ def _get_cached_indices(self, sample_type: SampleType) -> dict: # We only save indices of correctly formatted samples, though this means we need to load each at least once self._load_sample(row, sample_type=sample_type) cur_sample = row["sample_id"] - sampling_method = self.get_sampling_method(cur_sample) + sampling_method = self.get_sampling_method(row["sample"]) sample_ids[sampling_method].append(cur_sample) except Exception: continue @@ -136,15 +136,15 @@ def get_model_hash(self, model_config: ModelConfig) -> str: config_str = json.dumps(config_dict, sort_keys=True, default=str) return hashlib.sha256(config_str.encode()).hexdigest()[:16] - def get_task_hash(self, task_name: str) -> str: + def get_task_hash(self, full_task_name: str) -> str: if self.registry is None: logger.warning( "The task registry was not provided to the cache config. We can't test if the current task has the same hash as the saved tasks." ) return "NO_HASH" - task_config: LightevalTaskConfig = self.registry.get_tasks_configs(task_name) - config_dict = task_config.model_dump() - config_str = json.dumps(config_dict, sort_keys=True, default=str) + task_suite, task_name, _ = full_task_name.split("|") + task_configs: list[LightevalTaskConfig] = sorted(self.registry.task_to_configs[f"{task_suite}|{task_name}"]) + config_str = "|".join([task_config.__str__(lite=True) for task_config in task_configs]) return hashlib.sha256(config_str.encode()).hexdigest()[:16] def get_cache_path(self, task_name: str, task_hash: str, sample_type: SampleType) -> Path: @@ -222,9 +222,12 @@ def get_notcached_samples( task_name = doc.task_name task_hash = self.get_task_hash(task_name) task_id = (task_name, task_hash) - if task_id in cached_indices and doc.id in cached_indices[task_id][sampling_method]: - tasks_with_cached_samples.add((task_name, task_hash)) - else: + try: + if doc.id in cached_indices[task_id][sampling_method]: + tasks_with_cached_samples.add((task_name, task_hash)) + else: + docs_not_cached.append(doc) + except KeyError: # task id or sampling method not yet there docs_not_cached.append(doc) return docs_not_cached, set(tasks_with_cached_samples) @@ -267,6 +270,7 @@ def store_samples( results: List[dict] | List[ModelResponse], task_ids: list[tuple[str, str]], sample_type: SampleType, + sampling_method: SamplingMethod, ): """Store new results for samples in docs""" if not results: @@ -283,6 +287,9 @@ def store_samples( # Concatenate it with existing data and save to file for (task_name, task_hash), task_data in processed_data.items(): + if (task_name, task_hash) not in self.existing_indices.keys(): + self.existing_indices[sample_type][(task_name, task_hash)] = {} + cache_file = self.get_cache_path(task_name=task_name, task_hash=task_hash, sample_type=sample_type) # Load existing data if present @@ -297,13 +304,19 @@ def store_samples( ) # Merge with new data (new data overwrites existing) - existing_ids = {row["sample_id"] for row in existing_data} - - if any(row["sample_id"] in existing_ids for row in task_data): + # We look at id + sampling method + existing_samples = {(row["sample_id"], self.get_sampling_method(row["sample"])) for row in existing_data} + if any( + (row["sample_id"], self.get_sampling_method(row["sample"])) in existing_samples for row in task_data + ): logger.warning( "Unexpected behavior: You have reprocessed already cached items - we will ignore the new version." ) - all_samples = existing_data + [row for row in task_data if row["sample_id"] not in existing_ids] + all_samples = existing_data + [ + row + for row in task_data + if (row["sample_id"], self.get_sampling_method(row["sample"])) not in existing_samples + ] # Save updated dataset dataset = Dataset.from_list(all_samples) @@ -314,7 +327,7 @@ def store_samples( ) # Refresh cached indices after storing new samples - self.existing_indices[sample_type][(task_name, task_hash)] = [ + self.existing_indices[sample_type][(task_name, task_hash)][sampling_method] = [ sample["sample_id"] for sample in all_samples ] @@ -368,14 +381,21 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 new_results = [] if docs_not_cached: notcached_task_names = {(doc.task_name, cache.get_task_hash(doc.task_name)) for doc in docs_not_cached} + notcached_task_names_str = ", ".join( + f"{task_name} ({task_hash})" for task_name, task_hash in notcached_task_names + ) logger.info( - f"Cache: Processing {len(docs_not_cached)}/{len(docs)} {cache_type.name.lower()} samples for tasks {', '.join(notcached_task_names)}" + f"Cache: Processing {len(docs_not_cached)}/{len(docs)} {cache_type.name.lower()} samples for tasks {notcached_task_names_str}" ) new_results = func(self, docs_not_cached, *args, **kwargs) # Store new results in file cache cache.store_samples( - docs=docs_not_cached, results=new_results, task_ids=task_ids, sample_type=cache_type + docs=docs_not_cached, + results=new_results, + task_ids=task_ids, + sample_type=cache_type, + sampling_method=sampling_method, ) # 3) Create final results by pulling from newly saved file cache From a3eeebd588862acb829bb4923807e84c9acc2381 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Mon, 8 Sep 2025 18:35:51 +0000 Subject: [PATCH 10/21] last fix is to possibly push to configs --- src/lighteval/utils/cache_management.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 149a8a032..00d3f9328 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -161,9 +161,9 @@ def get_cache_path(self, task_name: str, task_hash: str, sample_type: SampleType return self.all_cache_dirs[sample_type] / task_name / f"{task_hash}.parquet" def get_sampling_method(self, sample: dict) -> str: - if "logprobs" in sample: + if len(sample.get("logprobs", [])) > 0: return SamplingMethod.LOGPROBS - if "text" in sample: + if len(sample.get("text", [])) > 0: return SamplingMethod.GENERATIVE return None @@ -312,11 +312,12 @@ def store_samples( logger.warning( "Unexpected behavior: You have reprocessed already cached items - we will ignore the new version." ) - all_samples = existing_data + [ + new_data = [ row for row in task_data if (row["sample_id"], self.get_sampling_method(row["sample"])) not in existing_samples ] + all_samples = existing_data + new_data # Save updated dataset dataset = Dataset.from_list(all_samples) @@ -374,7 +375,7 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 cached_count = len(docs) - len(docs_not_cached) if cached_count > 0: logger.info( - f"Cache: {cached_count}/{len(docs)} {cache_type.name.lower()} samples are cached for tasks {', '.join(tasks_with_cached_samples)}" + f"Cache: {cached_count}/{len(docs)} {cache_type.name.lower()} samples are cached for tasks {', '.join(t[0] for t in tasks_with_cached_samples)}" ) # 2) Process not cached docs and save to file @@ -402,7 +403,11 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 final_cached_results = cache.get_samples_from_cache(docs, task_ids, cache_type) # 4) We only keep samples with the correct sampling method - final_results = [s for s in final_cached_results if cache.get_sampling_method(s) == sampling_method] + final_results = [ + s + for s in final_cached_results + if cache.get_sampling_method(cache._dump_sample(s, cache_type)) == sampling_method + ] if any(r is None for r in final_results): raise ValueError("Problem while loading and aggregating items from cache.") From c7d1eb05eefbaba9cf42cc234790a1b1d2c0684d Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Tue, 9 Sep 2025 13:58:56 +0000 Subject: [PATCH 11/21] removed token system + added an actual separation between tasks with different hashs + added a separation for different metrics in the same task + has --- docs/source/evaluating-a-custom-model.mdx | 2 +- src/lighteval/models/dummy/dummy_model.py | 6 +- .../models/endpoints/endpoint_model.py | 6 +- .../endpoints/inference_providers_model.py | 6 +- .../models/endpoints/litellm_model.py | 6 +- .../models/nanotron/nanotron_model.py | 6 +- src/lighteval/models/sglang/sglang_model.py | 6 +- .../models/transformers/transformers_model.py | 6 +- .../transformers/vlm_transformers_model.py | 6 +- src/lighteval/models/vllm/vllm_model.py | 10 +- src/lighteval/utils/cache_management.py | 207 +++++++++--------- 11 files changed, 134 insertions(+), 133 deletions(-) diff --git a/docs/source/evaluating-a-custom-model.mdx b/docs/source/evaluating-a-custom-model.mdx index 5d022e25b..c8acbd8c9 100644 --- a/docs/source/evaluating-a-custom-model.mdx +++ b/docs/source/evaluating-a-custom-model.mdx @@ -26,7 +26,7 @@ class MyCustomModel(LightevalModel): # Enable caching (recommended) self._cache = SampleCache(config) - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]: # Implement generation logic pass diff --git a/src/lighteval/models/dummy/dummy_model.py b/src/lighteval/models/dummy/dummy_model.py index afb5aee7a..3ee5f1c03 100644 --- a/src/lighteval/models/dummy/dummy_model.py +++ b/src/lighteval/models/dummy/dummy_model.py @@ -87,11 +87,11 @@ def add_special_tokens(self): def max_length(self) -> int: return 2048 - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) def greedy_until(self, docs: list[Doc]) -> list[ModelResponse]: return [ModelResponse(text=["random baseline"]) for _ in range(len(docs))] - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: model_responses = [] for doc in docs: @@ -104,7 +104,7 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return model_responses - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: model_responses = [] for doc in docs: diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index 0fd227b2a..bed1de706 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -555,7 +555,7 @@ def _process_batch_logprob(self, docs: list[Doc], rolling: bool = False) -> list for context, doc in zip(contexts, docs) ] - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: List[Doc], @@ -599,11 +599,11 @@ def _greedy_until(self, docs: List[Doc]) -> list[ModelResponse]: return dataset.get_original_order(results) - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return self._loglikelihood(docs, rolling=False) - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc], override_bs=None) -> list[ModelResponse]: return self._loglikelihood(docs, rolling=True) diff --git a/src/lighteval/models/endpoints/inference_providers_model.py b/src/lighteval/models/endpoints/inference_providers_model.py index fc2d18046..c5b03fcf8 100644 --- a/src/lighteval/models/endpoints/inference_providers_model.py +++ b/src/lighteval/models/endpoints/inference_providers_model.py @@ -196,7 +196,7 @@ async def bounded_api_call(prompt, num_samples): return results - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -253,14 +253,14 @@ def max_length(self) -> int: logger.warning("Tokenizer was not correctly loaded. Max model context length is assumed to be 30K tokens") return 30000 - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" raise NotImplementedError diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index 01278f2c6..74568bd98 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -258,7 +258,7 @@ def __call_api_parallel( return results - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -323,14 +323,14 @@ def max_length(self) -> int: """Return the maximum sequence length of the model.""" return 4096 - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" raise NotImplementedError diff --git a/src/lighteval/models/nanotron/nanotron_model.py b/src/lighteval/models/nanotron/nanotron_model.py index 8383815de..22028f31e 100644 --- a/src/lighteval/models/nanotron/nanotron_model.py +++ b/src/lighteval/models/nanotron/nanotron_model.py @@ -484,7 +484,7 @@ def _check_continuations_start_space(self, continuation: str) -> str: continuation = continuation.lstrip() return continuation - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, requests: List[Doc]) -> List[ModelResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. @@ -507,7 +507,7 @@ def loglikelihood(self, requests: List[Doc]) -> List[ModelResponse]: disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0), ) - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, requests: List[Doc]) -> List[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" for request in tqdm( @@ -942,7 +942,7 @@ def _loglikelihood_tokens( return dataset.get_original_order(res) @torch.inference_mode() - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, requests: List[Doc], diff --git a/src/lighteval/models/sglang/sglang_model.py b/src/lighteval/models/sglang/sglang_model.py index f7ee426a9..70806f632 100644 --- a/src/lighteval/models/sglang/sglang_model.py +++ b/src/lighteval/models/sglang/sglang_model.py @@ -221,7 +221,7 @@ def _create_auto_tokenizer(self, config: SGLangModelConfig): tokenizer.pad_token = tokenizer.eos_token return tokenizer - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -347,7 +347,7 @@ def _generate( ) return outputs - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return self._loglikelihood_tokens(docs) @@ -416,6 +416,6 @@ def _loglikelihood_tokens( res.append(answer) return dataset.get_original_order(res) - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: raise NotImplementedError() diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index 24a17fd11..31a86229b 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -738,7 +738,7 @@ def _padded_greedy_until( return dataset.get_original_order(results) - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -867,7 +867,7 @@ def _generate( else: return self._generate_padded(**kwargs) - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood( self, docs: list[Doc], @@ -883,7 +883,7 @@ def loglikelihood( """ return self._loglikelihood_tokens(docs) - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood_rolling( self, docs: list[Doc], diff --git a/src/lighteval/models/transformers/vlm_transformers_model.py b/src/lighteval/models/transformers/vlm_transformers_model.py index 553180b57..3de3385ef 100644 --- a/src/lighteval/models/transformers/vlm_transformers_model.py +++ b/src/lighteval/models/transformers/vlm_transformers_model.py @@ -338,7 +338,7 @@ def _init_max_length(self) -> int: return 2048 - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -428,14 +428,14 @@ def _greedy_until( return dataset.get_original_order(results) - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood( self, docs: list[Doc], ) -> list[ModelResponse]: raise NotImplementedError() - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood_rolling( self, docs: list[Doc], diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 934b22801..f033cad21 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -297,7 +297,7 @@ def _create_auto_tokenizer(self, config: VLLMModelConfig): tokenizer.pad_token = tokenizer.eos_token return tokenizer - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) def greedy_until( self, docs: list[Doc], @@ -454,7 +454,7 @@ def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, r return outputs - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return self._loglikelihood_tokens(docs) @@ -523,7 +523,7 @@ def _loglikelihood_tokens( return dataset.get_original_order(res) - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: raise NotImplementedError() @@ -618,7 +618,7 @@ async def _async_batch(self, docs: list[Doc], generative: bool) -> list: results = await asyncio.gather(*processed_requests) return results - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) async def greedy_until( self, docs: list[Doc], @@ -652,7 +652,7 @@ async def greedy_until( return results - @cached("predictions", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) async def loglikelihood( self, docs: list[Doc], diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 00d3f9328..e4adc30a1 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -25,13 +25,13 @@ import json import logging import os +import shutil from dataclasses import asdict -from enum import Enum from pathlib import Path from typing import Callable, List, Set, Tuple, Union import pandas as pd -from datasets import Dataset, load_dataset +from datasets import Dataset, DatasetDict from lighteval.models.abstract_model import ModelConfig from lighteval.models.model_output import ModelResponse @@ -44,11 +44,6 @@ logger = logging.getLogger(__name__) -class SampleType(Enum): - PREDICTIONS = 1 - TOKENIZED_INPUTS = 2 # Not implemented yet - - class SampleCache: """Disk-based cache for sample evaluation results using HuggingFace datasets. The model hash is a hash of the model config, to make sure we rerun the eval if any parameter changes @@ -56,11 +51,10 @@ class SampleCache: Cache Structure: - {cache_dir}/ - - {sample_type}/ - {model_name}/ - - {model_hash}/ - - {task_name}/ - - {task_hash}.parquet + - {model_hash}/ + - {task_name}/ + - {task_hash}/ dataset dict, where splits are SamplingMethod """ def __init__(self, model_config: ModelConfig): @@ -79,25 +73,22 @@ def __init__(self, model_config: ModelConfig): # Create cache directory structure and load cached indices if present self.all_cache_dirs = {} self.existing_indices = {} - for sample_type in SampleType: - self.all_cache_dirs[sample_type] = ( - self.cache_dir / sample_type.name.lower() / self.model_config.model_name / self.model_hash - ) - self.all_cache_dirs[sample_type].mkdir(parents=True, exist_ok=True) - # sample type, (task_name, task_hash), sampling_method - self.existing_indices[sample_type] = self._get_cached_indices(sample_type) + self.all_cache_dirs = self.cache_dir / self.model_config.model_name / self.model_hash + self.all_cache_dirs.mkdir(parents=True, exist_ok=True) + # (task_name, task_hash, sampling_method) + self.existing_indices = self._get_cached_indices() def _init_registry(self, registry: Registry): self.registry = registry - def _get_cached_indices(self, sample_type: SampleType) -> dict: + def _get_cached_indices(self) -> dict: """Loads all indices for samples which are properly cached Returns: dict: Dictionary mapping task names to lists of cached sample indices """ cached_indices = {} - cache_dir = self.all_cache_dirs[sample_type] + cache_dir = self.all_cache_dirs if not cache_dir.exists(): return cached_indices @@ -106,20 +97,22 @@ def _get_cached_indices(self, sample_type: SampleType) -> dict: task_name = str(cache_file.parent).split("/")[-1] task_hash = cache_file.stem try: - dataset = load_dataset("parquet", data_files=str(cache_file), split="train") - sample_ids = {SamplingMethod.GENERATIVE: [], SamplingMethod.LOGPROBS: []} - for row in dataset: - try: - # We only save indices of correctly formatted samples, though this means we need to load each at least once - self._load_sample(row, sample_type=sample_type) - cur_sample = row["sample_id"] - sampling_method = self.get_sampling_method(row["sample"]) - sample_ids[sampling_method].append(cur_sample) - except Exception: - continue - - cached_indices[(task_name, task_hash)] = sample_ids - logger.debug(f"Loaded {len(sample_ids)} cached indices for task '{task_name}' from {cache_file}") + full_dataset = DatasetDict.load_from_disk(str(cache_file)) + for sampling_method in [SamplingMethod.GENERATIVE, SamplingMethod.LOGPROBS]: + sample_ids = [] + for row in full_dataset[str(sampling_method)]: + try: + # We only save indices of correctly formatted samples, though this means we need to load each at least once + self._load_sample(row) + cur_sample = row["sample_id"] + sample_ids.append(cur_sample) + except Exception: + continue + + cached_indices[(task_name, task_hash, sampling_method)] = sample_ids + logger.debug( + f"Loaded {len(sample_ids)} cached indices for task '{task_name}', {str(sampling_method)} from {cache_file}" + ) except Exception as e: logger.warning(f"Error loading cached indices for task '{task_name}' from {cache_file}: {e}") @@ -147,7 +140,7 @@ def get_task_hash(self, full_task_name: str) -> str: config_str = "|".join([task_config.__str__(lite=True) for task_config in task_configs]) return hashlib.sha256(config_str.encode()).hexdigest()[:16] - def get_cache_path(self, task_name: str, task_hash: str, sample_type: SampleType) -> Path: + def get_cache_path(self, task_name: str, task_hash: str) -> Path: """Get the file path for a specific task's cache file. Args: @@ -158,7 +151,7 @@ def get_cache_path(self, task_name: str, task_hash: str, sample_type: SampleType Returns: Path: Path to the cache file for the given task and sample type """ - return self.all_cache_dirs[sample_type] / task_name / f"{task_hash}.parquet" + return self.all_cache_dirs / task_name / task_hash def get_sampling_method(self, sample: dict) -> str: if len(sample.get("logprobs", [])) > 0: @@ -167,9 +160,7 @@ def get_sampling_method(self, sample: dict) -> str: return SamplingMethod.GENERATIVE return None - def _load_sample( - self, sample: pd.core.series.Series | dict, sample_type: SampleType - ) -> Union[dict, ModelResponse]: + def _load_sample(self, sample: pd.core.series.Series | dict) -> Union[dict, ModelResponse]: """Load a sample from cached data based on sample type. Args: @@ -182,29 +173,20 @@ def _load_sample( # If we just use the pandas dict, lists are converted to np arrays which we don't want if isinstance(sample, pd.core.series.Series): sample = json.loads(sample.to_json()) - if sample_type == SampleType.TOKENIZED_INPUTS: - return sample["sample"] - elif sample_type == SampleType.PREDICTIONS: - return ModelResponse(**sample["sample"]) + return ModelResponse(**sample["sample"]) - def _dump_sample(self, result: Union[dict, ModelResponse], sample_type: SampleType) -> dict: + def _dump_sample(self, result: Union[dict, ModelResponse]) -> dict: """Dumps the sample in the correct format for file saving Args: result (Union[dict, ModelResponse]): Processed sample to save - sample_type (SampleType): Type of sample Returns: dict """ - if sample_type == SampleType.TOKENIZED_INPUTS: - return result - elif sample_type == SampleType.PREDICTIONS: - return asdict(result) - - def get_notcached_samples( - self, docs: List[Doc], sample_type: SampleType, sampling_method: SamplingMethod - ) -> Tuple[List[Doc], Set]: + return asdict(result) + + def get_notcached_samples(self, docs: List[Doc], sampling_method: SamplingMethod) -> Tuple[List[Doc], Set]: """ Identify which docs need processing based on cached indices. @@ -213,7 +195,7 @@ def get_notcached_samples( - docs_not_cached contains docs that need processing - tasks_with_cached_samples are the tasks that have some cached samples """ - cached_indices = self.existing_indices[sample_type] + cached_indices = self.existing_indices docs_not_cached = [] tasks_with_cached_samples = set() @@ -221,10 +203,10 @@ def get_notcached_samples( for doc in docs: task_name = doc.task_name task_hash = self.get_task_hash(task_name) - task_id = (task_name, task_hash) + task_id = (task_name, task_hash, sampling_method) try: if doc.id in cached_indices[task_id][sampling_method]: - tasks_with_cached_samples.add((task_name, task_hash)) + tasks_with_cached_samples.add(task_id) else: docs_not_cached.append(doc) except KeyError: # task id or sampling method not yet there @@ -233,7 +215,7 @@ def get_notcached_samples( return docs_not_cached, set(tasks_with_cached_samples) def get_samples_from_cache( - self, docs: List[Doc], task_ids: list | set, sample_type: SampleType + self, docs: List[Doc], task_ids: list | set, sampling_method: SamplingMethod ) -> List[dict | ModelResponse]: """Get cached samples for the given docs. Warning: Assumes all docs and task_names provided are stored in cache, will fail otherwise. @@ -244,14 +226,16 @@ def get_samples_from_cache( # Load datasets for tasks that have cached docs task_datasets = {} - for task_name, task_hash in task_ids: - cache_file = self.get_cache_path(task_name=task_name, task_hash=task_hash, sample_type=sample_type) + for task_name, task_hash, task_sampling_method in task_ids: + if task_sampling_method != sampling_method: + continue + cache_file = self.get_cache_path(task_name=task_name, task_hash=task_hash) try: - dataset = load_dataset("parquet", data_files=str(cache_file), split="train") + dataset = DatasetDict.load_from_disk(str(cache_file))[str(sampling_method)] dataset_df = dataset.to_pandas().set_index("sample_id") - task_datasets[(task_name, task_hash)] = dataset_df + task_datasets[(task_name, task_hash, sampling_method)] = dataset_df except Exception as e: - logger.warning(f"Error loading {sample_type.name.lower()} cache for {task_name}: {e}") + logger.warning(f"Error loading prediction cache for {task_name}: {e}") # Build results list results = [] @@ -259,17 +243,17 @@ def get_samples_from_cache( for doc in docs: task_name = doc.task_name task_hash = self.get_task_hash(task_name) - row = task_datasets[(task_name, task_hash)].loc[doc.id] - results.append(self._load_sample(row, sample_type)) + task_id = (task_name, task_hash, sampling_method) + row = task_datasets[task_id].loc[doc.id] + results.append(self._load_sample(row)) return results - def store_samples( + def store_samples( # noqa C901 self, docs: List[Doc], results: List[dict] | List[ModelResponse], task_ids: list[tuple[str, str]], - sample_type: SampleType, sampling_method: SamplingMethod, ): """Store new results for samples in docs""" @@ -281,37 +265,49 @@ def store_samples( for doc, result in zip(docs, results): task_name = doc.task_name task_hash = self.get_task_hash(task_name) - task_id = (task_name, task_hash) - processed_data[task_id].append({"sample_id": doc.id, "sample": self._dump_sample(result, sample_type)}) + task_id = (task_name, task_hash, sampling_method) + sample = self._dump_sample(result) + + if self.get_sampling_method(sample) != sampling_method: + logger.warning("The sample which was returned by the model is not of the expected type ") + + processed_data[task_id].append({"sample_id": doc.id, "sample": sample}) processed_data = {task_id: task_data for task_id, task_data in processed_data.items() if task_data} # Concatenate it with existing data and save to file - for (task_name, task_hash), task_data in processed_data.items(): - if (task_name, task_hash) not in self.existing_indices.keys(): - self.existing_indices[sample_type][(task_name, task_hash)] = {} + for (task_name, task_hash, sampling_method), task_data in processed_data.items(): + if (task_name, task_hash, sampling_method) not in self.existing_indices.keys(): + self.existing_indices[(task_name, task_hash, sampling_method)] = {} - cache_file = self.get_cache_path(task_name=task_name, task_hash=task_hash, sample_type=sample_type) + cache_file = self.get_cache_path(task_name=task_name, task_hash=task_hash) # Load existing data if present existing_data = [] + existing_samples = {} if cache_file.exists(): try: - existing_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") + existing_dataset = DatasetDict.load_from_disk(str(cache_file))[str(sampling_method)] existing_data = existing_dataset.to_list() + except KeyError: + logger.info(f"No data was cached for {task_name} ({task_hash}, {str(sampling_method)}") except Exception as e: logger.error( - f"Error loading existing {sample_type.name.lower()} cache for {task_name} ({task_hash}): {e}" + f"Error loading existing prediction cache for {task_name} ({task_hash}, {str(sampling_method)}): {e}" + ) + + existing_samples = { + (row["sample_id"], self.get_sampling_method(row["sample"])) for row in existing_data + } + if any( + (row["sample_id"], self.get_sampling_method(row["sample"])) in existing_samples + for row in task_data + ): + logger.warning( + "Unexpected behavior: You have reprocessed already cached items - we will ignore the new version." ) # Merge with new data (new data overwrites existing) # We look at id + sampling method - existing_samples = {(row["sample_id"], self.get_sampling_method(row["sample"])) for row in existing_data} - if any( - (row["sample_id"], self.get_sampling_method(row["sample"])) in existing_samples for row in task_data - ): - logger.warning( - "Unexpected behavior: You have reprocessed already cached items - we will ignore the new version." - ) new_data = [ row for row in task_data @@ -319,21 +315,34 @@ def store_samples( ] all_samples = existing_data + new_data - # Save updated dataset + # Check if file exists and has other configs we need to preserve + dataset_dict = {} + if cache_file.exists(): + try: + # We load in memory to overwrite the written file + dataset_dict = DatasetDict.load_from_disk(str(cache_file), keep_in_memory=True) + except Exception as e: + logger.debug(f"Could not load existing configs from {cache_file}: {e}") + + # Add our current config, we overwrite the existing dataset = Dataset.from_list(all_samples) - dataset.to_parquet(str(cache_file)) + dataset_dict[str(sampling_method)] = dataset - logger.info( - f"Cached {len(all_samples)} {sample_type.name.lower()} samples of {task_name} at {str(cache_file)}." - ) + # Save as DatasetDict to preserve all configs + full_dataset = DatasetDict(dataset_dict) + if cache_file.exists(): + shutil.rmtree(cache_file) + full_dataset.save_to_disk(str(cache_file)) + + logger.info(f"Cached {len(all_samples)} samples of {task_name} at {str(cache_file)}.") # Refresh cached indices after storing new samples - self.existing_indices[sample_type][(task_name, task_hash)][sampling_method] = [ + self.existing_indices[(task_name, task_hash, sampling_method)] = [ sample["sample_id"] for sample in all_samples ] -def cached(cache_type_name: str, sampling_method: SamplingMethod = None): # noqa C901 +def cached(sampling_method: SamplingMethod = None): # noqa C901 """ Decorator to cache method results based on Doc inputs. @@ -341,11 +350,7 @@ def cached(cache_type_name: str, sampling_method: SamplingMethod = None): # noq cache_type_name: Type of cache ("tokenization" or "predictions") Usage: - @cached("tokenization") - def tok_encode_pair(self, docs: List[Doc], ...): - # method implementation - - @cached("predictions", "greedy") + @cached("greedy") def greedy_until(self, docs: List[Doc], ...): # method implementation @@ -356,7 +361,6 @@ def greedy_until(self, docs: List[Doc], ...): def decorator(func: Callable): # noqa C901 @functools.wraps(func) def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 - cache_type = SampleType[cache_type_name.upper()] docs = as_list(docs) # Check if caching is enabled for the model @@ -366,16 +370,16 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 cache: SampleCache = self._cache # Extract task names - task_ids = {(doc.task_name, cache.get_task_hash(doc.task_name)) for doc in docs} + task_ids = {(doc.task_name, cache.get_task_hash(doc.task_name), sampling_method) for doc in docs} # 1) Identify which samples must be processed because they are not cached - docs_not_cached, tasks_with_cached_samples = cache.get_notcached_samples(docs, cache_type, sampling_method) + docs_not_cached, tasks_with_cached_samples = cache.get_notcached_samples(docs, sampling_method) # Log cache statistics cached_count = len(docs) - len(docs_not_cached) if cached_count > 0: logger.info( - f"Cache: {cached_count}/{len(docs)} {cache_type.name.lower()} samples are cached for tasks {', '.join(t[0] for t in tasks_with_cached_samples)}" + f"Cache: {cached_count}/{len(docs)} samples are cached for tasks {', '.join(t[0] for t in tasks_with_cached_samples)}" ) # 2) Process not cached docs and save to file @@ -386,7 +390,7 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 f"{task_name} ({task_hash})" for task_name, task_hash in notcached_task_names ) logger.info( - f"Cache: Processing {len(docs_not_cached)}/{len(docs)} {cache_type.name.lower()} samples for tasks {notcached_task_names_str}" + f"Cache: Processing {len(docs_not_cached)}/{len(docs)} samples for tasks {notcached_task_names_str}" ) new_results = func(self, docs_not_cached, *args, **kwargs) @@ -395,18 +399,15 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 docs=docs_not_cached, results=new_results, task_ids=task_ids, - sample_type=cache_type, sampling_method=sampling_method, ) # 3) Create final results by pulling from newly saved file cache - final_cached_results = cache.get_samples_from_cache(docs, task_ids, cache_type) + final_cached_results = cache.get_samples_from_cache(docs, task_ids, sampling_method) # 4) We only keep samples with the correct sampling method final_results = [ - s - for s in final_cached_results - if cache.get_sampling_method(cache._dump_sample(s, cache_type)) == sampling_method + s for s in final_cached_results if cache.get_sampling_method(cache._dump_sample(s)) == sampling_method ] if any(r is None for r in final_results): From 42ec1ce5a4e6734446dbe86a0634c1b8b5cbda78 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Tue, 9 Sep 2025 13:59:55 +0000 Subject: [PATCH 12/21] fix --- docs/source/evaluating-a-custom-model.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/evaluating-a-custom-model.mdx b/docs/source/evaluating-a-custom-model.mdx index c8acbd8c9..540e30b35 100644 --- a/docs/source/evaluating-a-custom-model.mdx +++ b/docs/source/evaluating-a-custom-model.mdx @@ -31,12 +31,12 @@ class MyCustomModel(LightevalModel): # Implement generation logic pass - @cached("loglikelihood", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: List[Doc]) -> List[ModelResponse]: # Implement loglikelihood computation pass - @cached("loglikelihood", SamplingMethod.LOGPROBS) + @cached(SamplingMethod.LOGPROBS) def loglikelihood_rolling(self, docs: List[Doc]) -> List[ModelResponse]: # Implement rolling loglikelihood computation pass From 68688b43819c3ec1008ff907894554d684813822 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Tue, 9 Sep 2025 14:10:28 +0000 Subject: [PATCH 13/21] update caching tests --- tests/utils/test_caching.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 1d8f6060d..0bce98eb2 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -31,7 +31,7 @@ from lighteval.models.abstract_model import LightevalModel from lighteval.models.model_output import ModelResponse from lighteval.tasks.requests import Doc -from lighteval.utils.cache_management import SampleCache, SampleType +from lighteval.utils.cache_management import SampleCache class TestCaching(unittest.TestCase): @@ -152,16 +152,12 @@ def _test_cache(self, model: LightevalModel): cache: SampleCache = model._cache # Verify cache files were created - cache_file = cache.get_cache_path(task_name=self.task_name, sample_type=SampleType.PREDICTIONS) + cache_file = cache.get_cache_path(task_name=self.task_name) self.assertTrue(cache_file.exists(), "Cache file not created") # Test retrieving from cache - self.assertEqual( - cache._get_cached_indices(SampleType.PREDICTIONS), {self.task_name: [doc.id for doc in self.docs]} - ) - uncached_docs, tasks_with_cached_samples = cache.get_notcached_samples( - docs=self.docs, sample_type=SampleType.PREDICTIONS - ) + self.assertEqual(cache._get_cached_indices(), {self.task_name: [doc.id for doc in self.docs]}) + uncached_docs, tasks_with_cached_samples = cache.get_notcached_samples(docs=self.docs) self.assertEqual(tasks_with_cached_samples, {self.task_name}) self.assertEqual( @@ -169,9 +165,7 @@ def _test_cache(self, model: LightevalModel): ) # Verify cached results match original - cached_responses = cache.get_samples_from_cache( - docs=self.docs, task_names=[self.task_name], sample_type=SampleType.PREDICTIONS - ) + cached_responses = cache.get_samples_from_cache(docs=self.docs, task_names=[self.task_name]) for cached_response, response in zip(cached_responses, self.model_responses): self.assertEqual(asdict(cached_response), asdict(response)) From b463efff8763368ce74f172944c74eff1d06f08e Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Wed, 10 Sep 2025 14:20:39 +0000 Subject: [PATCH 14/21] simplified system with cleaner task_id --- src/lighteval/models/abstract_model.py | 2 + src/lighteval/utils/cache_management.py | 220 ++++++++++++------------ 2 files changed, 115 insertions(+), 107 deletions(-) diff --git a/src/lighteval/models/abstract_model.py b/src/lighteval/models/abstract_model.py index 46487656e..44399fff2 100644 --- a/src/lighteval/models/abstract_model.py +++ b/src/lighteval/models/abstract_model.py @@ -80,6 +80,8 @@ class ModelConfig(BaseModel, extra="forbid"): ``` """ + model_name: str = None + generation_parameters: GenerationParameters = GenerationParameters() system_prompt: str | None = None cache_dir: str = "~/.cache/huggingface/lighteval" diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index e4adc30a1..75cddcb90 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -25,13 +25,12 @@ import json import logging import os -import shutil -from dataclasses import asdict +from dataclasses import asdict, dataclass from pathlib import Path from typing import Callable, List, Set, Tuple, Union import pandas as pd -from datasets import Dataset, DatasetDict +from datasets import Dataset, load_dataset from lighteval.models.abstract_model import ModelConfig from lighteval.models.model_output import ModelResponse @@ -44,6 +43,24 @@ logger = logging.getLogger(__name__) +@dataclass +class TaskID: + """A unique ID for a grouping of task samples. It relies on the task name, + the task config (which gives the tash_hash), and the sampling method (linked to + the metric type) + """ + + task_name: str + task_hash: str + sampling_method: SamplingMethod + + def __str__(self): + return f"{self.task_name} ({self.task_hash}, {self.sampling_method.name})" + + def __hash__(self): + return int.from_bytes(hashlib.sha256(str(self).encode()).digest()) + + class SampleCache: """Disk-based cache for sample evaluation results using HuggingFace datasets. The model hash is a hash of the model config, to make sure we rerun the eval if any parameter changes @@ -75,18 +92,18 @@ def __init__(self, model_config: ModelConfig): self.existing_indices = {} self.all_cache_dirs = self.cache_dir / self.model_config.model_name / self.model_hash self.all_cache_dirs.mkdir(parents=True, exist_ok=True) - # (task_name, task_hash, sampling_method) - self.existing_indices = self._get_cached_indices() + self.existing_indices = self._load_cached_indices() def _init_registry(self, registry: Registry): self.registry = registry - def _get_cached_indices(self) -> dict: - """Loads all indices for samples which are properly cached + def _load_cached_indices(self) -> dict: + """Loads all indices for samples which are properly cached. We recursively search for all available tasks and files. Returns: dict: Dictionary mapping task names to lists of cached sample indices """ + logger.info("[CACHING] Initializing data cache") cached_indices = {} cache_dir = self.all_cache_dirs @@ -94,27 +111,28 @@ def _get_cached_indices(self) -> dict: return cached_indices for cache_file in cache_dir.rglob("*.parquet"): - task_name = str(cache_file.parent).split("/")[-1] - task_hash = cache_file.stem try: - full_dataset = DatasetDict.load_from_disk(str(cache_file)) - for sampling_method in [SamplingMethod.GENERATIVE, SamplingMethod.LOGPROBS]: - sample_ids = [] - for row in full_dataset[str(sampling_method)]: - try: - # We only save indices of correctly formatted samples, though this means we need to load each at least once - self._load_sample(row) - cur_sample = row["sample_id"] - sample_ids.append(cur_sample) - except Exception: - continue - - cached_indices[(task_name, task_hash, sampling_method)] = sample_ids - logger.debug( - f"Loaded {len(sample_ids)} cached indices for task '{task_name}', {str(sampling_method)} from {cache_file}" - ) + task_name, task_hash = cache_file.parts[-3:-1] + sampling_method = SamplingMethod[cache_file.stem] # removes the file extension + task_id = TaskID(task_name, task_hash, sampling_method) + + full_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") + sample_ids = [] + for row in full_dataset: + try: + # We only save indices of correctly formatted samples, though this means we need to load each at least once + self._load_sample(row) + cur_sample = row["sample_id"] + sample_ids.append(cur_sample) + except Exception: + continue + + cached_indices[task_id] = sample_ids + logger.info( + f"[CACHING] Loaded {len(sample_ids)} cached indices for task '{str(task_id)} from {cache_file}" + ) except Exception as e: - logger.warning(f"Error loading cached indices for task '{task_name}' from {cache_file}: {e}") + logger.warning(f"Error loading cached indices from {cache_file}: {e}") return cached_indices @@ -129,7 +147,16 @@ def get_model_hash(self, model_config: ModelConfig) -> str: config_str = json.dumps(config_dict, sort_keys=True, default=str) return hashlib.sha256(config_str.encode()).hexdigest()[:16] - def get_task_hash(self, full_task_name: str) -> str: + def _get_task_hash(self, full_task_name: str) -> str: + """Builds a task_hash from the LightevalTaskConfig loaded from the task name and the registry. + + Args: + full_task_name (str): task_name as provided to the registry (with suite|task|few_shot) + + Returns: + str: a hash of the task config in its current state in the registry, or the NO_HASH string if the + registry has not been preloaded + """ if self.registry is None: logger.warning( "The task registry was not provided to the cache config. We can't test if the current task has the same hash as the saved tasks." @@ -140,18 +167,31 @@ def get_task_hash(self, full_task_name: str) -> str: config_str = "|".join([task_config.__str__(lite=True) for task_config in task_configs]) return hashlib.sha256(config_str.encode()).hexdigest()[:16] - def get_cache_path(self, task_name: str, task_hash: str) -> Path: + def get_cache_path(self, task_id: TaskID) -> Path: """Get the file path for a specific task's cache file. Args: - task_name: Name of the task - task_hash: Hash of the task config, obtainable with self.get_task_hash - sample_type: Type of samples being cached + task_id: TaskID of the task Returns: Path: Path to the cache file for the given task and sample type """ - return self.all_cache_dirs / task_name / task_hash + return self.all_cache_dirs / task_id.task_name / task_id.task_hash / f"{task_id.sampling_method.name}.parquet" + + def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID: + """Returns a unique task indentifier. Depends on the task name, + task version and parameters (from which a hash is derived), and + current sampling method (current metric we look at). + + Args: + task_name (str): Name of the task + sampling_method (SamplingMethod): Sampling used for the current metric + + Returns: + TaskID: A unique identifier for the task + """ + task_hash = self._get_task_hash(task_name) + return TaskID(task_name, task_hash, sampling_method) def get_sampling_method(self, sample: dict) -> str: if len(sample.get("logprobs", [])) > 0: @@ -186,9 +226,11 @@ def _dump_sample(self, result: Union[dict, ModelResponse]) -> dict: """ return asdict(result) - def get_notcached_samples(self, docs: List[Doc], sampling_method: SamplingMethod) -> Tuple[List[Doc], Set]: + def get_samples_to_process_and_cache( + self, docs: List[Doc], sampling_method: SamplingMethod + ) -> Tuple[List[Doc], Set[TaskID]]: """ - Identify which docs need processing based on cached indices. + Identify which docs need processing because they are not cached yet, based on cached doc and task indices. Returns: Tuple of (docs_not_cached, tasks_with_cached_samples) where @@ -201,11 +243,9 @@ def get_notcached_samples(self, docs: List[Doc], sampling_method: SamplingMethod tasks_with_cached_samples = set() for doc in docs: - task_name = doc.task_name - task_hash = self.get_task_hash(task_name) - task_id = (task_name, task_hash, sampling_method) + task_id = self.get_task_id(doc.task_name, sampling_method) try: - if doc.id in cached_indices[task_id][sampling_method]: + if doc.id in cached_indices[task_id]: tasks_with_cached_samples.add(task_id) else: docs_not_cached.append(doc) @@ -215,7 +255,7 @@ def get_notcached_samples(self, docs: List[Doc], sampling_method: SamplingMethod return docs_not_cached, set(tasks_with_cached_samples) def get_samples_from_cache( - self, docs: List[Doc], task_ids: list | set, sampling_method: SamplingMethod + self, docs: List[Doc], task_ids: List[TaskID] | set[TaskID], sampling_method: SamplingMethod ) -> List[dict | ModelResponse]: """Get cached samples for the given docs. Warning: Assumes all docs and task_names provided are stored in cache, will fail otherwise. @@ -226,34 +266,32 @@ def get_samples_from_cache( # Load datasets for tasks that have cached docs task_datasets = {} - for task_name, task_hash, task_sampling_method in task_ids: - if task_sampling_method != sampling_method: + for task_id in task_ids: + if task_id.sampling_method != sampling_method: continue - cache_file = self.get_cache_path(task_name=task_name, task_hash=task_hash) + cache_file = self.get_cache_path(task_id) try: - dataset = DatasetDict.load_from_disk(str(cache_file))[str(sampling_method)] + dataset = load_dataset("parquet", data_files=str(cache_file), split="train") dataset_df = dataset.to_pandas().set_index("sample_id") - task_datasets[(task_name, task_hash, sampling_method)] = dataset_df + task_datasets[task_id] = dataset_df except Exception as e: - logger.warning(f"Error loading prediction cache for {task_name}: {e}") + logger.warning(f"Error loading prediction cache for {str(task_id)}: {e}") # Build results list results = [] for doc in docs: - task_name = doc.task_name - task_hash = self.get_task_hash(task_name) - task_id = (task_name, task_hash, sampling_method) + task_id = self.get_task_id(doc.task_name, sampling_method) row = task_datasets[task_id].loc[doc.id] results.append(self._load_sample(row)) return results - def store_samples( # noqa C901 + def cache_samples( # noqa C901 self, docs: List[Doc], results: List[dict] | List[ModelResponse], - task_ids: list[tuple[str, str]], + task_ids: list[TaskID], sampling_method: SamplingMethod, ): """Store new results for samples in docs""" @@ -263,83 +301,50 @@ def store_samples( # noqa C901 # Prepare newly processed data for dataset processed_data = {task_id: [] for task_id in task_ids} for doc, result in zip(docs, results): - task_name = doc.task_name - task_hash = self.get_task_hash(task_name) - task_id = (task_name, task_hash, sampling_method) + task_id = self.get_task_id(doc.task_name, sampling_method) sample = self._dump_sample(result) - if self.get_sampling_method(sample) != sampling_method: - logger.warning("The sample which was returned by the model is not of the expected type ") - processed_data[task_id].append({"sample_id": doc.id, "sample": sample}) processed_data = {task_id: task_data for task_id, task_data in processed_data.items() if task_data} # Concatenate it with existing data and save to file - for (task_name, task_hash, sampling_method), task_data in processed_data.items(): - if (task_name, task_hash, sampling_method) not in self.existing_indices.keys(): - self.existing_indices[(task_name, task_hash, sampling_method)] = {} + for task_id, task_data in processed_data.items(): + if task_id not in self.existing_indices.keys(): + self.existing_indices[task_id] = {} - cache_file = self.get_cache_path(task_name=task_name, task_hash=task_hash) + cache_file = self.get_cache_path(task_id) # Load existing data if present existing_data = [] existing_samples = {} if cache_file.exists(): try: - existing_dataset = DatasetDict.load_from_disk(str(cache_file))[str(sampling_method)] + existing_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") existing_data = existing_dataset.to_list() except KeyError: - logger.info(f"No data was cached for {task_name} ({task_hash}, {str(sampling_method)}") + logger.info(f"No data was cached for {str(task_id)}") except Exception as e: - logger.error( - f"Error loading existing prediction cache for {task_name} ({task_hash}, {str(sampling_method)}): {e}" - ) + logger.error(f"Error loading existing prediction cache for {str(task_id)}: {e}") - existing_samples = { - (row["sample_id"], self.get_sampling_method(row["sample"])) for row in existing_data - } - if any( - (row["sample_id"], self.get_sampling_method(row["sample"])) in existing_samples - for row in task_data - ): + existing_samples = {(row["sample_id"], sampling_method) for row in existing_data} + if any((row["sample_id"], sampling_method) in existing_samples for row in task_data): logger.warning( "Unexpected behavior: You have reprocessed already cached items - we will ignore the new version." ) # Merge with new data (new data overwrites existing) # We look at id + sampling method - new_data = [ - row - for row in task_data - if (row["sample_id"], self.get_sampling_method(row["sample"])) not in existing_samples - ] + new_data = [row for row in task_data if (row["sample_id"], sampling_method) not in existing_samples] all_samples = existing_data + new_data - # Check if file exists and has other configs we need to preserve - dataset_dict = {} - if cache_file.exists(): - try: - # We load in memory to overwrite the written file - dataset_dict = DatasetDict.load_from_disk(str(cache_file), keep_in_memory=True) - except Exception as e: - logger.debug(f"Could not load existing configs from {cache_file}: {e}") - - # Add our current config, we overwrite the existing + # Save updated dataset dataset = Dataset.from_list(all_samples) - dataset_dict[str(sampling_method)] = dataset - - # Save as DatasetDict to preserve all configs - full_dataset = DatasetDict(dataset_dict) - if cache_file.exists(): - shutil.rmtree(cache_file) - full_dataset.save_to_disk(str(cache_file)) + dataset.to_parquet(str(cache_file)) - logger.info(f"Cached {len(all_samples)} samples of {task_name} at {str(cache_file)}.") + logger.info(f"Cached {len(all_samples)} samples of {str(task_id)} at {str(cache_file)}.") # Refresh cached indices after storing new samples - self.existing_indices[(task_name, task_hash, sampling_method)] = [ - sample["sample_id"] for sample in all_samples - ] + self.existing_indices[task_id] = [sample["sample_id"] for sample in all_samples] def cached(sampling_method: SamplingMethod = None): # noqa C901 @@ -370,32 +375,33 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 cache: SampleCache = self._cache # Extract task names - task_ids = {(doc.task_name, cache.get_task_hash(doc.task_name), sampling_method) for doc in docs} + task_ids = {cache.get_task_id(doc.task_name, sampling_method) for doc in docs} # 1) Identify which samples must be processed because they are not cached - docs_not_cached, tasks_with_cached_samples = cache.get_notcached_samples(docs, sampling_method) + docs_not_cached: List[Doc] + tasks_with_cached_samples: Set[TaskID] + docs_not_cached, tasks_with_cached_samples = cache.get_samples_to_process_and_cache(docs, sampling_method) # Log cache statistics cached_count = len(docs) - len(docs_not_cached) if cached_count > 0: logger.info( - f"Cache: {cached_count}/{len(docs)} samples are cached for tasks {', '.join(t[0] for t in tasks_with_cached_samples)}" + f"Cache: {cached_count}/{len(docs)} samples are cached for tasks {', '.join(t_id.task_name for t_id in tasks_with_cached_samples)}" ) # 2) Process not cached docs and save to file new_results = [] if docs_not_cached: - notcached_task_names = {(doc.task_name, cache.get_task_hash(doc.task_name)) for doc in docs_not_cached} - notcached_task_names_str = ", ".join( - f"{task_name} ({task_hash})" for task_name, task_hash in notcached_task_names - ) + tasks_needing_sample_processing = { + cache.get_task_id(doc.task_name, sampling_method) for doc in docs_not_cached + } logger.info( - f"Cache: Processing {len(docs_not_cached)}/{len(docs)} samples for tasks {notcached_task_names_str}" + f"Cache: Starting to process {len(docs_not_cached)}/{len(docs)} samples (not found in cache) for tasks {','.join(str(t) for t in tasks_needing_sample_processing)}" ) new_results = func(self, docs_not_cached, *args, **kwargs) # Store new results in file cache - cache.store_samples( + cache.cache_samples( docs=docs_not_cached, results=new_results, task_ids=task_ids, From 27be046c996509d7360d02b9d0ee51a82103aa7f Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Wed, 10 Sep 2025 15:22:21 +0000 Subject: [PATCH 15/21] adapted to new functions --- docs/source/evaluating-a-custom-model.mdx | 4 +- src/lighteval/models/dummy/dummy_model.py | 2 +- .../models/endpoints/endpoint_model.py | 2 +- .../endpoints/inference_providers_model.py | 2 +- .../models/endpoints/litellm_model.py | 2 +- .../models/nanotron/nanotron_model.py | 2 +- src/lighteval/models/sglang/sglang_model.py | 2 +- .../models/transformers/transformers_model.py | 2 +- .../transformers/vlm_transformers_model.py | 2 +- src/lighteval/models/vllm/vllm_model.py | 2 +- src/lighteval/utils/cache_management.py | 17 ++-- tests/utils.py | 2 + tests/utils/test_caching.py | 87 ++++++++++++++----- 13 files changed, 87 insertions(+), 41 deletions(-) diff --git a/docs/source/evaluating-a-custom-model.mdx b/docs/source/evaluating-a-custom-model.mdx index 540e30b35..d5e7c8651 100644 --- a/docs/source/evaluating-a-custom-model.mdx +++ b/docs/source/evaluating-a-custom-model.mdx @@ -36,7 +36,7 @@ class MyCustomModel(LightevalModel): # Implement loglikelihood computation pass - @cached(SamplingMethod.LOGPROBS) + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: List[Doc]) -> List[ModelResponse]: # Implement rolling loglikelihood computation pass @@ -181,7 +181,7 @@ def __init__(self, config): 3. Add cache decorators to your prediction methods: ```python - @cached("predictions", SamplingMethod.GENERATIVE) + @cached(SamplingMethod.GENERATIVE) def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]: # Your implementation... ``` diff --git a/src/lighteval/models/dummy/dummy_model.py b/src/lighteval/models/dummy/dummy_model.py index 3ee5f1c03..e0a13b589 100644 --- a/src/lighteval/models/dummy/dummy_model.py +++ b/src/lighteval/models/dummy/dummy_model.py @@ -104,7 +104,7 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return model_responses - @cached(SamplingMethod.LOGPROBS) + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: model_responses = [] for doc in docs: diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index bed1de706..6b08be575 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -603,7 +603,7 @@ def _greedy_until(self, docs: List[Doc]) -> list[ModelResponse]: def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: return self._loglikelihood(docs, rolling=False) - @cached(SamplingMethod.LOGPROBS) + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc], override_bs=None) -> list[ModelResponse]: return self._loglikelihood(docs, rolling=True) diff --git a/src/lighteval/models/endpoints/inference_providers_model.py b/src/lighteval/models/endpoints/inference_providers_model.py index c5b03fcf8..bda9a517f 100644 --- a/src/lighteval/models/endpoints/inference_providers_model.py +++ b/src/lighteval/models/endpoints/inference_providers_model.py @@ -260,7 +260,7 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: """ raise NotImplementedError - @cached(SamplingMethod.LOGPROBS) + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" raise NotImplementedError diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index 74568bd98..00f7c8779 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -330,7 +330,7 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: """ raise NotImplementedError - @cached(SamplingMethod.LOGPROBS) + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" raise NotImplementedError diff --git a/src/lighteval/models/nanotron/nanotron_model.py b/src/lighteval/models/nanotron/nanotron_model.py index 22028f31e..310843d32 100644 --- a/src/lighteval/models/nanotron/nanotron_model.py +++ b/src/lighteval/models/nanotron/nanotron_model.py @@ -507,7 +507,7 @@ def loglikelihood(self, requests: List[Doc]) -> List[ModelResponse]: disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0), ) - @cached(SamplingMethod.LOGPROBS) + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, requests: List[Doc]) -> List[ModelResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" for request in tqdm( diff --git a/src/lighteval/models/sglang/sglang_model.py b/src/lighteval/models/sglang/sglang_model.py index 70806f632..fe37c64f9 100644 --- a/src/lighteval/models/sglang/sglang_model.py +++ b/src/lighteval/models/sglang/sglang_model.py @@ -416,6 +416,6 @@ def _loglikelihood_tokens( res.append(answer) return dataset.get_original_order(res) - @cached(SamplingMethod.LOGPROBS) + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: raise NotImplementedError() diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index 31a86229b..0fb6df464 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -883,7 +883,7 @@ def loglikelihood( """ return self._loglikelihood_tokens(docs) - @cached(SamplingMethod.LOGPROBS) + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling( self, docs: list[Doc], diff --git a/src/lighteval/models/transformers/vlm_transformers_model.py b/src/lighteval/models/transformers/vlm_transformers_model.py index 3de3385ef..3da1290be 100644 --- a/src/lighteval/models/transformers/vlm_transformers_model.py +++ b/src/lighteval/models/transformers/vlm_transformers_model.py @@ -435,7 +435,7 @@ def loglikelihood( ) -> list[ModelResponse]: raise NotImplementedError() - @cached(SamplingMethod.LOGPROBS) + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling( self, docs: list[Doc], diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index f033cad21..31ec8b5a3 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -523,7 +523,7 @@ def _loglikelihood_tokens( return dataset.get_original_order(res) - @cached(SamplingMethod.LOGPROBS) + @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: raise NotImplementedError() diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 75cddcb90..ec5265972 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -81,17 +81,16 @@ def __init__(self, model_config: ModelConfig): model_config: Configuration for the model being cached cache_dir: Directory to store cache files """ - self.cache_dir = Path(os.path.expanduser(model_config.cache_dir)) self.model_config = model_config self.model_hash = self.get_model_hash(model_config) + self.cache_dir = ( + Path(os.path.expanduser(self.model_config.cache_dir)) / self.model_config.model_name / self.model_hash + ) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.registry = None - # Create cache directory structure and load cached indices if present - self.all_cache_dirs = {} - self.existing_indices = {} - self.all_cache_dirs = self.cache_dir / self.model_config.model_name / self.model_hash - self.all_cache_dirs.mkdir(parents=True, exist_ok=True) self.existing_indices = self._load_cached_indices() def _init_registry(self, registry: Registry): @@ -105,7 +104,7 @@ def _load_cached_indices(self) -> dict: """ logger.info("[CACHING] Initializing data cache") cached_indices = {} - cache_dir = self.all_cache_dirs + cache_dir = self.cache_dir if not cache_dir.exists(): return cached_indices @@ -176,7 +175,7 @@ def get_cache_path(self, task_id: TaskID) -> Path: Returns: Path: Path to the cache file for the given task and sample type """ - return self.all_cache_dirs / task_id.task_name / task_id.task_hash / f"{task_id.sampling_method.name}.parquet" + return self.cache_dir / task_id.task_name / task_id.task_hash / f"{task_id.sampling_method.name}.parquet" def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID: """Returns a unique task indentifier. Depends on the task name, @@ -355,7 +354,7 @@ def cached(sampling_method: SamplingMethod = None): # noqa C901 cache_type_name: Type of cache ("tokenization" or "predictions") Usage: - @cached("greedy") + @cached(SamplingMethod.GENERATIVE) def greedy_until(self, docs: List[Doc], ...): # method implementation diff --git a/tests/utils.py b/tests/utils.py index b44d27551..7954b3531 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -34,6 +34,7 @@ from lighteval.tasks.lighteval_task import LightevalTask from lighteval.tasks.registry import Registry from lighteval.tasks.requests import Doc +from lighteval.utils.cache_management import SampleCache from lighteval.utils.imports import is_accelerate_available @@ -55,6 +56,7 @@ def __init__( self.greedy_until_responses = greedy_until_responses self.loglikelihood_responses = loglikelihood_responses self.loglikelihood_rolling_responses = loglikelihood_rolling_responses + self._cache = SampleCache(self.config) @property def tokenizer(self): diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 0bce98eb2..8282af050 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -30,7 +30,7 @@ from lighteval.models.abstract_model import LightevalModel from lighteval.models.model_output import ModelResponse -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.cache_management import SampleCache @@ -95,12 +95,10 @@ def test_cache_directory_structure(self): cache = SampleCache(config) # Check directory structure - cache_dirs = list(cache.all_cache_dirs.values()) - self.assertEqual(len(cache_dirs), 2) - for folder in cache_dirs: - self.assertTrue(folder.exists()) - self.assertIn(str(temp_dir), str(folder)) - self.assertIn(model_name, str(folder)) + folder = cache.cache_dir + self.assertTrue(folder.exists()) + self.assertIn(str(temp_dir), str(folder)) + self.assertIn(model_name, str(folder)) def test_cache_decorator_presence(self): """Test that @cached decorators are present on the right methods.""" @@ -142,30 +140,38 @@ def test_cache_decorator_presence(self): hasattr(method, "__wrapped__"), f"{method_name} missing @cached decorator for {model_class}" ) - def _test_cache(self, model: LightevalModel): + def _test_cache(self, model: LightevalModel, test_cases): """Test that the @cached decorator logic works correctly - called by all model specific functions below.""" - for function_name in ["greedy_until", "loglikelihood", "loglikelihood_rolling"]: + for function_name, sampling_method in test_cases: with self.subTest(function_name=function_name): process_inputs = getattr(model, function_name) process_inputs(self.docs) cache: SampleCache = model._cache + # Check task_id + task_id = cache.get_task_id(self.task_name, sampling_method) + self.assertEqual(task_id.task_name, self.task_name) + self.assertEqual(task_id.sampling_method, sampling_method) + # Verify cache files were created - cache_file = cache.get_cache_path(task_name=self.task_name) + cache_file = cache.get_cache_path(task_id) self.assertTrue(cache_file.exists(), "Cache file not created") # Test retrieving from cache - self.assertEqual(cache._get_cached_indices(), {self.task_name: [doc.id for doc in self.docs]}) - uncached_docs, tasks_with_cached_samples = cache.get_notcached_samples(docs=self.docs) - - self.assertEqual(tasks_with_cached_samples, {self.task_name}) + self.assertEqual(cache._load_cached_indices()[task_id], [doc.id for doc in self.docs]) + uncached_docs, tasks_with_cached_samples = cache.get_samples_to_process_and_cache( + docs=self.docs, sampling_method=sampling_method + ) + self.assertEqual(tasks_with_cached_samples, {task_id}) self.assertEqual( len(uncached_docs), 0, f"{len(uncached_docs)} documents not found in cache when it should be 0" ) # Verify cached results match original - cached_responses = cache.get_samples_from_cache(docs=self.docs, task_names=[self.task_name]) + cached_responses = cache.get_samples_from_cache( + docs=self.docs, task_ids=[task_id], sampling_method=sampling_method + ) for cached_response, response in zip(cached_responses, self.model_responses): self.assertEqual(asdict(cached_response), asdict(response)) @@ -194,7 +200,14 @@ def test_cache_transformers( config = TransformersModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir) model = TransformersModel(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ("loglikelihood", SamplingMethod.LOGPROBS), + ("loglikelihood_rolling", SamplingMethod.PERPLEXITY), + ], + ) @patch("lighteval.models.vllm.vllm_model.VLLMModel._create_auto_model") @patch("lighteval.models.vllm.vllm_model.VLLMModel._greedy_until") @@ -211,7 +224,14 @@ def test_cache_vllm(self, mock_create_model, mock_greedy_until, mock_loglikeliho config = VLLMModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir) model = VLLMModel(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ("loglikelihood", SamplingMethod.LOGPROBS), + ("loglikelihood_rolling", SamplingMethod.PERPLEXITY), + ], + ) @patch("requests.get") @patch("lighteval.models.endpoints.tgi_model.ModelClient._greedy_until") @@ -236,7 +256,14 @@ def test_cache_tgi(self, mock_loglikelihood, mock_greedy_until, mock_requests_ge ) model = ModelClient(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ("loglikelihood", SamplingMethod.LOGPROBS), + ("loglikelihood_rolling", SamplingMethod.PERPLEXITY), + ], + ) @patch("lighteval.models.endpoints.endpoint_model.InferenceEndpointModel._loglikelihood") @patch("lighteval.models.endpoints.endpoint_model.InferenceEndpointModel._greedy_until") @@ -257,7 +284,14 @@ def test_cache_endpoint(self, mock_init, mock_greedy_until, mock_loglikelihood): config = InferenceEndpointModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir) model = InferenceEndpointModel(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ("loglikelihood", SamplingMethod.LOGPROBS), + ("loglikelihood_rolling", SamplingMethod.PERPLEXITY), + ], + ) @patch("lighteval.models.sglang.sglang_model.SGLangModel._loglikelihood_tokens") @patch("lighteval.models.sglang.sglang_model.SGLangModel._greedy_until") @@ -278,7 +312,13 @@ def test_cache_sglang( config = SGLangModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir) model = SGLangModel(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ("loglikelihood", SamplingMethod.LOGPROBS), + ], + ) @patch("lighteval.models.transformers.vlm_transformers_model.VLMTransformersModel._greedy_until") @patch("lighteval.utils.imports.is_accelerate_available") @@ -306,4 +346,9 @@ def test_cache_vlm_transformers( config = VLMTransformersModelConfig(model_name="HuggingFaceTB/SmolVLM-256M-Instruct", cache_dir=temp_dir) model = VLMTransformersModel(config) - self._test_cache(model) + self._test_cache( + model, + [ + ("greedy_until", SamplingMethod.GENERATIVE), + ], + ) From f7c62bbbde07e85c14be0ba74f14d55a5484a89f Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Thu, 11 Sep 2025 09:23:21 +0000 Subject: [PATCH 16/21] update vllm test --- tests/utils/test_caching.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 8282af050..47ea599ae 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -209,9 +209,9 @@ def test_cache_transformers( ], ) - @patch("lighteval.models.vllm.vllm_model.VLLMModel._create_auto_model") - @patch("lighteval.models.vllm.vllm_model.VLLMModel._greedy_until") @patch("lighteval.models.vllm.vllm_model.VLLMModel._loglikelihood_tokens") + @patch("lighteval.models.vllm.vllm_model.VLLMModel._greedy_until") + @patch("lighteval.models.vllm.vllm_model.VLLMModel._create_auto_model") def test_cache_vllm(self, mock_create_model, mock_greedy_until, mock_loglikelihood): from lighteval.models.vllm.vllm_model import VLLMModel, VLLMModelConfig @@ -229,7 +229,6 @@ def test_cache_vllm(self, mock_create_model, mock_greedy_until, mock_loglikeliho [ ("greedy_until", SamplingMethod.GENERATIVE), ("loglikelihood", SamplingMethod.LOGPROBS), - ("loglikelihood_rolling", SamplingMethod.PERPLEXITY), ], ) From 0c4a429c93797be2009930557e9baf1119dc5f3e Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Thu, 11 Sep 2025 10:18:46 +0000 Subject: [PATCH 17/21] fixing the metric changed res by 1 point --- tests/reference_scores/SmolLM2-1.7B-Instruct-results-vllm.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/reference_scores/SmolLM2-1.7B-Instruct-results-vllm.json b/tests/reference_scores/SmolLM2-1.7B-Instruct-results-vllm.json index 66ab85090..7ee2bd54f 100644 --- a/tests/reference_scores/SmolLM2-1.7B-Instruct-results-vllm.json +++ b/tests/reference_scores/SmolLM2-1.7B-Instruct-results-vllm.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d1302090702deaf018f21f1dc5ffd2a2a2b93e19b50aa459508146f130aa9ecf +oid sha256:e56e83c63f8d2a066f7e1199a018583c7b304315c044412f7c7be5de62301f67 size 50565 From dbac859f2fbb974324e3f7b28e81452924360bcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Thu, 11 Sep 2025 12:41:28 +0200 Subject: [PATCH 18/21] byteorder arg --- src/lighteval/utils/cache_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index ec5265972..89406a055 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -58,7 +58,7 @@ def __str__(self): return f"{self.task_name} ({self.task_hash}, {self.sampling_method.name})" def __hash__(self): - return int.from_bytes(hashlib.sha256(str(self).encode()).digest()) + return int.from_bytes(hashlib.sha256(str(self).encode()).digest(), byteorder="big") class SampleCache: From 668e2f4498f30e8403cec2aaa758012f67ebdb5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Thu, 11 Sep 2025 13:16:46 +0200 Subject: [PATCH 19/21] this makes little sense --- tests/reference_scores/SmolLM2-1.7B-Instruct-results-vllm.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/reference_scores/SmolLM2-1.7B-Instruct-results-vllm.json b/tests/reference_scores/SmolLM2-1.7B-Instruct-results-vllm.json index 7ee2bd54f..66ab85090 100644 --- a/tests/reference_scores/SmolLM2-1.7B-Instruct-results-vllm.json +++ b/tests/reference_scores/SmolLM2-1.7B-Instruct-results-vllm.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e56e83c63f8d2a066f7e1199a018583c7b304315c044412f7c7be5de62301f67 +oid sha256:d1302090702deaf018f21f1dc5ffd2a2a2b93e19b50aa459508146f130aa9ecf size 50565 From ef5ffe748cc23f9cb0cb92bd2becac91d7901e75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Thu, 11 Sep 2025 17:53:10 +0200 Subject: [PATCH 20/21] Update src/lighteval/utils/cache_management.py Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> --- src/lighteval/utils/cache_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 89406a055..937bd1caa 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -46,7 +46,7 @@ @dataclass class TaskID: """A unique ID for a grouping of task samples. It relies on the task name, - the task config (which gives the tash_hash), and the sampling method (linked to + the task config (which gives the task_hash), and the sampling method (linked to the metric type) """ From 085f59cf5c3ec7361122e937b42634b0a8e2e178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Thu, 11 Sep 2025 18:32:16 +0200 Subject: [PATCH 21/21] comments --- src/lighteval/models/abstract_model.py | 2 ++ src/lighteval/utils/cache_management.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/lighteval/models/abstract_model.py b/src/lighteval/models/abstract_model.py index 44399fff2..81d725e6a 100644 --- a/src/lighteval/models/abstract_model.py +++ b/src/lighteval/models/abstract_model.py @@ -46,6 +46,8 @@ class ModelConfig(BaseModel, extra="forbid"): as well as shared attributes that are used by all models like generation parameters and system prompts. Attributes: + model_name (str): + The model name or unique id generation_parameters (GenerationParameters): Configuration parameters that control text generation behavior, including temperature, top_p, max_new_tokens, etc. Defaults to empty GenerationParameters. diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 937bd1caa..2059d2843 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -111,6 +111,8 @@ def _load_cached_indices(self) -> dict: for cache_file in cache_dir.rglob("*.parquet"): try: + # cache_file.parts gives all the subfolders of the url, up to the file name + # last 3 are task_name/task_hash/file_name.parquet, so we take -3 and -2 task_name, task_hash = cache_file.parts[-3:-1] sampling_method = SamplingMethod[cache_file.stem] # removes the file extension task_id = TaskID(task_name, task_hash, sampling_method)