aauss commited on
Commit
3d8a7dc
·
1 Parent(s): 2aa1857

Add more directed type casting for arithmetic subset.

Browse files
test_of_time_accuracy.py CHANGED
@@ -120,11 +120,41 @@ class TestOfTimeAccuracy(evaluate.Metric):
120
  except (ValueError, SyntaxError):
121
  return None
122
 
123
- def _sort_unordered_list(self, d):
 
124
  if isinstance(d, dict) and "unordered_list" in d:
125
  return sorted(d["unordered_list"])
126
  return d
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def _compute(
129
  self,
130
  predictions,
@@ -135,22 +165,28 @@ class TestOfTimeAccuracy(evaluate.Metric):
135
  """Returns the scores"""
136
  predictions = [self._extract_first_json_object(p) for p in predictions]
137
  if subset == "semantic":
 
 
138
  predictions = [self._get_answer(p) for p in predictions]
139
  elif subset == "arithmetic":
 
 
140
  predictions = [self._pop_explanation(p) for p in predictions]
141
  references = [self._parse_label(r) for r in references]
142
  else:
143
  raise ValueError(f"Invalid subset: {subset}")
144
  accuracy = []
145
- for i, j in zip(predictions, references):
146
- if subset == "arithmetic" and "unordered_list" in j:
147
- i = self._sort_unordered_list(i)
148
- j = self._sort_unordered_list(j)
 
 
149
  if subset == "semantic":
150
- i = str(i)
151
- j = str(j)
152
  accuracy.append(
153
- i == j
154
  ) # Semantic subset answer JSON somestimes has int as value. Label is string.
155
  if return_average:
156
  return {"accuracy": sum(accuracy) / len(accuracy)}
 
120
  except (ValueError, SyntaxError):
121
  return None
122
 
123
+ @staticmethod
124
+ def _sort_unordered_list(d):
125
  if isinstance(d, dict) and "unordered_list" in d:
126
  return sorted(d["unordered_list"])
127
  return d
128
 
129
+ @staticmethod
130
+ def _cast_prediction(reference: dict, prediction: dict) -> None | dict:
131
+ """
132
+ Casts the values in the prediction dictionary to match the types
133
+ of the values in the reference dictionary.
134
+ """
135
+ casted_prediction = {}
136
+
137
+ try:
138
+ for ref_key, ref_value in reference.items():
139
+ if ref_key not in prediction:
140
+ return None
141
+
142
+ reference_type = type(ref_value)
143
+ pred_value = prediction[ref_key]
144
+
145
+ # Special safeguard: Python allows list("abc") -> ['a', 'b', 'c'].
146
+ # We don't want to turn strings into character lists.
147
+ if reference_type == list and not isinstance(pred_value, list):
148
+ return None
149
+
150
+ # This handles int("123") -> 123, float(12) -> 12.0, str(100) -> "100"
151
+ casted_prediction[ref_key] = reference_type(pred_value)
152
+
153
+ return casted_prediction
154
+
155
+ except (ValueError, TypeError):
156
+ return None
157
+
158
  def _compute(
159
  self,
160
  predictions,
 
165
  """Returns the scores"""
166
  predictions = [self._extract_first_json_object(p) for p in predictions]
167
  if subset == "semantic":
168
+ # Semantic subset's answers are not JSON objects.
169
+ # Expected answers are always in "answer" field.
170
  predictions = [self._get_answer(p) for p in predictions]
171
  elif subset == "arithmetic":
172
+ # Arithmetic subset's answers are JSON objects.
173
+ # Answer fields vary. Thus, remove explanation field.
174
  predictions = [self._pop_explanation(p) for p in predictions]
175
  references = [self._parse_label(r) for r in references]
176
  else:
177
  raise ValueError(f"Invalid subset: {subset}")
178
  accuracy = []
179
+ for pred, ref in zip(predictions, references):
180
+ if subset == "arithmetic":
181
+ pred = self._cast_prediction(ref, pred)
182
+ if "unordered_list" in ref:
183
+ pred = self._sort_unordered_list(pred)
184
+ ref = self._sort_unordered_list(ref)
185
  if subset == "semantic":
186
+ pred = str(pred)
187
+ ref = str(ref)
188
  accuracy.append(
189
+ pred == ref
190
  ) # Semantic subset answer JSON somestimes has int as value. Label is string.
191
  if return_average:
192
  return {"accuracy": sum(accuracy) / len(accuracy)}
tests.py → tests/test_arithmetic_scoring.py RENAMED
@@ -8,30 +8,19 @@ arithmetic_test_cases = {
8
  '```json\n{\n "explanation": "The dates provided are March 2012, September 2011, June 2017, September 2019, and June 2015. These correspond to visits to Miami, Sydney, Tokyo, London, and Nairobi respectively. The latest date among these is September 2019, which is associated with London. Therefore, London is the last city visited.",\n "unordered_list": ["Berlin","London"]\n}\n```',
9
  '```json\n{\n "explanation": "The dates provided are March 2012, September 2011, June 2017, September 2019, and June 2015. These correspond to visits to Miami, Sydney, Tokyo, London, and Nairobi respectively. The latest date among these is September 2019, which is associated with London. Therefore, London is the last city visited.",\n "malformed_unordered_list": ["Berlin","London"]\n}\n```',
10
  ' "To find the date of the second most important game, we need to subtract 7 days from the date of the most important game. We can do this by counting back 7 days from April 14, 2005. April 14 - 7 days = April 7, 2005", "answer": "2005-04-07"}',
11
- '\n```json\n{\n "explanation": "Step 1: Determine the time it takes the robot to carry a single box. The robot takes 4 hours, 34 minutes, and 30 seconds to carry 2 boxes. We divide this time by 2 to find the time per box.\\n- Hours: 4 / 2 = 2 hours\\n- Minutes: 34 / 2 = 17 minutes\\n- Seconds: 30 / 2 = 15 seconds\\nSo, it takes the robot 2 hours, 17 minutes, and 15 seconds to carry one box.\\n\\nStep 2: Calculate the total time to carry 25 boxes. We multiply the time per box by the total number of boxes (25).\\n- Total Hours: 2 hours/box * 25 boxes = 50 hours\\n- Total Minutes: 17 minutes/box * 25 boxes = 425 minutes\\n- Total Seconds: 15 seconds/box * 25 boxes = 375 seconds\\n\\nStep 3: Convert the calculated time into the standard H:M:S format by carrying over excess seconds and minutes.\\n- Convert seconds to minutes: 375 seconds is equal to 6 minutes and 15 seconds (since 375 / 60 = 6 with a remainder of 15). We add the 6 minutes to our minutes total.\\n- New total: 50 hours, (425 + 6) minutes, 15 seconds -> 50 hours, 431 minutes, 15 seconds.\\n- Convert minutes to hours: 431 minutes is equal to 7 hours and 11 minutes (since 431 / 60 = 7 with a remainder of 11). We add the 7 hours to our hours total.\\n- New total: (50 + 7) hours, 11 minutes, 15 seconds -> 57 hours, 11 minutes, 15 seconds.\\n\\nThe final time is 57 hours, 11 minutes, and 15 seconds.",\n "H": 57,\n "M": 11,\n "S": 15\n}\n```'
12
  ],
13
  "references": [
14
  '{"answer": "352 BC"}',
15
  '{"unordered_list": ["London", "Berlin"]}',
16
  '{"unordered_list": ["London", "Berlin"]}',
17
- "{'answer': '2005-04-07'}",
18
- '{"H": 57.0, "M": 11.0, "S": 15.0}'
19
  ],
20
- "result": {"accuracy": 3/5},
21
  "per_item_accuracy": [True, True, False, False, True],
22
  }
23
 
24
- semantic_test_cases = {
25
- "predictions": [
26
- '{"explanation": First, we need to find the third occurrence of E33 being the R53 of E22. We can see that it happened from 1959 to 1962, then from 1967 to 1968, and then from 1982 to 1984. The third occurrence happened from 1982 to 1984. We can then compute the duration by subtracting the start time from the end time.", "answer": 2}',
27
- ' "To find the duration, we need to find the start and end time when E97 was the R71 of E67. From the given facts, we can see that E97 was the R71 of E67 from 1961 to 1961, and also from 1964 to 1964. We need to find the first occurrence, which is from 1961 to 1961.", "answer": 1}',
28
- '{"explanation": "To find when E92 stopped being the R88 of E11, we need to look at the temporal facts where E92 was the R88 of E11 and find the end time. We see that E92 was the R88 of E11 from 1982 to 1985, and there is no other fact that indicates E92 stopped being the R88 of E11 before 1985. However, we also see that E92 was the R17 of E42 from 1986 to 1992, and E92 was the R88 of E42 from 1977 to 1979, but this is irrelevant to the question. Therefore, E92 stopped being the R88 of E11 in 1985.", "answer": 1985}',
29
- ],
30
- "references": ["2", "0", "1985"],
31
- "result": {"accuracy": 1 / 3},
32
- "per_item_accuracy": [False, False, True],
33
- }
34
-
35
 
36
  def test_arithmetic_accuracy():
37
  metric = TestOfTimeAccuracy()
@@ -43,16 +32,6 @@ def test_arithmetic_accuracy():
43
  assert results == arithmetic_test_cases["result"]
44
 
45
 
46
- def test_semantic_accuracy():
47
- metric = TestOfTimeAccuracy()
48
- results = metric.compute(
49
- predictions=semantic_test_cases["predictions"],
50
- references=semantic_test_cases["references"],
51
- subset="semantic",
52
- )
53
- assert results == semantic_test_cases["result"]
54
-
55
-
56
  def test_per_item_arithmetic_accuracy():
57
  metric = TestOfTimeAccuracy()
58
  results = metric.compute(
@@ -64,17 +43,6 @@ def test_per_item_arithmetic_accuracy():
64
  assert results["accuracy"] == arithmetic_test_cases["per_item_accuracy"]
65
 
66
 
67
- def test_per_item_semantic_accuracy():
68
- metric = TestOfTimeAccuracy()
69
- results = metric.compute(
70
- predictions=semantic_test_cases["predictions"],
71
- references=semantic_test_cases["references"],
72
- subset="semantic",
73
- return_average=False,
74
- )
75
- assert results["accuracy"] == semantic_test_cases["per_item_accuracy"]
76
-
77
-
78
  def test_invalid_subset():
79
  metric = TestOfTimeAccuracy()
80
  with pytest.raises(ValueError):
 
8
  '```json\n{\n "explanation": "The dates provided are March 2012, September 2011, June 2017, September 2019, and June 2015. These correspond to visits to Miami, Sydney, Tokyo, London, and Nairobi respectively. The latest date among these is September 2019, which is associated with London. Therefore, London is the last city visited.",\n "unordered_list": ["Berlin","London"]\n}\n```',
9
  '```json\n{\n "explanation": "The dates provided are March 2012, September 2011, June 2017, September 2019, and June 2015. These correspond to visits to Miami, Sydney, Tokyo, London, and Nairobi respectively. The latest date among these is September 2019, which is associated with London. Therefore, London is the last city visited.",\n "malformed_unordered_list": ["Berlin","London"]\n}\n```',
10
  ' "To find the date of the second most important game, we need to subtract 7 days from the date of the most important game. We can do this by counting back 7 days from April 14, 2005. April 14 - 7 days = April 7, 2005", "answer": "2005-04-07"}',
11
+ '\n```json\n{\n "explanation": "Step 1: Determine the time it takes the robot to carry a single box. The robot takes 4 hours, 34 minutes, and 30 seconds to carry 2 boxes. We divide this time by 2 to find the time per box.\\n- Hours: 4 / 2 = 2 hours\\n- Minutes: 34 / 2 = 17 minutes\\n- Seconds: 30 / 2 = 15 seconds\\nSo, it takes the robot 2 hours, 17 minutes, and 15 seconds to carry one box.\\n\\nStep 2: Calculate the total time to carry 25 boxes. We multiply the time per box by the total number of boxes (25).\\n- Total Hours: 2 hours/box * 25 boxes = 50 hours\\n- Total Minutes: 17 minutes/box * 25 boxes = 425 minutes\\n- Total Seconds: 15 seconds/box * 25 boxes = 375 seconds\\n\\nStep 3: Convert the calculated time into the standard H:M:S format by carrying over excess seconds and minutes.\\n- Convert seconds to minutes: 375 seconds is equal to 6 minutes and 15 seconds (since 375 / 60 = 6 with a remainder of 15). We add the 6 minutes to our minutes total.\\n- New total: 50 hours, (425 + 6) minutes, 15 seconds -> 50 hours, 431 minutes, 15 seconds.\\n- Convert minutes to hours: 431 minutes is equal to 7 hours and 11 minutes (since 431 / 60 = 7 with a remainder of 11). We add the 7 hours to our hours total.\\n- New total: (50 + 7) hours, 11 minutes, 15 seconds -> 57 hours, 11 minutes, 15 seconds.\\n\\nThe final time is 57 hours, 11 minutes, and 15 seconds.",\n "H": 57,\n "M": 11,\n "S": 15\n}\n```',
12
  ],
13
  "references": [
14
  '{"answer": "352 BC"}',
15
  '{"unordered_list": ["London", "Berlin"]}',
16
  '{"unordered_list": ["London", "Berlin"]}',
17
+ '{"answer": "2005-04-07"}',
18
+ '{"H": 57.0, "M": 11.0, "S": 15.0}',
19
  ],
20
+ "result": {"accuracy": 3 / 5},
21
  "per_item_accuracy": [True, True, False, False, True],
22
  }
