Skip to content

Commit 2a2834d

Browse files
authored
Merge pull request #445 from smart-on-fhir/mikix/sublabels
upload-notes: support sublabels in --label-by-* sources
2 parents 38c05ac + 9d07835 commit 2a2834d

File tree

8 files changed

+254
-86
lines changed

8 files changed

+254
-86
lines changed

.github/workflows/ci.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ jobs:
4646
run: |
4747
sudo apt-get update
4848
sudo apt-get install dotnet8
49+
sed -i 's/;net9.0</</' mstool/Directory.Build.props # disable net9.0, it confuses SDK 8.0
4950
dotnet publish \
51+
--framework=net8.0 \
5052
--runtime=linux-x64 \
51-
--configuration=Release \
5253
-p:PublishSingleFile=true \
5354
--output=$HOME/.local/bin \
5455
mstool/FHIR/src/Microsoft.Health.Fhir.Anonymizer.R4.CommandLineTool

Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ FROM mcr.microsoft.com/dotnet/sdk:8.0 AS ms-tool
77
COPY --from=ms-tool-src /app /app
88
# This will force builds to fail if the environment piping breaks for some reason
99
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
10+
RUN sed -i 's/;net9.0</</' /app/Directory.Build.props # disable net9.0, it confuses SDK 8.0
1011
RUN arch=$(arch | sed s/aarch64/arm64/ | sed s/x86_64/x64/) && \
1112
dotnet publish \
13+
--framework=net8.0 \
1214
--runtime=linux-${arch} \
1315
--self-contained=true \
14-
--configuration=Release \
1516
-p:InvariantGlobalization=true \
1617
-p:PublishSingleFile=true \
1718
--output=/bin \

cumulus_etl/errors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
MULTIPLE_COHORT_ARGS = 47
4848
COHORT_NOT_FOUND = 48
4949
MULTIPLE_LABELING_ARGS = 49
50+
LABEL_UNKNOWN = 50
51+
LABEL_CONFIG_TYPE_UNKNOWN = 51
5052

5153

5254
class FatalError(Exception):

