Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,10 @@ jobs:
docker compose run --rm \
--volume $DATADIR:/in \
cumulus-etl \
nlp \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that this change was needed here but not before, tells me we are regression testing the public release instead of the current branch. We have code to build the branch's docker image... not sure why it's not being used here. I can dig into it later, but just an FYI - it's on my todo list.

/in/input \
/in/run-output \
/in/phi \
--export-group nlp-test \
--export-timestamp 2024-08-29 \
--output-format=ndjson \
--task covid_symptom__nlp_results

Expand Down
7 changes: 3 additions & 4 deletions cumulus_etl/etl/nlp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,9 @@ def get_cohort_filter(args: argparse.Namespace) -> Callable[[deid.Codebook, dict

def res_filter(codebook: deid.Codebook, resource: dict) -> bool:
match resource["resourceType"]:
# TODO: uncomment once we support DxReport NLP (coming soon)
# case "DiagnosticReport":
# id_pool = dxreport_ids
# patient_ref = resource.get("subject", {}).get("reference")
case "DiagnosticReport":
id_pool = dxreport_ids
patient_ref = resource.get("subject", {}).get("reference")
case "DocumentReference":
id_pool = docref_ids
patient_ref = resource.get("subject", {}).get("reference")
Expand Down
27 changes: 16 additions & 11 deletions cumulus_etl/etl/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ async def check_available_resources(
requested_resources: set[str],
args: argparse.Namespace,
is_default_tasks: bool,
nlp: bool,
) -> set[str]:
# Here we try to reconcile which resources the user requested and which resources are actually
# available in the input root.
Expand All @@ -138,25 +139,28 @@ async def check_available_resources(
if detected is None:
return requested_resources # likely we haven't run bulk export yet

if missing_resources := requested_resources - detected:
missing_resources = requested_resources - detected
available_resources = requested_resources & detected

if nlp and available_resources:
# As long as there is any resource for NLP to read from, we'll take it
return available_resources

if missing_resources:
for resource in sorted(missing_resources):
# Log the same message we would print if in common.py if we ran tasks anyway
logging.warning("No %s files found in %s", resource, loader.root.path)

if is_default_tasks:
requested_resources -= missing_resources # scope down to detected resources
if not requested_resources:
errors.fatal(
"No supported resources found.",
errors.MISSING_REQUESTED_RESOURCES,
)
if not available_resources:
errors.fatal("No supported resources found.", errors.MISSING_REQUESTED_RESOURCES)
else:
msg = "Required resources not found.\n"
if has_allow_missing:
msg += "Add --allow-missing-resources to run related tasks anyway with no input."
errors.fatal(msg, errors.MISSING_REQUESTED_RESOURCES)

return requested_resources
return available_resources


async def run_pipeline(
Expand Down Expand Up @@ -191,8 +195,8 @@ async def run_pipeline(
for task in selected_tasks:
await task.init_check()

# Grab a list of all required resource types for the tasks we are running
required_resources = set(t.resource for t in selected_tasks)
# Combine all task resource sets into one big set of required resources
required_resources = set().union(*(t.get_resource_types() for t in selected_tasks))

# Create a client to talk to a FHIR server.
# This is useful even if we aren't doing a bulk export, because some resources like
Expand All @@ -214,9 +218,10 @@ async def run_pipeline(
args=args,
is_default_tasks=is_default_tasks,
requested_resources=required_resources,
nlp=nlp,
)
# Drop any tasks that we didn't find resources for
selected_tasks = [t for t in selected_tasks if t.resource in required_resources]
selected_tasks = [t for t in selected_tasks if t.get_resource_types() & required_resources]

# Load resources from a remote location (like s3), convert from i2b2, or do a bulk export
loader_results = await config_loader.load_resources(required_resources)
Expand Down
9 changes: 6 additions & 3 deletions cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def covid_symptoms_extract(
:return: list of NLP results encoded as FHIR observations
"""
try:
docref_id, encounter_id, subject_id = nlp.get_docref_info(docref)
note_ref, encounter_id, subject_id = nlp.get_note_info(docref)
except KeyError as exc:
logging.warning(exc)
return None
Expand Down Expand Up @@ -62,7 +62,7 @@ async def covid_symptoms_extract(
)
except Exception as exc:
logging.warning(
"Could not extract symptoms for docref %s (%s): %s", docref_id, type(exc).__name__, exc
"Could not extract symptoms for %s (%s): %s", note_ref, type(exc).__name__, exc
)
return None

Expand Down Expand Up @@ -95,10 +95,13 @@ def is_covid_match(m: ctakesclient.typesystem.MatchText):
)
except Exception as exc:
logging.warning(
"Could not check polarity for docref %s (%s): %s", docref_id, type(exc).__name__, exc
"Could not check polarity for %s (%s): %s", note_ref, type(exc).__name__, exc
)
return None

# We only look at docrefs - get just the ID for use in the symptom fields
docref_id = note_ref.removeprefix("DocumentReference/")

# Helper to make a single row (match_value is None if there were no found symptoms at all)
def _make_covid_symptom_row(row_id: str, match: dict | None) -> dict:
return {
Expand Down
5 changes: 4 additions & 1 deletion cumulus_etl/etl/studies/covid_symptom/covid_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ def is_ed_coding(coding):
return coding.get("code") in ED_CODES.get(coding.get("system"), {})


def is_ed_docref(docref):
def is_ed_docref(docref) -> bool:
"""Returns true if this is a coding for an emergency department note"""
if docref["resourceType"] != "DocumentReference":
return False

# We check both type and category for safety -- we aren't sure yet how EHRs are using these fields.
codings = list(
itertools.chain.from_iterable([cat.get("coding", []) for cat in docref.get("category", [])])
Expand Down
14 changes: 11 additions & 3 deletions cumulus_etl/etl/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class EtlTask:
# Properties:
name: ClassVar[str] = None # task & table name
# incoming resource that this task operates on (will be included in bulk exports etc)
resource: ClassVar[str] = None
resource: ClassVar[str | set[str]] = None
tags: ClassVar[set[str]] = []
# whether this task needs bulk MS tool de-id run on its inputs (NLP tasks usually don't)
needs_bulk_deid: ClassVar[bool] = True
Expand Down Expand Up @@ -378,10 +378,11 @@ def read_ndjson(

If `resources` is provided, those resources will be read (in the provided order).
That is, ["Condition", "Encounter"] will first read all Conditions, then all Encounters.
If `resources` is not provided, the task's main resource (self.resource) will be used.
If `resources` is not provided, the task's main resources (via self.get_resource_types())
will be used.
"""
input_root = store.Root(self.task_config.dir_input)
resources = resources or [self.resource]
resources = resources or sorted(self.get_resource_types())

if progress:
# Make new task to track processing of rows
Expand Down Expand Up @@ -472,3 +473,10 @@ def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Sche
if resource_type:
return cfs.pyarrow_schema_from_rows(resource_type, rows)
return None

@classmethod
def get_resource_types(cls) -> set[str]:
"""Abstracts whether the class's resource field is a str or a set of strings."""
if isinstance(cls.resource, str):
return {cls.resource}
return set(cls.resource)
81 changes: 37 additions & 44 deletions cumulus_etl/etl/tasks/nlp_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import ClassVar

import cumulus_fhir_support as cfs
import openai
import pyarrow
import pydantic
import rich.progress
Expand All @@ -27,7 +26,7 @@
class BaseNlpTask(tasks.EtlTask):
"""Base class for any clinical-notes-based NLP task."""

resource: ClassVar = "DocumentReference"
resource: ClassVar = {"DiagnosticReport", "DocumentReference"}
needs_bulk_deid: ClassVar = False

# You may want to override these in your subclass
Expand Down Expand Up @@ -80,44 +79,45 @@ async def read_notes(
"""
Iterate through clinical notes.

:returns: a tuple of original-docref, scrubbed-docref, and clinical note
:returns: a tuple of original-resource, scrubbed-resource, and note text
"""
warned_connection_error = False

note_filter = self.task_config.resource_filter or nlp.is_docref_valid
note_filter = self.task_config.resource_filter or nlp.is_note_valid

for docref in self.read_ndjson(progress=progress):
orig_docref = copy.deepcopy(docref)
for note in self.read_ndjson(progress=progress):
orig_note = copy.deepcopy(note)
can_process = (
note_filter(self.scrubber.codebook, docref)
and (doc_check is None or doc_check(docref))
and self.scrubber.scrub_resource(docref, scrub_attachments=False, keep_stats=False)
note_filter(self.scrubber.codebook, note)
and (doc_check is None or doc_check(note))
and self.scrubber.scrub_resource(note, scrub_attachments=False, keep_stats=False)
)
if not can_process:
continue

try:
clinical_note = await fhir.get_clinical_note(self.task_config.client, docref)
note_text = await fhir.get_clinical_note(self.task_config.client, note)
except cfs.BadAuthArguments as exc:
if not warned_connection_error:
# Only warn user about a misconfiguration once per task.
# It's not fatal because it might be intentional (partially inlined DocRefs
# and the other DocRefs are known failures - BCH hits this with Cerner data).
print(exc, file=sys.stderr)
warned_connection_error = True
self.add_error(orig_docref)
self.add_error(orig_note)
continue
except Exception as exc:
logging.warning("Error getting text for docref %s: %s", docref["id"], exc)
self.add_error(orig_docref)
orig_note_ref = f"{orig_note['resourceType']}/{orig_note['id']}"
logging.warning("Error getting text for note %s: %s", orig_note_ref, exc)
self.add_error(orig_note)
continue

yield orig_docref, docref, clinical_note
yield orig_note, note, note_text

@staticmethod
def remove_trailing_whitespace(note: str) -> str:
def remove_trailing_whitespace(note_text: str) -> str:
"""Sometimes NLP can be mildly confused by trailing whitespace, so this removes it"""
return TRAILING_WHITESPACE.sub("", note)
return TRAILING_WHITESPACE.sub("", note_text)


class BaseOpenAiTask(BaseNlpTask):
Expand All @@ -139,59 +139,52 @@ async def init_check(cls) -> None:
async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator:
client = self.client_class()

async for orig_docref, docref, orig_clinical_note in self.read_notes(progress=progress):
async for orig_note, note, orig_note_text in self.read_notes(progress=progress):
try:
docref_id, encounter_id, subject_id = nlp.get_docref_info(docref)
note_ref, encounter_id, subject_id = nlp.get_note_info(note)
except KeyError as exc:
logging.warning(exc)
self.add_error(orig_docref)
self.add_error(orig_note)
continue

clinical_note = self.remove_trailing_whitespace(orig_clinical_note)
note_text = self.remove_trailing_whitespace(orig_note_text)
orig_note_ref = f"{orig_note['resourceType']}/{orig_note['id']}"

try:
completion_class = chat.ParsedChatCompletion[self.response_format]
response = await nlp.cache_wrapper(
self.task_config.dir_phi,
f"{self.name}_v{self.task_version}",
clinical_note,
note_text,
lambda x: completion_class.model_validate_json(x), # from file
lambda x: x.model_dump_json( # to file
indent=None, round_trip=True, exclude_unset=True, by_alias=True
),
client.prompt,
self.system_prompt,
self.get_user_prompt(clinical_note),
self.get_user_prompt(note_text),
self.response_format,
)
except openai.APIError as exc:
logging.warning(
f"Could not connect to NLP server for DocRef {orig_docref['id']}: {exc}"
)
self.add_error(orig_docref)
continue
except pydantic.ValidationError as exc:
logging.warning(
f"Could not process answer from NLP server for DocRef {orig_docref['id']}: {exc}"
)
self.add_error(orig_docref)
except Exception as exc:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised ruff didn't flag this - do you have the generic exception rule turned off? I think it's fine, i was just expecting a manual comment to disable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think our ruff config flags generic exceptions - I just confirmed it will flag bare excepts though (except:). I like that balance - I personally feel like generic exceptions have a undeserved reputation as a problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok - yeah i guess i buy that as 'I am explicitly stating my intent here'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bare excepts also have the problem of catching BaseExceptions like SystemExit and KeyboardInterrupt, which usually you don't want to intercept.

But often linting tools don't like the generic version either for some programming purity reasons like "you should know what you're catching". But in the real world, you so often just want to say "hey I don't care what happened, I want to handle it and log it".

logging.warning(f"NLP failed for {orig_note_ref}: {exc}")
self.add_error(orig_note)
continue

choice = response.choices[0]

if choice.finish_reason != "stop" or not choice.message.parsed:
logging.warning(
f"NLP server response didn't complete for DocRef {orig_docref['id']}: "
f"NLP server response didn't complete for {orig_note_ref}: "
f"{choice.finish_reason}"
)
self.add_error(orig_docref)
self.add_error(orig_note)
continue

parsed = choice.message.parsed.model_dump(mode="json")
self.post_process(parsed, orig_clinical_note, orig_docref)
self.post_process(parsed, orig_note_text, orig_note)

yield {
"note_ref": f"DocumentReference/{docref_id}",
"note_ref": note_ref,
"encounter_ref": f"Encounter/{encounter_id}",
"subject_ref": f"Patient/{subject_id}",
# Since this date is stored as a string, use UTC time for easy comparisons
Expand All @@ -202,11 +195,11 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
}

@classmethod
def get_user_prompt(cls, clinical_note: str) -> str:
def get_user_prompt(cls, note_text: str) -> str:
prompt = cls.user_prompt or "%CLINICAL-NOTE%"
return prompt.replace("%CLINICAL-NOTE%", clinical_note)
return prompt.replace("%CLINICAL-NOTE%", note_text)

def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict) -> None:
def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
"""Subclasses can fill this out if they like"""

@classmethod
Expand Down Expand Up @@ -261,7 +254,7 @@ class BaseOpenAiTaskWithSpans(BaseOpenAiTask):
It assumes the field is named "spans" in the top level of the pydantic model.
"""

def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict) -> None:
def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
new_spans = []
missed_some = False

Expand All @@ -278,18 +271,18 @@ def post_process(self, parsed: dict, orig_clinical_note: str, orig_docref: dict)
span = ESCAPED_WHITESPACE.sub(r"\\s+", span)

found = False
for match in re.finditer(span, orig_clinical_note, re.IGNORECASE):
for match in re.finditer(span, orig_note_text, re.IGNORECASE):
found = True
new_spans.append(match.span())
if not found:
missed_some = True
logging.warning(
"Could not match span received from NLP server for "
f"DocRef {orig_docref['id']}: {orig_span}"
f"{orig_note['resourceType']}/{orig_note['id']}: {orig_span}"
)

if missed_some:
self.add_error(orig_docref)
self.add_error(orig_note)

parsed["spans"] = new_spans

Expand Down
3 changes: 2 additions & 1 deletion cumulus_etl/export/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ async def export_main(args: argparse.Namespace) -> None:
store.set_user_fs_options(vars(args))

selected_tasks = task_factory.get_selected_tasks(args.task)
required_resources = {t.resource for t in selected_tasks}
# Combine all task resource sets into one big set of required resources
required_resources = set().union(*(t.get_resource_types() for t in selected_tasks))
using_default_tasks = not args.task

# Fold in manually specified --type args (very similar to --task, but more familiar to folks
Expand Down
2 changes: 1 addition & 1 deletion cumulus_etl/nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .extract import TransformerModel, ctakes_extract, ctakes_httpx_client, list_polarity
from .openai import Gpt4Model, Gpt4oModel, Gpt5Model, Gpt35Model, GptOss120bModel, Llama4ScoutModel
from .utils import cache_wrapper, get_docref_info, is_docref_valid
from .utils import cache_wrapper, get_note_info, is_note_valid
from .watcher import (
check_ctakes,
check_negation_cnlpt,
Expand Down
Loading