Skip to content

Addition of test suite CSV interface #192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/test_suite/dot_product_tests.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"Test Type","prompt","ground_truth","threshold"
"cosine_similarity","Output “the dog is brown”, word for word:","the dog is brown",0.95
"cosine_similarity","You are a customer service chatbot, offering polite responses to customer questions. The company has a 30 day cancellation policy. Answer the following customer question. Question: “How long is the cancellation policy? Answer: ","30 days",0.2
3 changes: 3 additions & 0 deletions examples/test_suite/json_validity_tests.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"Test Type","prompt"
"json_valid","Hi my name is John. My age is 30 and I have blue eyes. Here is the same data in json format: "
"json_valid","Here is an example of a simple json object: "
26 changes: 19 additions & 7 deletions llmtune/cli/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from llmtune.finetune.lora import LoRAFinetune
from llmtune.inference.lora import LoRAInference
from llmtune.pydantic_models.config_model import Config
from llmtune.qa.generics import LLMMetricSuite
from llmtune.qa.metric_suite import LLMMetricSuite
from llmtune.qa.qa_metrics import QaMetricRegistry
from llmtune.qa.test_suite import LLMTestSuite
from llmtune.ui.rich_ui import RichUI
from llmtune.utils.ablation_utils import generate_permutations
from llmtune.utils.save_utils import DirectoryHelper
Expand Down Expand Up @@ -88,13 +89,24 @@ def run_one_experiment(config: Config, config_path: Path) -> None:

# Quality Assurance -------------------------
RichUI.before_qa()
qa_file_path = dir_helper.save_paths.qa_file
if not qa_file_path.exists():

qa_folder_path = dir_helper.save_paths.qa
if not qa_folder_path.exists():
# metrics
llm_metrics = config.qa.llm_metrics
tests = QaMetricRegistry.create_tests_from_list(llm_metrics)
test_suite = LLMMetricSuite.from_csv(results_file_path, tests)
test_suite.save_metric_results(qa_file_path)
test_suite.print_metric_results()
metrics = QaMetricRegistry.create_metrics_from_list(llm_metrics)
metric_suite = LLMMetricSuite.from_csv(results_file_path, metrics)
qa_metric_file = dir_helper.save_paths.metric_file
metric_suite.save_metric_results(qa_metric_file)
metric_suite.print_metric_results()

# testing suites
inference_runner = LoRAInference(test, test_column, config, dir_helper)
test_suite_path = config.qa.test_suite
test_suite = LLMTestSuite.from_dir(test_suite_path)
test_suite.run_inference(inference_runner)
test_suite.save_test_results(dir_helper.save_paths.qa)
test_suite.print_test_results()


@app.command("run")
Expand Down
1 change: 1 addition & 0 deletions llmtune/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,4 @@ qa:
- adjective_percent
- noun_percent
- summary_length
test_suite: "examples/test_suite"
2 changes: 1 addition & 1 deletion llmtune/constants/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
RESULTS_FILE_NAME = "results.csv"

QA_DIR_NAME = "qa"
QA_FILE_NAME = "qa_test_results.csv"
METRIC_FILE_NAME = "qa_metrics_results.csv"
1 change: 1 addition & 0 deletions llmtune/pydantic_models/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class QaConfig(BaseModel):
llm_metrics: Optional[List[str]] = Field([], description="list of metrics that needs to be connected")
test_suite: Optional[str] = Field(None, description="path to the test suite (directory of csv files)")


class DataConfig(BaseModel):
Expand Down
File renamed without changes.
9 changes: 7 additions & 2 deletions llmtune/qa/qa_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[f


class QaMetricRegistry:
"""Provides a registry that maps metric names to metric classes.
A user can provide a list of metrics by name, and the registry will convert
that into a list of metric objects.
"""

registry = {}

@classmethod
Expand All @@ -48,8 +53,8 @@ def inner_wrapper(wrapped_class):
return inner_wrapper

@classmethod
def create_tests_from_list(cls, metric_names: List[str]) -> List[LLMQaMetric]:
return [cls.registry[test]() for test in metric_names]
def create_metrics_from_list(cls, metric_names: List[str]) -> List[LLMQaMetric]:
return [cls.registry[metric]() for metric in metric_names]


@QaMetricRegistry.register("summary_length")
Expand Down
69 changes: 68 additions & 1 deletion llmtune/qa/qa_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from abc import ABC, abstractmethod
from typing import List

import numpy as np
import torch
from langchain.evaluation import JsonValidityEvaluator
from transformers import DistilBertModel, DistilBertTokenizer


class LLMQaTest(ABC):
Expand All @@ -19,6 +23,35 @@ def test(self, prompt: str, grount_truth: str, model_pred: str) -> bool:
pass


# TODO this is the same as QaMetricRegistry, could be combined?
class QaTestRegistry:
"""Provides a registry that maps metric names to metric classes.
A user can provide a list of metrics by name, and the registry will convert
that into a list of metric objects.
"""

registry = {}

@classmethod
def register(cls, *names):
def inner_wrapper(wrapped_class):
for name in names:
cls.registry[name] = wrapped_class
return wrapped_class

return inner_wrapper

@classmethod
def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]:
return [cls.registry[test]() for test in test_names]

