Skip to content

Commit 7a26b11

Browse files
committed
feat: have NLP tasks read in DxReports as well as DocRefs
This commit adds support for dual-resource tasks and then adds DiagnosticReport to the NLP base task class. This required some vocabulary alignment, as we used "docref" a lot in places that now can take a docref or a dxreport. - "note" or "note resource": a DocRef or DxReport resource (dict) - "note text" or "text": the clinical text inside the note
1 parent 3cd9947 commit 7a26b11

File tree

11 files changed

+145
-93
lines changed

11 files changed

+145
-93
lines changed

cumulus_etl/etl/nlp/cli.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,9 @@ def get_cohort_filter(args: argparse.Namespace) -> Callable[[deid.Codebook, dict
112112

113113
def res_filter(codebook: deid.Codebook, resource: dict) -> bool:
114114
match resource["resourceType"]:
115-
# TODO: uncomment once we support DxReport NLP (coming soon)
116-
# case "DiagnosticReport":
117-
# id_pool = dxreport_ids
118-
# patient_ref = resource.get("subject", {}).get("reference")
115+
case "DiagnosticReport":
116+
id_pool = dxreport_ids
117+
patient_ref = resource.get("subject", {}).get("reference")
119118
case "DocumentReference":
120119
id_pool = docref_ids
121120
patient_ref = resource.get("subject", {}).get("reference")

cumulus_etl/etl/pipeline.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ async def check_available_resources(
120120
requested_resources: set[str],
121121
args: argparse.Namespace,
122122
is_default_tasks: bool,
123+
nlp: bool,
123124
) -> set[str]:
124125
# Here we try to reconcile which resources the user requested and which resources are actually
125126
# available in the input root.
@@ -138,25 +139,28 @@ async def check_available_resources(
138139
if detected is None:
139140
return requested_resources # likely we haven't run bulk export yet
140141

141-
if missing_resources := requested_resources - detected:
142+
missing_resources = requested_resources - detected
143+
available_resources = requested_resources & detected
144+
145+
if nlp and available_resources:
146+
# As long as there is any resource for NLP to read from, we'll take it
147+
return available_resources
148+
149+
if missing_resources:
142150
for resource in sorted(missing_resources):
143151
# Log the same message we would print if in common.py if we ran tasks anyway
144152
logging.warning("No %s files found in %s", resource, loader.root.path)
145153

146154
if is_default_tasks:
147-
requested_resources -= missing_resources # scope down to detected resources
148-
if not requested_resources:
149-
errors.fatal(
150-
"No supported resources found.",
151-
errors.MISSING_REQUESTED_RESOURCES,
152-
)
155+
if not available_resources:
156+
errors.fatal("No supported resources found.", errors.MISSING_REQUESTED_RESOURCES)
153157
else:
154158
msg = "Required resources not found.\n"
155159
if has_allow_missing:
156160
msg += "Add --allow-missing-resources to run related tasks anyway with no input."
157161
errors.fatal(msg, errors.MISSING_REQUESTED_RESOURCES)
158162

159-
return requested_resources
163+
return available_resources
160164

161165

162166
async def run_pipeline(
@@ -192,7 +196,7 @@ async def run_pipeline(
192196
await task.init_check()
193197

194198
# Grab a list of all required resource types for the tasks we are running
195-
required_resources = set(t.resource for t in selected_tasks)
199+
required_resources = set().union(*(t.get_resource_types() for t in selected_tasks))
196200

197201
# Create a client to talk to a FHIR server.
198202
# This is useful even if we aren't doing a bulk export, because some resources like
@@ -214,9 +218,10 @@ async def run_pipeline(
214218
args=args,
215219
is_default_tasks=is_default_tasks,
216220
requested_resources=required_resources,
221+
nlp=nlp,
217222
)
218223
# Drop any tasks that we didn't find resources for
219-
selected_tasks = [t for t in selected_tasks if t.resource in required_resources]
224+
selected_tasks = [t for t in selected_tasks if t.get_resource_types() & required_resources]
220225

221226
# Load resources from a remote location (like s3), convert from i2b2, or do a bulk export
222227
loader_results = await config_loader.load_resources(required_resources)

cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ async def covid_symptoms_extract(
3232
:return: list of NLP results encoded as FHIR observations
3333
"""
3434
try:
35-
docref_id, encounter_id, subject_id = nlp.get_docref_info(docref)
35+
note_ref, encounter_id, subject_id = nlp.get_note_info(docref)
3636
except KeyError as exc:
3737
logging.warning(exc)
3838
return None
@@ -62,7 +62,7 @@ async def covid_symptoms_extract(
6262
)
6363
except Exception as exc:
6464
logging.warning(
65-
"Could not extract symptoms for docref %s (%s): %s", docref_id, type(exc).__name__, exc
65+
"Could not extract symptoms for %s (%s): %s", note_ref, type(exc).__name__, exc
6666
)
6767
return None
6868

@@ -95,10 +95,13 @@ def is_covid_match(m: ctakesclient.typesystem.MatchText):
9595
)
9696
except Exception as exc:
9797
logging.warning(
98-
"Could not check polarity for docref %s (%s): %s", docref_id, type(exc).__name__, exc
98+
"Could not check polarity for %s (%s): %s", note_ref, type(exc).__name__, exc
9999
)
100100
return None
101101

102+
# We only look at docrefs - get just the ID for use in the symptom fields
103+
docref_id = note_ref.removeprefix("DocumentReference/")
104+
102105
# Helper to make a single row (match_value is None if there were no found symptoms at all)
103106
def _make_covid_symptom_row(row_id: str, match: dict | None) -> dict:
104107
return {

cumulus_etl/etl/studies/covid_symptom/covid_tasks.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,11 @@ def is_ed_coding(coding):
6565
return coding.get("code") in ED_CODES.get(coding.get("system"), {})
6666

6767

68-
def is_ed_docref(docref):
68+
def is_ed_docref(docref) -> bool:
6969
"""Returns true if this is a coding for an emergency department note"""
70+
if docref["resourceType"] != "DocumentReference":
71+
return False
72+
7073
# We check both type and category for safety -- we aren't sure yet how EHRs are using these fields.
7174
codings = list(
7275
itertools.chain.from_iterable([cat.get("coding", []) for cat in docref.get("category", [])])

cumulus_etl/etl/tasks/base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class EtlTask:
8989
# Properties:
9090
name: ClassVar[str] = None # task & table name
9191
# incoming resource that this task operates on (will be included in bulk exports etc)
92-
resource: ClassVar[str] = None
92+
resource: ClassVar[str | set[str]] = None
9393
tags: ClassVar[set[str]] = []
9494
# whether this task needs bulk MS tool de-id run on its inputs (NLP tasks usually don't)
9595
needs_bulk_deid: ClassVar[bool] = True
@@ -378,10 +378,11 @@ def read_ndjson(
378378
379379
If `resources` is provided, those resources will be read (in the provided order).
380380
That is, ["Condition", "Encounter"] will first read all Conditions, then all Encounters.
381-
If `resources` is not provided, the task's main resource (self.resource) will be used.
381+
If `resources` is not provided, the task's main resources (via self.get_resource_types())
382+
will be used.
382383
"""
383384
input_root = store.Root(self.task_config.dir_input)
384-
resources = resources or [self.resource]
385+
resources = resources or sorted(self.get_resource_types())
385386

386387
if progress:
387388
# Make new task to track processing of rows
@@ -472,3 +473,10 @@ def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Sche
472473
if resource_type:
473474
return cfs.pyarrow_schema_from_rows(resource_type, rows)
474475
return None
476+
477+
@classmethod
478+
def get_resource_types(cls) -> set[str]:
479+
"""Abstracts whether the class's resource field is a str or a set of strings."""
480+
if isinstance(cls.resource, str):
481+
return {cls.resource}
482+
return set(cls.resource)

cumulus_etl/etl/tasks/nlp_task.py

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
class BaseNlpTask(tasks.EtlTask):
2828
"""Base class for any clinical-notes-based NLP task."""
2929

30-
resource: ClassVar = "DocumentReference"
30+
resource: ClassVar = {"DiagnosticReport", "DocumentReference"}
3131
needs_bulk_deid: ClassVar = False
3232

3333
# You may want to override these in your subclass
@@ -80,44 +80,45 @@ async def read_notes(
8080
"""
8181
Iterate through clinical notes.
8282
83-
:returns: a tuple of original-docref, scrubbed-docref, and clinical note
83+
:returns: a tuple of original-resource, scrubbed-resource, and note text
8484
"""
8585
warned_connection_error = False
8686

87-
note_filter = self.task_config.resource_filter or nlp.is_docref_valid
87+
note_filter = self.task_config.resource_filter or nlp.is_note_valid
8888

89-
for docref in self.read_ndjson(progress=progress):
90-
orig_docref = copy.deepcopy(docref)
89+
for note in self.read_ndjson(progress=progress):
90+
orig_note = copy.deepcopy(note)
9191
can_process = (
92-
note_filter(self.scrubber.codebook, docref)
93-
and (doc_check is None or doc_check(docref))
94-
and self.scrubber.scrub_resource(docref, scrub_attachments=False, keep_stats=False)
92+
note_filter(self.scrubber.codebook, note)
93+
and (doc_check is None or doc_check(note))
94+
and self.scrubber.scrub_resource(note, scrub_attachments=False, keep_stats=False)
9595
)
9696
if not can_process:
9797
continue
9898

9999
try:
100-
clinical_note = await fhir.get_clinical_note(self.task_config.client, docref)
100+
note_text = await fhir.get_clinical_note(self.task_config.client, note)
101101
except cfs.BadAuthArguments as exc:
102102
if not warned_connection_error:
103103
# Only warn user about a misconfiguration once per task.
104104
# It's not fatal because it might be intentional (partially inlined DocRefs
105105
# and the other DocRefs are known failures - BCH hits this with Cerner data).
106106
print(exc, file=sys.stderr)
107107
warned_connection_error = True
108-
self.add_error(orig_docref)
108+
self.add_error(orig_note)
109109
continue
110110
except Exception as exc:
111-
logging.warning("Error getting text for docref %s: %s", docref["id"], exc)
112-
self.add_error(orig_docref)
111+
orig_note_ref = f"{orig_note['resourceType']}/{orig_note['id']}"
112+
logging.warning("Error getting text for note %s: %s", orig_note_ref, exc)
113+
self.add_error(orig_note)
113114
continue
114115

115-
yield orig_docref, docref, clinical_note
116+
yield orig_note, note, note_text
116117

117118
@staticmethod
118-
def remove_trailing_whitespace(note: str) -> str:
119+
def remove_trailing_whitespace(note_text: str) -> str:
119120
"""Sometimes NLP can be mildly confused by trailing whitespace, so this removes it"""
120-
return TRAILING_WHITESPACE.sub("", note)
121+
return TRAILING_WHITESPACE.sub("", note_text)
121122

122123

123124
class BaseOpenAiTask(BaseNlpTask):
@@ -139,59 +140,52 @@ async def init_check(cls) -> None:
139140
async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator:
140141
client = self.client_class()
141142

142-
async for orig_docref, docref, orig_clinical_note in self.read_notes(progress=progress):
143+
async for orig_note, note, orig_note_text in self.read_notes(progress=progress):
143144
try:
144-
docref_id, encounter_id, subject_id = nlp.get_docref_info(docref)
145+
note_ref, encounter_id, subject_id = nlp.get_note_info(note)
145146
except KeyError as exc:
146147
logging.warning(exc)
147-
self.add_error(orig_docref)
148+
self.add_error(orig_note)
148149
continue
149150

150-
clinical_note = self.remove_trailing_whitespace(orig_clinical_note)
151+
note_text = self.remove_trailing_whitespace(orig_note_text)
152+
orig_note_ref = f"{orig_note['resourceType']}/{orig_note['id']}"
151153

152154
try:
153155
completion_class = chat.ParsedChatCompletion[self.response_format]
154156
response = await nlp.cache_wrapper(
155157
self.task_config.dir_phi,
156158
f"{self.name}_v{self.task_version}",
157-
clinical_note,
159+
note_text,
158160
lambda x: completion_class.model_validate_json(x), # from file
159161
lambda x: x.model_dump_json( # to file
160162
indent=None, round_trip=True, exclude_unset=True, by_alias=True
161163
),
162164
client.prompt,
163165
self.system_prompt,
164-
self.get_user_prompt(clinical_note),
166+
self.get_user_prompt(note_text),
165167
self.response_format,
166168
)
167-
except openai.APIError as exc:
168-
logging.warning(
169-
f"Could not connect to NLP server for DocRef {orig_docref['id']}: {exc}"
170-
)
171-
self.add_error(orig_docref)
172-
continue
173-
except pydantic.ValidationError as exc:
174-
logging.warning(
175-
f"Could not process answer from NLP server for DocRef {orig_docref['id']}: {exc}"
176-
)
177-
self.add_error(orig_docref)
169+
except Exception as exc:
170+
logging.warning(f"NLP failed for {orig_note_ref}: {exc}")
171+
self.add_error(orig_note)
178172
continue
179173

180174
choice = response.choices[0]
181175

182176
if choice.finish_reason != "stop" or not choice.message.parsed:
183177
logging.warning(
184-
f"NLP server response didn't complete for DocRef {orig_docref['id']}: "
178+
f"NLP server response didn't complete for {orig_note_ref}: "
185179
f"{choice.finish_reason}"
186180
)
187-
self.add_error(orig_docref)
181+
self.add_error(orig_note)
188182
continue
189183

190184
parsed = choice.message.parsed.model_dump(mode="json")
191-
self.post_process(parsed, orig_clinical_note, orig_docref)
185+
self.post_process(parsed, orig_note_text, orig_note)
192186

193187
yield {
194-
"note_ref": f"DocumentReference/{docref_id}",
188+
"note_ref": note_ref,
195189
"encounter_ref": f"Encounter/{encounter_id}",
196190
"subject_ref": f"Patient/{subject_id}",
197191
# Since this date is stored as a string, use UTC time for easy comparisons
@@ -202,11 +196,11 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
202196
}
203197

204198
@classmethod
205-
def get_user_prompt(cls, clinical_note: str) -> str:
199+
def get_user_prompt(cls, note_text: str) -> str:
206200
prompt = cls.user_prompt or "%CLINICAL-NOTE%"
207-
return prompt.replace("%CLINICAL-NOTE%", clinical_note)
201+
return prompt.replace("%CLINICAL-NOTE%", note_text)
208202

209-
def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict) -> None:
203+
def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
210204
"""Subclasses can fill this out if they like"""
211205

212206
@classmethod
@@ -261,7 +255,7 @@ class BaseOpenAiTaskWithSpans(BaseOpenAiTask):
261255
It assumes the field is named "spans" in the top level of the pydantic model.
262256
"""
263257

264-
def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict) -> None:
258+
def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
265259
new_spans = []
266260
missed_some = False
267261

@@ -278,18 +272,18 @@ def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict)
278272
span = ESCAPED_WHITESPACE.sub(r"\\s+", span)
279273

280274
found = False
281-
for match in re.finditer(span, orig_clinical_note, re.IGNORECASE):
275+
for match in re.finditer(span, orig_note_text, re.IGNORECASE):
282276
found = True
283277
new_spans.append(match.span())
284278
if not found:
285279
missed_some = True
286280
logging.warning(
287281
"Could not match span received from NLP server for "
288-
f"DocRef {orig_docref['id']}: {orig_span}"
282+
f"{orig_note['resourceType']}/{orig_note['id']}: {orig_span}"
289283
)
290284

291285
if missed_some:
292-
self.add_error(orig_docref)
286+
self.add_error(orig_note)
293287

294288
parsed["spans"] = new_spans
295289

cumulus_etl/export/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Do a standalone bulk export from an EHR"""
22

33
import argparse
4+
import itertools
45
import sys
56

67
from cumulus_etl import cli_utils, common, errors, fhir, loaders, store
@@ -32,7 +33,7 @@ async def export_main(args: argparse.Namespace) -> None:
3233
store.set_user_fs_options(vars(args))
3334

3435
selected_tasks = task_factory.get_selected_tasks(args.task)
35-
required_resources = {t.resource for t in selected_tasks}
36+
required_resources = set().union(*(t.get_resource_types() for t in selected_tasks))
3637
using_default_tasks = not args.task
3738

3839
# Fold in manually specified --type args (very similar to --task, but more familiar to folks

cumulus_etl/nlp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .extract import TransformerModel, ctakes_extract, ctakes_httpx_client, list_polarity
44
from .openai import Gpt4Model, Gpt4oModel, Gpt5Model, Gpt35Model, GptOss120bModel, Llama4ScoutModel
5-
from .utils import cache_wrapper, get_docref_info, is_docref_valid
5+
from .utils import cache_wrapper, get_note_info, is_note_valid
66
from .watcher import (
77
check_ctakes,
88
check_negation_cnlpt,

0 commit comments

Comments
 (0)