Skip to content

Commit 593cf0f

Browse files
authored
Merge pull request #467 from smart-on-fhir/mikix/note-ordering
irae: add an early-exit condition for NLP if patient is in an end state
2 parents e56cc69 + d9d207a commit 593cf0f

File tree

13 files changed

+321
-72
lines changed

13 files changed

+321
-72
lines changed

cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ async def covid_symptoms_extract(
3232
:return: list of NLP results encoded as FHIR observations
3333
"""
3434
try:
35-
note_ref, encounter_id, subject_id = nlp.get_note_info(docref)
35+
note_ref, encounter_id, subject_ref = nlp.get_note_info(docref)
3636
except KeyError as exc:
3737
logging.warning(exc)
3838
return None
@@ -108,7 +108,7 @@ def _make_covid_symptom_row(row_id: str, match: dict | None) -> dict:
108108
"id": row_id,
109109
"docref_id": docref_id,
110110
"encounter_id": encounter_id,
111-
"subject_id": subject_id,
111+
"subject_id": subject_ref.split("/")[-1],
112112
"generated_on": timestamp,
113113
"task_version": task_version,
114114
"match": match,

cumulus_etl/etl/studies/irae/irae_tasks.py

Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""Define tasks for the irae study"""
22

3+
import datetime
4+
import logging
5+
from collections.abc import Generator, Iterator
36
from enum import StrEnum
47

8+
import cumulus_fhir_support as cfs
59
from pydantic import BaseModel, Field
610

7-
from cumulus_etl import nlp
11+
from cumulus_etl import common, nlp, store
812
from cumulus_etl.etl import tasks
913

1014

@@ -453,61 +457,124 @@ class BaseIraeTask(tasks.BaseModelTaskWithSpans):
453457
)
454458

455459

456-
class IraeDonorGpt4oTask(BaseIraeTask):
460+
class BaseDonorIraeTask(BaseIraeTask):
461+
response_format = KidneyTransplantDonorGroupAnnotation
462+
463+
464+
class BaseLongitudinalIraeTask(BaseIraeTask):
465+
response_format = KidneyTransplantLongitudinalAnnotation
466+
467+
def __init__(self, *args, **kwargs):
468+
super().__init__(*args, **kwargs)
469+
self.subject_refs_to_skip = set()
470+
471+
@staticmethod
472+
def ndjson_in_order(input_root: store.Root, resource: str) -> Generator[dict]:
473+
# To avoid loading all the notes into memory, we'll first go through each note, and keep
474+
# track of their byte offset on disk and their date. Then we'll grab each from disk in
475+
# order.
476+
477+
# Get a list of all files we're going to be working with here
478+
filenames = common.ls_resources(input_root, {resource})
479+
480+
# Go through all files, keeping a record of each line's dates and offsets.
481+
note_info = []
482+
for file_index, path in enumerate(filenames):
483+
for row in cfs.read_multiline_json_with_details(path, fsspec_fs=input_root.fs):
484+
date = nlp.get_note_date(row["json"]) or datetime.datetime.max
485+
note_info.append((date, file_index, row["byte_offset"]))
486+
487+
# Now yield each note again in order, reading each from disk
488+
note_info.sort()
489+
for _date, file_index, offset in note_info:
490+
rows = cfs.read_multiline_json_with_details(
491+
filenames[file_index],
492+
offset=offset,
493+
fsspec_fs=input_root.fs,
494+
)
495+
# StopIteration errors shouldn't happen here, because we just went through these
496+
# files above, but just to be safe, we'll gracefully intercept it.
497+
try:
498+
yield next(rows)["json"]
499+
except StopIteration: # pragma: no cover
500+
logging.warning(
501+
f"File '{filenames[file_index]}' changed while reading, skipping some notes."
502+
)
503+
continue
504+
505+
# Override the read-from-disk portion, so we can order notes in oldest-to-newest order
506+
def read_ndjson_from_disk(self, input_root: store.Root, resource: str) -> Iterator[dict]:
507+
yield from self.ndjson_in_order(input_root, resource)
508+
509+
def should_skip(self, orig_note: dict) -> bool:
510+
subject_ref = nlp.get_note_subject_ref(orig_note)
511+
return subject_ref in self.subject_refs_to_skip or super().should_skip(orig_note)
512+
513+
def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
514+
super().post_process(parsed, orig_note_text, orig_note)
515+
516+
# If we have an annotation that asserts a graft failure or deceased,
517+
# we can stop processing charts for that patient, to avoid pointless NLP requests.
518+
519+
graft_failure = parsed.get("graft_failure_mention", {})
520+
is_failed = (
521+
graft_failure.get("has_mention")
522+
and graft_failure.get("graft_failure") == GraftFailurePresent.CONFIRMED
523+
)
524+
525+
deceased = parsed.get("deceased_mention", {})
526+
is_deceased = deceased.get("has_mention") and deceased.get("deceased")
527+
528+
if is_failed or is_deceased:
529+
if subject_ref := nlp.get_note_subject_ref(orig_note):
530+
self.subject_refs_to_skip.add(subject_ref)
531+
532+
533+
class IraeDonorGpt4oTask(BaseDonorIraeTask):
457534
name = "irae__nlp_donor_gpt4o"
458535
client_class = nlp.Gpt4oModel
459-
response_format = KidneyTransplantDonorGroupAnnotation
460536

