3
3
# github.com/guardrails-ai/models-host/tree/main/ray#adding-new-inference-endpoints
4
4
import os
5
5
from typing import Optional
6
+ from logging import getLogger
6
7
7
8
from fastapi import HTTPException
8
9
from pydantic import BaseModel
11
12
from validator import DetectJailbreak
12
13
13
14
15
+ # Environment variables:
16
+ # MODEL_PATH - "s3" (default) read model from MODEL_S3_PATH, "hf" read from huggingface.
17
+ # MODEL_S3_PATH - Defaults to
18
+ # s3://guardrails-ai-public-read-only/detect-jailbreak-v0/detect-jailbreak-v0.tar.gz
19
+ # HF_TOKEN - if set, will read model from HF.
20
+
21
+ logger = getLogger (__name__ )
22
+
23
+
14
24
class InputRequest (BaseModel ):
15
25
message : str
16
26
threshold : Optional [float ] = None
@@ -35,8 +45,23 @@ def device_name(self):
35
45
return torch_device
36
46
37
47
def load (self ):
48
+ kwargs = {
49
+ "device" : self .device_name
50
+ }
51
+ read_from = os .environ .get ("MODEL_PATH" , "s3" ).lower ()
52
+ if read_from == "s3" :
53
+ print ("Reading model from S3." )
54
+ kwargs ["model_path_override" ] = os .environ .get (
55
+ "MODEL_S3_PATH" ,
56
+ "s3://guardrails-ai-public-read-only/detect-jailbreak-v0/detect-jailbreak-v0.tar.gz"
57
+ )
58
+ elif read_from == "hf" :
59
+ print ("Reading model from HF." )
60
+ pass # Auto read from HF by default.
61
+ else :
62
+ logger .warning (f"MODEL_PATH is not set to 's3' or 'hf': '{ read_from } '" )
38
63
print (f"Loading model DetectJailbreak and moving to { self .device_name } ..." )
39
- self .model = DetectJailbreak (device = self . device_name )
64
+ self .model = DetectJailbreak (** kwargs )
40
65
41
66
def process_request (self , input_request : InputRequest ):
42
67
message = input_request .message
0 commit comments