Skip to content

Commit 63c94ef

Browse files
authored
Merge pull request #468 from smart-on-fhir/mikix/claude-fix
nlp: be more flexible with the answers accepted from claude
2 parents 593cf0f + 1ac203f commit 63c94ef

File tree

13 files changed

+91
-53
lines changed

13 files changed

+91
-53
lines changed

.github/workflows/ci.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ jobs:
2424
python-version: ["3.11"]
2525

2626
steps:
27-
- uses: actions/checkout@v4
27+
- uses: actions/checkout@v5
2828

2929
- name: Set up Python ${{ matrix.python-version }}
30-
uses: actions/setup-python@v5
30+
uses: actions/setup-python@v6
3131
with:
3232
python-version: ${{ matrix.python-version }}
3333

@@ -37,7 +37,7 @@ jobs:
3737
pip install .[tests]
3838
3939
- name: Check out MS tool
40-
uses: actions/checkout@v4
40+
uses: actions/checkout@v5
4141
with:
4242
repository: microsoft/Tools-for-Health-Data-Anonymization
4343
path: mstool
@@ -77,7 +77,7 @@ jobs:
7777
env:
7878
UMLS_API_KEY: ${{ secrets.UMLS_API_KEY }}
7979
steps:
80-
- uses: actions/checkout@v4
80+
- uses: actions/checkout@v5
8181

8282
- name: Install Docker
8383
uses: docker/setup-buildx-action@v3
@@ -122,7 +122,7 @@ jobs:
122122
lint:
123123
runs-on: ubuntu-latest
124124
steps:
125-
- uses: actions/checkout@v4
125+
- uses: actions/checkout@v5
126126

127127
- name: Install linters
128128
run: |

.github/workflows/docker-hub.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
name: Build and push image
1010
runs-on: ubuntu-latest
1111
steps:
12-
- uses: actions/checkout@v4
12+
- uses: actions/checkout@v5
1313

1414
- name: Set up Docker Buildx
1515
uses: docker/setup-buildx-action@v3

.github/workflows/pages.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
runs-on: ubuntu-latest
1111
steps:
1212
- name: Send workflow dispatch
13-
uses: actions/github-script@v6
13+
uses: actions/github-script@v8
1414
with:
1515
# This token is set to expire in May 2024.
1616
# You can make a new one with write access to Actions on the cumulus repo.

.github/workflows/pypi.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
id-token: write # this permission is required for PyPI "trusted publishing"
1414

1515
steps:
16-
- uses: actions/checkout@v4
16+
- uses: actions/checkout@v5
1717

1818
- name: Install dependencies
1919
run: |

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.13.0 # keep in rough sync with pyproject.toml
3+
rev: v0.14.1 # keep in rough sync with pyproject.toml
44
hooks:
55
- name: Ruff formatting
66
id: ruff-format

cumulus_etl/etl/studies/covid_symptom/covid_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
155155
# But the current approach instead focuses purely on accuracy and makes sure that we zero-out any dangling
156156
# entries for groups that we do process.
157157
# Downstream SQL can ignore the above cases itself, as needed.
158-
self.seen_docrefs.add(docref["id"])
158+
self.seen_groups.add(docref["id"])
159159

160160
# Yield the whole set of symptoms at once, to allow for more easily replacing previous a set of symptoms.
161161
# This way we don't need to worry about symptoms from the same note crossing batch boundaries.

cumulus_etl/etl/studies/irae/irae_tasks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,8 @@ def should_skip(self, orig_note: dict) -> bool:
510510
subject_ref = nlp.get_note_subject_ref(orig_note)
511511
return subject_ref in self.subject_refs_to_skip or super().should_skip(orig_note)
512512

513-
def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
514-
super().post_process(parsed, orig_note_text, orig_note)
513+
def post_process(self, parsed: dict, details: tasks.NoteDetails) -> None:
514+
super().post_process(parsed, details)
515515

516516
# If we have an annotation that asserts a graft failure or deceased,
517517
# we can stop processing charts for that patient, to avoid pointless NLP requests.
@@ -526,7 +526,7 @@ def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> No
526526
is_deceased = deceased.get("has_mention") and deceased.get("deceased")
527527