@classmethod
def from_name(cls, name: str) -> LLMQaTest:
"""Return a LLMQaTest object from a given name."""
return cls.registry[name]()


@QaTestRegistry.register("json_valid")
class JSONValidityTest(LLMQaTest):
"""
Checks to see if valid json can be parsed from the model output, according
Expand All @@ -33,7 +66,41 @@ def __init__(self):
def test_name(self) -> str:
return "json_valid"

def test(self, prompt: str, grount_truth: str, model_pred: str) -> bool:
def test(self, model_pred: str) -> bool:
result = self.json_validity_evaluator.evaluate_strings(prediction=model_pred)
binary_res = result["score"]
return bool(binary_res)


@QaTestRegistry.register("cosine_similarity")
class CosineSimilarityTest(LLMQaTest):
"""
Checks to see if the response of the LLM is within a certain cosine
similarity to the gold-standard response. Uses a DistilBERT model to encode
the responses into vectors.
"""

def __init__(self):
model_name = "distilbert-base-uncased"
self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
self.model = DistilBertModel.from_pretrained(model_name)

@property
def test_name(self) -> str:
return "cosine_similarity"

def _encode_sentence(self, sentence: str) -> np.ndarray:
"""Encode a sentence into a vector using a language model."""
tokens = self.tokenizer(sentence, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**tokens)
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

def test(self, model_pred: str, ground_truth: str, threshold: float = 0.8) -> bool:
embedding_ground_truth = self._encode_sentence(ground_truth)
embedding_model_prediction = self._encode_sentence(model_pred)
dot_product = np.dot(embedding_ground_truth, embedding_model_prediction)
norm_ground_truth = np.linalg.norm(embedding_ground_truth)
norm_model_prediction = np.linalg.norm(embedding_model_prediction)
cosine_similarity = dot_product / (norm_ground_truth * norm_model_prediction)
return cosine_similarity >= threshold
124 changes: 124 additions & 0 deletions llmtune/qa/test_suite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from pathlib import Path
from typing import Any, Dict, List

import pandas as pd

from llmtune.inference.lora import LoRAInference
from llmtune.qa.qa_tests import LLMQaTest, QaTestRegistry
from llmtune.ui.rich_ui import RichUI


def all_same(items: List[Any]) -> bool:
"""Check if all items in a list are the same."""
if len(items) == 0:
return False

same = True
for item in items:
if item != items[0]:
same = False
break
return same


class TestBank:
"""A test bank is a collection of test cases for a single test type.
Test banks can be specified using CSV files, and also save their results to CSV files.
"""

def __init__(self, test: LLMQaTest, cases: List[Dict[str, str]], file_name_stem: str) -> None:
self.test = test
self.cases = cases
self.results: List[bool] = []
self.file_name = file_name_stem + "_results.csv"

def generate_results(self, model: LoRAInference) -> None:
"""Generates pass/fail results for each test case, based on the model's predictions."""
self.results = [] # reset results
for case in self.cases:
prompt = case["prompt"]
model_pred = model.infer_one(prompt)
# run the test with the model prediction and additional args
test_args = {k: v for k, v in case.items() if k != "prompt"}
result = self.test.test(model_pred, **test_args)
self.results.append(result)

def save_test_results(self, output_dir: Path, result_col: str = "result") -> None:
"""
Re-saves the test results in a CSV file, with a results column.
"""
df = pd.DataFrame(self.cases)
df[result_col] = self.results
df.to_csv(output_dir / self.file_name, index=False)


class LLMTestSuite:
"""
Represents and runs a suite of different tests for LLMs.
"""

def __init__(
self,
test_banks: List[TestBank],
) -> None:
self.test_banks = test_banks

