Skip to content

Commit 1f059ca

Browse files
Read app_inference_spec from S3 by default.
1 parent 51e5a0b commit 1f059ca

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

app_inference_spec.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# github.com/guardrails-ai/models-host/tree/main/ray#adding-new-inference-endpoints
44
import os
55
from typing import Optional
6+
from logging import getLogger
67

78
from fastapi import HTTPException
89
from pydantic import BaseModel
@@ -11,6 +12,15 @@
1112
from validator import DetectJailbreak
1213

1314

15+
# Environment variables:
16+
# MODEL_PATH - "s3" (default) read model from MODEL_S3_PATH, "hf" read from huggingface.
17+
# MODEL_S3_PATH - Defaults to
18+
# s3://guardrails-ai-public-read-only/detect-jailbreak-v0/detect-jailbreak-v0.tar.gz
19+
# HF_TOKEN - if set, will read model from HF.
20+
21+
logger = getLogger(__name__)
22+
23+
1424
class InputRequest(BaseModel):
1525
message: str
1626
threshold: Optional[float] = None
@@ -35,8 +45,23 @@ def device_name(self):
3545
return torch_device
3646

3747
def load(self):
48+
kwargs = {
49+
"device": self.device_name
50+
}
51+
read_from = os.environ.get("MODEL_PATH", "s3").lower()
52+
if read_from == "s3":
53+
print("Reading model from S3.")
54+
kwargs["model_path_override"] = os.environ.get(
55+
"MODEL_S3_PATH",
56+
"s3://guardrails-ai-public-read-only/detect-jailbreak-v0/detect-jailbreak-v0.tar.gz"
57+
)
58+
elif read_from == "hf":
59+
print("Reading model from HF.")
60+
pass # Auto read from HF by default.
61+
else:
62+
logger.warning(f"MODEL_PATH is not set to 's3' or 'hf': '{read_from}'")
3863
print(f"Loading model DetectJailbreak and moving to {self.device_name}...")
39-
self.model = DetectJailbreak(device=self.device_name)
64+
self.model = DetectJailbreak(**kwargs)
4065

4166
def process_request(self, input_request: InputRequest):
4267
message = input_request.message

0 commit comments

Comments
 (0)