1
- import os
2
- from pathlib import Path
3
1
from typing import Optional , Union
4
2
5
3
import numpy
6
4
import torch
7
5
import torch .nn as nn
8
- from cached_path import cached_path
9
6
10
7
from .resources import get_tokenizer_and_model_by_path
11
8
12
9
13
10
def string_to_one_hot_tensor (
14
- text : Union [str , list [str ]],
11
+ text : Union [str , list [str ], tuple [ str ] ],
15
12
max_length : int = 2048 ,
16
13
left_truncate : bool = True ,
17
14
) -> torch .Tensor :
@@ -32,10 +29,14 @@ def string_to_one_hot_tensor(
32
29
for idx , t in enumerate (text ):
33
30
if left_truncate :
34
31
t = t [- max_length :]
35
- out [idx , - len (t ):, :] = string_to_one_hot_tensor (t , max_length , left_truncate )[0 , :, :]
32
+ out [idx , - len (t ):, :] = string_to_one_hot_tensor (
33
+ t , max_length , left_truncate
34
+ )[0 , :, :]
36
35
else :
37
36
t = t [:max_length ]
38
- out [idx , :len (t ), :] = string_to_one_hot_tensor (t , max_length , left_truncate )[0 , :, :]
37
+ out [idx , :len (t ), :] = string_to_one_hot_tensor (
38
+ t , max_length , left_truncate
39
+ )[0 , :, :]
39
40
else :
40
41
raise Exception ("Input was neither a string nor a list of strings." )
41
42
return out
@@ -80,7 +81,7 @@ def forward(
80
81
x = self .fan_in (x )
81
82
x = self .lstm1 (x )[0 ]
82
83
x = self .output_head (x )
83
- x = x [:,- 1 ,0 ]
84
+ x = x [:, - 1 , 0 ]
84
85
x = self .output_activation (x )
85
86
return x
86
87
@@ -124,9 +125,14 @@ def forward(
124
125
longest_sequence = len (x [0 ])
125
126
x = torch .LongTensor (x ).to (self .get_current_device ())
126
127
# TODO: is 1 masked or unmasked?
127
- attention_mask = torch .LongTensor ([1 ] * longest_sequence ).to (self .get_current_device ())
128
+ attention_mask = torch .LongTensor (
129
+ [1 ] * longest_sequence
130
+ ).to (self .get_current_device ())
128
131
elif isinstance (x , list ) or isinstance (x , tuple ):
129
- sequences = [self .tokenizer .encode (text , add_special_tokens = True )[- max_size :] for text in x ]
132
+ sequences = [
133
+ self .tokenizer .encode (text , add_special_tokens = True )[- max_size :]
134
+ for text in x
135
+ ]
130
136
for token_list in sequences :
131
137
longest_sequence = max (longest_sequence , len (token_list ))
132
138
x = list ()
@@ -135,16 +141,28 @@ def forward(
135
141
x .append (
136
142
([self .pad_token ] * (longest_sequence - len (sequence ))) + sequence
137
143
)
138
- attention_mask .append ([0 ] * (longest_sequence - len (sequence )) + [1 ] * len (sequence ))
144
+ attention_mask .append (
145
+ [0 ] * (longest_sequence - len (sequence )) + [1 ] * len (sequence )
146
+ )
139
147
x = torch .LongTensor (x ).to (self .get_current_device ())
140
148
attention_mask = torch .tensor (attention_mask ).to (self .get_current_device ())
141
149
142
- #segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
143
- segments_tensors = torch .zeros (x .shape , dtype = torch .int ).to (self .get_current_device ())
150
+ # segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
151
+ segments_tensors = torch .zeros (x .shape , dtype = torch .int ) \
152
+ .to (self .get_current_device ())
144
153
if y is not None :
145
- return self .transformer (x , token_type_ids = segments_tensors , attention_mask = attention_mask , labels = y )
154
+ return self .transformer (
155
+ x ,
156
+ token_type_ids = segments_tensors ,
157
+ attention_mask = attention_mask ,
158
+ labels = y
159
+ )
146
160
else :
147
- return self .transformer (x , token_type_ids = segments_tensors , attention_mask = attention_mask ).logits
161
+ return self .transformer (
162
+ x ,
163
+ token_type_ids = segments_tensors ,
164
+ attention_mask = attention_mask
165
+ ).logits
148
166
149
167
150
168
class PromptSaturationDetectorV3 : # Note: Not nn.Module.
@@ -155,7 +173,9 @@ def __init__(
155
173
device : torch .device = torch .device ('cpu' ),
156
174
model_path_override : str = ""
157
175
):
158
- from transformers import pipeline , AutoTokenizer , AutoModelForSequenceClassification
176
+ from transformers import (
177
+ pipeline , AutoTokenizer , AutoModelForSequenceClassification
178
+ )
159
179
if not model_path_override :
160
180
self .model = AutoModelForSequenceClassification .from_pretrained (
161
181
"GuardrailsAI/prompt-saturation-attack-detector" ,
0 commit comments