- 
                Notifications
    
You must be signed in to change notification settings  - Fork 4
 
feat: have NLP tasks read in DxReports as well as DocRefs #435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -11,7 +11,6 @@ | |
| from typing import ClassVar | ||
| 
     | 
||
| import cumulus_fhir_support as cfs | ||
| import openai | ||
| import pyarrow | ||
| import pydantic | ||
| import rich.progress | ||
| 
        
          
        
         | 
    @@ -27,7 +26,7 @@ | |
| class BaseNlpTask(tasks.EtlTask): | ||
| """Base class for any clinical-notes-based NLP task.""" | ||
| 
     | 
||
| resource: ClassVar = "DocumentReference" | ||
| resource: ClassVar = {"DiagnosticReport", "DocumentReference"} | ||
| needs_bulk_deid: ClassVar = False | ||
| 
     | 
||
| # You may want to override these in your subclass | ||
| 
          
            
          
           | 
    @@ -80,44 +79,45 @@ async def read_notes( | |
| """ | ||
| Iterate through clinical notes. | ||
| 
     | 
||
| :returns: a tuple of original-docref, scrubbed-docref, and clinical note | ||
| :returns: a tuple of original-resource, scrubbed-resource, and note text | ||
| """ | ||
| warned_connection_error = False | ||
| 
     | 
||
| note_filter = self.task_config.resource_filter or nlp.is_docref_valid | ||
| note_filter = self.task_config.resource_filter or nlp.is_note_valid | ||
| 
     | 
||
| for docref in self.read_ndjson(progress=progress): | ||
| orig_docref = copy.deepcopy(docref) | ||
| for note in self.read_ndjson(progress=progress): | ||
| orig_note = copy.deepcopy(note) | ||
| can_process = ( | ||
| note_filter(self.scrubber.codebook, docref) | ||
| and (doc_check is None or doc_check(docref)) | ||
| and self.scrubber.scrub_resource(docref, scrub_attachments=False, keep_stats=False) | ||
| note_filter(self.scrubber.codebook, note) | ||
| and (doc_check is None or doc_check(note)) | ||
| and self.scrubber.scrub_resource(note, scrub_attachments=False, keep_stats=False) | ||
| ) | ||
| if not can_process: | ||
| continue | ||
| 
     | 
||
| try: | ||
| clinical_note = await fhir.get_clinical_note(self.task_config.client, docref) | ||
| note_text = await fhir.get_clinical_note(self.task_config.client, note) | ||
| except cfs.BadAuthArguments as exc: | ||
| if not warned_connection_error: | ||
| # Only warn user about a misconfiguration once per task. | ||
| # It's not fatal because it might be intentional (partially inlined DocRefs | ||
| # and the other DocRefs are known failures - BCH hits this with Cerner data). | ||
| print(exc, file=sys.stderr) | ||
| warned_connection_error = True | ||
| self.add_error(orig_docref) | ||
| self.add_error(orig_note) | ||
| continue | ||
| except Exception as exc: | ||
| logging.warning("Error getting text for docref %s: %s", docref["id"], exc) | ||
| self.add_error(orig_docref) | ||
| orig_note_ref = f"{orig_note['resourceType']}/{orig_note['id']}" | ||
| logging.warning("Error getting text for note %s: %s", orig_note_ref, exc) | ||
| self.add_error(orig_note) | ||
| continue | ||
| 
     | 
||
| yield orig_docref, docref, clinical_note | ||
| yield orig_note, note, note_text | ||
| 
     | 
||
| @staticmethod | ||
| def remove_trailing_whitespace(note: str) -> str: | ||
| def remove_trailing_whitespace(note_text: str) -> str: | ||
| """Sometimes NLP can be mildly confused by trailing whitespace, so this removes it""" | ||
| return TRAILING_WHITESPACE.sub("", note) | ||
| return TRAILING_WHITESPACE.sub("", note_text) | ||
| 
     | 
||
| 
     | 
||
| class BaseOpenAiTask(BaseNlpTask): | ||
| 
        
          
        
         | 
    @@ -139,59 +139,52 @@ async def init_check(cls) -> None: | |
| async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator: | ||
| client = self.client_class() | ||
| 
     | 
||
| async for orig_docref, docref, orig_clinical_note in self.read_notes(progress=progress): | ||
| async for orig_note, note, orig_note_text in self.read_notes(progress=progress): | ||
| try: | ||
| docref_id, encounter_id, subject_id = nlp.get_docref_info(docref) | ||
| note_ref, encounter_id, subject_id = nlp.get_note_info(note) | ||
| except KeyError as exc: | ||
| logging.warning(exc) | ||
| self.add_error(orig_docref) | ||
| self.add_error(orig_note) | ||
| continue | ||
| 
     | 
||
| clinical_note = self.remove_trailing_whitespace(orig_clinical_note) | ||
| note_text = self.remove_trailing_whitespace(orig_note_text) | ||
| orig_note_ref = f"{orig_note['resourceType']}/{orig_note['id']}" | ||
| 
     | 
||
| try: | ||
| completion_class = chat.ParsedChatCompletion[self.response_format] | ||
| response = await nlp.cache_wrapper( | ||
| self.task_config.dir_phi, | ||
| f"{self.name}_v{self.task_version}", | ||
| clinical_note, | ||
| note_text, | ||
| lambda x: completion_class.model_validate_json(x), # from file | ||
| lambda x: x.model_dump_json( # to file | ||
| indent=None, round_trip=True, exclude_unset=True, by_alias=True | ||
| ), | ||
| client.prompt, | ||
| self.system_prompt, | ||
| self.get_user_prompt(clinical_note), | ||
| self.get_user_prompt(note_text), | ||
| self.response_format, | ||
| ) | ||
| except openai.APIError as exc: | ||
| logging.warning( | ||
| f"Could not connect to NLP server for DocRef {orig_docref['id']}: {exc}" | ||
| ) | ||
| self.add_error(orig_docref) | ||
| continue | ||
| except pydantic.ValidationError as exc: | ||
| logging.warning( | ||
| f"Could not process answer from NLP server for DocRef {orig_docref['id']}: {exc}" | ||
| ) | ||
| self.add_error(orig_docref) | ||
| except Exception as exc: | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm surprised ruff didn't flag this - do you have the generic exception rule turned off? I think it's fine, i was just expecting a manual comment to disable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think our ruff config flags generic exceptions - I just confirmed it will flag bare excepts though ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah ok - yeah i guess i buy that as 'I am explicitly stating my intent here'. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bare excepts also have the problem of catching  But often linting tools don't like the generic version either for some programming purity reasons like "you should know what you're catching". But in the real world, you so often just want to say "hey I don't care what happened, I want to handle it and log it".  | 
||
| logging.warning(f"NLP failed for {orig_note_ref}: {exc}") | ||
| self.add_error(orig_note) | ||
| continue | ||
| 
     | 
||
| choice = response.choices[0] | ||
| 
     | 
||
| if choice.finish_reason != "stop" or not choice.message.parsed: | ||
| logging.warning( | ||
| f"NLP server response didn't complete for DocRef {orig_docref['id']}: " | ||
| f"NLP server response didn't complete for {orig_note_ref}: " | ||
| f"{choice.finish_reason}" | ||
| ) | ||
| self.add_error(orig_docref) | ||
| self.add_error(orig_note) | ||
| continue | ||
| 
     | 
||
| parsed = choice.message.parsed.model_dump(mode="json") | ||
| self.post_process(parsed, orig_clinical_note, orig_docref) | ||
| self.post_process(parsed, orig_note_text, orig_note) | ||
| 
     | 
||
| yield { | ||
| "note_ref": f"DocumentReference/{docref_id}", | ||
| "note_ref": note_ref, | ||
| "encounter_ref": f"Encounter/{encounter_id}", | ||
| "subject_ref": f"Patient/{subject_id}", | ||
| # Since this date is stored as a string, use UTC time for easy comparisons | ||
| 
        
          
        
         | 
    @@ -202,11 +195,11 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task | |
| } | ||
| 
     | 
||
| @classmethod | ||
| def get_user_prompt(cls, clinical_note: str) -> str: | ||
| def get_user_prompt(cls, note_text: str) -> str: | ||
| prompt = cls.user_prompt or "%CLINICAL-NOTE%" | ||
| return prompt.replace("%CLINICAL-NOTE%", clinical_note) | ||
| return prompt.replace("%CLINICAL-NOTE%", note_text) | ||
| 
     | 
||
| def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict) -> None: | ||
| def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None: | ||
| """Subclasses can fill this out if they like""" | ||
| 
     | 