cumulus_etl/upload_notes/cli.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import argparse
44
import asyncio
5+
import dataclasses
56
import datetime
67
import sys
78
from collections.abc import Callable, Collection
@@ -247,7 +248,7 @@ def group_notes_by_unique_id(notes: Collection[LabelStudioNote]) -> list[LabelSt
247248
for unique_id, group_notes in by_unique_id.items():
248249
grouped_text = ""
249250
grouped_ctakes_matches = []
250-
grouped_highlights = {}
251+
grouped_highlights = []
251252
grouped_philter_map = {}
252253
grouped_doc_mappings = {}
253254
grouped_doc_spans = {}
@@ -283,14 +284,10 @@ def group_notes_by_unique_id(notes: Collection[LabelStudioNote]) -> list[LabelSt
283284
match.end += offset
284285
grouped_ctakes_matches.append(match)
285286

286-
for source, labels in note.highlights.items():
287-
grouped_labels = grouped_highlights.setdefault(source, {})
288-
for label, spans in labels.items():
289-
for span in spans:
290-
new_span = ctakesclient.typesystem.Span(
291-
span.begin + offset, span.end + offset
292-
)
293-
grouped_labels.setdefault(label, []).append(new_span)
287+
for highlight in note.highlights:
288+
span = highlight.span
289+
new_span = (span[0] + offset, span[1] + offset)
290+
grouped_highlights.append(dataclasses.replace(highlight, span=new_span))
294291

295292
for start, stop in note.philter_map.items():
296293
grouped_philter_map[start + offset] = stop + offset

cumulus_etl/upload_notes/labeling.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import argparse
44
from collections.abc import Collection
55

6-
import ctakesclient
7-
86
from cumulus_etl import cli_utils, common, deid, errors, nlp
97
from cumulus_etl.upload_notes import labelstudio
108

@@ -56,7 +54,17 @@ def _label_by_csv(
5654
*,
5755
is_anon: bool,
5856
) -> None:
59-
matcher = nlp.CsvMatcher(csv_file, is_anon=is_anon, extra_fields=["label", "span", "origin"])
57+
matcher = nlp.CsvMatcher(
58+
csv_file,
59+
is_anon=is_anon,
60+
extra_fields=[
61+
"label",
62+
"span",
63+
"sublabel_name",
64+
"sublabel_value",
65+
"origin",
66+
],
67+
)
6068

6169
for note in notes:
6270
for ref, doc_span in note.doc_spans.items():
@@ -65,15 +73,23 @@ def _label_by_csv(
6573
for match in sorted(matches):
6674
label = match[0]
6775
span = match[1]
68-
origin = match[2] or DEFAULT_ORIGIN
76+
sublabel_name = match[2] or None
77+
sublabel_value = match[3] or None
78+
origin = match[4] or DEFAULT_ORIGIN
6979
if "__" in origin: # if it looks like a table name, chop it down
7080
origin = origin.split("__", 1)[-1].removeprefix("nlp_")
7181
if label and span and ":" in span:
7282
begin, end = span.split(":", 1)
73-
span = ctakesclient.typesystem.Span(
74-
int(begin) + doc_span[0], int(end) + doc_span[0]
83+
span = (int(begin) + doc_span[0], int(end) + doc_span[0])
84+
note.highlights.append(
85+
labelstudio.Highlight(
86+
label,
87+
span,
88+
origin=origin,
89+
sublabel_name=sublabel_name,
90+
sublabel_value=sublabel_value,
91+
)
7592
)
76-
note.highlights.setdefault(origin, {}).setdefault(label, []).append(span)
7793

7894

7995
def _highlight_words(
@@ -91,8 +107,12 @@ def _highlight_words(
91107
for note in notes:
92108
for pattern in patterns:
93109
for match in pattern.finditer(note.text):
94-
# Look at group 2 (the middle term group, ignoring the edge groups)
95-
span = ctakesclient.typesystem.Span(match.start(2), match.end(2))
96-
labels = note.highlights.setdefault(DEFAULT_ORIGIN, {})
97-
# We use a generic default label to cause Label Studio to highlight it
98-
labels.setdefault(DEFAULT_LABEL, []).append(span)
110+
note.highlights.append(
111+
labelstudio.Highlight(
112+
# We use a generic default label to cause Label Studio to highlight it
113+
label=DEFAULT_LABEL,
114+
# Look at group 2 (the middle term group, ignoring the edge groups)
115+
span=(match.start(2), match.end(2)),
116+
origin=DEFAULT_ORIGIN,
117+
)
118+
)

cumulus_etl/upload_notes/labelstudio.py

Lines changed: 88 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dataclasses
44
import datetime
5+
import hashlib
56
import math
67
from collections.abc import AsyncIterator, Collection, Iterable
78

@@ -18,6 +19,17 @@
1819
###############################################################################
1920

2021

22+
@dataclasses.dataclass
23+
class Highlight:
24+
"""Describes a label, a span, and some extra metadata"""
25+
26+
label: str
27+
span: tuple[int, int]
28+
origin: str
29+
sublabel_name: str | None = None
30+
sublabel_value: str | None = None
31+
32+
2133
@dataclasses.dataclass
2234
class LabelStudioNote:
2335
"""Holds all the data that Label Studio will need for a single note (or a single grouped encounter note)"""
@@ -45,10 +57,8 @@ class LabelStudioNote:
4557
default_factory=list
4658
)
4759

48-
# Matches found by word search or csv, as a dict of origins -> labels -> found spans
49-
highlights: dict[str, dict[str | None, list[ctakesclient.typesystem.Span]]] = dataclasses.field(
50-
default_factory=dict
51-
)
60+
# Matches found by word search or csv
61+
highlights: list[Highlight] = dataclasses.field(default_factory=list)
5262

5363
# Matches found by Philter
5464
philter_map: dict[int, int] = dataclasses.field(default_factory=dict)
@@ -167,19 +177,49 @@ def _format_task_for_note(self, note: LabelStudioNote) -> dict:
167177

168178
return task
169179

170-
def _format_match(self, begin: int, end: int, text: str, labels: Iterable[str]) -> dict:
171-
return {
172-
"from_name": self._labels_name,
173-
"to_name": self._labels_config["to_name"][0],
174-
"type": "labels",
180+
def _format_match(
181+
self,
182+
begin: int,
183+
end: int,
184+
text: str,
185+
labels: Iterable[str],
186+
from_name: str | None = None,
187+
label_id: str | None = None,
188+
) -> dict:
189+
from_name = from_name or self._labels_name
190+
config = self._project.parsed_label_config.get(from_name)
191+
if not config:
192+
errors.fatal(f"Unrecognized label name '{from_name}'.", errors.LABEL_UNKNOWN)
193+
194+
match = {
195+
"from_name": from_name,
196+
"to_name": config["to_name"][0],
197+
"type": config["type"].casefold(),
175198
"value": {
176199
"start": begin,
177200
"end": end,
178201
"score": 1.0,
179202
"text": text,
180-
"labels": list(labels),
181203
},
182204
}
205+
if label_id:
206+
match["id"] = label_id
207+
208+
match config["type"].casefold():
209+
case "labels":
210+
field = "labels"
211+
case "choices":
212+
field = "choices"
213+
case "textarea":
214+
field = "text"
215+
case _:
216+
errors.fatal(
217+
f"Unrecognized Label Studio config type '{config['type']}'.",
218+
errors.LABEL_CONFIG_TYPE_UNKNOWN,
219+
)
220+
221+
match["value"][field] = list(labels)
222+
return match
183223

184224
def _format_ctakes_predictions(self, task: dict, note: LabelStudioNote) -> None:
185225
if not note.ctakes_matches:
@@ -208,20 +248,46 @@ def _format_ctakes_predictions(self, task: dict, note: LabelStudioNote) -> None:
208248
self._update_used_labels(task, used_labels)
209249

210250
def _format_highlights_predictions(self, task: dict, note: LabelStudioNote) -> None:
211-
for source, labels in note.highlights.items():
212-
prediction = {"model_version": source}
213-
results = []
214-
for label, spans in labels.items():
215-
for span in spans:
216-
results.append(
217-
self._format_match(
218-
span.begin, span.end, note.text[span.begin : span.end], [label]
219-
)
251+
# Group up the highlights by parent label.
252+
# Then we'll see how many sublabels it has.
253+
grouped_highlights = {} # key-tuple -> sublabel name -> sublabel value list
254+
for highlight in note.highlights:
255+
key = (highlight.label, highlight.span, highlight.origin)
256+
sublabels = grouped_highlights.setdefault(key, {})
257+
sublabels.setdefault(highlight.sublabel_name, []).append(highlight.sublabel_value)
258+
259+
predictions = {} # dict of origin -> prediction dict
260+
for key, sublabels in grouped_highlights.items():
261+
label, span, origin = key
262+
default_prediction = {"model_version": origin, "result": []}
263+
prediction = predictions.setdefault(origin, default_prediction)
264+
265+
label_id = "__".join(str(k) for k in key)
266+
label_id = hashlib.md5(label_id.encode(), usedforsecurity=False).hexdigest()
267+
text = note.text[span[0] : span[1]]
268+
269+
# First, add the parent label
270+
prediction["result"].append(
271+
self._format_match(span[0], span[1], text, [label], label_id=label_id)
272+
)
273+
274+
# Now add sublabels
275+
for sublabel_name, sublabel_values in sublabels.items():
276+
if not sublabel_name:
277+
continue
278+
prediction["result"].append(
279+
self._format_match(
280+
span[0],
281+
span[1],
282+
text,
283+
sublabel_values,
284+
label_id=label_id,
285+
from_name=sublabel_name,
220286
)
221-
prediction["result"] = results
222-
task["predictions"].append(prediction)
287+
)
223288

224-
self._update_used_labels(task, labels.keys())
289+
task["predictions"].extend(predictions.values())
290+
self._update_used_labels(task, {x.label for x in note.highlights})
225291

226292
def _format_philter_predictions(self, task: dict, note: LabelStudioNote) -> None:
227293
"""

0 commit comments

Comments
 (0)