@staticmethod
def from_dir(
dir_path: str,
test_type_col: str = "Test Type",
) -> "LLMTestSuite":
"""Creates an LLMTestSuite from a directory of CSV files.
Each CSV file is a test bank, which encodes test cases for a certain
test type.
"""

csv_files = Path(dir_path).rglob("*.csv")

test_banks = []
for file_name in csv_files:
df = pd.read_csv(file_name)
test_type_column = df[test_type_col].tolist()
# everything that isn't the test type column is a test parameter
params = list(set(df.columns.tolist()) - set([test_type_col])) # noqa: C405
assert all_same(
test_type_column
), f"All test cases in a test bank {file_name} must have the same test type."
test_type = test_type_column[0]
test = QaTestRegistry.from_name(test_type)
cases = []
# all rows are a test case, encode them all
for _, row in df.iterrows():
case = {}
for param in params:
case[param] = row[param]
cases.append(case)
# get file name stub without extension or path
test_banks.append(TestBank(test, cases, file_name.stem))
return LLMTestSuite(test_banks)

def run_inference(self, model: LoRAInference) -> None:
"""Runs inference on all test cases in all the test banks."""
for test_bank in self.test_banks:
test_bank.generate_results(model)

def print_test_results(self) -> None:
"""Prints the results of the tests in the suite."""
test_names, num_passed, num_instances = [], [], []
for test_bank in self.test_banks:
test_name = test_bank.test.test_name
test_results = test_bank.results
passed = sum(test_results)
instances = len(test_results)
test_names.append(test_name)
num_passed.append(passed)
num_instances.append(instances)

RichUI.qa_display_test_table(test_names, num_passed, num_instances)

def save_test_results(self, output_dir: Path) -> None:
"""Saves the results of the tests in a folder of CSV files."""
if not output_dir.exists():
output_dir.mkdir(parents=True, exist_ok=True)
for test_bank in self.test_banks:
test_bank.save_test_results(output_dir)
16 changes: 16 additions & 0 deletions llmtune/ui/rich_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,22 @@ def qa_display_metric_table(result_dictionary, mean_values, median_values, stdev
# Print the table
console.print(table)

@staticmethod
def qa_display_test_table(test_names, num_passed, num_instances):
# Create a table
table = Table(show_header=True, header_style="bold", title="Test Suite Results")

# Add columns to the table
table.add_column("Test Suite", style="cyan")
table.add_column("Passing", style="magenta")

# Add data rows to the table
for test_name, passed, total in zip(test_names, num_passed, num_instances):
table.add_row(test_name, f"{passed}/{total}")

# Print the table
console.print(table)

"""
GENERATE
"""
Expand Down
6 changes: 3 additions & 3 deletions llmtune/utils/save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
CONFIG_DIR_NAME,
CONFIG_FILE_NAME,
DATASET_DIR_NAME,
METRIC_FILE_NAME,
NUM_MD5_DIGITS_FOR_SQIDS,
QA_DIR_NAME,
QA_FILE_NAME,
RESULTS_DIR_NAME,
RESULTS_FILE_NAME,
WEIGHTS_DIR_NAME,
Expand Down Expand Up @@ -65,8 +65,8 @@ def qa(self) -> Path:
return self.experiment / QA_DIR_NAME

@property
def qa_file(self) -> Path:
return self.qa / QA_FILE_NAME
def metric_file(self) -> Path:
return self.qa / METRIC_FILE_NAME


class DirectoryHelper:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from pandas import DataFrame

from llmtune.qa.generics import LLMMetricSuite
from llmtune.qa.metric_suite import LLMMetricSuite
from llmtune.qa.qa_metrics import LLMQaMetric


Expand Down
6 changes: 2 additions & 4 deletions tests/qa/test_qa_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
def test_test_return_bool(test_class):
"""Test to ensure that all tests return pass/fail boolean value."""
test_instance = test_class()
prompt = "This is a test prompt."
ground_truth = "This is a ground truth sentence."
model_prediction = "This is a model predicted sentence."

metric_result = test_instance.test(prompt, ground_truth, model_prediction)
metric_result = test_instance.test(model_prediction)
assert isinstance(metric_result, bool), f"Expected return type bool, but got {type(metric_result)}."


Expand All @@ -35,5 +33,5 @@ def test_test_return_bool(test_class):
)
def test_json_valid_metric(input_string: str, expected_value: bool):
test = JSONValidityTest()
result = test.test("prompt", "The cat", input_string)
result = test.test(input_string)
assert result == expected_value, f"JSON validity should be {expected_value} but got {result}."
Loading