1
+ import torch
1
2
from presidio_analyzer import EntityRecognizer , RecognizerResult
2
3
from gliner import GLiNER
3
4
from .constants import PRESIDIO_TO_GLINER , GLINER_TO_PRESIDIO
4
5
5
6
6
7
class GLiNERRecognizer (EntityRecognizer ):
7
- def __init__ (self , supported_entities , model_name ):
8
+ def __init__ (self , supported_entities , model_name , use_gpu = True ):
8
9
self .model_name = model_name
9
10
self .supported_entities = supported_entities
11
+ self .use_gpu = use_gpu
12
+ self .device = self ._get_device ()
10
13
11
14
gliner_entities = set ()
12
15
@@ -18,12 +21,28 @@ def __init__(self, supported_entities, model_name):
18
21
19
22
super ().__init__ (supported_entities = supported_entities )
20
23
24
+ def _get_device (self ):
25
+ """Determine the device to use for inference"""
26
+ if self .use_gpu and torch .cuda .is_available ():
27
+ return torch .device ("cuda" )
28
+ return torch .device ("cpu" )
29
+
21
30
def load (self ) -> None :
22
- """No loading required as the model is loaded in the constructor """
31
+ """Load the model and move it to the appropriate device """
23
32
self .model = GLiNER .from_pretrained (self .model_name )
33
+ self .model = self .model .to (self .device )
34
+ if self .use_gpu and torch .cuda .is_available ():
35
+ self .model .eval ()
24
36
25
37
def analyze (self , text , entities = None , nlp_artifacts = None ):
26
- results = self .model .predict_entities (text , self .gliner_entities )
38
+ """Analyze text using GPU-accelerated GLiNER model"""
39
+ # Ensure model is on correct device
40
+ if hasattr (self .model , 'device' ) and self .model .device != self .device :
41
+ self .model = self .model .to (self .device )
42
+
43
+ # Run inference with gradient disabled for efficiency
44
+ with torch .no_grad ():
45
+ results = self .model .predict_entities (text , self .gliner_entities )
27
46
return [
28
47
RecognizerResult (
29
48
entity_type = GLINER_TO_PRESIDIO [entity ["label" ]],
0 commit comments