Skip to content

Commit 03929c6

Browse files
authored
Merge pull request #435 from smart-on-fhir/mikix/nlp-dxreports
feat: have NLP tasks read in DxReports as well as DocRefs
2 parents 3cd9947 + 5160616 commit 03929c6

File tree

13 files changed

+149
-104
lines changed

13 files changed

+149
-104
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,10 @@ jobs:
102102
docker compose run --rm \
103103
--volume $DATADIR:/in \
104104
cumulus-etl \
105+
nlp \
105106
/in/input \
106107
/in/run-output \
107108
/in/phi \
108-
--export-group nlp-test \
109-
--export-timestamp 2024-08-29 \
110109
--output-format=ndjson \
111110
--task covid_symptom__nlp_results
112111

cumulus_etl/etl/nlp/cli.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,9 @@ def get_cohort_filter(args: argparse.Namespace) -> Callable[[deid.Codebook, dict
112112

113113
def res_filter(codebook: deid.Codebook, resource: dict) -> bool:
114114
match resource["resourceType"]:
115-
# TODO: uncomment once we support DxReport NLP (coming soon)
116-
# case "DiagnosticReport":
117-
# id_pool = dxreport_ids
118-
# patient_ref = resource.get("subject", {}).get("reference")
115+
case "DiagnosticReport":
116+
id_pool = dxreport_ids
117+
patient_ref = resource.get("subject", {}).get("reference")
119118
case "DocumentReference":
120119
id_pool = docref_ids
121120
patient_ref = resource.get("subject", {}).get("reference")

cumulus_etl/etl/pipeline.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ async def check_available_resources(
120120
requested_resources: set[str],
121121
args: argparse.Namespace,
122122
is_default_tasks: bool,
123+
nlp: bool,
123124
) -> set[str]:
124125
# Here we try to reconcile which resources the user requested and which resources are actually
125126
# available in the input root.
@@ -138,25 +139,28 @@ async def check_available_resources(
138139
if detected is None:
139140
return requested_resources # likely we haven't run bulk export yet
140141

141-
if missing_resources := requested_resources - detected:
142+
missing_resources = requested_resources - detected
143+
available_resources = requested_resources & detected
144+
145+
if nlp and available_resources:
146+
# As long as there is any resource for NLP to read from, we'll take it
147+
return available_resources
148+
149+
if missing_resources:
142150
for resource in sorted(missing_resources):
143151
# Log the same message we would print if in common.py if we ran tasks anyway
144152
logging.warning("No %s files found in %s", resource, loader.root.path)
145153

146154
if is_default_tasks:
147-
requested_resources -= missing_resources # scope down to detected resources
148-
if not requested_resources:
149-
errors.fatal(
150-
"No supported resources found.",
151-
errors.MISSING_REQUESTED_RESOURCES,
152-
)
155+
if not available_resources:
156+
errors.fatal("No supported resources found.", errors.MISSING_REQUESTED_RESOURCES)
153157
else:
154158
msg = "Required resources not found.\n"
155159
if has_allow_missing:
156160
msg += "Add --allow-missing-resources to run related tasks anyway with no input."
157161
errors.fatal(msg, errors.MISSING_REQUESTED_RESOURCES)
158162

159-
return requested_resources
163+
return available_resources
160164

161165

162166
async def run_pipeline(
@@ -191,8 +195,8 @@ async def run_pipeline(
191195
for task in selected_tasks:
192196
await task.init_check()
193197

194-
# Grab a list of all required resource types for the tasks we are running
195-
required_resources = set(t.resource for t in selected_tasks)
198+
# Combine all task resource sets into one big set of required resources
199+
required_resources = set().union(*(t.get_resource_types() for t in selected_tasks))
196200

197201
# Create a client to talk to a FHIR server.
198202
# This is useful even if we aren't doing a bulk export, because some resources like
@@ -214,9 +218,10 @@ async def run_pipeline(
214218
args=args,
215219
is_default_tasks=is_default_tasks,
216220
requested_resources=required_resources,
221+
nlp=nlp,
217222
)
218223
# Drop any tasks that we didn't find resources for
219-
selected_tasks = [t for t in selected_tasks if t.resource in required_resources]
224+
selected_tasks = [t for t in selected_tasks if t.get_resource_types() & required_resources]
220225

221226
# Load resources from a remote location (like s3), convert from i2b2, or do a bulk export
222227
loader_results = await config_loader.load_resources(required_resources)

cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ async def covid_symptoms_extract(
3232
:return: list of NLP results encoded as FHIR observations
3333
"""
3434
try:
35-
docref_id, encounter_id, subject_id = nlp.get_docref_info(docref)
35+
note_ref, encounter_id, subject_id = nlp.get_note_info(docref)
3636
except KeyError as exc:
3737
logging.warning(exc)
3838
return None
@@ -62,7 +62,7 @@ async def covid_symptoms_extract(
6262
)
6363
except Exception as exc:
6464
logging.warning(
65-
"Could not extract symptoms for docref %s (%s): %s", docref_id, type(exc).__name__, exc
65+
"Could not extract symptoms for %s (%s): %s", note_ref, type(exc).__name__, exc
6666
)
6767
return None
6868

@@ -95,10 +95,13 @@ def is_covid_match(m: ctakesclient.typesystem.MatchText):
9595
)
9696
except Exception as exc:
9797
logging.warning(
98-
"Could not check polarity for docref %s (%s): %s", docref_id, type(exc).__name__, exc
98+
"Could not check polarity for %s (%s): %s", note_ref, type(exc).__name__, exc
9999
)
100100
return None
101101

102+
# We only look at docrefs - get just the ID for use in the symptom fields
103+
docref_id = note_ref.removeprefix("DocumentReference/")
104+
102105
# Helper to make a single row (match_value is None if there were no found symptoms at all)
103106
def _make_covid_symptom_row(row_id: str, match: dict | None) -> dict:
104107
return {

cumulus_etl/etl/studies/covid_symptom/covid_tasks.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,11 @@ def is_ed_coding(coding):
6565
return coding.get("code") in ED_CODES.get(coding.get("system"), {})
6666

6767

68-
def is_ed_docref(docref):
68+
def is_ed_docref(docref) -> bool:
6969
"""Returns true if this is a coding for an emergency department note"""
70+
if docref["resourceType"] != "DocumentReference":
71+
return False
72+
7073
# We check both type and category for safety -- we aren't sure yet how EHRs are using these fields.
7174
codings = list(
7275
itertools.chain.from_iterable([cat.get("coding", []) for cat in docref.get("category", [])])

cumulus_etl/etl/tasks/base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class EtlTask:
8989
# Properties:
9090
name: ClassVar[str] = None # task & table name
9191
# incoming resource that this task operates on (will be included in bulk exports etc)
92-
resource: ClassVar[str] = None
92+
resource: ClassVar[str | set[str]] = None
9393
tags: ClassVar[set[str]] = []
9494
# whether this task needs bulk MS tool de-id run on its inputs (NLP tasks usually don't)
9595
needs_bulk_deid: ClassVar[bool] = True
@@ -378,10 +378,11 @@ def read_ndjson(
378378
379379
If `resources` is provided, those resources will be read (in the provided order).
380380
That is, ["Condition", "Encounter"] will first read all Conditions, then all Encounters.
381-
If `resources` is not provided, the task's main resource (self.resource) will be used.
381+
If `resources` is not provided, the task's main resources (via self.get_resource_types())
382+
will be used.
382383
"""
383384
input_root = store.Root(self.task_config.dir_input)
384-
resources = resources or [self.resource]
385+
resources = resources or sorted(self.get_resource_types())
385386

386387
if progress:
387388
# Make new task to track processing of rows
@@ -472,3 +473,10 @@ def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Sche
472473
if resource_type:
473474
return cfs.pyarrow_schema_from_rows(resource_type, rows)
474475
return None
476+
477+
@classmethod
478+
def get_resource_types(cls) -> set[str]:
479+
"""Abstracts whether the class's resource field is a str or a set of strings."""
480+
if isinstance(cls.resource, str):
481+
return {cls.resource}
482+
return set(cls.resource)

cumulus_etl/etl/tasks/nlp_task.py

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from typing import ClassVar
1212

1313
import cumulus_fhir_support as cfs
14-
import openai
1514
import pyarrow
1615
import pydantic
1716
import rich.progress
@@ -27,7 +26,7 @@
2726
class BaseNlpTask(tasks.EtlTask):
2827
"""Base class for any clinical-notes-based NLP task."""
2928

30-
resource: ClassVar = "DocumentReference"
29+
resource: ClassVar = {"DiagnosticReport", "DocumentReference"}
3130
needs_bulk_deid: ClassVar = False
3231

3332
# You may want to override these in your subclass
@@ -80,44 +79,45 @@ async def read_notes(
8079
"""
8180
Iterate through clinical notes.
8281
83-
:returns: a tuple of original-docref, scrubbed-docref, and clinical note
82+
:returns: a tuple of original-resource, scrubbed-resource, and note text
8483
"""
8584
warned_connection_error = False
8685

87-
note_filter = self.task_config.resource_filter or nlp.is_docref_valid
86+
note_filter = self.task_config.resource_filter or nlp.is_note_valid
8887

89-
for docref in self.read_ndjson(progress=progress):
90-
orig_docref = copy.deepcopy(docref)
88+
for note in self.read_ndjson(progress=progress):
89+
orig_note = copy.deepcopy(note)
9190
can_process = (
92-
note_filter(self.scrubber.codebook, docref)
93-
and (doc_check is None or doc_check(docref))
94-
and self.scrubber.scrub_resource(docref, scrub_attachments=False, keep_stats=False)
91+
note_filter(self.scrubber.codebook, note)
92+
and (doc_check is None or doc_check(note))
93+
and self.scrubber.scrub_resource(note, scrub_attachments=False, keep_stats=False)
9594
)
9695
if not can_process:
9796
continue
9897

9998
try:
100-
clinical_note = await fhir.get_clinical_note(self.task_config.client, docref)
99+
note_text = await fhir.get_clinical_note(self.task_config.client, note)
101100
except cfs.BadAuthArguments as exc:
102101
if not warned_connection_error:
103102
# Only warn user about a misconfiguration once per task.
104103
# It's not fatal because it might be intentional (partially inlined DocRefs
105104
# and the other DocRefs are known failures - BCH hits this with Cerner data).
106105
print(exc, file=sys.stderr)
107106
warned_connection_error = True
108-
self.add_error(orig_docref)
107+
self.add_error(orig_note)
109108
continue
110109
except Exception as exc:
111-
logging.warning("Error getting text for docref %s: %s", docref["id"], exc)
112-
self.add_error(orig_docref)
110+
orig_note_ref = f"{orig_note['resourceType']}/{orig_note['id']}"
111+
logging.warning("Error getting text for note %s: %s", orig_note_ref, exc)
112+
self.add_error(orig_note)
113113
continue
114114

115-
yield orig_docref, docref, clinical_note
115+
yield orig_note, note, note_text
116116

117117
@staticmethod
118-
def remove_trailing_whitespace(note: str) -> str:
118+
def remove_trailing_whitespace(note_text: str) -> str:
119119
"""Sometimes NLP can be mildly confused by trailing whitespace, so this removes it"""
120-
return TRAILING_WHITESPACE.sub("", note)
120+
return TRAILING_WHITESPACE.sub("", note_text)
121121

122122

123123
class BaseOpenAiTask(BaseNlpTask):
@@ -139,59 +139,52 @@ async def init_check(cls) -> None:
139139
async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator:
140140
client = self.client_class()
141141

142-
async for orig_docref, docref, orig_clinical_note in self.read_notes(progress=progress):
142+
async for orig_note, note, orig_note_text in self.read_notes(progress=progress):
143143
try:
144-
docref_id, encounter_id, subject_id = nlp.get_docref_info(docref)
144+
note_ref, encounter_id, subject_id = nlp.get_note_info(note)
145145
except KeyError as exc:
146146
logging.warning(exc)
147-
self.add_error(orig_docref)
147+
self.add_error(orig_note)
148148
continue
149149

150-
clinical_note = self.remove_trailing_whitespace(orig_clinical_note)
150+
note_text = self.remove_trailing_whitespace(orig_note_text)
151+
orig_note_ref = f"{orig_note['resourceType']}/{orig_note['id']}"
151152

152153
try:
153154
completion_class = chat.ParsedChatCompletion[self.response_format]
154155
response = await nlp.cache_wrapper(
155156
self.task_config.dir_phi,
156157
f"{self.name}_v{self.task_version}",
157-
clinical_note,
158+
note_text,
158159
lambda x: completion_class.model_validate_json(x), # from file
159160
lambda x: x.model_dump_json( # to file
160161
indent=None, round_trip=True, exclude_unset=True, by_alias=True
161162
),
162163
client.prompt,
163164
self.system_prompt,
164-
self.get_user_prompt(clinical_note),
165+
self.get_user_prompt(note_text),
165166
self.response_format,
166167
)
167-
except openai.APIError as exc:
168-
logging.warning(
169-
f"Could not connect to NLP server for DocRef {orig_docref['id']}: {exc}"
170-
)
171-
self.add_error(orig_docref)
172-
continue
173-
except pydantic.ValidationError as exc:
174-
logging.warning(
175-
f"Could not process answer from NLP server for DocRef {orig_docref['id']}: {exc}"
176-
)
177-
self.add_error(orig_docref)
168+
except Exception as exc:
169+
logging.warning(f"NLP failed for {orig_note_ref}: {exc}")
170+
self.add_error(orig_note)
178171
continue
179172

180173
choice = response.choices[0]
181174

182175
if choice.finish_reason != "stop" or not choice.message.parsed:
183176
logging.warning(
184-
f"NLP server response didn't complete for DocRef {orig_docref['id']}: "
177+
f"NLP server response didn't complete for {orig_note_ref}: "
185178
f"{choice.finish_reason}"
186179
)
187-
self.add_error(orig_docref)
180+
self.add_error(orig_note)
188181
continue
189182

190183
parsed = choice.message.parsed.model_dump(mode="json")
191-
self.post_process(parsed, orig_clinical_note, orig_docref)
184+
self.post_process(parsed, orig_note_text, orig_note)
192185

193186
yield {
194-
"note_ref": f"DocumentReference/{docref_id}",
187+
"note_ref": note_ref,
195188
"encounter_ref": f"Encounter/{encounter_id}",
196189
"subject_ref": f"Patient/{subject_id}",
197190
# Since this date is stored as a string, use UTC time for easy comparisons
@@ -202,11 +195,11 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
202195
}
203196

204197
@classmethod
205-
def get_user_prompt(cls, clinical_note: str) -> str:
198+
def get_user_prompt(cls, note_text: str) -> str:
206199
prompt = cls.user_prompt or "%CLINICAL-NOTE%"
207-
return prompt.replace("%CLINICAL-NOTE%", clinical_note)
200+
return prompt.replace("%CLINICAL-NOTE%", note_text)
208201

209-
def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict) -> None:
202+
def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
210203
"""Subclasses can fill this out if they like"""
211204

212205
@classmethod
@@ -261,7 +254,7 @@ class BaseOpenAiTaskWithSpans(BaseOpenAiTask):
261254
It assumes the field is named "spans" in the top level of the pydantic model.
262255
"""
263256

264-
def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict) -> None:
257+
def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
265258
new_spans = []
266259
missed_some = False
267260

@@ -278,18 +271,18 @@ def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict)
278271
span = ESCAPED_WHITESPACE.sub(r"\\s+", span)
279272

280273
found = False
281-
for match in re.finditer(span, orig_clinical_note, re.IGNORECASE):
274+
for match in re.finditer(span, orig_note_text, re.IGNORECASE):
282275
found = True
283276
new_spans.append(match.span())
284277
if not found:
285278
missed_some = True
286279
logging.warning(
287280
"Could not match span received from NLP server for "
288-
f"DocRef {orig_docref['id']}: {orig_span}"
281+
f"{orig_note['resourceType']}/{orig_note['id']}: {orig_span}"
289282
)
290283

291284
if missed_some:
292-
self.add_error(orig_docref)
285+
self.add_error(orig_note)
293286

294287
parsed["spans"] = new_spans
295288

cumulus_etl/export/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ async def export_main(args: argparse.Namespace) -> None:
3232
store.set_user_fs_options(vars(args))
3333

3434
selected_tasks = task_factory.get_selected_tasks(args.task)
35-
required_resources = {t.resource for t in selected_tasks}
35+
# Combine all task resource sets into one big set of required resources
36+
required_resources = set().union(*(t.get_resource_types() for t in selected_tasks))
3637
using_default_tasks = not args.task
3738

3839
# Fold in manually specified --type args (very similar to --task, but more familiar to folks

cumulus_etl/nlp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .extract import TransformerModel, ctakes_extract, ctakes_httpx_client, list_polarity
44
from .openai import Gpt4Model, Gpt4oModel, Gpt5Model, Gpt35Model, GptOss120bModel, Llama4ScoutModel
5-
from .utils import cache_wrapper, get_docref_info, is_docref_valid
5+
from .utils import cache_wrapper, get_note_info, is_note_valid
66
from .watcher import (
77
check_ctakes,
88
check_negation_cnlpt,

0 commit comments

Comments
 (0)