Skip to content

Commit 90061b5

Browse files
Merge pull request #5 from guardrails-ai/jc/change_behavior
Fix 'OnFix' behavior.
2 parents 7b504ed + 680a8f8 commit 90061b5

File tree

6 files changed

+127
-44
lines changed

6 files changed

+127
-44
lines changed

.github/workflows/publish_pypi.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
name: Publish to Guardrails Hub
2+
3+
on:
4+
workflow_dispatch:
5+
push:
6+
# Publish when new releases are tagged.
7+
tags:
8+
- '*'
9+
10+
jobs:
11+
setup:
12+
runs-on: ubuntu-latest
13+
steps:
14+
- name: Build & Deploy
15+
uses: guardrails-ai/guardrails/.github/actions/validator_pypi_publish@main
16+
with:
17+
guardrails_token: ${{ secrets.GR_GUARDRAILS_TOKEN }}
18+
validator_id: guardrails/bias_check

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ This validator can be used to ensure fairness of model output across various dem
1717

1818
* Dependencies:
1919
- guardrails-ai>=0.5.0
20-
- transformers>=4.40
20+
- transformers>=4.37
21+
- tf-keras
22+
- sentencepiece
23+
- tensorflow>=2.16
24+
- sentence-splitter
2125

2226
* Dev Dependencies:
2327
- pytest
@@ -50,5 +54,5 @@ guard = Guard().use(
5054
BiasCheck(threshold=0.9, on_fail="exception")
5155
)
5256

53-
guard.validate("The movie was great!") # Validator passes
54-
guard.validate("Why do men always think the movie was great?") # Validator fails
57+
guard.validate("The movie was great.") # Validator passes
58+
guard.validate("Why do men think that movie was great?") # Validator fails

pyproject.toml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
[project]
22
name = "bias_check"
3-
version = "0.2.0"
3+
version = "0.3.0"
44
description = "Wrapping a bias check model into a validator."
55
authors = [
66
{name = "Guardrails AI", email = "contact@guardrailsai.com"}
77
]
88
license = {file = "LICENSE"}
99
readme = "README.md"
10-
requires-python = ">= 3.8.1"
10+
requires-python = ">= 3.9"
1111
dependencies = [
12-
"guardrails-ai>=0.4.0",
13-
"transformers>=4.40.2",
14-
"tensorflow>=2.18.0"
12+
"guardrails-ai>=0.5.15",
13+
"transformers>=4.37.0",
14+
"tf-keras",
15+
"sentencepiece",
16+
"tensorflow>=2.16.0", # Required for the dbias model, but not as a direct dependency.
17+
"sentence-splitter>=1.4"
1518
]
1619

1720
[project.optional-dependencies]

tests/test_validator.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,26 @@
33

44
from validator import BiasCheck
55

6-
# We use 'exception' as the validator's fail action,
7-
# so we expect failures to always raise an Exception
8-
# Learn more about corrective actions here:
9-
# https://www.guardrailsai.com/docs/concepts/output/#%EF%B8%8F-specifying-corrective-actions
6+
107
def test_success_case():
11-
validator = BiasCheck()
12-
input_text = "The sun rises in the morning."
13-
result = validator.validate(input_text, {})
14-
assert isinstance(result, PassResult)
8+
validator = BiasCheck()
9+
input_text = "The sun rises in the morning."
10+
result = validator.validate(input_text, {})
11+
assert isinstance(result, PassResult)
12+
1513

