Skip to content

Commit e48835e

Browse files
Merge pull request #2 from guardrails-ai/jc/add_inference_spec
Add Inference Spec
2 parents b92ac48 + 6a588d0 commit e48835e

File tree

3 files changed

+110
-16
lines changed

3 files changed

+110
-16
lines changed

app_inference_spec.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# app_inference_spec.py
2+
# Forked from spec:
3+
# github.com/guardrails-ai/models-host/tree/main/ray#adding-new-inference-endpoints
4+
import os
5+
from typing import Optional
6+
7+
from fastapi import HTTPException
8+
from pydantic import BaseModel
9+
from models_host.base_inference_spec import BaseInferenceSpec
10+
11+
from validator import DetectJailbreak
12+
13+
14+
class InputRequest(BaseModel):
15+
message: str
16+
threshold: Optional[float] = None
17+
18+
19+
class OutputResponse(BaseModel):
20+
classification: str
21+
score: float
22+
is_jailbreak: bool
23+
24+
25+
# Using same nomenclature as in Sagemaker classes
26+
class InferenceSpec(BaseInferenceSpec):
27+
def __init__(self):
28+
self.model = None
29+
30+
@property
31+
def device_name(self):
32+
env = os.environ.get("env", "dev")
33+
# JC: Legacy usage of 'env' as a device.
34+
torch_device = "cuda" if env == "prod" else "cpu"
35+
return torch_device
36+
37+
def load(self):
38+
print(f"Loading model DetectJailbreak and moving to {self.device_name}...")
39+
self.model = DetectJailbreak(device=self.device_name)
40+
41+
def process_request(self, input_request: InputRequest):
42+
message = input_request.message
43+
# If needed, sanity check.
44+
# raise HTTPException(status_code=400, detail="Invalid input format")
45+
args = (message,)
46+
kwargs = {}
47+
if input_request.threshold is not None:
48+
kwargs["threshold"] = input_request.threshold
49+
if not 0.0 <= input_request.threshold <= 1.0:
50+
raise HTTPException(
51+
status_code=400,
52+
detail=f"Threshold must be between 0.0 and 1.0. "
53+
f"Got {input_request.threshold}"
54+
)
55+
return args, kwargs
56+
57+
def infer(self, message: str, threshold: Optional[float] = None) -> OutputResponse:
58+
if threshold is None:
59+
threshold = 0.81
60+
61+
score = self.model.predict_jailbreak([message,])[0]
62+
if score > threshold:
63+
classification = "jailbreak"
64+
is_jailbreak = True
65+
else:
66+
classification = "safe"
67+
is_jailbreak = False
68+
69+
return OutputResponse(
70+
classification=classification,
71+
score=score,
72+
is_jailbreak=is_jailbreak,
73+
)

validator/models.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
1-
import os
2-
from pathlib import Path
31
from typing import Optional, Union
42

53
import numpy
64
import torch
75
import torch.nn as nn
8-
from cached_path import cached_path
96

107
from .resources import get_tokenizer_and_model_by_path
118

129

