Skip to content

Commit 5b27f89

Browse files
committed
feat(nlp): add a variety of --cohort-* args to filter notes
Three main new arguments: * --cohort-csv: a csv with a column like patient_id or note_ref * --cohort-anon-csv: same but with anonymized IDs * --cohort-athena-table: same but points at an Athena table To support the Athena option, we have some other new args: * --athena-workgroup: specify the workgroup to use * --athena-database: specify the database to use * --allow-large-cohort: if the table is gigantic, use it anyway (this is here because we have a typo-guard in there - if you accidentally point at the base observation table, we're gonna stop you from downloading a terabyte of data) You can specify the database in the table arg with a period. And the workgroup can be specified via env var or CLI. If we find a docref/dxreport ID/ref column, we'll use that. Otherwise, we'll use a patient ID column and grab all notes for those patients. This cohort filtering replaces instead of augments the previous default filtering of "final" status notes (i.e. skipping draft or superceded notes). But if the user is specifying the IDs manually for us, they must know what they want and we don't need to do the status check for them.
1 parent 2291efc commit 5b27f89

File tree

26 files changed

+454
-147
lines changed

26 files changed

+454
-147
lines changed

compose.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@ services:
3333
environment:
3434
# Environment variobles to pull in from the host
3535
- AWS_ACCESS_KEY_ID
36+
- AWS_ATHENA_WORK_GROUP
3637
- AWS_DEFAULT_PROFILE
38+
- AWS_DEFAULT_REGION
3739
- AWS_PROFILE
40+
- AWS_REGION
3841
- AWS_SECRET_ACCESS_KEY
3942
- AWS_SESSION_TOKEN
4043
- AZURE_OPENAI_API_KEY

cumulus_etl/cli_utils.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def add_auth(parser: argparse.ArgumentParser, *, use_fhir_url: bool = True):
3939
group.add_argument("--smart-jwks", metavar="PATH", help=argparse.SUPPRESS)
4040

4141

42-
def add_aws(parser: argparse.ArgumentParser) -> None:
42+
def add_aws(parser: argparse.ArgumentParser, athena: bool = False) -> None:
4343
group = parser.add_argument_group("AWS")
4444
group.add_argument(
4545
"--s3-region",
@@ -51,6 +51,22 @@ def add_aws(parser: argparse.ArgumentParser) -> None:
5151
metavar="KEY",
5252
help="if using S3 paths (s3://...), this is the KMS key ID to use",
5353
)
54+
if athena:
55+
group.add_argument(
56+
"--athena-region",
57+
metavar="REGION",
58+
help="the region of your Athena workgroup (default is us-east-1)",
59+
)
60+
group.add_argument(
61+
"--athena-workgroup",
62+
metavar="GROUP",
63+
help="the name of your Athena workgroup",
64+
)
65+
group.add_argument(
66+
"--athena-database",
67+
metavar="DB",
68+
help="the name of your Athena database",
69+
)
5470

5571

5672
def add_bulk_export(parser: argparse.ArgumentParser, *, as_subgroup: bool = True):
@@ -85,15 +101,13 @@ def add_bulk_export(parser: argparse.ArgumentParser, *, as_subgroup: bool = True
85101
return parser
86102

87103

88-
def add_nlp(parser: argparse.ArgumentParser):
89-
group = parser.add_argument_group("NLP")
90-
group.add_argument(
104+
def add_ctakes_override(parser: argparse.ArgumentParser):
105+
parser.add_argument(
91106
"--ctakes-overrides",
92107
metavar="DIR",
93108
default="/ctakes-overrides",
94109
help="path to cTAKES overrides dir (default is /ctakes-overrides)",
95110
)
96-
return group
97111

98112

99113
def add_output_format(parser: argparse.ArgumentParser) -> None:
@@ -105,20 +119,14 @@ def add_output_format(parser: argparse.ArgumentParser) -> None:
105119
)
106120

107121

108-
def add_task_selection(parser: argparse.ArgumentParser):
109-
task = parser.add_argument_group("task selection")
110-
task.add_argument(
122+
def add_task_selection(parser: argparse.ArgumentParser, *, etl_mode: bool):
123+
default = ", default is all supported FHIR resources" if etl_mode else ""
124+
required = not etl_mode
125+
parser.add_argument(
111126
"--task",
112127
action="append",
113-
help="only consider these tasks (comma separated, "
114-
"default is all supported FHIR resources, "
115-
"use '--task help' to see full list)",
116-
)
117-
task.add_argument(
118-
"--task-filter",
119-
action="append",
120-
choices=["covid_symptom", "irae", "cpu", "gpu"],
121-
help="restrict tasks to only the given sets (comma separated)",
128+
help=f"only run these tasks (comma separated{default}, use '--task help' to see full list)",
129+
required=required,
122130
)
123131

124132

cumulus_etl/errors.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
TASK_UNKNOWN = 17
1818
CNLPT_MISSING = 18
1919
TASK_FAILED = 19
20-
TASK_FILTERED_OUT = 20
21-
TASK_SET_EMPTY = 21
20+
# TASK_FILTERED_OUT = 20 # Obsolete, we don't have --task-filter anymore
21+
# TASK_SET_EMPTY = 21 # Obsolete, we don't have --task-filter anymore
2222
ARGS_CONFLICT = 22
2323
ARGS_INVALID = 23
2424
# FHIR_URL_MISSING = 24 # Obsolete, it's no longer fatal
@@ -39,8 +39,13 @@
3939
INLINE_TASK_FAILED = 39
4040
INLINE_WITHOUT_FOLDER = 40
4141
WRONG_PHI_FOLDER = 41
42-
TASK_NOT_PROVIDED = 42
42+
# TASK_NOT_PROVIDED = 42 # checked now by argparse
4343
TASK_MISMATCH = 43
44+
ATHENA_TABLE_TOO_BIG = 44
45+
ATHENA_TABLE_NAME_INVALID = 45
46+
ATHENA_DATABASE_MISSING = 46
47+
MULTIPLE_COHORT_ARGS = 47
48+
COHORT_NOT_FOUND = 48
4449

4550

4651
class FatalError(Exception):

cumulus_etl/etl/cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ def define_etl_parser(parser: argparse.ArgumentParser) -> None:
4848
parser.add_argument(
4949
"--philter", action="store_true", help="run philter on all freeform text fields"
5050
)
51+
parser.add_argument(
52+
"--allow-missing-resources",
53+
action="store_true",
54+
help="run tasks even if their resources are not present",
55+
)
56+
cli_utils.add_task_selection(parser, etl_mode=True)
57+
58+
cli_utils.add_aws(parser)
5159

5260
export = cli_utils.add_bulk_export(parser)
5361
export.add_argument(

cumulus_etl/etl/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
import datetime
44
import os
5+
from collections.abc import Callable
56
from socket import gethostname
67

78
import cumulus_fhir_support as cfs
89

9-
from cumulus_etl import common, errors, formats, store
10+
from cumulus_etl import common, deid, errors, formats, store
1011

1112

1213
class JobConfig:
@@ -39,6 +40,7 @@ def __init__(
3940
export_datetime: datetime.datetime | None = None,
4041
export_url: str | None = None,
4142
deleted_ids: dict[str, set[str]] | None = None,
43+
resource_filter: Callable[[deid.Codebook, dict], bool] | None = None,
4244
):
4345
self._dir_input_orig = dir_input_orig
4446
self.dir_input = dir_input_deid
@@ -59,6 +61,7 @@ def __init__(
5961
self.export_datetime = export_datetime
6062
self.export_url = export_url
6163
self.deleted_ids = deleted_ids or {}
64+
self.resource_filter = resource_filter
6265

6366
# initialize format class
6467
self._output_root = store.Root(self._dir_output, create=True)

cumulus_etl/etl/nlp/cli.py

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99
"""
1010

1111
import argparse
12+
import string
13+
from collections.abc import Callable
1214

13-
from cumulus_etl import cli_utils, deid, loaders
15+
import pyathena
16+
17+
from cumulus_etl import cli_utils, deid, errors, id_handling, loaders
1418
from cumulus_etl.etl import pipeline
1519

1620

@@ -19,12 +23,131 @@ def define_nlp_parser(parser: argparse.ArgumentParser) -> None:
1923
parser.usage = "%(prog)s [OPTION]... INPUT OUTPUT PHI"
2024

2125
pipeline.add_common_etl_args(parser)
22-
cli_utils.add_nlp(parser)
26+
cli_utils.add_ctakes_override(parser)
27+
cli_utils.add_task_selection(parser, etl_mode=False)
28+
29+
cli_utils.add_aws(parser, athena=True)
30+
31+
group = parser.add_argument_group("cohort selection")
32+
group.add_argument(
33+
"--cohort-csv",
34+
metavar="FILE",
35+
help="path to a .csv file with original patient and/or note IDs",
36+
)
37+
group.add_argument(
38+
"--cohort-anon-csv",
39+
metavar="FILE",
40+
help="path to a .csv file with anonymized patient and/or note IDs",
41+
)
42+
group.add_argument(
43+
"--cohort-athena-table",
44+
metavar="DB.TABLE",
45+
help="name of an Athena table with patient and/or note IDs",
46+
)
47+
group.add_argument(
48+
"--allow-large-cohort",
49+
action="store_true",
50+
help="allow a larger-than-normal cohort",
51+
)
52+
53+
54+
def get_cohort_filter(args: argparse.Namespace) -> Callable[[deid.Codebook, dict], bool] | None:
55+
"""Returns (patient refs to match, resource refs to match)"""
56+
# Poor man's add_mutually_exclusive_group(), which we don't use because we have additional
57+
# flags for the group, like "--allow-large-cohort".
58+
has_csv = bool(args.cohort_csv)
59+
has_anon_csv = bool(args.cohort_anon_csv)
60+
has_athena_table = bool(args.cohort_athena_table)
61+
arg_count = int(has_csv) + int(has_anon_csv) + int(has_athena_table)
62+
if not arg_count:
63+
return None
64+
elif arg_count > 1:
65+
errors.fatal(
66+
"Multiple cohort arguments provided. Please specify just one.",
67+
errors.MULTIPLE_COHORT_ARGS,
68+
)
69+
70+
if has_athena_table:
71+
if "." in args.cohort_athena_table:
72+
parts = args.cohort_athena_table.split(".", 1)
73+
database = parts[0]
74+
table = parts[-1]
75+
else:
76+
database = args.athena_database
77+
table = args.cohort_athena_table
78+
if not database:
79+
errors.fatal(
80+
"You must provide an Athena database with --athena-database.",
81+
errors.ATHENA_DATABASE_MISSING,
82+
)
83+
if set(table) - set(string.ascii_letters + string.digits + "-_"):
84+
errors.fatal(
85+
f"Athena table name '{table}' has invalid characters.",
86+
errors.ATHENA_TABLE_NAME_INVALID,
87+
)
88+
cursor = pyathena.connect(
89+
region_name=args.athena_region,
90+
work_group=args.athena_workgroup,
91+
schema_name=database,
92+
).cursor()
93+
count = cursor.execute(f'SELECT count(*) FROM "{table}"').fetchone()[0] # noqa: S608
94+
if int(count) > 20_000 and not args.allow_large_cohort:
95+
errors.fatal(
96+
f"Athena cohort in '{table}' is very large ({int(count):,} rows).\n"
97+
"If you want to use it anyway, pass --allow-large-cohort",
98+
errors.ATHENA_TABLE_TOO_BIG,
99+
)
100+
csv_file = cursor.execute(f'SELECT * FROM "{table}"').output_location # noqa: S608
101+
else:
102+
csv_file = args.cohort_anon_csv or args.cohort_csv
103+
104+
is_anon = has_anon_csv or has_athena_table
105+
106+
dxreport_ids = id_handling.get_ids_from_csv(csv_file, "DiagnosticReport", is_anon=is_anon)
107+
docref_ids = id_handling.get_ids_from_csv(csv_file, "DocumentReference", is_anon=is_anon)
108+
patient_ids = id_handling.get_ids_from_csv(csv_file, "Patient", is_anon=is_anon)
109+
110+
if not dxreport_ids and not docref_ids and not patient_ids:
111+
errors.fatal("No patient or note IDs found in cohort.", errors.COHORT_NOT_FOUND)
112+
113+
def res_filter(codebook: deid.Codebook, resource: dict) -> bool:
114+
match resource["resourceType"]:
115+
# TODO: uncomment once we support DxReport NLP (coming soon)
116+
# case "DiagnosticReport":
117+
# id_pool = dxreport_ids
118+
# patient_ref = resource.get("subject", {}).get("reference")
119+
case "DocumentReference":
120+
id_pool = docref_ids
121+
patient_ref = resource.get("subject", {}).get("reference")
122+
case _: # pragma: no cover
123+
# shouldn't happen
124+
return False # pragma: no cover
125+
126+
# Check if we have an exact resource ID match (if the user defined exact IDs, we only use
127+
# them, and don't do any patient matching)
128+
if id_pool:
129+
res_id = resource["id"]
130+
if is_anon:
131+
res_id = codebook.fake_id(resource["resourceType"], res_id, caching_allowed=False)
132+
return res_id in id_pool
133+
134+
# Else match on patients if no resource IDs were defined
135+
if not patient_ref:
136+
return False
137+
patient_id = patient_ref.removeprefix("Patient/")
138+
if is_anon:
139+
patient_id = codebook.fake_id("Patient", patient_id, caching_allowed=False)
140+
return patient_id in patient_ids
141+
142+
return res_filter
23143

24144

25145
async def nlp_main(args: argparse.Namespace) -> None:
146+
res_filter = get_cohort_filter(args)
147+
26148
async def prep_scrubber(_results: loaders.LoaderResults) -> tuple[deid.Scrubber, dict]:
27-
return deid.Scrubber(args.dir_phi), {"ctakes_overrides": args.ctakes_overrides}
149+
config_args = {"ctakes_overrides": args.ctakes_overrides, "resource_filter": res_filter}
150+
return deid.Scrubber(args.dir_phi), config_args
28151

29152
await pipeline.run_pipeline(args, prep_scrubber=prep_scrubber, nlp=True)
30153

cumulus_etl/etl/pipeline.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,8 @@ def add_common_etl_args(parser: argparse.ArgumentParser) -> None:
6363
parser.add_argument(
6464
"--errors-to", metavar="DIR", help="where to put resources that could not be processed"
6565
)
66-
parser.add_argument(
67-
"--allow-missing-resources",
68-
action="store_true",
69-
help="run tasks even if their resources are not present",
70-
)
7166

72-
cli_utils.add_aws(parser)
7367
cli_utils.add_auth(parser)
74-
cli_utils.add_task_selection(parser)
7568
cli_utils.add_debugging(parser)
7669

7770

@@ -137,7 +130,8 @@ async def check_available_resources(
137130
# Reconciling is helpful for performance reasons (don't need to finalize untouched tables),
138131
# UX reasons (can tell user if they made a CLI mistake), and completion tracking (don't
139132
# mark a resource as complete if we didn't even export it)
140-
if args.allow_missing_resources:
133+
has_allow_missing = hasattr(args, "allow_missing_resources")
134+
if has_allow_missing and args.allow_missing_resources:
141135
return requested_resources
142136

143137
detected = await loader.detect_resources()
@@ -158,7 +152,8 @@ async def check_available_resources(
158152
)
159153
else:
160154
msg = "Required resources not found.\n"
161-
msg += "Add --allow-missing-resources to run related tasks anyway with no input."
155+
if has_allow_missing:
156+
msg += "Add --allow-missing-resources to run related tasks anyway with no input."
162157
errors.fatal(msg, errors.MISSING_REQUESTED_RESOURCES)
163158

164159
return requested_resources
@@ -182,8 +177,8 @@ async def run_pipeline(
182177
job_context = context.JobContext(root_phi.joinpath("context.json"))
183178
job_datetime = common.datetime_now() # grab timestamp before we do anything
184179

185-
selected_tasks = task_factory.get_selected_tasks(args.task, args.task_filter, nlp=nlp)
186-
is_default_tasks = not args.task and not args.task_filter
180+
selected_tasks = task_factory.get_selected_tasks(args.task, nlp=nlp)
181+
is_default_tasks = not args.task
187182

188183
# Print configuration
189184
print_config(args, job_datetime, selected_tasks)

cumulus_etl/etl/tasks/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,10 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> Entr
416416
Something like "yield x, y" or "yield x, [y, z]" - these streams of entries will be kept
417417
separated into two different DataFrames.
418418
"""
419+
resource_filter = self.task_config.resource_filter
419420
for x in filter(self.scrubber.scrub_resource, self.read_ndjson(progress=progress)):
420-
yield x
421+
if not resource_filter or resource_filter(self.scrubber.codebook, x):
422+
yield x
421423

422424
def table_batch_cleanup(self, table_index: int, batch_index: int) -> None:
423425
"""Override to add any necessary cleanup from writing a batch out (releasing memory etc)"""

cumulus_etl/etl/tasks/nlp_task.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ async def read_notes(
8484
"""
8585
warned_connection_error = False
8686

87+
note_filter = self.task_config.resource_filter or nlp.is_docref_valid
88+
8789
for docref in self.read_ndjson(progress=progress):
8890
orig_docref = copy.deepcopy(docref)
8991
can_process = (
90-
nlp.is_docref_valid(docref)
92+
note_filter(self.scrubber.codebook, docref)
9193
and (doc_check is None or doc_check(docref))
9294
and self.scrubber.scrub_resource(docref, scrub_attachments=False, keep_stats=False)
9395
)

0 commit comments

Comments
 (0)