1614
def test_failure_case():
17-
validator = BiasCheck(threshold=0.5)
18-
input_text = "The sun only rises for Humanists."
19-
result = validator.validate(input_text, {})
20-
assert isinstance(result, FailResult)
21-
assert result.error_message == "The original response contains potentially biased messages:\nThe sun only rises for Humanists.\n (Message scores: 0.6546186208724976)"
22-
assert result.fix_value == ""
15+
validator = BiasCheck(threshold=0.0)
16+
input_text = "Men these days don't care about my arbitrary and deletarious standards of gender. They only care about emotional honesty and participating in a productive, healthy society. smh"
17+
result = validator.validate(input_text, {})
18+
assert isinstance(result, FailResult)
19+
assert result.error_message == "The original response contains potentially biased messages:\nMen these days don't care about my arbitrary and deletarious standards of gender. They only care about emotional honesty and participating in a productive, healthy society. smh\n (Message scores: 0.9477301836013794)"
20+
assert result.fix_value == ""
21+
22+
23+
def test_sentence_fix():
24+
v = BiasCheck(on_fail='fix', threshold=0.9)
25+
input_text = "Men these days don't care about my arbitrary and deletarious standards of gender. They only care about emotional honesty and participating in a productive, healthy society. smh"
26+
out = v.validate(input_text)
27+
assert isinstance(out, FailResult)
28+
assert out.fix_value == "Men these days don't care about my arbitrary and deletarious standards of gender."

validator/main.py

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, Optional
1+
from typing import Callable, Dict, List, Optional, Union
22

33
from guardrails.validator_base import (
44
FailResult,
@@ -7,9 +7,11 @@
77
Validator,
88
register_validator,
99
)
10-
10+
from guardrails.types import OnFailAction
11+
from sentence_splitter import split_text_into_sentences
1112
from transformers import pipeline
1213

14+
1315
@register_validator(name="guardrails/bias_check", data_type="string")
1416
class BiasCheck(Validator):
1517
"""Validates that the text is free from biases related to age, gender, sex, ethnicity, religion, etc.
@@ -23,61 +25,107 @@ class BiasCheck(Validator):
2325
| Programmatic fix | The debiased text if bias is detected |
2426
2527
Args:
26-
threshold (float): Higher is more likely to allow bias. Lower is more sensitive and more likely to flag biased messages.
27-
on_fail (Callable): The policy to enact when a validator fails. If `str`, must be one of `filter`, `noop`, or `exception`. Otherwise, must be a function that is called when the validator fails.
28+
threshold (float): Higher is more likely to allow bias. Lower is more sensitive and more likely to flag biased messages.
29+
on_fail (Callable): The policy to enact when a validator fails. If `str`, must be one of `noop`, `fix`, or `exception`. Otherwise, must be a function that is called when the validator fails.
2830
""" # noqa
2931

