Skip to content

Change validator to respect local/remote inference. #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/publish_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ name: Publish to Guardrails Hub
on:
workflow_dispatch:
push:
branches:
- main
# Publish when new releases are tagged.
tags:
- '*'

jobs:
setup:
Expand Down
39 changes: 7 additions & 32 deletions app_inference_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# Forked from spec:
# github.com/guardrails-ai/models-host/tree/main/ray#adding-new-inference-endpoints
import os
from typing import Optional
from logging import getLogger
from typing import List

from fastapi import HTTPException
from pydantic import BaseModel
from models_host.base_inference_spec import BaseInferenceSpec

Expand All @@ -22,14 +21,11 @@


class InputRequest(BaseModel):
message: str
threshold: Optional[float] = None
prompts: List[str]


class OutputResponse(BaseModel):
classification: str
score: float
is_jailbreak: bool
scores: List[float]


# Using same nomenclature as in Sagemaker classes
Expand Down Expand Up @@ -64,35 +60,14 @@ def load(self):
self.model = DetectJailbreak(**kwargs)

def process_request(self, input_request: InputRequest):
message = input_request.message
prompts = input_request.prompts
# If needed, sanity check.
# raise HTTPException(status_code=400, detail="Invalid input format")
args = (message,)
args = (prompts,)
kwargs = {}
if input_request.threshold is not None:
kwargs["threshold"] = input_request.threshold
if not 0.0 <= input_request.threshold <= 1.0:
raise HTTPException(
status_code=400,
detail=f"Threshold must be between 0.0 and 1.0. "
f"Got {input_request.threshold}"
)
return args, kwargs

def infer(self, message: str, threshold: Optional[float] = None) -> OutputResponse:
if threshold is None:
threshold = 0.81

score = self.model.predict_jailbreak([message,])[0]
if score > threshold:
classification = "jailbreak"
is_jailbreak = True
else:
classification = "safe"
is_jailbreak = False

def infer(self, prompts: List[str]) -> OutputResponse:
return OutputResponse(
classification=classification,
score=score,
is_jailbreak=is_jailbreak,
scores=self.model.predict_jailbreak(prompts),
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "detect-jailbreak"
version = "0.1.3"
version = "0.1.4"
description = "A prompt-injection and jailbreak detector for LLMs."
authors = [
{name = "Guardrails AI", email = "contact@guardrailsai.com"},
Expand Down
127 changes: 79 additions & 48 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import math
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Union, Any

import torch
from torch.nn import functional as F
Expand Down Expand Up @@ -65,57 +66,64 @@ def __init__(
device: str = "cpu",
on_fail: Optional[Callable] = None,
model_path_override: str = "",
**kwargs,
):
super().__init__(on_fail=on_fail)
super().__init__(on_fail=on_fail, **kwargs)
self.device = device
self.threshold = threshold
self.saturation_attack_detector = None
self.text_classifier = None
self.embedding_tokenizer = None
self.embedding_model = None
self.known_malicious_embeddings = []

if not model_path_override:
self.saturation_attack_detector = PromptSaturationDetectorV3(
device=torch.device(device),
)
self.text_classifier = pipeline(
"text-classification",
DetectJailbreak.TEXT_CLASSIFIER_NAME,
max_length=512, # HACK: Fix classifier size.
truncation=True,
device=device,
)
# There are a large number of fairly low-effort prompts people will use.
# The embedding detectors do checks to roughly match those.
self.embedding_tokenizer = AutoTokenizer.from_pretrained(
DetectJailbreak.EMBEDDING_MODEL_NAME
)
self.embedding_model = AutoModel.from_pretrained(
DetectJailbreak.EMBEDDING_MODEL_NAME
).to(device)
else:
# Saturation:
self.saturation_attack_detector = PromptSaturationDetectorV3(
device=torch.device(device),
model_path_override=model_path_override
)
# Known attacks:
embedding_tokenizer, embedding_model = get_tokenizer_and_model_by_path(
model_path_override,
"embedding",
AutoTokenizer,
AutoModel
)
self.embedding_tokenizer = embedding_tokenizer
self.embedding_model = embedding_model.to(device)
# Other text attacks:
self.text_classifier = get_pipeline_by_path(
model_path_override,
"text-classifier",
"text-classification",
max_length=512,
truncation=True,
device=device
)
if self.use_local:
if not model_path_override:
self.saturation_attack_detector = PromptSaturationDetectorV3(
device=torch.device(device),
)
self.text_classifier = pipeline(
"text-classification",
DetectJailbreak.TEXT_CLASSIFIER_NAME,
max_length=512, # HACK: Fix classifier size.
truncation=True,
device=device,
)
# There are a large number of fairly low-effort prompts people will use.
# The embedding detectors do checks to roughly match those.
self.embedding_tokenizer = AutoTokenizer.from_pretrained(
DetectJailbreak.EMBEDDING_MODEL_NAME
)
self.embedding_model = AutoModel.from_pretrained(
DetectJailbreak.EMBEDDING_MODEL_NAME
).to(device)
else:
# Saturation:
self.saturation_attack_detector = PromptSaturationDetectorV3(
device=torch.device(device),
model_path_override=model_path_override
)
# Known attacks:
embedding_tokenizer, embedding_model = get_tokenizer_and_model_by_path(
model_path_override,
"embedding",
AutoTokenizer,
AutoModel
)
self.embedding_tokenizer = embedding_tokenizer
self.embedding_model = embedding_model.to(device)
# Other text attacks:
self.text_classifier = get_pipeline_by_path(
model_path_override,
"text-classifier",
"text-classification",
max_length=512,
truncation=True,
device=device
)

# Quick compute on startup:
self.known_malicious_embeddings = self._embed(KNOWN_ATTACKS)
# Quick compute on startup:
self.known_malicious_embeddings = self._embed(KNOWN_ATTACKS)

# These _are_ modifyable, but not explicitly advertised.
self.known_attack_scales = DetectJailbreak.DEFAULT_KNOWN_ATTACK_SCALE_FACTORS
Expand Down Expand Up @@ -233,6 +241,9 @@ def predict_jailbreak(
prompts: List[str],
reduction_function: Optional[Callable] = max,
) -> Union[List[float], List[dict]]:
"""predict_jailbreak will return an array of floats by default, one per prompt.
If reduction_function is set to 'none' it will return a dict with the different
sub-validators and their scores. Useful for debugging and tuning."""
if isinstance(prompts, str):
print("WARN: predict_jailbreak should be called with a list of strings.")
prompts = [prompts, ]
Expand Down Expand Up @@ -271,7 +282,9 @@ def validate(
if isinstance(value, str):
value = [value, ]

scores = self.predict_jailbreak(value)
# _inference is to support local/remote. It is equivalent to this:
# scores = self.predict_jailbreak(value)
scores = self._inference(value)

failed_prompts = list()
failed_scores = list() # To help people calibrate their thresholds.
Expand All @@ -289,3 +302,21 @@ def validate(
error_message=failure_message
)
return PassResult()

# The rest of these methods are made for validator compatibility and may have some
# strange properties,

def _inference_local(self, model_input: List[str]) -> Any:
return self.predict_jailbreak(model_input)

def _inference_remote(self, model_input: List[str]) -> Any:
# This needs to be kept in-sync with app_inference_spec.
request_body = {"prompts": model_input}
response = self._hub_inference_request(
json.dumps(request_body),
self.validation_endpoint
)
if not response or "scores" not in response:
raise ValueError("Invalid response from remote inference", response)

return response["scores"]
Loading