528528
if is_failed or is_deceased:
529-
if subject_ref := nlp.get_note_subject_ref(orig_note):
529+
if subject_ref := nlp.get_note_subject_ref(details.orig_note):
530530
self.subject_refs_to_skip.add(subject_ref)
531531

532532

cumulus_etl/etl/tasks/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Task support for the ETL workflow"""
22

3-
from .base import EntryIterator, EtlTask, OutputTable
4-
from .nlp_task import BaseModelTask, BaseModelTaskWithSpans, BaseNlpTask
3+
from .base import EntryBundle, EntryIterator, EtlTask, OutputTable
4+
from .nlp_task import BaseModelTask, BaseModelTaskWithSpans, BaseNlpTask, NoteDetails

cumulus_etl/etl/tasks/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
# Defined here, as syntactic sugar for when you subclass your own task and re-define read_entries()
2121
EntryAtom = dict | list[dict]
22-
EntryIterator = AsyncIterator[EntryAtom | tuple[EntryAtom, ...]]
22+
EntryBundle = EntryAtom | tuple[EntryAtom, ...]
23+
EntryIterator = AsyncIterator[EntryBundle]
2324

2425

2526
@dataclasses.dataclass(kw_only=True)

cumulus_etl/etl/tasks/nlp_task.py

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Base NLP task support"""
22

33
import copy
4+
import dataclasses
45
import json
56
import logging
67
import os
@@ -54,11 +55,11 @@ class BaseNlpTask(tasks.EtlTask):
5455

5556
def __init__(self, *args, **kwargs):
5657
super().__init__(*args, **kwargs)
57-
self.seen_docrefs = set()
58+
self.seen_groups = set()
5859

5960
def pop_current_group_values(self, table_index: int) -> set[str]:
60-
values = self.seen_docrefs
61-
self.seen_docrefs = set()
61+
values = self.seen_groups
62+
self.seen_groups = set()
6263
return values
6364

6465
def add_error(self, docref: dict) -> None:
@@ -121,10 +122,28 @@ def remove_trailing_whitespace(note_text: str) -> str:
121122
return TRAILING_WHITESPACE.sub("", note_text)
122123

123124

125+
@dataclasses.dataclass(kw_only=True)
126+
class NoteDetails:
127+
note_ref: str
128+
encounter_id: str
129+
subject_ref: str
130+
131+
note_text: str
132+
note: dict
133+
134+
orig_note_ref: str
135+
orig_note_text: str
136+
orig_note: dict
137+
138+
124139
class BaseModelTask(BaseNlpTask):
125140
"""Base class for any NLP task talking to LLM models."""
126141

127-
outputs: ClassVar = [tasks.OutputTable(resource_type=None, uniqueness_fields={"note_ref"})]
142+
outputs: ClassVar = [
143+
tasks.OutputTable(
144+
resource_type=None, uniqueness_fields={"note_ref"}, group_field="note_ref"
145+
)
146+
]
128147

129148
# If you change these prompts, consider updating task_version.
130149
system_prompt: str = None
@@ -155,33 +174,47 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
155174
note_text = self.remove_trailing_whitespace(orig_note_text)
156175
orig_note_ref = f"{orig_note['resourceType']}/{orig_note['id']}"
157176

177+
details = NoteDetails(
178+
note_ref=note_ref,
179+
encounter_id=encounter_id,
180+
subject_ref=subject_ref,
181+
note_text=note_text,
182+
note=note,
183+
orig_note_ref=orig_note_ref,
184+
orig_note_text=orig_note_text,
185+
orig_note=orig_note,
186+
)
187+
158188
try:
159-
response = await self.model.prompt(
160-
self.get_system_prompt(),
161-
self.get_user_prompt(note_text),
162-
schema=self.response_format,
163-
cache_dir=self.task_config.dir_phi,
164-
cache_namespace=f"{self.name}_v{self.task_version}",
165-
note_text=note_text,
166-
)
189+
if result := await self.process_note(details):
190+
yield result
167191
except Exception as exc:
168192
logging.warning(f"NLP failed for {orig_note_ref}: {exc}")
169193
self.add_error(orig_note)
170-
continue
171194

