Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ services:
environment:
# Environment variobles to pull in from the host
- AWS_ACCESS_KEY_ID
- AWS_ATHENA_WORK_GROUP
- AWS_DEFAULT_PROFILE
- AWS_DEFAULT_REGION
- AWS_PROFILE
- AWS_REGION
Comment on lines +36 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just noting here that the library uses different names for these same fields - we might want to natter about harmonizing on one (probably this set)

- AWS_SECRET_ACCESS_KEY
- AWS_SESSION_TOKEN
- AZURE_OPENAI_API_KEY
Expand Down
42 changes: 25 additions & 17 deletions cumulus_etl/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def add_auth(parser: argparse.ArgumentParser, *, use_fhir_url: bool = True):
group.add_argument("--smart-jwks", metavar="PATH", help=argparse.SUPPRESS)


def add_aws(parser: argparse.ArgumentParser) -> None:
def add_aws(parser: argparse.ArgumentParser, athena: bool = False) -> None:
group = parser.add_argument_group("AWS")
group.add_argument(
"--s3-region",
Expand All @@ -51,6 +51,22 @@ def add_aws(parser: argparse.ArgumentParser) -> None:
metavar="KEY",
help="if using S3 paths (s3://...), this is the KMS key ID to use",
)
if athena:
group.add_argument(
"--athena-region",
metavar="REGION",
help="the region of your Athena workgroup (default is us-east-1)",
)
group.add_argument(
"--athena-workgroup",
metavar="GROUP",
help="the name of your Athena workgroup",
)
group.add_argument(
"--athena-database",
metavar="DB",
help="the name of your Athena database",
)


