|
2 | 2 |
|
3 | 3 | import dataclasses |
4 | 4 | import datetime |
| 5 | +import hashlib |
5 | 6 | import math |
6 | 7 | from collections.abc import AsyncIterator, Collection, Iterable |
7 | 8 |
|
|
18 | 19 | ############################################################################### |
19 | 20 |
|
20 | 21 |
|
| 22 | +@dataclasses.dataclass |
| 23 | +class Highlight: |
| 24 | + """Describes a label, a span, and some extra metadata""" |
| 25 | + |
| 26 | + label: str |
| 27 | + span: tuple[int, int] |
| 28 | + origin: str |
| 29 | + sublabel_name: str | None = None |
| 30 | + sublabel_value: str | None = None |
| 31 | + |
| 32 | + |
21 | 33 | @dataclasses.dataclass |
22 | 34 | class LabelStudioNote: |
23 | 35 | """Holds all the data that Label Studio will need for a single note (or a single grouped encounter note)""" |
@@ -45,10 +57,8 @@ class LabelStudioNote: |
45 | 57 | default_factory=list |
46 | 58 | ) |
47 | 59 |
|
48 | | - # Matches found by word search or csv, as a dict of origins -> labels -> found spans |
49 | | - highlights: dict[str, dict[str | None, list[ctakesclient.typesystem.Span]]] = dataclasses.field( |
50 | | - default_factory=dict |
51 | | - ) |
| 60 | + # Matches found by word search or csv |
| 61 | + highlights: list[Highlight] = dataclasses.field(default_factory=list) |
52 | 62 |
|
53 | 63 | # Matches found by Philter |
54 | 64 | philter_map: dict[int, int] = dataclasses.field(default_factory=dict) |
@@ -167,19 +177,49 @@ def _format_task_for_note(self, note: LabelStudioNote) -> dict: |
167 | 177 |
|
168 | 178 | return task |
169 | 179 |
|
170 | | - def _format_match(self, begin: int, end: int, text: str, labels: Iterable[str]) -> dict: |
171 | | - return { |
172 | | - "from_name": self._labels_name, |
173 | | - "to_name": self._labels_config["to_name"][0], |
174 | | - "type": "labels", |
| 180 | + def _format_match( |
| 181 | + self, |
| 182 | + begin: int, |
| 183 | + end: int, |
| 184 | + text: str, |
| 185 | + labels: Iterable[str], |
| 186 | + from_name: str | None = None, |
| 187 | + label_id: str | None = None, |
| 188 | + ) -> dict: |
| 189 | + from_name = from_name or self._labels_name |
| 190 | + config = self._project.parsed_label_config.get(from_name) |
| 191 | + if not config: |
| 192 | + errors.fatal(f"Unrecognized label name '{from_name}'.", errors.LABEL_UNKNOWN) |
| 193 | + |
| 194 | + match = { |
| 195 | + "from_name": from_name, |
| 196 | + "to_name": config["to_name"][0], |
| 197 | + "type": config["type"].casefold(), |
175 | 198 | "value": { |
176 | 199 | "start": begin, |
177 | 200 | "end": end, |
178 | 201 | "score": 1.0, |
179 | 202 | "text": text, |
180 | | - "labels": list(labels), |
181 | 203 | }, |
182 | 204 | } |
| 205 | + if label_id: |
| 206 | + match["id"] = label_id |
| 207 | + |
| 208 | + match config["type"].casefold(): |
| 209 | + case "labels": |
| 210 | + field = "labels" |
| 211 | + case "choices": |
| 212 | + field = "choices" |
| 213 | + case "textarea": |
| 214 | + field = "text" |
| 215 | + case _: |
| 216 | + errors.fatal( |
| 217 | + f"Unrecognized Label Studio config type '{config['type']}'.", |
| 218 | + errors.LABEL_CONFIG_TYPE_UNKNOWN, |
| 219 | + ) |
| 220 | + |
| 221 | + match["value"][field] = list(labels) |
| 222 | + return match |
183 | 223 |
|
184 | 224 | def _format_ctakes_predictions(self, task: dict, note: LabelStudioNote) -> None: |
185 | 225 | if not note.ctakes_matches: |
@@ -208,20 +248,46 @@ def _format_ctakes_predictions(self, task: dict, note: LabelStudioNote) -> None: |
208 | 248 | self._update_used_labels(task, used_labels) |
209 | 249 |
|
210 | 250 | def _format_highlights_predictions(self, task: dict, note: LabelStudioNote) -> None: |
211 | | - for source, labels in note.highlights.items(): |
212 | | - prediction = {"model_version": source} |
213 | | - results = [] |
214 | | - for label, spans in labels.items(): |
215 | | - for span in spans: |
216 | | - results.append( |
217 | | - self._format_match( |
218 | | - span.begin, span.end, note.text[span.begin : span.end], [label] |
219 | | - ) |
| 251 | + # Group up the highlights by parent label. |
| 252 | + # Then we'll see how many sublabels it has. |
| 253 | + grouped_highlights = {} # key-tuple -> sublabel name -> sublabel value list |
| 254 | + for highlight in note.highlights: |
| 255 | + key = (highlight.label, highlight.span, highlight.origin) |
| 256 | + sublabels = grouped_highlights.setdefault(key, {}) |
| 257 | + sublabels.setdefault(highlight.sublabel_name, []).append(highlight.sublabel_value) |
| 258 | + |
| 259 | + predictions = {} # dict of origin -> prediction dict |
| 260 | + for key, sublabels in grouped_highlights.items(): |
| 261 | + label, span, origin = key |
| 262 | + default_prediction = {"model_version": origin, "result": []} |
| 263 | + prediction = predictions.setdefault(origin, default_prediction) |
| 264 | + |
| 265 | + label_id = "__".join(str(k) for k in key) |
| 266 | + label_id = hashlib.md5(label_id.encode(), usedforsecurity=False).hexdigest() |
| 267 | + text = note.text[span[0] : span[1]] |
| 268 | + |
| 269 | + # First, add the parent label |
| 270 | + prediction["result"].append( |
| 271 | + self._format_match(span[0], span[1], text, [label], label_id=label_id) |
| 272 | + ) |
| 273 | + |
| 274 | + # Now add sublabels |
| 275 | + for sublabel_name, sublabel_values in sublabels.items(): |
| 276 | + if not sublabel_name: |
| 277 | + continue |
| 278 | + prediction["result"].append( |
| 279 | + self._format_match( |
| 280 | + span[0], |
| 281 | + span[1], |
| 282 | + text, |
| 283 | + sublabel_values, |
| 284 | + label_id=label_id, |
| 285 | + from_name=sublabel_name, |
220 | 286 | ) |
221 | | - prediction["result"] = results |
222 | | - task["predictions"].append(prediction) |
| 287 | + ) |
223 | 288 |
|
224 | | - self._update_used_labels(task, labels.keys()) |
| 289 | + task["predictions"].extend(predictions.values()) |
| 290 | + self._update_used_labels(task, {x.label for x in note.highlights}) |
225 | 291 |
|
226 | 292 | def _format_philter_predictions(self, task: dict, note: LabelStudioNote) -> None: |
227 | 293 | """ |
|
0 commit comments