461537

462-
class IraeLongitudinalGpt4oTask(BaseIraeTask):
538+
class IraeLongitudinalGpt4oTask(BaseLongitudinalIraeTask):
463539
name = "irae__nlp_gpt4o"
464540
client_class = nlp.Gpt4oModel
465-
response_format = KidneyTransplantLongitudinalAnnotation
466541

467542

468-
class IraeDonorGpt5Task(BaseIraeTask):
543+
class IraeDonorGpt5Task(BaseDonorIraeTask):
469544
name = "irae__nlp_donor_gpt5"
470545
client_class = nlp.Gpt5Model
471-
response_format = KidneyTransplantDonorGroupAnnotation
472546

473547

474-
class IraeLongitudinalGpt5Task(BaseIraeTask):
548+
class IraeLongitudinalGpt5Task(BaseLongitudinalIraeTask):
475549
name = "irae__nlp_gpt5"
476550
client_class = nlp.Gpt5Model
477-
response_format = KidneyTransplantLongitudinalAnnotation
478551

479552

480-
class IraeDonorGptOss120bTask(BaseIraeTask):
553+
class IraeDonorGptOss120bTask(BaseDonorIraeTask):
481554
name = "irae__nlp_donor_gpt_oss_120b"
482555
client_class = nlp.GptOss120bModel
483-
response_format = KidneyTransplantDonorGroupAnnotation
484556

485557

486-
class IraeLongitudinalGptOss120bTask(BaseIraeTask):
558+
class IraeLongitudinalGptOss120bTask(BaseLongitudinalIraeTask):
487559
name = "irae__nlp_gpt_oss_120b"
488560
client_class = nlp.GptOss120bModel
489-
response_format = KidneyTransplantLongitudinalAnnotation
490561

491562

492-
class IraeDonorLlama4ScoutTask(BaseIraeTask):
563+
class IraeDonorLlama4ScoutTask(BaseDonorIraeTask):
493564
name = "irae__nlp_donor_llama4_scout"
494565
client_class = nlp.Llama4ScoutModel
495-
response_format = KidneyTransplantDonorGroupAnnotation
496566

497567

498-
class IraeLongitudinalLlama4ScoutTask(BaseIraeTask):
568+
class IraeLongitudinalLlama4ScoutTask(BaseLongitudinalIraeTask):
499569
name = "irae__nlp_llama4_scout"
500570
client_class = nlp.Llama4ScoutModel
501-
response_format = KidneyTransplantLongitudinalAnnotation
502571

503572

504-
class IraeDonorClaudeSonnet45Task(BaseIraeTask):
573+
class IraeDonorClaudeSonnet45Task(BaseDonorIraeTask):
505574
name = "irae__nlp_donor_claude_sonnet45"
506575
client_class = nlp.ClaudeSonnet45Model
507-
response_format = KidneyTransplantDonorGroupAnnotation
508576

509577

