Skip to content

Commit 23026d9

Browse files
Merge pull request #3 from guardrails-ai/jc/add_3_9_compatibility
Add some imports and remove some style elements to make the validator _maybe_ compatible with Python 3.9.
2 parents e48835e + fb8a6b4 commit 23026d9

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

validator/main.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Callable, Optional, Union
2+
from typing import Callable, List, Optional, Union
33

44
import torch
55
from torch.nn import functional as F
@@ -140,7 +140,7 @@ def _mean_pool(model_output, attention_mask):
140140
input_mask_expanded.sum(1), min=1e-9
141141
)
142142

143-
def _embed(self, prompts: list[str]):
143+
def _embed(self, prompts: List[str]):
144144
"""Taken from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
145145
We use the long-form to avoid a dependency on sentence transformers.
146146
This method returns the maximum of the matches against all known attacks.
@@ -160,8 +160,8 @@ def _embed(self, prompts: list[str]):
160160

161161
def _match_known_malicious_prompts(
162162
self,
163-
prompts: list[str] | torch.Tensor,
164-
) -> list[float]:
163+
prompts: Union[List[str], torch.Tensor],
164+
) -> List[float]:
165165
"""Returns an array of floats, one per prompt, with the max match to known
166166
attacks. If prompts is a list of strings, embeddings will be generated. If
167167
embeddings are passed, they will be used."""
@@ -179,7 +179,7 @@ def _match_known_malicious_prompts(
179179
def _predict_and_remap(
180180
self,
181181
model,
182-
prompts: list[str],
182+
prompts: List[str],
183183
label_field: str,
184184
score_field: str,
185185
safe_case: str,
@@ -199,7 +199,7 @@ def _predict_and_remap(
199199
scores.append(new_score)
200200
return scores
201201

202-
def _predict_jailbreak(self, prompts: list[str]) -> list[float]:
202+
def _predict_jailbreak(self, prompts: List[str]) -> List[float]:
203203
return [
204204
DetectJailbreak._rescale(s, *self.text_attack_scales)
205205
for s in self._predict_and_remap(
@@ -212,7 +212,7 @@ def _predict_jailbreak(self, prompts: list[str]) -> list[float]:
212212
)
213213
]
214214

215-
def _predict_saturation(self, prompts: list[str]) -> list[float]:
215+
def _predict_saturation(self, prompts: List[str]) -> List[float]:
216216
return [
217217
DetectJailbreak._rescale(
218218
s,
@@ -230,9 +230,9 @@ def _predict_saturation(self, prompts: list[str]) -> list[float]:
230230

231231
def predict_jailbreak(
232232
self,
233-
prompts: list[str],
233+
prompts: List[str],
234234
reduction_function: Optional[Callable] = max,
235-
) -> Union[list[float], list[dict]]:
235+
) -> Union[List[float], List[dict]]:
236236
if isinstance(prompts, str):
237237
print("WARN: predict_jailbreak should be called with a list of strings.")
238238
prompts = [prompts, ]
@@ -256,7 +256,7 @@ def predict_jailbreak(
256256

257257
def validate(
258258
self,
259-
value: Union[str, list[str]],
259+
value: Union[str, List[str]],
260260
metadata: Optional[dict] = None,
261261
) -> ValidationResult:
262262
"""Validates that will return a failure if the value is a jailbreak attempt.

validator/models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Union
1+
from typing import List, Tuple, Optional, Union
22

33
import numpy
44
import torch
@@ -8,7 +8,7 @@
88

99

1010
def string_to_one_hot_tensor(
11-
text: Union[str, list[str], tuple[str]],
11+
text: Union[str, List[str], Tuple[str]],
1212
max_length: int = 2048,
1313
left_truncate: bool = True,
1414
) -> torch.Tensor:
@@ -71,7 +71,7 @@ def get_current_device(self):
7171

7272
def forward(
7373
self,
74-
x: Union[str, list[str], numpy.ndarray, torch.Tensor]
74+
x: Union[str, List[str], numpy.ndarray, torch.Tensor]
7575
) -> torch.Tensor:
7676
if isinstance(x, str) or isinstance(x, list) or isinstance(x, tuple):
7777
x = string_to_one_hot_tensor(x).to(self.get_current_device())
@@ -113,7 +113,7 @@ def get_current_device(self):
113113

114114
def forward(
115115
self,
116-
x: Union[str, list[str], numpy.ndarray, torch.Tensor],
116+
x: Union[str, List[str], numpy.ndarray, torch.Tensor],
117117
y: Optional[torch.Tensor] = None,
118118
attention_mask: Optional[torch.Tensor] = None,
119119
) -> torch.Tensor:
@@ -209,5 +209,5 @@ def __init__(
209209
device=device,
210210
)
211211

212-
def __call__(self, text: Union[str, list[str]]) -> list[dict]:
212+
def __call__(self, text: Union[str, List[str]]) -> List[dict]:
213213
return self.pipe(text)

0 commit comments

Comments
 (0)