Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ jobs:
python-version: ["3.11"]

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -37,7 +37,7 @@ jobs:
pip install .[tests]

- name: Check out MS tool
uses: actions/checkout@v4
uses: actions/checkout@v5
with:
repository: microsoft/Tools-for-Health-Data-Anonymization
path: mstool
Expand Down Expand Up @@ -77,7 +77,7 @@ jobs:
env:
UMLS_API_KEY: ${{ secrets.UMLS_API_KEY }}
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5

- name: Install Docker
uses: docker/setup-buildx-action@v3
Expand Down Expand Up @@ -122,7 +122,7 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5

- name: Install linters
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docker-hub.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
name: Build and push image
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pages.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Send workflow dispatch
uses: actions/github-script@v6
uses: actions/github-script@v8
with:
# This token is set to expire in May 2024.
# You can make a new one with write access to Actions on the cumulus repo.
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
id-token: write # this permission is required for PyPI "trusted publishing"

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.13.0 # keep in rough sync with pyproject.toml
rev: v0.14.1 # keep in rough sync with pyproject.toml
hooks:
- name: Ruff formatting
id: ruff-format
Expand Down
2 changes: 1 addition & 1 deletion cumulus_etl/etl/studies/covid_symptom/covid_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
# But the current approach instead focuses purely on accuracy and makes sure that we zero-out any dangling
# entries for groups that we do process.
# Downstream SQL can ignore the above cases itself, as needed.
self.seen_docrefs.add(docref["id"])
self.seen_groups.add(docref["id"])

# Yield the whole set of symptoms at once, to allow for more easily replacing previous a set of symptoms.
# This way we don't need to worry about symptoms from the same note crossing batch boundaries.
Expand Down
6 changes: 3 additions & 3 deletions cumulus_etl/etl/studies/irae/irae_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ def should_skip(self, orig_note: dict) -> bool:
subject_ref = nlp.get_note_subject_ref(orig_note)
return subject_ref in self.subject_refs_to_skip or super().should_skip(orig_note)

def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
super().post_process(parsed, orig_note_text, orig_note)
def post_process(self, parsed: dict, details: tasks.NoteDetails) -> None:
super().post_process(parsed, details)

# If we have an annotation that asserts a graft failure or deceased,
# we can stop processing charts for that patient, to avoid pointless NLP requests.
Expand All @@ -526,7 +526,7 @@ def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> No
is_deceased = deceased.get("has_mention") and deceased.get("deceased")

if is_failed or is_deceased:
if subject_ref := nlp.get_note_subject_ref(orig_note):
if subject_ref := nlp.get_note_subject_ref(details.orig_note):
self.subject_refs_to_skip.add(subject_ref)


Expand Down
4 changes: 2 additions & 2 deletions cumulus_etl/etl/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Task support for the ETL workflow"""

from .base import EntryIterator, EtlTask, OutputTable
from .nlp_task import BaseModelTask, BaseModelTaskWithSpans, BaseNlpTask
from .base import EntryBundle, EntryIterator, EtlTask, OutputTable
from .nlp_task import BaseModelTask, BaseModelTaskWithSpans, BaseNlpTask, NoteDetails
3 changes: 2 additions & 1 deletion cumulus_etl/etl/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

# Defined here, as syntactic sugar for when you subclass your own task and re-define read_entries()
EntryAtom = dict | list[dict]
EntryIterator = AsyncIterator[EntryAtom | tuple[EntryAtom, ...]]
EntryBundle = EntryAtom | tuple[EntryAtom, ...]
EntryIterator = AsyncIterator[EntryBundle]


@dataclasses.dataclass(kw_only=True)
Expand Down
99 changes: 66 additions & 33 deletions cumulus_etl/etl/tasks/nlp_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base NLP task support"""

import copy
import dataclasses
import json
import logging
import os
Expand Down Expand Up @@ -54,11 +55,11 @@ class BaseNlpTask(tasks.EtlTask):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.seen_docrefs = set()
self.seen_groups = set()

def pop_current_group_values(self, table_index: int) -> set[str]:
values = self.seen_docrefs
self.seen_docrefs = set()
values = self.seen_groups
self.seen_groups = set()
return values

def add_error(self, docref: dict) -> None:
Expand Down Expand Up @@ -121,10 +122,28 @@ def remove_trailing_whitespace(note_text: str) -> str:
return TRAILING_WHITESPACE.sub("", note_text)


@dataclasses.dataclass(kw_only=True)
class NoteDetails:
note_ref: str
encounter_id: str
subject_ref: str

note_text: str
note: dict

orig_note_ref: str
orig_note_text: str
orig_note: dict


class BaseModelTask(BaseNlpTask):
"""Base class for any NLP task talking to LLM models."""

outputs: ClassVar = [tasks.OutputTable(resource_type=None, uniqueness_fields={"note_ref"})]
outputs: ClassVar = [
tasks.OutputTable(
resource_type=None, uniqueness_fields={"note_ref"}, group_field="note_ref"
)
]

# If you change these prompts, consider updating task_version.
system_prompt: str = None
Expand Down Expand Up @@ -155,33 +174,47 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
note_text = self.remove_trailing_whitespace(orig_note_text)
orig_note_ref = f"{orig_note['resourceType']}/{orig_note['id']}"

details = NoteDetails(
note_ref=note_ref,
encounter_id=encounter_id,
subject_ref=subject_ref,
note_text=note_text,
note=note,
orig_note_ref=orig_note_ref,
orig_note_text=orig_note_text,
orig_note=orig_note,
)