1310
def string_to_one_hot_tensor(
14-
text: Union[str, list[str]],
11+
text: Union[str, list[str], tuple[str]],
1512
max_length: int = 2048,
1613
left_truncate: bool = True,
1714
) -> torch.Tensor:
@@ -32,10 +29,14 @@ def string_to_one_hot_tensor(
3229
for idx, t in enumerate(text):
3330
if left_truncate:
3431
t = t[-max_length:]
35-
out[idx, -len(t):, :] = string_to_one_hot_tensor(t, max_length, left_truncate)[0, :, :]
32+
out[idx, -len(t):, :] = string_to_one_hot_tensor(
33+
t, max_length, left_truncate
34+
)[0, :, :]
3635
else:
3736
t = t[:max_length]
38-
out[idx, :len(t), :] = string_to_one_hot_tensor(t, max_length, left_truncate)[0, :, :]
37+
out[idx, :len(t), :] = string_to_one_hot_tensor(
38+
t, max_length, left_truncate
39+
)[0, :, :]
3940
else:
4041
raise Exception("Input was neither a string nor a list of strings.")
4142
return out
@@ -80,7 +81,7 @@ def forward(
8081
x = self.fan_in(x)
8182
x = self.lstm1(x)[0]
8283
x = self.output_head(x)
83-
x = x[:,-1,0]
84+
x = x[:, -1, 0]
8485
x = self.output_activation(x)
8586
return x
8687

@@ -124,9 +125,14 @@ def forward(
124125
longest_sequence = len(x[0])
125126
x = torch.LongTensor(x).to(self.get_current_device())
126127
# TODO: is 1 masked or unmasked?
127-
attention_mask = torch.LongTensor([1] * longest_sequence).to(self.get_current_device())
128+
attention_mask = torch.LongTensor(
129+
[1] * longest_sequence
130+
).to(self.get_current_device())
128131
elif isinstance(x, list) or isinstance(x, tuple):
129-
sequences = [self.tokenizer.encode(text, add_special_tokens=True)[-max_size:] for text in x]
132+
sequences = [
133+
self.tokenizer.encode(text, add_special_tokens=True)[-max_size:]
134+
for text in x
135+
]
130136
for token_list in sequences:
131137
longest_sequence = max(longest_sequence, len(token_list))
132138
x = list()
@@ -135,16 +141,28 @@ def forward(
135141
x.append(
136142
([self.pad_token] * (longest_sequence - len(sequence))) + sequence
137143
)
138-
attention_mask.append([0] * (longest_sequence - len(sequence)) + [1] * len(sequence))
144+
attention_mask.append(
145+
[0] * (longest_sequence - len(sequence)) + [1] * len(sequence)
146+
)
139147
x = torch.LongTensor(x).to(self.get_current_device())
140148
attention_mask = torch.tensor(attention_mask).to(self.get_current_device())
141149

142-
#segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
143-
segments_tensors = torch.zeros(x.shape, dtype=torch.int).to(self.get_current_device())
150+
# segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
151+
segments_tensors = torch.zeros(x.shape, dtype=torch.int) \
152+
.to(self.get_current_device())
144153
if y is not None:
145-
return self.transformer(x, token_type_ids=segments_tensors, attention_mask=attention_mask, labels=y)
154+
return self.transformer(
155+
x,
156+
token_type_ids=segments_tensors,
157+
attention_mask=attention_mask,
158+
labels=y
159+
)
146160
else:
147-
return self.transformer(x, token_type_ids=segments_tensors, attention_mask=attention_mask).logits
161+
return self.transformer(
162+
x,
163+
token_type_ids=segments_tensors,
164+
attention_mask=attention_mask
165+
).logits
148166

149167

150168
class PromptSaturationDetectorV3: # Note: Not nn.Module.
@@ -155,7 +173,9 @@ def __init__(
155173
device: torch.device = torch.device('cpu'),
156174
model_path_override: str = ""
157175
):
158-
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
176+
from transformers import (
177+
pipeline, AutoTokenizer, AutoModelForSequenceClassification
178+
)
159179
if not model_path_override:
160180
self.model = AutoModelForSequenceClassification.from_pretrained(
161181
"GuardrailsAI/prompt-saturation-attack-detector",

validator/post-install.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from transformers import pipeline, AutoTokenizer, AutoModel
2+
13
print("post-install starting...")
24
# TODO: It's not clear if the DetectJailbreak will be on the path yet.
35
# If we can import Detect Jailbreak, it will be safer to read the names of the models
46
# from the composite model as exposed by DetectJailbreak.XYZ.
5-
from transformers import pipeline, AutoTokenizer, AutoModel
67
print("Fetching model 1 of 3 (Saturation)")
78
AutoModel.from_pretrained("GuardrailsAI/prompt-saturation-attack-detector")
89
AutoTokenizer.from_pretrained("google-bert/bert-base-cased")

0 commit comments

Comments
 (0)