11"""Base NLP task support"""
22
33import copy
4+ import dataclasses
45import json
56import logging
67import os
@@ -54,11 +55,11 @@ class BaseNlpTask(tasks.EtlTask):
5455
5556 def __init__ (self , * args , ** kwargs ):
5657 super ().__init__ (* args , ** kwargs )
57- self .seen_docrefs = set ()
58+ self .seen_groups = set ()
5859
5960 def pop_current_group_values (self , table_index : int ) -> set [str ]:
60- values = self .seen_docrefs
61- self .seen_docrefs = set ()
61+ values = self .seen_groups
62+ self .seen_groups = set ()
6263 return values
6364
6465 def add_error (self , docref : dict ) -> None :
@@ -121,10 +122,28 @@ def remove_trailing_whitespace(note_text: str) -> str:
121122 return TRAILING_WHITESPACE .sub ("" , note_text )
122123
123124
125+ @dataclasses .dataclass (kw_only = True )
126+ class NoteDetails :
127+ note_ref : str
128+ encounter_id : str
129+ subject_ref : str
130+
131+ note_text : str
132+ note : dict
133+
134+ orig_note_ref : str
135+ orig_note_text : str
136+ orig_note : dict
137+
138+
124139class BaseModelTask (BaseNlpTask ):
125140 """Base class for any NLP task talking to LLM models."""
126141
127- outputs : ClassVar = [tasks .OutputTable (resource_type = None , uniqueness_fields = {"note_ref" })]
142+ outputs : ClassVar = [
143+ tasks .OutputTable (
144+ resource_type = None , uniqueness_fields = {"note_ref" }, group_field = "note_ref"
145+ )
146+ ]
128147
129148 # If you change these prompts, consider updating task_version.
130149 system_prompt : str = None
@@ -155,33 +174,47 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
155174 note_text = self .remove_trailing_whitespace (orig_note_text )
156175 orig_note_ref = f"{ orig_note ['resourceType' ]} /{ orig_note ['id' ]} "
157176
177+ details = NoteDetails (
178+ note_ref = note_ref ,
179+ encounter_id = encounter_id ,
180+ subject_ref = subject_ref ,
181+ note_text = note_text ,
182+ note = note ,
183+ orig_note_ref = orig_note_ref ,
184+ orig_note_text = orig_note_text ,
185+ orig_note = orig_note ,
186+ )
187+
158188 try :
159- response = await self .model .prompt (
160- self .get_system_prompt (),
161- self .get_user_prompt (note_text ),
162- schema = self .response_format ,
163- cache_dir = self .task_config .dir_phi ,
164- cache_namespace = f"{ self .name } _v{ self .task_version } " ,
165- note_text = note_text ,
166- )
189+ if result := await self .process_note (details ):
190+ yield result
167191 except Exception as exc :
168192 logging .warning (f"NLP failed for { orig_note_ref } : { exc } " )
169193 self .add_error (orig_note )
170- continue
171194
172- parsed = response .answer .model_dump (mode = "json" )
173- self .post_process (parsed , orig_note_text , orig_note )
195+ async def process_note (self , details : NoteDetails ) -> tasks .EntryBundle | None :
196+ response = await self .model .prompt (
197+ self .get_system_prompt (),
198+ self .get_user_prompt (details .note_text ),
199+ schema = self .response_format ,
200+ cache_dir = self .task_config .dir_phi ,
201+ cache_namespace = f"{ self .name } _v{ self .task_version } " ,
202+ note_text = details .note_text ,
203+ )
204+
205+ parsed = response .answer .model_dump (mode = "json" )
206+ self .post_process (parsed , details )
174207
175- yield {
176- "note_ref" : note_ref ,
177- "encounter_ref" : f"Encounter/{ encounter_id } " ,
178- "subject_ref" : subject_ref ,
179- # Since this date is stored as a string, use UTC time for easy comparisons
180- "generated_on" : common .datetime_now ().isoformat (),
181- "task_version" : self .task_version ,
182- "system_fingerprint" : response .fingerprint ,
183- "result" : parsed ,
184- }
208+ return {
209+ "note_ref" : details . note_ref ,
210+ "encounter_ref" : f"Encounter/{ details . encounter_id } " ,
211+ "subject_ref" : details . subject_ref ,
212+ # Since this date is stored as a string, use UTC time for easy comparisons
213+ "generated_on" : common .datetime_now ().isoformat (),
214+ "task_version" : self .task_version ,
215+ "system_fingerprint" : response .fingerprint ,
216+ "result" : parsed ,
217+ }
185218
186219 def finish_task (self ) -> None :
187220 stats = self .model .stats
@@ -225,7 +258,7 @@ def should_skip(self, orig_note: dict) -> bool:
225258 """Subclasses can fill this out if they like, to skip notes"""
226259 return False
227260
228- def post_process (self , parsed : dict , orig_note_text : str , orig_note : dict ) -> None :
261+ def post_process (self , parsed : dict , details : NoteDetails ) -> None :
229262 """Subclasses can fill this out if they like"""
230263
231264 @classmethod
@@ -289,18 +322,18 @@ class BaseModelTaskWithSpans(BaseModelTask):
289322 It assumes any field named "spans" in the hierarchy of the pydantic model should be converted.
290323 """
291324
292- def post_process (self , parsed : dict , orig_note_text : str , orig_note : dict ) -> None :
293- if not self ._process_dict (parsed , orig_note_text , orig_note ):
294- self .add_error (orig_note )
325+ def post_process (self , parsed : dict , details : NoteDetails ) -> None :
326+ if not self ._process_dict (parsed , details ):
327+ self .add_error (details . orig_note )
295328
296- def _process_dict (self , parsed : dict , orig_note_text : str , orig_note : dict ) -> bool :
329+ def _process_dict (self , parsed : dict , details : NoteDetails ) -> bool :
297330 """Returns False if any span couldn't be matched"""
298331 all_found = True
299332
300333 for key , value in parsed .items ():
301334 if key != "spans" :
302335 if isinstance (value , dict ):
303- all_found &= self ._process_dict (value , orig_note_text , orig_note ) # descend
336+ all_found &= self ._process_dict (value , details ) # descend
304337 continue
305338
306339 new_spans = []
@@ -318,14 +351,14 @@ def _process_dict(self, parsed: dict, orig_note_text: str, orig_note: dict) -> b
318351 span = ESCAPED_WHITESPACE .sub (r"\\s+" , span )
319352
320353 found = False
321- for match in re .finditer (span , orig_note_text , re .IGNORECASE ):
354+ for match in re .finditer (span , details . orig_note_text , re .IGNORECASE ):
322355 found = True
323356 new_spans .append (match .span ())
324357 if not found :
325358 all_found = False
326359 logging .warning (
327360 "Could not match span received from NLP server for "
328- f"{ orig_note [ 'resourceType' ] } / { orig_note [ 'id' ] } : { orig_span } "
361+ f"{ details . orig_note_ref } : { orig_span } "
329362 )
330363
331364 parsed [key ] = new_spans
0 commit comments