def add_bulk_export(parser: argparse.ArgumentParser, *, as_subgroup: bool = True):
Expand Down Expand Up @@ -85,15 +101,13 @@ def add_bulk_export(parser: argparse.ArgumentParser, *, as_subgroup: bool = True
return parser


def add_nlp(parser: argparse.ArgumentParser):
group = parser.add_argument_group("NLP")
group.add_argument(
def add_ctakes_override(parser: argparse.ArgumentParser):
parser.add_argument(
"--ctakes-overrides",
metavar="DIR",
default="/ctakes-overrides",
help="path to cTAKES overrides dir (default is /ctakes-overrides)",
)
return group


def add_output_format(parser: argparse.ArgumentParser) -> None:
Expand All @@ -105,20 +119,14 @@ def add_output_format(parser: argparse.ArgumentParser) -> None:
)


def add_task_selection(parser: argparse.ArgumentParser):
task = parser.add_argument_group("task selection")
task.add_argument(
def add_task_selection(parser: argparse.ArgumentParser, *, etl_mode: bool):
default = ", default is all supported FHIR resources" if etl_mode else ""
required = not etl_mode
parser.add_argument(
"--task",
action="append",
help="only consider these tasks (comma separated, "
"default is all supported FHIR resources, "
"use '--task help' to see full list)",
)
task.add_argument(
"--task-filter",
action="append",
choices=["covid_symptom", "irae", "cpu", "gpu"],
help="restrict tasks to only the given sets (comma separated)",
help=f"only run these tasks (comma separated{default}, use '--task help' to see full list)",
required=required,
)


Expand Down
11 changes: 8 additions & 3 deletions cumulus_etl/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
TASK_UNKNOWN = 17
CNLPT_MISSING = 18
TASK_FAILED = 19
TASK_FILTERED_OUT = 20
TASK_SET_EMPTY = 21
# TASK_FILTERED_OUT = 20 # Obsolete, we don't have --task-filter anymore
# TASK_SET_EMPTY = 21 # Obsolete, we don't have --task-filter anymore
ARGS_CONFLICT = 22
ARGS_INVALID = 23
# FHIR_URL_MISSING = 24 # Obsolete, it's no longer fatal
Expand All @@ -39,8 +39,13 @@
INLINE_TASK_FAILED = 39
INLINE_WITHOUT_FOLDER = 40
WRONG_PHI_FOLDER = 41
TASK_NOT_PROVIDED = 42
# TASK_NOT_PROVIDED = 42 # checked now by argparse
TASK_MISMATCH = 43
ATHENA_TABLE_TOO_BIG = 44
ATHENA_TABLE_NAME_INVALID = 45
ATHENA_DATABASE_MISSING = 46
MULTIPLE_COHORT_ARGS = 47
COHORT_NOT_FOUND = 48


class FatalError(Exception):
Expand Down
8 changes: 8 additions & 0 deletions cumulus_etl/etl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def define_etl_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--philter", action="store_true", help="run philter on all freeform text fields"
)
parser.add_argument(
"--allow-missing-resources",
action="store_true",
help="run tasks even if their resources are not present",
)
cli_utils.add_task_selection(parser, etl_mode=True)

cli_utils.add_aws(parser)

export = cli_utils.add_bulk_export(parser)
export.add_argument(
Expand Down
5 changes: 4 additions & 1 deletion cumulus_etl/etl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import datetime
import os
from collections.abc import Callable
from socket import gethostname

import cumulus_fhir_support as cfs

from cumulus_etl import common, errors, formats, store
from cumulus_etl import common, deid, errors, formats, store


class JobConfig:
Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
export_datetime: datetime.datetime | None = None,
export_url: str | None = None,
deleted_ids: dict[str, set[str]] | None = None,
resource_filter: Callable[[deid.Codebook, dict], bool] | None = None,
):
self._dir_input_orig = dir_input_orig
self.dir_input = dir_input_deid
Expand All @@ -59,6 +61,7 @@ def __init__(
self.export_datetime = export_datetime
self.export_url = export_url
self.deleted_ids = deleted_ids or {}
self.resource_filter = resource_filter

# initialize format class
self._output_root = store.Root(self._dir_output, create=True)
Expand Down
129 changes: 126 additions & 3 deletions cumulus_etl/etl/nlp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
"""

import argparse
import string
from collections.abc import Callable

from cumulus_etl import cli_utils, deid, loaders
import pyathena

from cumulus_etl import cli_utils, deid, errors, id_handling, loaders
from cumulus_etl.etl import pipeline


Expand All @@ -19,12 +23,131 @@ def define_nlp_parser(parser: argparse.ArgumentParser) -> None:
parser.usage = "%(prog)s [OPTION]... INPUT OUTPUT PHI"

pipeline.add_common_etl_args(parser)
cli_utils.add_nlp(parser)
cli_utils.add_ctakes_override(parser)
cli_utils.add_task_selection(parser, etl_mode=False)

cli_utils.add_aws(parser, athena=True)

group = parser.add_argument_group("cohort selection")
group.add_argument(
"--cohort-csv",
metavar="FILE",
help="path to a .csv file with original patient and/or note IDs",
)
group.add_argument(
"--cohort-anon-csv",
metavar="FILE",
help="path to a .csv file with anonymized patient and/or note IDs",
)
group.add_argument(
"--cohort-athena-table",
metavar="DB.TABLE",
help="name of an Athena table with patient and/or note IDs",
)
group.add_argument(
"--allow-large-cohort",
action="store_true",
help="allow a larger-than-normal cohort",
)


def get_cohort_filter(args: argparse.Namespace) -> Callable[[deid.Codebook, dict], bool] | None:
"""Returns (patient refs to match, resource refs to match)"""
# Poor man's add_mutually_exclusive_group(), which we don't use because we have additional
# flags for the group, like "--allow-large-cohort".
has_csv = bool(args.cohort_csv)
has_anon_csv = bool(args.cohort_anon_csv)
has_athena_table = bool(args.cohort_athena_table)
arg_count = int(has_csv) + int(has_anon_csv) + int(has_athena_table)
if not arg_count:
return None
elif arg_count > 1:
errors.fatal(
"Multiple cohort arguments provided. Please specify just one.",
errors.MULTIPLE_COHORT_ARGS,
)

if has_athena_table:
if "." in args.cohort_athena_table:
parts = args.cohort_athena_table.split(".", 1)
database = parts[0]
table = parts[-1]
else:
database = args.athena_database
table = args.cohort_athena_table
if not database:
errors.fatal(
"You must provide an Athena database with --athena-database.",
errors.ATHENA_DATABASE_MISSING,
)
if set(table) - set(string.ascii_letters + string.digits + "-_"):
errors.fatal(
f"Athena table name '{table}' has invalid characters.",
errors.ATHENA_TABLE_NAME_INVALID,
)
cursor = pyathena.connect(
region_name=args.athena_region,
work_group=args.athena_workgroup,
schema_name=database,
).cursor()
count = cursor.execute(f'SELECT count(*) FROM "{table}"').fetchone()[0] # noqa: S608
if int(count) > 20_000 and not args.allow_large_cohort:
errors.fatal(
f"Athena cohort in '{table}' is very large ({int(count):,} rows).\n"
"If you want to use it anyway, pass --allow-large-cohort",
errors.ATHENA_TABLE_TOO_BIG,
)
csv_file = cursor.execute(f'SELECT * FROM "{table}"').output_location # noqa: S608
else:
csv_file = args.cohort_anon_csv or args.cohort_csv

is_anon = has_anon_csv or has_athena_table

dxreport_ids = id_handling.get_ids_from_csv(csv_file, "DiagnosticReport", is_anon=is_anon)
docref_ids = id_handling.get_ids_from_csv(csv_file, "DocumentReference", is_anon=is_anon)
patient_ids = id_handling.get_ids_from_csv(csv_file, "Patient", is_anon=is_anon)

if not dxreport_ids and not docref_ids and not patient_ids:
errors.fatal("No patient or note IDs found in cohort.", errors.COHORT_NOT_FOUND)

def res_filter(codebook: deid.Codebook, resource: dict) -> bool:
match resource["resourceType"]:
# TODO: uncomment once we support DxReport NLP (coming soon)
# case "DiagnosticReport":
# id_pool = dxreport_ids
# patient_ref = resource.get("subject", {}).get("reference")
case "DocumentReference":
id_pool = docref_ids
patient_ref = resource.get("subject", {}).get("reference")
case _: # pragma: no cover
# shouldn't happen
return False # pragma: no cover

# Check if we have an exact resource ID match (if the user defined exact IDs, we only use
# them, and don't do any patient matching)
if id_pool:
res_id = resource["id"]
if is_anon:
res_id = codebook.fake_id(resource["resourceType"], res_id, caching_allowed=False)
return res_id in id_pool

# Else match on patients if no resource IDs were defined
if not patient_ref:
return False
patient_id = patient_ref.removeprefix("Patient/")
if is_anon:
patient_id = codebook.fake_id("Patient", patient_id, caching_allowed=False)
return patient_id in patient_ids

return res_filter


async def nlp_main(args: argparse.Namespace) -> None:
res_filter = get_cohort_filter(args)

async def prep_scrubber(_results: loaders.LoaderResults) -> tuple[deid.Scrubber, dict]:
return deid.Scrubber(args.dir_phi), {"ctakes_overrides": args.ctakes_overrides}
config_args = {"ctakes_overrides": args.ctakes_overrides, "resource_filter": res_filter}
return deid.Scrubber(args.dir_phi), config_args

await pipeline.run_pipeline(args, prep_scrubber=prep_scrubber, nlp=True)

Expand Down
17 changes: 6 additions & 11 deletions cumulus_etl/etl/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,8 @@ def add_common_etl_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--errors-to", metavar="DIR", help="where to put resources that could not be processed"
)
parser.add_argument(
"--allow-missing-resources",
action="store_true",
help="run tasks even if their resources are not present",
)

cli_utils.add_aws(parser)
cli_utils.add_auth(parser)
cli_utils.add_task_selection(parser)
cli_utils.add_debugging(parser)


Expand Down Expand Up @@ -137,7 +130,8 @@ async def check_available_resources(
# Reconciling is helpful for performance reasons (don't need to finalize untouched tables),
# UX reasons (can tell user if they made a CLI mistake), and completion tracking (don't
# mark a resource as complete if we didn't even export it)
if args.allow_missing_resources:
has_allow_missing = hasattr(args, "allow_missing_resources")
if has_allow_missing and args.allow_missing_resources:
return requested_resources

detected = await loader.detect_resources()
Expand All @@ -158,7 +152,8 @@ async def check_available_resources(
)
else:
msg = "Required resources not found.\n"
msg += "Add --allow-missing-resources to run related tasks anyway with no input."
if has_allow_missing:
msg += "Add --allow-missing-resources to run related tasks anyway with no input."
errors.fatal(msg, errors.MISSING_REQUESTED_RESOURCES)

return requested_resources
Expand All @@ -182,8 +177,8 @@ async def run_pipeline(
job_context = context.JobContext(root_phi.joinpath("context.json"))
job_datetime = common.datetime_now() # grab timestamp before we do anything

selected_tasks = task_factory.get_selected_tasks(args.task, args.task_filter, nlp=nlp)
is_default_tasks = not args.task and not args.task_filter
selected_tasks = task_factory.get_selected_tasks(args.task, nlp=nlp)
is_default_tasks = not args.task

# Print configuration
print_config(args, job_datetime, selected_tasks)
Expand Down
4 changes: 3 additions & 1 deletion cumulus_etl/etl/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,10 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> Entr
Something like "yield x, y" or "yield x, [y, z]" - these streams of entries will be kept
separated into two different DataFrames.
"""
resource_filter = self.task_config.resource_filter
for x in filter(self.scrubber.scrub_resource, self.read_ndjson(progress=progress)):
yield x
if not resource_filter or resource_filter(self.scrubber.codebook, x):
yield x

def table_batch_cleanup(self, table_index: int, batch_index: int) -> None:
"""Override to add any necessary cleanup from writing a batch out (releasing memory etc)"""
Expand Down
4 changes: 3 additions & 1 deletion cumulus_etl/etl/tasks/nlp_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,12 @@ async def read_notes(
"""
warned_connection_error = False

note_filter = self.task_config.resource_filter or nlp.is_docref_valid

for docref in self.read_ndjson(progress=progress):
orig_docref = copy.deepcopy(docref)
can_process = (
nlp.is_docref_valid(docref)
note_filter(self.scrubber.codebook, docref)
and (doc_check is None or doc_check(docref))
and self.scrubber.scrub_resource(docref, scrub_attachments=False, keep_stats=False)
)
Expand Down
Loading