23
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def test_arithmetic_accuracy():
26
  metric = TestOfTimeAccuracy()
 
32
  assert results == arithmetic_test_cases["result"]
33
 
34
 
 
 
 
 
 
 
 
 
 
 
35
  def test_per_item_arithmetic_accuracy():
36
  metric = TestOfTimeAccuracy()
37
  results = metric.compute(
 
43
  assert results["accuracy"] == arithmetic_test_cases["per_item_accuracy"]
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
46
  def test_invalid_subset():
47
  metric = TestOfTimeAccuracy()
48
  with pytest.raises(ValueError):
tests/test_arithmetic_type_casting.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stastics about expected data type per key
2
+ ##################################################################################################
3
+ # label_key label_value_type
4
+ # answer <class 'str'> 600
5
+ # <class 'int'> 300
6
+ # unordered_list <class 'list'> 250
7
+ # date <class 'str'> 150
8
+ # day, time <class 'str'>, <class 'str'> 123
9
+ # H, M, S <class 'float'>, <class 'float'>, <class 'float'> 100
10
+ # ordered_list <class 'list'> 100
11
+ # A, B, C <class 'int'>, <class 'int'>, <class 'int'> 50
12
+ # X, Y, Z <class 'float'>, <class 'float'>, <class 'float'> 50
13
+ # <class 'int'>, <class 'int'>, <class 'int'> 50
14
+ # hours, minutes, seconds <class 'int'>, <class 'int'>, <class 'int'> 50
15
+ # hours, minutes <class 'int'>, <class 'int'> 15
16
+ # days, hours, minutes, seconds <class 'int'>, <class 'int'>, <class 'int'>, <class 'int'> 12
17
+ ##################################################################################################
18
+
19
+ from test_arithmetic_scoring import TestOfTimeAccuracy
20
+
21
+
22
+ def test_answer_type_casting():
23
+ references_answer_key = [
24
+ {"answer": "352 BC"},
25
+ {"answer": 1032},
26
+ ]
27
+ predictions_answer_key = [
28
+ {"answer": "352 BC"},
29
+ {"answer": "1032"},
30
+ ]
31
+ for ref, pred in zip(references_answer_key, predictions_answer_key):
32
+ pred_cast = TestOfTimeAccuracy._cast_prediction(ref, pred)
33
+ assert ref == pred_cast
34
+
35
+
36
+ def test_unordered_list_type_casting():
37
+ references_unordered_list_key = [{"unordered_list": ["Kyle", "Jason", "Joe"]}]
38
+ predictions_unordered_list_key = [{"unordered_list": ["Kyle", "Jason", "Joe"]}]
39
+ for ref, pred in zip(references_unordered_list_key, predictions_unordered_list_key):
40
+ pred_cast = TestOfTimeAccuracy._cast_prediction(ref, pred)
41
+ assert ref == pred_cast
42
+
43
+
44
+ def test_date_type_casting():
45
+ references_date_key = [{"date": "12/11/2011"}]
46
+ predictions_date_key = [{"date": "12/11/2011"}]
47
+ for ref, pred in zip(references_date_key, predictions_date_key):
48
+ pred_cast = TestOfTimeAccuracy._cast_prediction(ref, pred)
49
+ assert ref == pred_cast
50
+
51
+
52
+ def test_day_time_type_casting():
53
+ references_day_time_keys = [{"day": "+2", "time": "21:44:10"}]
54
+ predictions_day_time_keys = [{"day": "+2", "time": "21:44:10"}]
55
+ for ref, pred in zip(references_day_time_keys, predictions_day_time_keys):
56
+ pred_cast = TestOfTimeAccuracy._cast_prediction(ref, pred)
57
+ assert ref == pred_cast
58
+
59
+
60
+ def test_hms_type_casting():
61
+ references_hms_keys = [
62
+ {"H": 2.0, "M": 13.0, "S": 30.0},
63
+ {"H": 2.0, "M": 13.0, "S": 30.0},
64
+ {"H": 2.0, "M": 13.0, "S": 30.0},
65
+ {"H": 2.0, "M": 13.0, "S": 30.0},
66
+ ]
67
+ predictions_hms_keys = [
68
+ {"H": 2, "M": 13, "S": 30},
69
+ {"H": 2.0, "M": 13.0, "S": 30.0},
70
+ {"H": "2", "M": "13", "S": "30"},
71
+ {"H": "2.0", "M": "13.0", "S": "30.0"},
72
+ ]
73
+ for ref, pred in zip(references_hms_keys, predictions_hms_keys):
74
+ pred_cast = TestOfTimeAccuracy._cast_prediction(ref, pred)
75
+ assert ref == pred_cast
76
+
77
+
78
+ def test_ordered_list_type_casting():
79
+ references_ordered_list_key = [
80
+ {"ordered_list": ["Joe", "Jenny", "Jason", "Dan", "Kyle"]},
81
+ ]
82
+ predictions_ordered_list_key = [
83
+ {"ordered_list": ["Joe", "Jenny", "Jason", "Dan", "Kyle"]},
84
+ ]
85
+ for ref, pred in zip(references_ordered_list_key, predictions_ordered_list_key):
86
+ pred_cast = TestOfTimeAccuracy._cast_prediction(ref, pred)
87
+ assert ref == pred_cast
88
+
89
+ # TODO: Check if I should treat float strings differently, e.g., int(float("18.0"))
90
+ def test_abc_type_casting():
91
+ references_abc_keys = [
92
+ {"A": 80, "B": 22, "C": 20},
93
+ {"A": 80, "B": 22, "C": 20},
94
+ {"A": 80, "B": 22, "C": 20},
95
+ # {"A": 80, "B": 22, "C": 20},
96
+ ]
97
+ predictions_abc_keys = [
98
+ {"A": 80, "B": 22, "C": 20},
99
+ {"A": 80.0, "B": 22.0, "C": 20.0},
100
+ {"A": "80", "B": "22", "C": "20"},
101
+ # {"A": "80.0", "B": "22.0", "C": "20.0"},
102
+ ]
103
+ for ref, pred in zip(references_abc_keys, predictions_abc_keys):
104
+ pred_cast = TestOfTimeAccuracy._cast_prediction(ref, pred)
105
+ assert ref == pred_cast
106
+
107
+
108
+ def test_xyz_type_casting():
109
+ references_xyz_keys = [
110
+ {"X": 2.0, "Y": 24.0, "Z": 44.0},
111
+ {"X": 2, "Y": 4, "Z": 36},
112
+ {"X": 2.0, "Y": 24.0, "Z": 44.0},
113
+ {"X": 2, "Y": 4, "Z": 36},
114
+ {"X": 2.0, "Y": 24.0, "Z": 44.0},
115
+ {"X": 2, "Y": 4, "Z": 36},
116
+ ]
117
+ predictions_xyz_keys = [
118
+ {"X": 2.0, "Y": 24.0, "Z": 44.0},
119
+ {"X": 2, "Y": 4, "Z": 36},
120
+ {"X": "2.0", "Y": "24.0", "Z": "44.0"},
121
+ {"X": "2", "Y": "4", "Z": "36"},
122
+ {"X": 2, "Y": 24, "Z": 44},
123
+ {"X": 2.0, "Y": 4.0, "Z": 36.0},
124
+ ]
125
+ for ref, pred in zip(references_xyz_keys, predictions_xyz_keys):
126
+ pred_cast = TestOfTimeAccuracy._cast_prediction(ref, pred)
127
+ assert ref == pred_cast
128
+
129
+
130
+ def test_hours_minutes_seconds_type_casting():
131
+ references_hours_minutes_seconds_keys = [
132
+ {"hours": 17, "minutes": 48, "seconds": 51},
133
+ {"hours": 17, "minutes": 48, "seconds": 51},
134
+ {"hours": 17, "minutes": 48, "seconds": 51},
135
+ # {"hours": 17, "minutes": 48, "seconds": 51},
136
+ ]
137
+ predictions_hours_minutes_seconds_keys = [
138
+ {"hours": 17, "minutes": 48, "seconds": 51},
139
+ {"hours": 17.0, "minutes": 48.0, "seconds": 51.0},
140
+ {"hours": "17", "minutes": "48", "seconds": "51"},
141
+ # {"hours": "17.0", "minutes": "48.0", "seconds": "51.0"},
142
+ ]
143
+ for ref, pred in zip(
144
+ references_hours_minutes_seconds_keys, predictions_hours_minutes_seconds_keys
145
+ ):
146
+ pred_cast = TestOfTimeAccuracy._cast_prediction(ref, pred)
147
+ assert ref == pred_cast
148
+
149
+
150
+ def test_hours_minutes_type_casting():
151
+ references_hours_minutes_keys = [
152
+ {"hours": 5, "minutes": 0},
153
+ {"hours": 5, "minutes": 0},
154
+ {"hours": 5, "minutes": 0},
155
+ # {"hours": 5, "minutes": 0},
156
+ ]
157
+ predictions_hours_minutes_keys = [
158
+ {"hours": 5, "minutes": 0},
159
+ {"hours": 5.0, "minutes": 0.0},
160
+ {"hours": "5", "minutes": "0"},
161
+ # {"hours": "5.0", "minutes": "0.0"},
162
+ ]
163
+ for ref, pred in zip(references_hours_minutes_keys, predictions_hours_minutes_keys):
164
+ pred_cast = TestOfTimeAccuracy._cast_prediction(ref, pred)
165
+ assert ref == pred_cast
166
+
167
+
168
+ def test_days_hours_minutes_seconds_type_casting():
169
+ references_days_hours_minutes_seconds_keys = [
170
+ {"days": 1, "hours": 12, "minutes": 45, "seconds": 0},
171
+ {"days": 1, "hours": 12, "minutes": 45, "seconds": 0},
172
+ {"days": 1, "hours": 12, "minutes": 45, "seconds": 0},
173
+ # {"days": 1, "hours": 12, "minutes": 45, "seconds": 0},
174
+ ]
175
+ predictions_days_hours_minutes_seconds_keys = [
176
+ {"days": 1, "hours": 12, "minutes": 45, "seconds": 0},
177
+ {"days": 1.0, "hours": 12.0, "minutes": 45.0, "seconds": 0.0},
178
+ {"days": "1", "hours": "12", "minutes": "45", "seconds": "0"},
179
+ # {"days": "1.0", "hours": "12.0", "minutes": "45.0", "seconds": "0.0"},
180
+ ]
181
+ for ref, pred in zip(
182
+ references_days_hours_minutes_seconds_keys,
183
+ predictions_days_hours_minutes_seconds_keys,
184
+ ):
185
+ pred_cast = TestOfTimeAccuracy._cast_prediction(ref, pred)
186
+ assert ref == pred_cast
tests/test_semantic_scoring.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from test_of_time_accuracy import TestOfTimeAccuracy
2
+
3
+ semantic_test_cases = {
4
+ "predictions": [
5
+ '{"explanation": First, we need to find the third occurrence of E33 being the R53 of E22. We can see that it happened from 1959 to 1962, then from 1967 to 1968, and then from 1982 to 1984. The third occurrence happened from 1982 to 1984. We can then compute the duration by subtracting the start time from the end time.", "answer": 2}',
6
+ ' "To find the duration, we need to find the start and end time when E97 was the R71 of E67. From the given facts, we can see that E97 was the R71 of E67 from 1961 to 1961, and also from 1964 to 1964. We need to find the first occurrence, which is from 1961 to 1961.", "answer": 1}',
7
+ '{"explanation": "To find when E92 stopped being the R88 of E11, we need to look at the temporal facts where E92 was the R88 of E11 and find the end time. We see that E92 was the R88 of E11 from 1982 to 1985, and there is no other fact that indicates E92 stopped being the R88 of E11 before 1985. However, we also see that E92 was the R17 of E42 from 1986 to 1992, and E92 was the R88 of E42 from 1977 to 1979, but this is irrelevant to the question. Therefore, E92 stopped being the R88 of E11 in 1985.", "answer": 1985}',
8
+ ],
9
+ "references": ["2", "0", "1985"],
10
+ "result": {"accuracy": 1 / 3},
11
+ "per_item_accuracy": [False, False, True],
12
+ }
13
+
14
+
15
+ def test_per_item_semantic_accuracy():
16
+ metric = TestOfTimeAccuracy()
17
+ results = metric.compute(
18
+ predictions=semantic_test_cases["predictions"],
19
+ references=semantic_test_cases["references"],
20
+ subset="semantic",
21
+ return_average=False,
22
+ )
23
+ assert results["accuracy"] == semantic_test_cases["per_item_accuracy"]
24
+
25
+
26
+ def test_semantic_accuracy():
27
+ metric = TestOfTimeAccuracy()
28
+ results = metric.compute(
29
+ predictions=semantic_test_cases["predictions"],
30
+ references=semantic_test_cases["references"],
31
+ subset="semantic",
32
+ )
33
+ assert results == semantic_test_cases["result"]