diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py index ee88bff25a39..e387ff733950 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py @@ -13,6 +13,7 @@ ViolenceEvaluator, ) from ._evaluators._f1_score import F1ScoreEvaluator +from ._evaluators._mmlu import MMLUEvaluator from ._evaluators._fluency import FluencyEvaluator from ._evaluators._gleu import GleuScoreEvaluator from ._evaluators._groundedness import GroundednessEvaluator @@ -82,6 +83,7 @@ "ContentSafetyEvaluator", "IndirectAttackEvaluator", "BleuScoreEvaluator", + "MMLUEvaluator", "GleuScoreEvaluator", "MeteorScoreEvaluator", "RetrievalEvaluator", diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_eval_mapping.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_eval_mapping.py index 6a4690ccf4eb..0987814148a6 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_eval_mapping.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_eval_mapping.py @@ -13,6 +13,7 @@ from azure.ai.evaluation._evaluators._eci._eci import ECIEvaluator from azure.ai.evaluation import ( BleuScoreEvaluator, + MMLUEvaluator, CodeVulnerabilityEvaluator, CoherenceEvaluator, ContentSafetyEvaluator, @@ -43,6 +44,7 @@ EVAL_CLASS_MAP = { BleuScoreEvaluator: "bleu_score", + MMLUEvaluator: "mmlu_score", CodeVulnerabilityEvaluator: "code_vulnerability", CoherenceEvaluator: "coherence", ContentSafetyEvaluator: "content_safety", diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/__init__.py index e883883e21d4..9b4d41b0870e 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/__init__.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/__init__.py @@ -6,10 +6,12 @@ from ._base_prompty_eval import PromptyEvaluatorBase from ._base_rai_svc_eval import RaiServiceEvaluatorBase from ._base_multi_eval import MultiEvaluatorBase +from ._base_regex_eval import RegexEvaluatorBase __all__ = [ "EvaluatorBase", "PromptyEvaluatorBase", "RaiServiceEvaluatorBase", "MultiEvaluatorBase", + "RegexEvaluatorBase", ] diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_regex_eval.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_regex_eval.py new file mode 100644 index 000000000000..9c623877fac5 --- /dev/null +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_regex_eval.py @@ -0,0 +1,183 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, List, Optional +from typing_extensions import override +from abc import abstractmethod +import re +import numpy as np + +from azure.ai.evaluation._evaluators._common import EvaluatorBase + + +class RegexEvaluatorBase(EvaluatorBase): + """Base class for all evaluators that are regex-based and use pattern matching to extract answers. + + This class provides a framework for evaluators that need to extract structured answers from text + using regular expressions. It handles the common pattern of: + 1. Using regex patterns to extract answers from predictions + 2. Comparing extracted answers against expected answers + 3. Computing accuracy and instruction-following metrics + + Child classes must implement the abstract methods: + - extract_expected_answer: Extract the expected answer from the label + - extract_regex: Extract a match object from the prediction using regex patterns + - compute_match: Compare actual and extracted answers to determine correctness + + :param regex_patterns: A list of regex patterns to use for extracting answers from predictions. + Each pattern should have a single capture group to extract the answer. If None, the child class + must implement get_regex_patterns method. + :type regex_patterns: Optional[List[str]] + """ + + def __init__(self, *, regex_patterns: Optional[list[str]] = None, threshold=0.5) -> None: + super().__init__(threshold=threshold, _higher_is_better=True) + self.regex_patterns = regex_patterns + self.is_missing_regex_patterns = regex_patterns is None + self.follow_instructions = [] + self.scores = [] + self.chain_of_thought_lengths = [] + + @override + async def _do_eval(self, eval_input: Dict) -> Dict[str, Any]: + """Produce a score evaluation result. + + :param eval_input: The input to the evaluation function. + :type eval_input: Dict + :return: The evaluation result. + :rtype: Dict + """ + response = eval_input["response"] + ground_truth = eval_input["ground_truth"] + + result = self.update( + prediction=response, + label=ground_truth, + json_data={} + ) + + return result + + def update(self, prediction: str, label: str, json_data: dict) -> Dict[str, Any]: + if self.is_missing_regex_patterns: + self.regex_patterns = self.get_regex_patterns(prediction, label, json_data) + expected_answer = self.extract_expected_answer(label, json_data) + regex_match = self.extract_regex(prediction, label, json_data) + + if regex_match: + extracted_answer = regex_match.group(1).strip() + follow_instruction = 1 + chain_of_thought_length = self._get_chain_of_thought_length( + prediction, regex_match.start() + ) + self.chain_of_thought_lengths.append(chain_of_thought_length) + else: + extracted_answer = "" + follow_instruction = 0 + chain_of_thought_length = -1 + + score = self.compute_match(expected_answer, extracted_answer, label, json_data) + self.scores.append(score) + self.follow_instructions.append(follow_instruction) + + return { + "score": score, + "follow_instruction": follow_instruction, + "extracted_answer": extracted_answer, + "chain_of_thought_length": chain_of_thought_length, + } + + def __aggregate__(self, line_results: List[str]) -> Dict[str, float]: + """ + Base aggregation method for the RegexEvaluatorBase. + This method aggregates the results of the metric across multiple lines of input data. + Throws an exception if the input list or chain lengths is empty. + """ + if not line_results: + raise ValueError("line_results is empty passed to __aggregate__") + + # collect individual metric values + scores = [r["score"] for r in line_results] + follow_flags = [r["follow_instruction"] for r in line_results] + + # only include chain lengths where instruction was followed (non-negative) + chain_lengths = [ + r.get("chain_of_thought_length", -1) + for r in line_results + if r.get("chain_of_thought_length", -1) >= 0 + ] + + # compute aggregate metrics + accuracy = np.mean(scores) + follow_instruction_rate = np.mean(follow_flags) + chain_of_thought_length = np.mean(chain_lengths) if chain_lengths else -1 + + return { + "accuracy": accuracy, + "follow_instruction_rate": follow_instruction_rate, + "chain_of_thought_length": chain_of_thought_length, + } + + def get_regex_patterns( + self, prediction: str, label: str, json_data: dict + ) -> List[str]: + """ + Implement this method to get the regex patterns if you do not set them in the constructor. + Regex patterns must have a single group to extract the answer. + """ + raise NotImplementedError( + "Regex pattern should be set in the constructor or implemented in this method." + ) + + @abstractmethod + def extract_expected_answer(self, label: str, json_data: dict) -> str: + """ + Abstract method to extract the expected answer from the label. + + Returns: + str: The expected answer. + """ + pass + + @abstractmethod + def extract_regex( + self, prediction: str, label: str, json_data: Dict + ) -> Optional[re.Match]: + """ + Abstract method to extract a match object from the prediction string based on the provided regex patterns. + + Returns: + Optional[re.Match]: The extracted match object or None. + """ + pass + + @abstractmethod + def compute_match( + self, actual_answer: str, extracted_answer: str, label: str, json_data: Dict + ) -> int: + """ + Abstract method to compare the actual answer to the extracted answer. + + Returns: + int: 1 if the answers match, 0 otherwise. + """ + pass + + def _extract_regex( + self, prediction: str, label: str, json_data: Dict + ) -> Optional[re.Match]: + if self.regex_patterns: + for regex_pattern in self.regex_patterns: + match = re.search(regex_pattern, prediction) + if match: + return match + return None + + def _compute_match( + self, actual_answer: str, extracted_answer: str, label: str, json_data: Dict + ) -> int: + return 1 if actual_answer == extracted_answer else 0 + + def _get_chain_of_thought_length(self, prediction: str, match_index: int) -> int: + return len(prediction[:match_index]) \ No newline at end of file diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_mmlu/__init__.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_mmlu/__init__.py new file mode 100644 index 000000000000..aca7aee6e78b --- /dev/null +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_mmlu/__init__.py @@ -0,0 +1,9 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from ._mmlu import MMLUEvaluator + +__all__ = [ + "MMLUEvaluator", +] diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_mmlu/_mmlu.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_mmlu/_mmlu.py new file mode 100644 index 000000000000..29ff8cb5f757 --- /dev/null +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_mmlu/_mmlu.py @@ -0,0 +1,178 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import json +from collections import defaultdict +from typing import Any, Dict, List + +import numpy as np +from azure.ai.evaluation._evaluators._common import RegexEvaluatorBase +from typing_extensions import overload, override + +ANSWER_PATTERNS = [r"(?i)ANSWER\s*:\s*\$?([A-J])\$?"] + + +class MMLUEvaluator(RegexEvaluatorBase): + """ + Evaluates model performance on the MMLU (Massive Multitask Language Understanding) benchmark. + + MMLU is a comprehensive benchmark that tests a model's understanding across 57 academic subjects + ranging from mathematics and physics to history and law. It measures both breadth and depth of + knowledge through multiple-choice questions. + + The evaluator expects answers in the format "ANSWER: A" (or B, C, D, etc.) and computes accuracy + metrics both overall and grouped by subject and category. Use this evaluator when you want to + assess a model's general knowledge and reasoning abilities across diverse academic domains. + + The MMLU score value is either 0 or 1, with higher scores indicating better performance. + :param threshold: The threshold for the evaluation. Default is 0.5. + :type threshold: float + + .. admonition:: Example: + + .. literalinclude:: ../samples/evaluation_samples_evaluate.py + :start-after: [START mmlu_score_evaluator] + :end-before: [END mmlu_score_evaluator] + :language: python + :dedent: 8 + :caption: Initialize and call an MMLUEvaluator. + + .. admonition:: Example using Azure AI Project URL: + + .. literalinclude:: ../samples/evaluation_samples_evaluate_fdp.py + :start-after: [START mmlu_score_evaluator] + :end-before: [END mmlu_score_evaluator] + :language: python + :dedent: 8 + :caption: Initialize and call an MMLUEvaluator using Azure AI Project URL in following format + https://{resource_name}.services.ai.azure.com/api/projects/{project_name} + + .. admonition:: Example with Threshold: + + .. literalinclude:: ../samples/evaluation_samples_threshold.py + :start-after: [START threshold_mmlu_score_evaluator] + :end-before: [END threshold_mmlu_score_evaluator] + :language: python + :dedent: 8 + :caption: Initialize with threshold and call an MMLUEvaluator. + """ + + def __init__(self, *, threshold=0.5): + super().__init__(regex_patterns=ANSWER_PATTERNS, threshold=threshold) + self.subject2scores = defaultdict(list) + self.category2scores = defaultdict(list) + + def update(self, prediction: str, label: str, json_data: dict) -> Dict[str, Any]: + sample_metrics = super().update(prediction, label, json_data) + try: + label_dict = json.loads(label) + except json.JSONDecodeError: + raise ValueError("The label/ground_truth must be a valid JSON string.") + + subject = label_dict.get("subject") + if subject is None: + raise ValueError("The label/ground_truth JSON must contain a 'subject' key.") + + category = label_dict.get("category") + if category is None: + raise ValueError("The label/ground_truth JSON must contain a 'category' key.") + + if label_dict.get("answer") is None: + raise ValueError("The label/ground_truth JSON must contain an 'answer' key.") + + self.subject2scores[subject].append(sample_metrics["score"]) + self.category2scores[category].append(sample_metrics["score"]) + + sample_metrics.update( + { + "mmlu_score": sample_metrics["score"], # Just needed for _real_call + "accuracy": sample_metrics["score"], + "subject": subject, + "category": category, + } + ) + + return sample_metrics + + def __aggregate__(self, line_results: List[str]) -> Dict[str, float]: + """Aggregate the results from the line results.""" + base_metrics = super().__aggregate__(line_results) + + # compute macro accuracy by subject + subject2scores: Dict[Any, List[float]] = defaultdict(list) + for r in line_results: + if "subject" not in r: + raise KeyError(f"Missing key 'subject' in line result: {r}") + subject2scores[r["subject"]].append(r["score"]) + accuracy_macro_by_subject = ( + np.mean([np.mean(v) for v in subject2scores.values()]) + if subject2scores + else float("nan") + ) + + # compute macro accuracy by category + category2scores: Dict[Any, List[float]] = defaultdict(list) + for r in line_results: + if "category" not in r: + raise KeyError(f"Missing key 'category' in line result: {r}") + category2scores[r["category"]].append(r["score"]) + accuracy_macro_by_category = ( + np.mean([np.mean(v) for v in category2scores.values()]) + if category2scores + else float("nan") + ) + base_metrics.update( + { + "accuracy_macro_by_subject": accuracy_macro_by_subject, + "accuracy_macro_by_category": accuracy_macro_by_category, + } + ) + + return base_metrics + + def extract_expected_answer(self, label: str, json_data: dict) -> str: + return json.loads(label).get("answer") + + def extract_regex(self, prediction: str, label: str, json_data: Dict) -> Any: + return self._extract_regex(prediction, label, json_data) + + def compute_match( + self, actual_answer: str, extracted_answer: str, label: str, json_data: Dict + ) -> int: + return self._compute_match(actual_answer, extracted_answer, label, json_data) + + @overload # type: ignore + def __call__(self, *, response: str, ground_truth: str): + """ + Evaluate the MMLU score between the response and the ground truth. + + :keyword response: The response to be evaluated. + :paramtype response: str + The ground truth must be in serialized JSON string format containing the following fields: + - subject: The subject area of the question + - category: The category classification + - answer: The correct answer + :keyword ground_truth: The ground truth to be compared against in serialized JSON string format. + :paramtype ground_truth: str + :return: The MMLU score. + :rtype: Dict[str, Any] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, + **kwargs, + ): + """ + Evaluate the MMLU score between the response and the ground truth. + + :keyword response: The response to be evaluated. + :paramtype response: str + :keyword ground_truth: The ground truth to be compared against. + :paramtype ground_truth: str + :return: The MMLU score. + :rtype: Dict[str, Any] + """ + return super().__call__(*args, **kwargs) diff --git a/sdk/evaluation/azure-ai-evaluation/migration_guide.md b/sdk/evaluation/azure-ai-evaluation/migration_guide.md index 90e9ca2877ae..794b338c17f7 100644 --- a/sdk/evaluation/azure-ai-evaluation/migration_guide.md +++ b/sdk/evaluation/azure-ai-evaluation/migration_guide.md @@ -15,7 +15,7 @@ Following Built-in evaluators are provided in new Azure AI Evaluation SDK ([azur | Category | Evaluator class| |-----------------------------|------------------------------------------| | [Performance and quality][performance_and_quality_evaluators] (AI-assisted) | `GroundednessEvaluator`, `RelevanceEvaluator`, `CoherenceEvaluator`, `FluencyEvaluator`, `SimilarityEvaluator`, `RetrievalEvaluator` | -| [Performance and quality][performance_and_quality_evaluators] (NLP) | `F1ScoreEvaluator`, `RougeScoreEvaluator`, `GleuScoreEvaluator`, `BleuScoreEvaluator`, `MeteorScoreEvaluator`| +| [Performance and quality][performance_and_quality_evaluators] (NLP) | `F1ScoreEvaluator`, `RougeScoreEvaluator`, `GleuScoreEvaluator`, `BleuScoreEvaluator`, `MMLUEvaluator`, `MeteorScoreEvaluator`| | [Risk and safety][risk_and_safety_evaluators] (AI-assisted) | `ViolenceEvaluator`, `SexualEvaluator`, `SelfHarmEvaluator`, `HateUnfairnessEvaluator`, `IndirectAttackEvaluator`, `ProtectedMaterialEvaluator` | | [Composite][composite_evaluators] | `QAEvaluator`, `ContentSafetyEvaluator` diff --git a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_evaluate.py b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_evaluate.py index 368c0dbb8fab..d5b7f4d8165e 100644 --- a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_evaluate.py +++ b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_evaluate.py @@ -73,6 +73,13 @@ def evaluation_evaluate_classes_methods(self): bleu_evaluator(response="Lyon is the capital of France.", ground_truth="Paris is the capital of France.") # [END bleu_score_evaluator] + # [START mmlu_score_evaluator] + from azure.ai.evaluation import MMLUEvaluator + + mmlu_evaluator = MMLUEvaluator() + mmlu_evaluator(response="ANSWER: F", ground_truth="{\"answer\": \"F\", \"subject\": \"business\", \"category\": \"other\"}") + # [END mmlu_score_evaluator] + # [START coherence_evaluator] import os from azure.ai.evaluation import CoherenceEvaluator diff --git a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_evaluate_fdp.py b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_evaluate_fdp.py index e1134f1b1609..783985800e97 100644 --- a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_evaluate_fdp.py +++ b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_evaluate_fdp.py @@ -78,6 +78,13 @@ def evaluation_evaluate_classes_methods(self): bleu_evaluator(response="Lyon is the capital of France.", ground_truth="Paris is the capital of France.") # [END bleu_score_evaluator] + # [START mmlu_score_evaluator] + from azure.ai.evaluation import MMLUEvaluator + + mmlu_evaluator = MMLUEvaluator() + mmlu_evaluator(response="ANSWER: F", ground_truth="{\"answer\": \"F\", \"subject\": \"business\", \"category\": \"other\"}") + # [END mmlu_score_evaluator] + # [START coherence_evaluator] import os from azure.ai.evaluation import CoherenceEvaluator diff --git a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_threshold.py b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_threshold.py index 1b660b0a771c..411211dd540d 100644 --- a/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_threshold.py +++ b/sdk/evaluation/azure-ai-evaluation/samples/evaluation_samples_threshold.py @@ -198,6 +198,13 @@ def evaluation_classes_methods_with_thresholds(self): f1_evaluator(response="Lyon is the capital of France.", ground_truth="Paris is the capital of France.") # [END threshold_f1_score_evaluator] + # [START threshold_mmlu_score_evaluator] + from azure.ai.evaluation import MMLUEvaluator + + mmlu_evaluator = MMLUEvaluator(threshold=0.6) + mmlu_evaluator(response="ANSWER: F", ground_truth="{\"answer\": \"F\", \"subject\": \"business\", \"category\": \"other\"}") + # [END threshold_mmlu_score_evaluator] + # [START threshold_fluency_evaluator] import os from azure.ai.evaluation import FluencyEvaluator diff --git a/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_builtin_evaluators.py b/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_builtin_evaluators.py index e05a7a5eeb94..ccca3ee34a1c 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_builtin_evaluators.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_builtin_evaluators.py @@ -23,6 +23,7 @@ from azure.ai.evaluation._http_utils import AsyncHttpPipeline from azure.ai.evaluation import ( BleuScoreEvaluator, + MMLUEvaluator, CoherenceEvaluator, ContentSafetyEvaluator, F1ScoreEvaluator, @@ -63,6 +64,15 @@ def test_math_evaluator_bleu_score(self): assert score is not None and "bleu_score" in score assert 0 <= score["bleu_score"] <= 1 + def test_math_evaluator_mmlu_score(self): + eval_fn = MMLUEvaluator() + score = eval_fn( + ground_truth="{\"answer\": \"F\", \"subject\": \"business\", \"category\": \"other\"}", + response="ANSWER: F", + ) + assert score is not None and "mmlu_score" in score + assert 0 <= score["mmlu_score"] <= 1 + def test_math_evaluator_gleu_score(self): eval_fn = GleuScoreEvaluator() score = eval_fn( diff --git a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluators/test_threshold_behavior.py b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluators/test_threshold_behavior.py index cc42192c0f54..de4991edb97c 100644 --- a/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluators/test_threshold_behavior.py +++ b/sdk/evaluation/azure-ai-evaluation/tests/unittests/test_evaluators/test_threshold_behavior.py @@ -7,6 +7,7 @@ from azure.ai.evaluation._constants import EVALUATION_PASS_FAIL_MAPPING from azure.ai.evaluation._evaluators._bleu._bleu import BleuScoreEvaluator +from azure.ai.evaluation._evaluators._mmlu._mmlu import MMLUEvaluator from azure.ai.evaluation._evaluators._rouge._rouge import RougeScoreEvaluator, RougeType from azure.ai.evaluation._evaluators._gleu._gleu import GleuScoreEvaluator from azure.ai.evaluation._evaluators._meteor._meteor import MeteorScoreEvaluator @@ -43,6 +44,32 @@ def test_bleu_threshold(self, mock_call, threshold, score, should_pass): # Verify pass/fail based on threshold comparison assert mock_result["bleu_result"] == EVALUATION_PASS_FAIL_MAPPING[should_pass] + @pytest.mark.parametrize("threshold,score,should_pass", [(0.5, 0.8, True), (0.7, 0.5, False)]) + @patch("azure.ai.evaluation._evaluators._mmlu._mmlu.MMLUEvaluator.__call__") + def test_mmlu_threshold(self, mock_call, threshold, score, should_pass): + """Test threshold behavior in MMLUEvaluator.""" + # Create the evaluator + evaluator = MMLUEvaluator(threshold=threshold) + + # Create a result dictionary with the expected values + result = { + "mmlu_score": score, + "mmlu_result": EVALUATION_PASS_FAIL_MAPPING[should_pass], + "mmlu_threshold": threshold, + } + + # Configure mock to return our result + mock_call.return_value = result + + # Because we're mocking the __call__ method directly, we need to call the mock + mock_result = mock_call(ground_truth="reference", response="candidate") + + # Verify threshold is correctly included in output + assert mock_result["mmlu_threshold"] == threshold + + # Verify pass/fail based on threshold comparison + assert mock_result["mmlu_result"] == EVALUATION_PASS_FAIL_MAPPING[should_pass] + @pytest.mark.parametrize("threshold,score,should_pass", [(0.5, 0.7, True), (0.7, 0.6, False)]) @patch("azure.ai.evaluation._evaluators._gleu._gleu.GleuScoreEvaluator.__call__") def test_gleu_threshold(self, mock_call, threshold, score, should_pass): @@ -234,6 +261,20 @@ def test_bleu_score_decimal_threshold_behavior(self): assert result["bleu_result"] == "pass" assert result["bleu_threshold"] == 0.1 + def test_mmlu_score_decimal_threshold_behavior(self): + """Test that MMLUEvaluator correctly handles decimal scores for threshold comparison.""" + evaluator = MMLUEvaluator(threshold=0.1) + + result = evaluator( + response="ANSWER: F", + ground_truth="{\"answer\": \"F\", \"subject\": \"business\", \"category\": \"other\"}" + ) + + # The score should be > 0.1 and result should be "pass" + assert result["mmlu_score"] > 0.1 + assert result["mmlu_result"] == "pass" + assert result["mmlu_threshold"] == 0.1 + def test_meteor_score_threshold_boundary_cases(self): """Test MeteorScoreEvaluator threshold boundary cases.""" # Test where score should be just above threshold