||
| @classmethod | ||
| 
          
            
          
           | 
    @@ -261,7 +254,7 @@ class BaseOpenAiTaskWithSpans(BaseOpenAiTask): | |
| It assumes the field is named "spans" in the top level of the pydantic model. | ||
| """ | ||
| 
     | 
||
| def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict) -> None: | ||
| def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None: | ||
| new_spans = [] | ||
| missed_some = False | ||
| 
     | 
||
| 
        
          
        
         | 
    @@ -278,18 +271,18 @@ def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict) | |
| span = ESCAPED_WHITESPACE.sub(r"\\s+", span) | ||
| 
     | 
||
| found = False | ||
| for match in re.finditer(span, orig_clinical_note, re.IGNORECASE): | ||
| for match in re.finditer(span, orig_note_text, re.IGNORECASE): | ||
| found = True | ||
| new_spans.append(match.span()) | ||
| if not found: | ||
| missed_some = True | ||
| logging.warning( | ||
| "Could not match span received from NLP server for " | ||
| f"DocRef {orig_docref['id']}: {orig_span}" | ||
| f"{orig_note['resourceType']}/{orig_note['id']}: {orig_span}" | ||
| ) | ||
| 
     | 
||
| if missed_some: | ||
| self.add_error(orig_docref) | ||
| self.add_error(orig_note) | ||
| 
     | 
||
| parsed["spans"] = new_spans | ||
| 
     | 
||
| 
          
            
          
           | 
    ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fact that this change was needed here but not before, tells me we are regression testing the public release instead of the current branch. We have code to build the branch's docker image... not sure why it's not being used here. I can dig into it later, but just an FYI - it's on my todo list.