1111from  typing  import  ClassVar 
1212
1313import  cumulus_fhir_support  as  cfs 
14- import  openai 
1514import  pyarrow 
1615import  pydantic 
1716import  rich .progress 
2726class  BaseNlpTask (tasks .EtlTask ):
2827    """Base class for any clinical-notes-based NLP task.""" 
2928
30-     resource : ClassVar  =  " DocumentReference"
29+     resource : ClassVar  =  { "DiagnosticReport" ,  " DocumentReference"} 
3130    needs_bulk_deid : ClassVar  =  False 
3231
3332    # You may want to override these in your subclass 
@@ -80,44 +79,45 @@ async def read_notes(
8079        """ 
8180        Iterate through clinical notes. 
8281
83-         :returns: a tuple of original-docref , scrubbed-docref , and clinical  note 
82+         :returns: a tuple of original-resource , scrubbed-resource , and note text  
8483        """ 
8584        warned_connection_error  =  False 
8685
87-         note_filter  =  self .task_config .resource_filter  or  nlp .is_docref_valid 
86+         note_filter  =  self .task_config .resource_filter  or  nlp .is_note_valid 
8887
89-         for  docref  in  self .read_ndjson (progress = progress ):
90-             orig_docref  =  copy .deepcopy (docref )
88+         for  note  in  self .read_ndjson (progress = progress ):
89+             orig_note  =  copy .deepcopy (note )
9190            can_process  =  (
92-                 note_filter (self .scrubber .codebook , docref )
93-                 and  (doc_check  is  None  or  doc_check (docref ))
94-                 and  self .scrubber .scrub_resource (docref , scrub_attachments = False , keep_stats = False )
91+                 note_filter (self .scrubber .codebook , note )
92+                 and  (doc_check  is  None  or  doc_check (note ))
93+                 and  self .scrubber .scrub_resource (note , scrub_attachments = False , keep_stats = False )
9594            )
9695            if  not  can_process :
9796                continue 
9897
9998            try :
100-                 clinical_note  =  await  fhir .get_clinical_note (self .task_config .client , docref )
99+                 note_text  =  await  fhir .get_clinical_note (self .task_config .client , note )
101100            except  cfs .BadAuthArguments  as  exc :
102101                if  not  warned_connection_error :
103102                    # Only warn user about a misconfiguration once per task. 
104103                    # It's not fatal because it might be intentional (partially inlined DocRefs 
105104                    # and the other DocRefs are known failures - BCH hits this with Cerner data). 
106105                    print (exc , file = sys .stderr )
107106                    warned_connection_error  =  True 
108-                 self .add_error (orig_docref )
107+                 self .add_error (orig_note )
109108                continue 
110109            except  Exception  as  exc :
111-                 logging .warning ("Error getting text for docref %s: %s" , docref ["id" ], exc )
112-                 self .add_error (orig_docref )
110+                 orig_note_ref  =  f"{ orig_note ['resourceType' ]}  /{ orig_note ['id' ]}  " 
111+                 logging .warning ("Error getting text for note %s: %s" , orig_note_ref , exc )
112+                 self .add_error (orig_note )
113113                continue 
114114
115-             yield  orig_docref ,  docref ,  clinical_note 
115+             yield  orig_note ,  note ,  note_text 
116116
117117    @staticmethod  
118-     def  remove_trailing_whitespace (note : str ) ->  str :
118+     def  remove_trailing_whitespace (note_text : str ) ->  str :
119119        """Sometimes NLP can be mildly confused by trailing whitespace, so this removes it""" 
120-         return  TRAILING_WHITESPACE .sub ("" , note )
120+         return  TRAILING_WHITESPACE .sub ("" , note_text )
121121
122122
123123class  BaseOpenAiTask (BaseNlpTask ):
@@ -139,59 +139,52 @@ async def init_check(cls) -> None:
139139    async  def  read_entries (self , * , progress : rich .progress .Progress  =  None ) ->  tasks .EntryIterator :
140140        client  =  self .client_class ()
141141
142-         async  for  orig_docref ,  docref ,  orig_clinical_note  in  self .read_notes (progress = progress ):
142+         async  for  orig_note ,  note ,  orig_note_text  in  self .read_notes (progress = progress ):
143143            try :
144-                 docref_id , encounter_id , subject_id  =  nlp .get_docref_info ( docref )
144+                 note_ref , encounter_id , subject_id  =  nlp .get_note_info ( note )
145145            except  KeyError  as  exc :
146146                logging .warning (exc )
147-                 self .add_error (orig_docref )
147+                 self .add_error (orig_note )
148148                continue 
149149
150-             clinical_note  =  self .remove_trailing_whitespace (orig_clinical_note )
150+             note_text  =  self .remove_trailing_whitespace (orig_note_text )
151+             orig_note_ref  =  f"{ orig_note ['resourceType' ]}  /{ orig_note ['id' ]}  " 
151152
152153            try :
153154                completion_class  =  chat .ParsedChatCompletion [self .response_format ]
154155                response  =  await  nlp .cache_wrapper (
155156                    self .task_config .dir_phi ,
156157                    f"{ self .name }  _v{ self .task_version }  " ,
157-                     clinical_note ,
158+                     note_text ,
158159                    lambda  x : completion_class .model_validate_json (x ),  # from file 
159160                    lambda  x : x .model_dump_json (  # to file 
160161                        indent = None , round_trip = True , exclude_unset = True , by_alias = True 
161162                    ),
162163                    client .prompt ,
163164                    self .system_prompt ,
164-                     self .get_user_prompt (clinical_note ),
165+                     self .get_user_prompt (note_text ),
165166                    self .response_format ,
166167                )
167-             except  openai .APIError  as  exc :
168-                 logging .warning (
169-                     f"Could not connect to NLP server for DocRef { orig_docref ['id' ]}  : { exc }  " 
170-                 )
171-                 self .add_error (orig_docref )
172-                 continue 
173-             except  pydantic .ValidationError  as  exc :
174-                 logging .warning (
175-                     f"Could not process answer from NLP server for DocRef { orig_docref ['id' ]}  : { exc }  " 
176-                 )
177-                 self .add_error (orig_docref )
168+             except  Exception  as  exc :
169+                 logging .warning (f"NLP failed for { orig_note_ref }  : { exc }  " )
170+                 self .add_error (orig_note )
178171                continue 
179172
180173            choice  =  response .choices [0 ]
181174
182175            if  choice .finish_reason  !=  "stop"  or  not  choice .message .parsed :
183176                logging .warning (
184-                     f"NLP server response didn't complete for DocRef  { orig_docref [ 'id' ] }  : " 
177+                     f"NLP server response didn't complete for { orig_note_ref }  : " 
185178                    f"{ choice .finish_reason }  " 
186179                )
187-                 self .add_error (orig_docref )
180+                 self .add_error (orig_note )
188181                continue 
189182
190183            parsed  =  choice .message .parsed .model_dump (mode = "json" )
191-             self .post_process (parsed , orig_clinical_note ,  orig_docref )
184+             self .post_process (parsed , orig_note_text ,  orig_note )
192185
193186            yield  {
194-                 "note_ref" : f"DocumentReference/ { docref_id } "  ,
187+                 "note_ref" : note_ref ,
195188                "encounter_ref" : f"Encounter/{ encounter_id }  " ,
196189                "subject_ref" : f"Patient/{ subject_id }  " ,
197190                # Since this date is stored as a string, use UTC time for easy comparisons 
@@ -202,11 +195,11 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
202195            }
203196
204197    @classmethod  
205-     def  get_user_prompt (cls , clinical_note : str ) ->  str :
198+     def  get_user_prompt (cls , note_text : str ) ->  str :
206199        prompt  =  cls .user_prompt  or  "%CLINICAL-NOTE%" 
207-         return  prompt .replace ("%CLINICAL-NOTE%" , clinical_note )
200+         return  prompt .replace ("%CLINICAL-NOTE%" , note_text )
208201
209-     def  post_process (self , parsed : dict , orig_clinical_note : str , orig_docref : dict ) ->  None :
202+     def  post_process (self , parsed : dict , orig_note_text : str , orig_note : dict ) ->  None :
210203        """Subclasses can fill this out if they like""" 
211204
212205    @classmethod  
@@ -261,7 +254,7 @@ class BaseOpenAiTaskWithSpans(BaseOpenAiTask):
261254    It assumes the field is named "spans" in the top level of the pydantic model. 
262255    """ 
263256
264-     def  post_process (self , parsed : dict , orig_clinical_note : str , orig_docref : dict ) ->  None :
257+     def  post_process (self , parsed : dict , orig_note_text : str , orig_note : dict ) ->  None :
265258        new_spans  =  []
266259        missed_some  =  False 
267260
@@ -278,18 +271,18 @@ def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict)
278271            span  =  ESCAPED_WHITESPACE .sub (r"\\s+" , span )
279272
280273            found  =  False 
281-             for  match  in  re .finditer (span , orig_clinical_note , re .IGNORECASE ):
274+             for  match  in  re .finditer (span , orig_note_text , re .IGNORECASE ):
282275                found  =  True 
283276                new_spans .append (match .span ())
284277            if  not  found :
285278                missed_some  =  True 
286279                logging .warning (
287280                    "Could not match span received from NLP server for " 
288-                     f"DocRef  { orig_docref ['id' ]}  : { orig_span }  " 
281+                     f"{ orig_note [ 'resourceType' ] } / { orig_note ['id' ]}  : { orig_span }  " 
289282                )
290283
291284        if  missed_some :
292-             self .add_error (orig_docref )
285+             self .add_error (orig_note )
293286
294287        parsed ["spans" ] =  new_spans 
295288
0 commit comments