From d0cec5a47cffffa072b59558a23ff822d3e19058 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Mon, 7 Jul 2025 11:47:05 +0900 Subject: [PATCH 1/3] Add AI2ARC dataset --- dspy/datasets/__init__.py | 2 + dspy/datasets/ai2_arc.py | 151 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 dspy/datasets/ai2_arc.py diff --git a/dspy/datasets/__init__.py b/dspy/datasets/__init__.py index 731284aa5f..543a807b0c 100644 --- a/dspy/datasets/__init__.py +++ b/dspy/datasets/__init__.py @@ -1,3 +1,4 @@ +from dspy.datasets.ai2_arc import AI2ARC from dspy.datasets.alfworld import AlfWorld from dspy.datasets.colors import Colors from dspy.datasets.dataloader import DataLoader @@ -6,6 +7,7 @@ from dspy.datasets.math import MATH __all__ = [ + "AI2ARC", "Colors", "DataLoader", "Dataset", diff --git a/dspy/datasets/ai2_arc.py b/dspy/datasets/ai2_arc.py new file mode 100644 index 0000000000..b0d522bbc1 --- /dev/null +++ b/dspy/datasets/ai2_arc.py @@ -0,0 +1,151 @@ +import re +from typing import Literal + +import tqdm + +import dspy + + +class AI2ARC: + """AI2 Reasoning Challenge (ARC) Dataset. + + The ARC dataset contains multiple-choice science questions at a grade-school level. + It consists of two subsets: ARC-Challenge (harder questions) and ARC-Easy (easier questions). + + Args: + subset: Either "challenge" or "easy" to specify which subset to load + """ + + def __init__(self, subset: Literal["challenge", "easy"] = "challenge"): + if subset not in ["challenge", "easy"]: + raise ValueError("subset must be either 'challenge' or 'easy'") + + self.subset = subset + self.do_shuffle = False + + from datasets import load_dataset + + dataset_name = "allenai/ai2_arc" + config_name = f"ARC-{subset.title()}" + + try: + hf_dataset = load_dataset(dataset_name, config_name) + except Exception as e: + raise RuntimeError(f"Failed to load {config_name} from {dataset_name}: {e}") + + official_train = self._process_split(hf_dataset["train"]) if "train" in hf_dataset else [] + official_dev = self._process_split(hf_dataset["validation"]) if "validation" in hf_dataset else [] + official_test = self._process_split(hf_dataset["test"]) if "test" in hf_dataset else [] + + self.train = [dspy.Example(**x).with_inputs("question", "choices") for x in official_train] + self.dev = [dspy.Example(**x).with_inputs("question", "choices") for x in official_dev] + self.test = [dspy.Example(**x).with_inputs("question", "choices") for x in official_test] + + def _process_split(self, split_data): + """Process a data split and convert to DSPy format.""" + processed_data = [] + + for example in tqdm.tqdm(split_data, desc=f"Processing {self.subset} split"): + choices_text = [] + choice_labels = example["choices"]["label"] + choice_texts = example["choices"]["text"] + + for label, text in zip(choice_labels, choice_texts, strict=False): + choices_text.append(f"({label}) {text}") + + processed_example = { + "id": example["id"], + "question": example["question"], + "choices": "\n".join(choices_text), + "choices_list": choice_texts, + "choice_labels": choice_labels, + "answer": example["answerKey"], + "answer_text": self._get_answer_text(example) + } + + processed_data.append(processed_example) + + return processed_data + + def _get_answer_text(self, example): + """Extract the answer text corresponding to the correct answer key.""" + answer_key = example["answerKey"] + choice_labels = example["choices"]["label"] + choice_texts = example["choices"]["text"] + + try: + answer_index = choice_labels.index(answer_key) + return choice_texts[answer_index] + except ValueError: + # If answer key not found in labels, return the key itself + return answer_key + + +def ai2_arc_metric(gold, pred, trace=None): + """Metric function for AI2 ARC dataset. + + Args: + gold: Gold example with 'answer' field + pred: Predicted example with 'answer' field + trace: Optional trace (unused) + + Returns: + bool: True if prediction matches gold answer + """ + pred_answer = str(pred.answer).strip().upper() + gold_answer = str(gold.answer).strip().upper() + + if len(pred_answer) == 1 and pred_answer in ["A", "B", "C", "D"]: + return pred_answer == gold_answer + + for letter in ["A", "B", "C", "D"]: + if f"({letter})" in pred_answer or f"{letter})" in pred_answer: + return letter == gold_answer + + if hasattr(gold, "answer_text"): + return pred_answer.lower() in gold.answer_text.lower() + + return False + + +def parse_arc_answer(answer_text): + """Parse answer from model output to extract the letter choice. + + Args: + answer_text: Raw model output + + Returns: + str: Extracted answer letter (A, B, C, or D), or the original text if not found + """ + answer_text = str(answer_text).strip() + + patterns = [ + r"\(([ABCD])\)", # (A), (B), etc. + r"([ABCD])\)", # A), B), etc. + r"^([ABCD])$", # Just A, B, C, D at start of line + r"answer is ([ABCD])", # "answer is A" + r"choice ([ABCD])", # "choice A" + ] + + for pattern in patterns: + match = re.search(pattern, answer_text.upper()) + if match: + return match.group(1) + + # If no pattern found, return the first letter if it's A, B, C, or D + first_char = answer_text.upper()[0] if answer_text else "" + if first_char in ["A", "B", "C", "D"]: + return first_char + + return answer_text + + +# Convenience functions for loading specific subsets +def ARC_Challenge(): + """Load the ARC-Challenge subset.""" + return AI2ARC(subset="challenge") + + +def ARC_Easy(): + """Load the ARC-Easy subset.""" + return AI2ARC(subset="easy") From c44de0c63f69b0fecf35d0f2390c2874b6a91bde Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Mon, 7 Jul 2025 13:23:02 +0900 Subject: [PATCH 2/3] address comment --- dspy/datasets/ai2_arc.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dspy/datasets/ai2_arc.py b/dspy/datasets/ai2_arc.py index b0d522bbc1..b4f1d6bf58 100644 --- a/dspy/datasets/ai2_arc.py +++ b/dspy/datasets/ai2_arc.py @@ -21,7 +21,6 @@ def __init__(self, subset: Literal["challenge", "easy"] = "challenge"): raise ValueError("subset must be either 'challenge' or 'easy'") self.subset = subset - self.do_shuffle = False from datasets import load_dataset @@ -102,9 +101,6 @@ def ai2_arc_metric(gold, pred, trace=None): if f"({letter})" in pred_answer or f"{letter})" in pred_answer: return letter == gold_answer - if hasattr(gold, "answer_text"): - return pred_answer.lower() in gold.answer_text.lower() - return False From 19cb652e2087b63f814fcbfc781b56284bf7f8c9 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Tue, 8 Jul 2025 15:52:59 +0900 Subject: [PATCH 3/3] address comments --- dspy/datasets/ai2_arc.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/dspy/datasets/ai2_arc.py b/dspy/datasets/ai2_arc.py index b4f1d6bf58..39fb347d73 100644 --- a/dspy/datasets/ai2_arc.py +++ b/dspy/datasets/ai2_arc.py @@ -32,9 +32,9 @@ def __init__(self, subset: Literal["challenge", "easy"] = "challenge"): except Exception as e: raise RuntimeError(f"Failed to load {config_name} from {dataset_name}: {e}") - official_train = self._process_split(hf_dataset["train"]) if "train" in hf_dataset else [] - official_dev = self._process_split(hf_dataset["validation"]) if "validation" in hf_dataset else [] - official_test = self._process_split(hf_dataset["test"]) if "test" in hf_dataset else [] + official_train = self._process_split(hf_dataset["train"]) + official_dev = self._process_split(hf_dataset["validation"]) + official_test = self._process_split(hf_dataset["test"]) self.train = [dspy.Example(**x).with_inputs("question", "choices") for x in official_train] self.dev = [dspy.Example(**x).with_inputs("question", "choices") for x in official_dev] @@ -91,8 +91,8 @@ def ai2_arc_metric(gold, pred, trace=None): Returns: bool: True if prediction matches gold answer """ - pred_answer = str(pred.answer).strip().upper() - gold_answer = str(gold.answer).strip().upper() + pred_answer = _parse_arc_answer(pred.answer) + gold_answer = _parse_arc_answer(gold.answer) if len(pred_answer) == 1 and pred_answer in ["A", "B", "C", "D"]: return pred_answer == gold_answer @@ -104,7 +104,7 @@ def ai2_arc_metric(gold, pred, trace=None): return False -def parse_arc_answer(answer_text): +def _parse_arc_answer(answer_text): """Parse answer from model output to extract the letter choice. Args: @@ -113,7 +113,7 @@ def parse_arc_answer(answer_text): Returns: str: Extracted answer letter (A, B, C, or D), or the original text if not found """ - answer_text = str(answer_text).strip() + answer_text = str(answer_text).strip().upper() patterns = [ r"\(([ABCD])\)", # (A), (B), etc. @@ -134,14 +134,3 @@ def parse_arc_answer(answer_text): return first_char return answer_text - - -# Convenience functions for loading specific subsets -def ARC_Challenge(): - """Load the ARC-Challenge subset.""" - return AI2ARC(subset="challenge") - - -def ARC_Easy(): - """Load the ARC-Easy subset.""" - return AI2ARC(subset="easy")