Spaces:
Running
Running
| # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Accuracy metric for the Test of Time benchmark by Bahar et al. (2025).""" | |
| import ast | |
| import json | |
| from typing import Any, Literal | |
| import datasets | |
| import evaluate | |
| _CITATION = """\ | |
| @InProceedings{huggingface:module, | |
| title = {Test of Time Accuracy}, | |
| authors={Auss Abbood}, | |
| year={2025} | |
| } | |
| """ | |
| _DESCRIPTION = """\ | |
| The Test of Time (ToT) benchmarks expects models format their answers as a JSON with an explanation field and an answer field that follows a predefined format. The metrics extracts JSONs objects from the model's output, retains only the first JSON, drops the explanation field and compares it with the reference answer. | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| Compares the extracted answer from the model's output with the reference answer. | |
| Args: | |
| predictions: list of predictions to score. Each prediction should be a string that contains a JSON object (e.g., generated by an LLM). | |
| references: list of reference answers. | |
| subset: The subset of the benchmark being evaluated. Must be one of "arithmetic" or "semantic". | |
| return_average: If True, returns the average accuracy. If False, returns a list of boolean scores (correct/incorrect) for each sample. Defaults to True. | |
| Returns: | |
| accuracy: The accuracy score (0.0 to 1.0) if return_average=True, or a list of booleans indicating correctness per sample if return_average=False. | |
| Examples: | |
| >>> import evaluate | |
| >>> metric = evaluate.load("aauss/test_of_time_accuracy") | |
| >>> predictions = [ | |
| ... '{"explanation": "Some explanation...", "unordered_list": ["London"]}', | |
| ... ' "Response without opening curly brackets...", "answer": "2005-04-07"}', | |
| ... ] | |
| >>> references = [ | |
| ... '{"unordered_list": ["London"]}', | |
| ... "{'answer': '2005-04-07'}", | |
| ... ] | |
| >>> results = metric.compute(predictions=predictions, references=references, subset="arithmetic") | |
| >>> print(results) | |
| {'accuracy': 0.5} | |
| """ | |
| class TestOfTimeAccuracy(evaluate.Metric): | |
| """Accuracy metric for the Test of Time benchmark by Bahar et al. (2025).""" | |
| __test__ = False | |
| def _info(self) -> evaluate.MetricInfo: | |
| """Returns metadata about this metric.""" | |
| return evaluate.MetricInfo( | |
| module_type="metric", | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| # This defines the format of each prediction and reference | |
| features=datasets.Features( | |
| { | |
| "predictions": datasets.Value("string"), | |
| "references": datasets.Value("string"), | |
| } | |
| ), | |
| # Homepage of the module for documentation | |
| # homepage="http://module.homepage", | |
| # Additional links to the codebase or references | |
| # codebase_urls=["http://github.com/path/to/codebase/of/new_module"], | |
| # reference_urls=["http://path.to.reference.url/new_module"], | |
| ) | |
| def _extract_first_json_object(text: str) -> dict | None: | |
| """ | |
| Extract the first valid JSON object from text. | |
| Handles common LLM output issues like unescaped newlines in string | |
| values (LLMs produce human-readable output, not strict JSON). | |
| Args: | |
| text: String that may contain JSON objects | |
| Returns: | |
| The first JSON dictionary found, or None if no valid JSON exists | |
| """ | |
| # Fix unescaped control chars in strings (common LLM issue) | |
| text = TestOfTimeAccuracy._escape_control_chars_in_strings(text) | |
| decoder = json.JSONDecoder() | |
| idx = 0 | |
| while idx < len(text): | |
| if text[idx] == '{': | |
| try: | |
| obj, _ = decoder.raw_decode(text, idx) | |
| if isinstance(obj, dict): | |
| return obj | |
| except json.JSONDecodeError: | |
| pass | |
| idx += 1 | |
| return None | |
| def _escape_control_chars_in_strings(text: str) -> str: | |
| """ | |
| Escape literal control characters inside JSON string values. | |
| LLMs produce newlines/tabs for readability, but JSON requires them | |
| to be escaped within strings. | |
| """ | |
| result = [] | |
| in_string = False | |
| i = 0 | |
| while i < len(text): | |
| char = text[i] | |
| if char == '\\' and in_string and i + 1 < len(text): | |
| # Preserve existing escape sequences | |
| result.append(char) | |
| result.append(text[i + 1]) | |
| i += 2 | |
| continue | |
| if char == '"': | |
| in_string = not in_string | |
| if in_string and char == '\n': | |
| result.append('\\n') | |
| elif in_string and char == '\r': | |
| result.append('\\r') | |
| elif in_string and char == '\t': | |
| result.append('\\t') | |
| else: | |
| result.append(char) | |
| i += 1 | |
| return ''.join(result) | |
| def _parse_reference_label(label_str: str) -> dict | None: | |
| """ | |
| Parses a reference label string into a dictionary. | |
| Handles Python dict strings (e.g., "{'key': 'value'}") by | |
| evaluating them as literals. | |
| Args: | |
| label_str: String representation of a dictionary | |
| Returns: | |
| Parsed dictionary, or None if parsing fails | |
| """ | |
| try: | |
| return ast.literal_eval(label_str) | |
| except (ValueError, SyntaxError): | |
| return None | |
| def _remove_explanation_field(data: Any) -> Any: | |
| """ | |
| Removes the 'explanation' field from a dictionary. | |
| Args: | |
| data: Dictionary or other data type | |
| Returns: | |
| The data with explanation field removed (if it was a dict), | |
| or the original data unchanged | |
| """ | |
| if isinstance(data, dict): | |
| data.pop("explanation", None) | |
| return data | |
| def _extract_answer_field(data: Any) -> Any: | |
| """ | |
| Extracts the 'answer' field from a dictionary. | |
| Args: | |
| data: Dictionary or other data type | |
| Returns: | |
| The value of the 'answer' field if data is a dict, | |
| otherwise returns the data unchanged | |
| """ | |
| if isinstance(data, dict): | |
| return data.get("answer", None) | |
| return data | |
| def _sort_unordered_list_field(data: Any) -> Any: | |
| """ | |
| Sorts the 'unordered_list' field in a dictionary. | |
| This enables comparison of unordered lists by converting them to | |
| a canonical sorted form. | |
| Args: | |
| data: Dictionary potentially containing an 'unordered_list' field | |
| Returns: | |
| Sorted list if data is a dict with 'unordered_list', | |
| otherwise returns data unchanged | |
| """ | |
| if isinstance(data, dict) and "unordered_list" in data: | |
| return sorted([item for item in data["unordered_list"] if isinstance(item, str)]) | |
| return data | |
| def _cast_prediction_to_reference_types( | |
| reference: dict, prediction: dict | |
| ) -> dict | None: | |
| """ | |
| Casts prediction values to match reference types. | |
| Ensures that predictions can be compared with references even when | |
| the types differ (e.g., string "123" vs int 123, int 5 vs float 5.0). | |
| Args: | |
| reference: Reference dictionary with expected types | |
| prediction: Prediction dictionary to cast | |
| Returns: | |
| Dictionary with casted values, or None if casting fails or | |
| prediction is missing required keys | |
| """ | |
| if not isinstance(prediction, dict) or not isinstance(reference, dict): | |
| return None | |
| casted_prediction = {} | |
| try: | |
| for ref_key, ref_value in reference.items(): | |
| if ref_key not in prediction: | |
| return None | |
| reference_type = type(ref_value) | |
| pred_value = prediction[ref_key] | |
| # Safeguard: Python allows list("abc") -> ['a', 'b', 'c'] | |
| # We don't want to turn strings into character lists | |
| if reference_type is list and not isinstance(pred_value, list): | |
| return None | |
| # Cast to reference type: int("123") -> 123, float(12) -> 12.0, etc. | |
| casted_prediction[ref_key] = reference_type(pred_value) | |
| return casted_prediction | |
| except (ValueError, TypeError): | |
| return None | |
| def _normalise_list_field_casing(data: dict | None) -> dict | None: | |
| """ | |
| Converts all list items to lowercase for case-insensitive comparison. | |
| Applied to 'ordered_list' and 'unordered_list' fields to handle variations | |
| in capitalization (e.g., "Skating" vs "skating"). | |
| Args: | |
| data: Dictionary potentially containing list fields | |
| Returns: | |
| Dictionary with lowercased list items, or None if data is None | |
| """ | |
| if data is None or not isinstance(data, dict): | |
| return data | |
| # Process list fields regardless of key order | |
| for key in ["ordered_list", "unordered_list"]: | |
| if key in data and isinstance(data[key], list): | |
| data[key] = [item.lower() for item in data[key] if isinstance(item, str)] | |
| return data | |
| def _fix_age_field_conflict(prediction: dict | None) -> dict | None: | |
| """ | |
| Fixes a known conflict in the dataset regarding the 'age' field. | |
| In some dataset samples, the instruction asks for an 'age' field but | |
| the reference uses 'answer'. This method normalises the prediction | |
| to match the expected format. | |
| Args: | |
| prediction: Prediction dictionary potentially with 'age' field | |
| Returns: | |
| Dictionary with 'age' converted to 'answer', or unchanged if | |
| 'age' field not present | |
| """ | |
| if prediction is not None and isinstance(prediction, dict): | |
| if "age" in prediction: | |
| prediction = {"answer": prediction["age"]} | |
| return prediction | |
| def _process_arithmetic_prediction( | |
| self, prediction: dict | None, reference: dict | None | |
| ) -> tuple[Any, Any]: | |
| """ | |
| Processes a prediction-reference pair for the arithmetic subset. | |
| Applies arithmetic-specific transformations: | |
| 1. Fixes age field conflicts | |
| 2. normalises list casing | |
| 3. Casts prediction types to match reference | |
| 4. Sorts unordered lists for comparison | |
| Args: | |
| prediction: Raw prediction dictionary | |
| reference: Raw reference dictionary | |
| Returns: | |
| Tuple of (processed_prediction, processed_reference) | |
| """ | |
| prediction = self._fix_age_field_conflict(prediction) | |
| prediction = self._normalise_list_field_casing(prediction) | |
| reference = self._normalise_list_field_casing(reference) | |
| prediction = self._cast_prediction_to_reference_types(reference, prediction) | |
| # Sort unordered lists for order-independent comparison | |
| if reference and "unordered_list" in reference: | |
| prediction = self._sort_unordered_list_field(prediction) | |
| reference = self._sort_unordered_list_field(reference) | |
| return prediction, reference | |
| def _process_semantic_prediction( | |
| self, prediction: Any, reference: Any | |
| ) -> tuple[str, str]: | |
| """ | |
| Processes a prediction-reference pair for the semantic subset. | |
| Converts both to strings for comparison since semantic answers | |
| may have type mismatches (e.g., int in JSON vs string in reference). | |
| Args: | |
| prediction: Raw prediction value | |
| reference: Raw reference value | |
| Returns: | |
| Tuple of (str(prediction), str(reference)) | |
| """ | |
| return str(prediction), str(reference) | |
| def _extract_predictions( | |
| self, raw_predictions: list[str], subset: str | |
| ) -> list[Any]: | |
| """ | |
| Extracts and preprocesses predictions based on subset type. | |
| Args: | |
| raw_predictions: List of raw prediction strings (e.g., from LLM output) | |
| subset: Either 'arithmetic' or 'semantic' | |
| Returns: | |
| List of extracted prediction values | |
| """ | |
| predictions = [self._extract_first_json_object(p) for p in raw_predictions] | |
| if subset == "semantic": | |
| # Since labels are not dicts, we need to extract the value from the LLM's answer field. | |
| predictions = [self._extract_answer_field(p) for p in predictions] | |
| elif subset == "arithmetic": | |
| # Labels and LLMs differ only by the explanation field. Thus, remove. | |
| predictions = [self._remove_explanation_field(p) for p in predictions] | |
| return predictions | |
| def _extract_references(self, raw_references: list[str], subset: str) -> list[Any]: | |
| """ | |
| Extracts and preprocesses references based on subset type. | |
| Args: | |
| raw_references: List of raw reference strings | |
| subset: Either 'arithmetic' or 'semantic' | |
| Returns: | |
| List of extracted reference values | |
| """ | |
| if subset == "arithmetic": | |
| # Arithmetic references are Python dict strings that need parsing | |
| return [self._parse_reference_label(r) for r in raw_references] | |
| else: | |
| # Semantic references are used as-is | |
| return raw_references | |
| def _compare_pair(self, prediction: Any, reference: Any, subset: str) -> bool: | |
| """ | |
| Compares a single prediction-reference pair. | |
| Args: | |
| prediction: Processed prediction value | |
| reference: Processed reference value | |
| subset: Either 'arithmetic' or 'semantic' | |
| Returns: | |
| True if prediction matches reference, False otherwise | |
| """ | |
| if subset == "arithmetic": | |
| prediction, reference = self._process_arithmetic_prediction( | |
| prediction, reference | |
| ) | |
| elif subset == "semantic": | |
| prediction, reference = self._process_semantic_prediction( | |
| prediction, reference | |
| ) | |
| return prediction == reference | |
| def _compute( | |
| self, | |
| predictions: list[str], | |
| references: list[str], | |
| subset: Literal["arithmetic", "semantic"], | |
| return_average: bool = True, | |
| ) -> dict[str, float | list[bool]]: | |
| """ | |
| Computes accuracy scores for the Test of Time benchmark. | |
| Args: | |
| predictions: List of prediction strings (LLM outputs) | |
| references: List of reference answer strings | |
| subset: Benchmark subset - either 'arithmetic' or 'semantic' | |
| return_average: If True, returns average accuracy; if False, | |
| returns per-sample correctness | |
| Returns: | |
| Dictionary with 'accuracy' key containing either: | |
| - float: average accuracy (if return_average=True) | |
| - list[bool]: per-sample correctness (if return_average=False) | |
| Raises: | |
| ValueError: If subset is not 'arithmetic' or 'semantic' | |
| """ | |
| # Validate subset | |
| if subset not in ["arithmetic", "semantic"]: | |
| raise ValueError( | |
| f"Invalid subset: {subset}. Must be 'arithmetic' or 'semantic'." | |
| ) | |
| # Extract and preprocess predictions and references | |
| predictions = self._extract_predictions(predictions, subset) | |
| references = self._extract_references(references, subset) | |
| # Compare each prediction-reference pair | |
| accuracy_scores = [ | |
| self._compare_pair(pred, ref, subset) | |
| for pred, ref in zip(predictions, references) | |
| ] | |
| # Return average or per-sample scores | |
| if return_average: | |
| return {"accuracy": sum(accuracy_scores) / len(accuracy_scores)} | |
| return {"accuracy": accuracy_scores} | |