172-
parsed = response.answer.model_dump(mode="json")
173-
self.post_process(parsed, orig_note_text, orig_note)
195+
async def process_note(self, details: NoteDetails) -> tasks.EntryBundle | None:
196+
response = await self.model.prompt(
197+
self.get_system_prompt(),
198+
self.get_user_prompt(details.note_text),
199+
schema=self.response_format,
200+
cache_dir=self.task_config.dir_phi,
201+
cache_namespace=f"{self.name}_v{self.task_version}",
202+
note_text=details.note_text,
203+
)
204+
205+
parsed = response.answer.model_dump(mode="json")
206+
self.post_process(parsed, details)
174207

175-
yield {
176-
"note_ref": note_ref,
177-
"encounter_ref": f"Encounter/{encounter_id}",
178-
"subject_ref": subject_ref,
179-
# Since this date is stored as a string, use UTC time for easy comparisons
180-
"generated_on": common.datetime_now().isoformat(),
181-
"task_version": self.task_version,
182-
"system_fingerprint": response.fingerprint,
183-
"result": parsed,
184-
}
208+
return {
209+
"note_ref": details.note_ref,
210+
"encounter_ref": f"Encounter/{details.encounter_id}",
211+
"subject_ref": details.subject_ref,
212+
# Since this date is stored as a string, use UTC time for easy comparisons
213+
"generated_on": common.datetime_now().isoformat(),
214+
"task_version": self.task_version,
215+
"system_fingerprint": response.fingerprint,
216+
"result": parsed,
217+
}
185218

186219
def finish_task(self) -> None:
187220
stats = self.model.stats
@@ -225,7 +258,7 @@ def should_skip(self, orig_note: dict) -> bool:
225258
"""Subclasses can fill this out if they like, to skip notes"""
226259
return False
227260

228-
def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
261+
def post_process(self, parsed: dict, details: NoteDetails) -> None:
229262
"""Subclasses can fill this out if they like"""
230263

231264
@classmethod
@@ -289,18 +322,18 @@ class BaseModelTaskWithSpans(BaseModelTask):
289322
It assumes any field named "spans" in the hierarchy of the pydantic model should be converted.
290323
"""
291324

292-
def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None:
293-
if not self._process_dict(parsed, orig_note_text, orig_note):
294-
self.add_error(orig_note)
325+
def post_process(self, parsed: dict, details: NoteDetails) -> None:
326+
if not self._process_dict(parsed, details):
327+
self.add_error(details.orig_note)
295328

296-
def _process_dict(self, parsed: dict, orig_note_text: str, orig_note: dict) -> bool:
329+
def _process_dict(self, parsed: dict, details: NoteDetails) -> bool:
297330
"""Returns False if any span couldn't be matched"""
298331
all_found = True
299332

300333
for key, value in parsed.items():
301334
if key != "spans":
302335
if isinstance(value, dict):
303-
all_found &= self._process_dict(value, orig_note_text, orig_note) # descend
336+
all_found &= self._process_dict(value, details) # descend
304337
continue
305338

306339
new_spans = []
@@ -318,14 +351,14 @@ def _process_dict(self, parsed: dict, orig_note_text: str, orig_note: dict) -> b
318351
span = ESCAPED_WHITESPACE.sub(r"\\s+", span)
319352

320353
found = False
321-
for match in re.finditer(span, orig_note_text, re.IGNORECASE):
354+
for match in re.finditer(span, details.orig_note_text, re.IGNORECASE):
322355
found = True
323356
new_spans.append(match.span())
324357
if not found:
325358
all_found = False
326359
logging.warning(
327360
"Could not match span received from NLP server for "
328-
f"{orig_note['resourceType']}/{orig_note['id']}: {orig_span}"
361+
f"{details.orig_note_ref}: {orig_span}"
329362
)
330363

331364
parsed[key] = new_spans

0 commit comments

Comments
 (0)