Skip to content

MMLU Pro Evaluator #41860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ViolenceEvaluator,
)
from ._evaluators._f1_score import F1ScoreEvaluator
from ._evaluators._mmlu import MMLUEvaluator
from ._evaluators._fluency import FluencyEvaluator
from ._evaluators._gleu import GleuScoreEvaluator
from ._evaluators._groundedness import GroundednessEvaluator
Expand Down Expand Up @@ -82,6 +83,7 @@
"ContentSafetyEvaluator",
"IndirectAttackEvaluator",
"BleuScoreEvaluator",
"MMLUEvaluator",
"GleuScoreEvaluator",
"MeteorScoreEvaluator",
"RetrievalEvaluator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from azure.ai.evaluation._evaluators._eci._eci import ECIEvaluator
from azure.ai.evaluation import (
BleuScoreEvaluator,
MMLUEvaluator,
CodeVulnerabilityEvaluator,
CoherenceEvaluator,
ContentSafetyEvaluator,
Expand Down Expand Up @@ -43,6 +44,7 @@

EVAL_CLASS_MAP = {
BleuScoreEvaluator: "bleu_score",
MMLUEvaluator: "mmlu_score",
CodeVulnerabilityEvaluator: "code_vulnerability",
CoherenceEvaluator: "coherence",
ContentSafetyEvaluator: "content_safety",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from ._base_prompty_eval import PromptyEvaluatorBase
from ._base_rai_svc_eval import RaiServiceEvaluatorBase
from ._base_multi_eval import MultiEvaluatorBase
from ._base_regex_eval import RegexEvaluatorBase

__all__ = [
"EvaluatorBase",
"PromptyEvaluatorBase",
"RaiServiceEvaluatorBase",
"MultiEvaluatorBase",
"RegexEvaluatorBase",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from typing import Any, Dict, List, Optional
from typing_extensions import override
from abc import abstractmethod
import re
import numpy as np

from azure.ai.evaluation._evaluators._common import EvaluatorBase


class RegexEvaluatorBase(EvaluatorBase):
"""Base class for all evaluators that are regex-based and use pattern matching to extract answers.

This class provides a framework for evaluators that need to extract structured answers from text
using regular expressions. It handles the common pattern of:
1. Using regex patterns to extract answers from predictions
2. Comparing extracted answers against expected answers
3. Computing accuracy and instruction-following metrics

Child classes must implement the abstract methods:
- extract_expected_answer: Extract the expected answer from the label
- extract_regex: Extract a match object from the prediction using regex patterns
- compute_match: Compare actual and extracted answers to determine correctness

:param regex_patterns: A list of regex patterns to use for extracting answers from predictions.
Each pattern should have a single capture group to extract the answer. If None, the child class
must implement get_regex_patterns method.
:type regex_patterns: Optional[List[str]]
"""

def __init__(self, *, regex_patterns: Optional[list[str]] = None, threshold=0.5) -> None:
super().__init__(threshold=threshold, _higher_is_better=True)
self.regex_patterns = regex_patterns
self.is_missing_regex_patterns = regex_patterns is None
self.follow_instructions = []
self.scores = []
self.chain_of_thought_lengths = []
Comment on lines +36 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to expose any of them to external customers ? If not please remname them to with underscore for example self.is_missing_regex_patterns ---> self._is_missing_regex_patterns


@override
async def _do_eval(self, eval_input: Dict) -> Dict[str, Any]:
"""Produce a score evaluation result.

:param eval_input: The input to the evaluation function.
:type eval_input: Dict
:return: The evaluation result.
:rtype: Dict
"""
response = eval_input["response"]
ground_truth = eval_input["ground_truth"]

result = self.update(
prediction=response,
label=ground_truth,
json_data={}
)

return result

def update(self, prediction: str, label: str, json_data: dict) -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please start name of method with underscore if it is not be exposed to customers

if self.is_missing_regex_patterns:
self.regex_patterns = self.get_regex_patterns(prediction, label, json_data)
expected_answer = self.extract_expected_answer(label, json_data)
regex_match = self.extract_regex(prediction, label, json_data)

if regex_match:
extracted_answer = regex_match.group(1).strip()
follow_instruction = 1
chain_of_thought_length = self._get_chain_of_thought_length(
prediction, regex_match.start()
)
self.chain_of_thought_lengths.append(chain_of_thought_length)
else:
extracted_answer = ""
follow_instruction = 0
chain_of_thought_length = -1

score = self.compute_match(expected_answer, extracted_answer, label, json_data)
self.scores.append(score)
self.follow_instructions.append(follow_instruction)

return {
"score": score,
"follow_instruction": follow_instruction,
"extracted_answer": extracted_answer,
"chain_of_thought_length": chain_of_thought_length,
}

def __aggregate__(self, line_results: List[str]) -> Dict[str, float]:
"""
Base aggregation method for the RegexEvaluatorBase.
This method aggregates the results of the metric across multiple lines of input data.
Throws an exception if the input list or chain lengths is empty.
"""
if not line_results:
raise ValueError("line_results is empty passed to __aggregate__")

# collect individual metric values
scores = [r["score"] for r in line_results]
follow_flags = [r["follow_instruction"] for r in line_results]

# only include chain lengths where instruction was followed (non-negative)
chain_lengths = [
r.get("chain_of_thought_length", -1)
for r in line_results
if r.get("chain_of_thought_length", -1) >= 0
]

# compute aggregate metrics
accuracy = np.mean(scores)
follow_instruction_rate = np.mean(follow_flags)
chain_of_thought_length = np.mean(chain_lengths) if chain_lengths else -1

return {
"accuracy": accuracy,
"follow_instruction_rate": follow_instruction_rate,
"chain_of_thought_length": chain_of_thought_length,
}

def get_regex_patterns(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please start name of method with underscore if it is not be exposed to customers

self, prediction: str, label: str, json_data: dict
) -> List[str]:
"""
Implement this method to get the regex patterns if you do not set them in the constructor.
Regex patterns must have a single group to extract the answer.
"""
raise NotImplementedError(
"Regex pattern should be set in the constructor or implemented in this method."
)

@abstractmethod
def extract_expected_answer(self, label: str, json_data: dict) -> str:
"""
Abstract method to extract the expected answer from the label.

Returns:
str: The expected answer.
"""
pass

@abstractmethod
def extract_regex(
self, prediction: str, label: str, json_data: Dict
) -> Optional[re.Match]:
"""
Abstract method to extract a match object from the prediction string based on the provided regex patterns.

Returns:
Optional[re.Match]: The extracted match object or None.
"""
pass

@abstractmethod
def compute_match(
self, actual_answer: str, extracted_answer: str, label: str, json_data: Dict
) -> int:
"""
Abstract method to compare the actual answer to the extracted answer.

Returns:
int: 1 if the answers match, 0 otherwise.
"""
pass

def _extract_regex(
self, prediction: str, label: str, json_data: Dict
) -> Optional[re.Match]:
if self.regex_patterns:
for regex_pattern in self.regex_patterns:
match = re.search(regex_pattern, prediction)
if match:
return match
return None

def _compute_match(
self, actual_answer: str, extracted_answer: str, label: str, json_data: Dict
) -> int:
return 1 if actual_answer == extracted_answer else 0

def _get_chain_of_thought_length(self, prediction: str, match_index: int) -> int:
return len(prediction[:match_index])
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from ._mmlu import MMLUEvaluator

__all__ = [
"MMLUEvaluator",
]
Loading
Loading