1
1
import math
2
- from typing import Callable , Optional , Union
2
+ from typing import Callable , List , Optional , Union
3
3
4
4
import torch
5
5
from torch .nn import functional as F
@@ -140,7 +140,7 @@ def _mean_pool(model_output, attention_mask):
140
140
input_mask_expanded .sum (1 ), min = 1e-9
141
141
)
142
142
143
- def _embed (self , prompts : list [str ]):
143
+ def _embed (self , prompts : List [str ]):
144
144
"""Taken from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
145
145
We use the long-form to avoid a dependency on sentence transformers.
146
146
This method returns the maximum of the matches against all known attacks.
@@ -160,8 +160,8 @@ def _embed(self, prompts: list[str]):
160
160
161
161
def _match_known_malicious_prompts (
162
162
self ,
163
- prompts : list [ str ] | torch .Tensor ,
164
- ) -> list [float ]:
163
+ prompts : Union [ List [ str ], torch .Tensor ] ,
164
+ ) -> List [float ]:
165
165
"""Returns an array of floats, one per prompt, with the max match to known
166
166
attacks. If prompts is a list of strings, embeddings will be generated. If
167
167
embeddings are passed, they will be used."""
@@ -179,7 +179,7 @@ def _match_known_malicious_prompts(
179
179
def _predict_and_remap (
180
180
self ,
181
181
model ,
182
- prompts : list [str ],
182
+ prompts : List [str ],
183
183
label_field : str ,
184
184
score_field : str ,
185
185
safe_case : str ,
@@ -199,7 +199,7 @@ def _predict_and_remap(
199
199
scores .append (new_score )
200
200
return scores
201
201
202
- def _predict_jailbreak (self , prompts : list [str ]) -> list [float ]:
202
+ def _predict_jailbreak (self , prompts : List [str ]) -> List [float ]:
203
203
return [
204
204
DetectJailbreak ._rescale (s , * self .text_attack_scales )
205
205
for s in self ._predict_and_remap (
@@ -212,7 +212,7 @@ def _predict_jailbreak(self, prompts: list[str]) -> list[float]:
212
212
)
213
213
]
214
214
215
- def _predict_saturation (self , prompts : list [str ]) -> list [float ]:
215
+ def _predict_saturation (self , prompts : List [str ]) -> List [float ]:
216
216
return [
217
217
DetectJailbreak ._rescale (
218
218
s ,
@@ -230,9 +230,9 @@ def _predict_saturation(self, prompts: list[str]) -> list[float]:
230
230
231
231
def predict_jailbreak (
232
232
self ,
233
- prompts : list [str ],
233
+ prompts : List [str ],
234
234
reduction_function : Optional [Callable ] = max ,
235
- ) -> Union [list [float ], list [dict ]]:
235
+ ) -> Union [List [float ], List [dict ]]:
236
236
if isinstance (prompts , str ):
237
237
print ("WARN: predict_jailbreak should be called with a list of strings." )
238
238
prompts = [prompts , ]
@@ -256,7 +256,7 @@ def predict_jailbreak(
256
256
257
257
def validate (
258
258
self ,
259
- value : Union [str , list [str ]],
259
+ value : Union [str , List [str ]],
260
260
metadata : Optional [dict ] = None ,
261
261
) -> ValidationResult :
262
262
"""Validates that will return a failure if the value is a jailbreak attempt.
0 commit comments