510-
class IraeLongitudinalClaudeSonnet45Task(BaseIraeTask):
578+
class IraeLongitudinalClaudeSonnet45Task(BaseLongitudinalIraeTask):
511579
name = "irae__nlp_claude_sonnet45"
512580
client_class = nlp.ClaudeSonnet45Model
513-
response_format = KidneyTransplantLongitudinalAnnotation

cumulus_etl/etl/tasks/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,9 @@ def _write_errors(self, batch: formats.Batch, batch_index: int) -> None:
371371
#
372372
##########################################################################################
373373

374+
def read_ndjson_from_disk(self, input_root: store.Root, resource: str) -> Iterator[dict]:
375+
yield from common.read_resource_ndjson(input_root, resource)
376+
374377
def read_ndjson(
375378
self, *, progress: rich.progress.Progress | None = None, resources: list[str] | None = None
376379
) -> Iterator[dict]:
@@ -399,7 +402,7 @@ def read_ndjson(
399402
# You may want to process all linked resources first, and only then the "real" resource
400403
# (like we do for Medications and MedicationRequests).
401404
for resource in resources:
402-
for line in common.read_resource_ndjson(input_root, resource):
405+
for line in self.read_ndjson_from_disk(input_root, resource):
403406
yield line
404407
if progress:
405408
progress.advance(row_task)

cumulus_etl/etl/tasks/nlp_task.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,11 @@ async def init_check(cls) -> None:
142142

143143
async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator:
144144
async for orig_note, note, orig_note_text in self.read_notes(progress=progress):
145+
if self.should_skip(orig_note):
146+
continue
147+
145148
try:
146-
note_ref, encounter_id, subject_id = nlp.get_note_info(note)
149+
note_ref, encounter_id, subject_ref = nlp.get_note_info(note)
147150
except KeyError as exc:
148151
logging.warning(exc)
149152
self.add_error(orig_note)
@@ -172,7 +175,7 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
172175
yield {
173176
"note_ref": note_ref,
174177
"encounter_ref": f"Encounter/{encounter_id}",
175-
"subject_ref": f"Patient/{subject_id}",
178+
"subject_ref": subject_ref,
176179
# Since this date is stored as a string, use UTC time for easy comparisons
177180
"generated_on": common.datetime_now().isoformat(),
178181
"task_version": self.task_version,
@@ -218,6 +221,10 @@ def get_user_prompt(cls, note_text: str) -> str:
218221
prompt = cls.user_prompt or "%CLINICAL-NOTE%"
219222
return prompt.replace("%CLINICAL-NOTE%", note_text)
220223

224+
def should_skip(self, orig_note: dict) -> bool:
225+
"""Subclasses can fill this out if they like, to skip notes"""
226+
return False
227+
221228
def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
222229
"""Subclasses can fill this out if they like"""
223230

cumulus_etl/formats/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def get_format_class(name: str) -> type[Format]:
1616
try:
1717
return classes[name]
1818
except KeyError as exc:
19-
raise ValueError(f"Unknown output format name {name}.") from exc
19+
raise ValueError(f"Unknown output format name '{name}'.") from exc

cumulus_etl/nlp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
set_nlp_provider,
1515
)
1616
from .selection import CsvMatcher, add_note_selection, get_note_filter, query_athena_table
17-
from .utils import cache_wrapper, get_note_info, is_note_valid
17+
from .utils import cache_wrapper, get_note_date, get_note_info, get_note_subject_ref, is_note_valid
1818
from .watcher import (
1919
check_ctakes,
2020
check_negation_cnlpt,

cumulus_etl/nlp/utils.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Misc NLP functions"""
22

3+
import datetime
34
import hashlib
45
import os
56
from collections.abc import Callable
@@ -35,7 +36,7 @@ async def is_note_valid(codebook: deid.Codebook, note: dict) -> bool:
3536

3637
def get_note_info(note: dict) -> tuple[str, str, str]:
3738
"""
38-
Returns note_ref, encounter_id, subject_id for the given DocRef/DxReport.
39+
Returns note_ref, encounter_id, subject_ref for the given DocRef/DxReport.
3940
4041
Raises KeyError if any of them aren't present.
4142
"""
@@ -44,10 +45,43 @@ def get_note_info(note: dict) -> tuple[str, str, str]:
4445
if not encounters: # check for dxreport encounter field
4546
encounters = [note["encounter"]] if "encounter" in note else []
4647
if not encounters:
47-
raise KeyError(f"No encounters for note {note_ref}")
48+
raise KeyError(f"No encounters for note '{note_ref}'")
4849
_, encounter_id = fhir.unref_resource(encounters[0])
49-
_, subject_id = fhir.unref_resource(note["subject"])
50-
return note_ref, encounter_id, subject_id
50+
subject_ref = get_note_subject_ref(note)
51+
if not subject_ref:
52+
raise KeyError(f"No subject for note '{note_ref}'")
53+
return note_ref, encounter_id, subject_ref
54+
55+
56+
def get_note_subject_ref(note: dict) -> str | None:
57+
"""Returns the subject ref of a note, suitable for cross-referencing across notes"""
58+
try:
59+
subject_type, subject_id = fhir.unref_resource(note.get("subject"))
60+
except ValueError:
61+
return None
62+
63+
if subject_type:
64+
return f"{subject_type}/{subject_id}"
65+
else:
66+
# avoids dealing with contained refs or other oddities that won't match across notes
67+
return None
68+
69+
70+
def get_note_date(note: dict) -> datetime.datetime | None:
71+
"""Returns the date of a note - preferring clinical dates, then administrative ones"""
72+
if note.get("resourceType") == "DiagnosticReport":
73+
if time := fhir.parse_datetime(note.get("effectiveDateTime")):
74+
return time
75+
if time := fhir.parse_datetime(note.get("effectivePeriod", {}).get("start")):
76+
return time
77+
if time := fhir.parse_datetime(note.get("issued")):
78+
return time
79+
elif note.get("resourceType") == "DocumentReference":
80+
if time := fhir.parse_datetime(note.get("context", {}).get("period", {}).get("start")):
81+
return time
82+
if time := fhir.parse_datetime(note.get("date")):
83+
return time
84+
return None
5185

5286

5387
async def cache_wrapper(

cumulus_etl/upload_notes/cli.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,6 @@ async def gather_resources(
5555
)
5656

5757

58-
def datetime_from_resource(resource: dict) -> datetime.datetime | None:
59-
"""Returns the date of a resource - preferring clinical dates, then administrative ones"""
60-
if resource["resourceType"] == "DiagnosticReport":
61-
if time := fhir.parse_datetime(resource.get("effectiveDateTime")):
62-
return time
63-
if time := fhir.parse_datetime(resource.get("effectivePeriod", {}).get("start")):
64-
return time
65-
if time := fhir.parse_datetime(resource.get("issued")):
66-
return time
67-
elif resource["resourceType"] == "DocumentReference":
68-
if time := fhir.parse_datetime(resource.get("context", {}).get("period", {}).get("start")):
69-
return time
70-
if time := fhir.parse_datetime(resource.get("date")):
71-
return time
72-
return None
73-
74-
7558
def _get_encounter_id(resource: dict) -> str | None:
7659
encounter_ref = None
7760
if resource["resourceType"] == "DiagnosticReport":
@@ -158,7 +141,7 @@ async def read_notes_from_ndjson(
158141
doc_spans=doc_spans,
159142
title=title,
160143
text=text,
161-
date=datetime_from_resource(resource),
144+
date=nlp.get_note_date(resource),
162145
)
163146
)
164147

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies = [
1010
"aiobotocore[boto3] >= 2.14.0",
1111
"boto3 >= 1.34.131",
1212
"ctakesclient >= 5.1",
13-
"cumulus-fhir-support >= 1.6",
13+
"cumulus-fhir-support >= 1.8",
1414
"delta-spark >= 4, < 5",
1515
"fsspec[http,s3]",
1616
"httpx",

tests/etl/test_etl_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class TestJobContext(utils.AsyncTestCase):
1414
def test_missing_file_context(self):
1515
context = JobContext("nope")
1616
self.assertEqual({}, context.as_json())
17+
self.assertIsNone(context.last_successful_datetime)
1718

1819
def test_save_and_load(self):
1920
with tempfile.NamedTemporaryFile(mode="w+") as f:

0 commit comments

Comments
 (0)