Skip to content

Commit a5d25bc

Browse files
1
1 parent 5a43b69 commit a5d25bc

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

validator/gliner_recognizer.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import torch
12
from presidio_analyzer import EntityRecognizer, RecognizerResult
23
from gliner import GLiNER
34
from .constants import PRESIDIO_TO_GLINER, GLINER_TO_PRESIDIO
45

56

67
class GLiNERRecognizer(EntityRecognizer):
7-
def __init__(self, supported_entities, model_name):
8+
def __init__(self, supported_entities, model_name, use_gpu=True):
89
self.model_name = model_name
910
self.supported_entities = supported_entities
11+
self.use_gpu = use_gpu
12+
self.device = self._get_device()
1013

1114
gliner_entities = set()
1215

@@ -18,12 +21,28 @@ def __init__(self, supported_entities, model_name):
1821

1922
super().__init__(supported_entities=supported_entities)
2023

24+
def _get_device(self):
25+
"""Determine the device to use for inference"""
26+
if self.use_gpu and torch.cuda.is_available():
27+
return torch.device("cuda")
28+
return torch.device("cpu")
29+
2130
def load(self) -> None:
22-
"""No loading required as the model is loaded in the constructor"""
31+
"""Load the model and move it to the appropriate device"""
2332
self.model = GLiNER.from_pretrained(self.model_name)
33+
self.model = self.model.to(self.device)
34+
if self.use_gpu and torch.cuda.is_available():
35+
self.model.eval()
2436

2537
def analyze(self, text, entities=None, nlp_artifacts=None):
26-
results = self.model.predict_entities(text, self.gliner_entities)
38+
"""Analyze text using GPU-accelerated GLiNER model"""
39+
# Ensure model is on correct device
40+
if hasattr(self.model, 'device') and self.model.device != self.device:
41+
self.model = self.model.to(self.device)
42+
43+
# Run inference with gradient disabled for efficiency
44+
with torch.no_grad():
45+
results = self.model.predict_entities(text, self.gliner_entities)
2746
return [
2847
RecognizerResult(
2948
entity_type=GLINER_TO_PRESIDIO[entity["label"]],

validator/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
get_entity_threshold: Callable = get_entity_threshold,
8080
on_fail: Optional[Callable] = None,
8181
use_local: bool = True,
82+
use_gpu: bool = True,
8283
**kwargs,
8384
):
8485
"""Validates that the LLM-generated text does not contain Personally Identifiable Information (PII).
@@ -109,6 +110,7 @@ def __init__(
109110
entities=entities,
110111
get_entity_threshold=get_entity_threshold,
111112
use_local=use_local,
113+
use_gpu=use_gpu,
112114
**kwargs,
113115
)
114116

@@ -119,11 +121,13 @@ def __init__(
119121
self.entities = entities
120122
self.model_name = model_name
121123
self.get_entity_threshold = get_entity_threshold
124+
self.use_gpu = use_gpu
122125

123126
if self.use_local:
124127
self.gliner_recognizer = GLiNERRecognizer(
125128
supported_entities=self.entities,
126129
model_name=model_name,
130+
use_gpu=use_gpu,
127131
)
128132
registry = RecognizerRegistry()
129133
registry.load_predefined_recognizers()

0 commit comments

Comments
 (0)