try:
response = await self.model.prompt(
self.get_system_prompt(),
self.get_user_prompt(note_text),
schema=self.response_format,
cache_dir=self.task_config.dir_phi,
cache_namespace=f"{self.name}_v{self.task_version}",
note_text=note_text,
)
if result := await self.process_note(details):
yield result
except Exception as exc:
logging.warning(f"NLP failed for {orig_note_ref}: {exc}")
self.add_error(orig_note)
continue

parsed = response.answer.model_dump(mode="json")
self.post_process(parsed, orig_note_text, orig_note)
async def process_note(self, details: NoteDetails) -> tasks.EntryBundle | None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just some logic moved from one function into a helper function, to make it easier for subclassess to override it and add more logic if they want.

response = await self.model.prompt(
self.get_system_prompt(),
self.get_user_prompt(details.note_text),
schema=self.response_format,
cache_dir=self.task_config.dir_phi,
cache_namespace=f"{self.name}_v{self.task_version}",
note_text=details.note_text,
)

parsed = response.answer.model_dump(mode="json")
self.post_process(parsed, details)

yield {
"note_ref": note_ref,
"encounter_ref": f"Encounter/{encounter_id}",
"subject_ref": subject_ref,
# Since this date is stored as a string, use UTC time for easy comparisons
"generated_on": common.datetime_now().isoformat(),
"task_version": self.task_version,
"system_fingerprint": response.fingerprint,
"result": parsed,
}
return {
"note_ref": details.note_ref,
"encounter_ref": f"Encounter/{details.encounter_id}",
"subject_ref": details.subject_ref,
# Since this date is stored as a string, use UTC time for easy comparisons
"generated_on": common.datetime_now().isoformat(),
"task_version": self.task_version,
"system_fingerprint": response.fingerprint,
"result": parsed,
}

def finish_task(self) -> None:
stats = self.model.stats
Expand Down Expand Up @@ -225,7 +258,7 @@ def should_skip(self, orig_note: dict) -> bool:
"""Subclasses can fill this out if they like, to skip notes"""
return False

def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
def post_process(self, parsed: dict, details: NoteDetails) -> None:
"""Subclasses can fill this out if they like"""

@classmethod
Expand Down Expand Up @@ -289,18 +322,18 @@ class BaseModelTaskWithSpans(BaseModelTask):
It assumes any field named "spans" in the hierarchy of the pydantic model should be converted.
"""

def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
if not self._process_dict(parsed, orig_note_text, orig_note):
self.add_error(orig_note)
def post_process(self, parsed: dict, details: NoteDetails) -> None:
if not self._process_dict(parsed, details):
self.add_error(details.orig_note)

def _process_dict(self, parsed: dict, orig_note_text: str, orig_note: dict) -> bool:
def _process_dict(self, parsed: dict, details: NoteDetails) -> bool:
"""Returns False if any span couldn't be matched"""
all_found = True

for key, value in parsed.items():
if key != "spans":
if isinstance(value, dict):
all_found &= self._process_dict(value, orig_note_text, orig_note) # descend
all_found &= self._process_dict(value, details) # descend
continue

new_spans = []
Expand All @@ -318,14 +351,14 @@ def _process_dict(self, parsed: dict, orig_note_text: str, orig_note: dict) -> b
span = ESCAPED_WHITESPACE.sub(r"\\s+", span)

found = False
for match in re.finditer(span, orig_note_text, re.IGNORECASE):
for match in re.finditer(span, details.orig_note_text, re.IGNORECASE):
found = True
new_spans.append(match.span())
if not found:
all_found = False
logging.warning(
"Could not match span received from NLP server for "
f"{orig_note['resourceType']}/{orig_note['id']}: {orig_span}"
f"{details.orig_note_ref}: {orig_span}"
)

parsed[key] = new_spans
Expand Down
2 changes: 1 addition & 1 deletion cumulus_etl/inliner/inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,5 +219,5 @@ async def _inline_one_attachment(
# Overwrite other associated metadata with latest info (existing metadata might now be stale)
attachment["contentType"] = f"{mimetype}; charset={response.encoding}"
attachment["size"] = len(response.content)
sha1_hash = hashlib.sha1(response.content).digest() # noqa: S324
sha1_hash = hashlib.sha1(response.content, usedforsecurity=False).digest()
attachment["hash"] = base64.standard_b64encode(sha1_hash).decode("ascii")
8 changes: 6 additions & 2 deletions cumulus_etl/nlp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,12 @@ async def prompt(self, system: str, user: str, schema: type[BaseModel]) -> Promp
if "toolUse" in content:
raw_json = content["toolUse"]["input"]
# Sometimes (e.g. with claude sonnet 4.5) we get a wrapper field of "parameter"
if len(raw_json) == 1 and "parameter" in raw_json:
raw_json = raw_json["parameter"]
# or "$PARAMETER_NAME" :shrug: - for now, just look for those names in particular.
# If we see a wider variety, we can try to skip any single wrapper field, but I
# want to keep the option of studies that have only one top level field for now.
top_keys = set(raw_json)
if len(top_keys) == 1 and {"parameter", "$PARAMETER_NAME"} & top_keys:
raw_json = raw_json.popitem()[1]
Comment on lines +147 to +152
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the real fix in this PR - everything else was just some light refactoring/updating.

answer = schema.model_validate(raw_json)
break
if "text" in content:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dev = [
"pre-commit",
# Ruff is using minor versions for breaking changes until their 1.0 release.
# See https://docs.astral.sh/ruff/versioning/
"ruff < 0.14", # keep in rough sync with .pre-commit-config.yaml
"ruff < 0.15", # keep in rough sync with .pre-commit-config.yaml
]

[project.urls]
Expand Down