3032
def __init__(
3133
self,
3234
threshold: float = 0.9,
33-
on_fail: Optional[Callable] = None,
35+
on_fail: Optional[Union[str, Callable]] = None,
3436
):
35-
super().__init__(on_fail=on_fail)
36-
valid_on_fail_operations = {"filter", "noop", "exception"}
37+
super().__init__(on_fail=on_fail) # type: ignore
38+
valid_on_fail_operations = {"fix", "noop", "exception"}
3739
if isinstance(on_fail, str) and on_fail not in valid_on_fail_operations:
3840
raise Exception(
3941
f"on_fail value ({on_fail}) not in list of allowable operations: {valid_on_fail_operations}"
4042
)
4143
self.threshold = threshold
42-
self.model = pipeline(
44+
45+
# There are some spurious loading complaints with TFDistilBert models.
46+
# See https://discuss.huggingface.co/t/message-some-layers-from-the-model-were-not-used/1972/7
47+
self.classification_model = pipeline(
4348
'text-classification',
4449
model="d4data/bias-detection-model",
50+
tokenizer="d4data/bias-detection-model",
4551
)
46-
# There are some spurious loading complaints with TFDistilBert models.
47-
# See https://discuss.huggingface.co/t/message-some-layers-from-the-model-were-not-used/1972/7
4852

49-
def validate(self, value: Any, metadata: Optional[Dict] = None) -> ValidationResult:
53+
def validate(
54+
self,
55+
value: Union[str, List[str]],
56+
metadata: Optional[Dict] = None
57+
) -> ValidationResult:
5058
"""Validates that the text is free from biases related to age, gender, sex, ethnicity, religion, etc."""
59+
single_sentence_passed = False
5160
if isinstance(value, str):
61+
single_sentence_passed = True
5262
value = [value,] # Ensure we're always passing lists of strings into the classifier.
5363

54-
classified_examples = self.model(value)
64+
scores = self._inference(value)
5565
passing_outputs = list()
5666
passing_scores = list()
5767
failing_outputs = list()
5868
failing_scores = list()
59-
for text, prediction in zip(value, classified_examples):
60-
if prediction['label'] == 'Biased':
61-
score = prediction['score']
62-
elif prediction['label'] == 'Non-biased':
63-
score = -prediction['score'] # Note the negation!
64-
else:
65-
raise Exception(f"Got unexpected prediction label: {prediction['label']}")
69+
all_outputs = list() # A tuple of (fix/ignore, sentence)
70+
for text, score in zip(value, scores):
6671
if score > self.threshold:
6772
failing_outputs.append(text)
6873
failing_scores.append(score)
6974
else:
7075
passing_outputs.append(text)
7176
passing_scores.append(score)
77+
all_outputs.append((score > self.threshold, text))
7278

7379
if failing_outputs:
7480
failure_message = "The original response contains potentially biased messages:\n"
7581
failure_message += "\n - ".join(failing_outputs)
7682
message_scores = [str(s) for s in failing_scores]
7783
failure_message += "\n (Message scores: {})".format(", ".join(message_scores))
78-
# Do we need to call the on_fail_method here?
84+
# Three paths: noop, exception, fix.
85+
# on_fail == NOOP, return only passing passages.
86+
# on_fail == FIX, split passages into sentences and drop sentences.
87+
# EXCEPTION is handled farther up the stack.
88+
if self.on_fail_descriptor != OnFailAction.FIX:
89+
fix_value = passing_outputs
90+
else:
91+
fix_value = list()
92+
for needs_fix, text in all_outputs:
93+
if not needs_fix:
94+
fix_value.append(text)
95+
else:
96+
# The 'text' is a full document, passage, or paragraph.
97+
fix_value.append(self.fix_passage(text))
7998
return FailResult(
8099
error_message=failure_message,
81-
fix_value=" ".join(passing_outputs),
100+
fix_value=" ".join(fix_value) if single_sentence_passed else fix_value,
82101
)
83102
return PassResult()
103+
104+
def fix_passage(self, text: str) -> str:
105+
"""Given a passage of text, split it into sentences, evaluate each for bias,
106+
then recombine them and return a new paragraph. May not preserve whitespace
107+
between sentences."""
108+
sentences = split_text_into_sentences(text, language='en')
109+
scores = self._inference(sentences)
110+
unbiased_sentences = list()
111+
for score, sentence in zip(scores, sentences):
112+
if score < self.threshold:
113+
unbiased_sentences.append(sentence)
114+
return " ".join(unbiased_sentences)
115+
116+
# This normally will be called by _inference.
117+
# Remote inference is unsupported for this model on account of the NER.
118+
def _inference_local(self, sentences: List[str]) -> List[float]: # type: ignore
119+
scores = list()
120+
predictions = self.classification_model(sentences)
121+
for pred in predictions:
122+
label = pred['label'] # type: ignore
123+
score = pred['score'] # type: ignore
124+
if label == 'Biased':
125+
scores.append(score)
126+
elif label == 'Non-biased':
127+
scores.append(-score)
128+
else:
129+
# This should never happen:
130+
raise Exception("Unexpected prediction label: {}".format(label))
131+
return scores

validator/post-install.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from transformers import pipeline
22
print("post-install starting...")
3-
_ = pipeline("text-classification", "d4data/bias-detection-model")
3+
_ = pipeline(
4+
'text-classification',
5+
model="d4data/bias-detection-model",
6+
tokenizer="d4data/bias-detection-model",
7+
)
48
print("post-install complete!")

0 commit comments

Comments
 (0)