| 
2 | 2 | 
 
  | 
3 | 3 | import dataclasses  | 
4 | 4 | import datetime  | 
 | 5 | +import hashlib  | 
5 | 6 | import math  | 
6 | 7 | from collections.abc import AsyncIterator, Collection, Iterable  | 
7 | 8 | 
 
  | 
 | 
18 | 19 | ###############################################################################  | 
19 | 20 | 
 
  | 
20 | 21 | 
 
  | 
 | 22 | +@dataclasses.dataclass  | 
 | 23 | +class Highlight:  | 
 | 24 | +    """Describes a label, a span, and some extra metadata"""  | 
 | 25 | + | 
 | 26 | +    label: str  | 
 | 27 | +    span: tuple[int, int]  | 
 | 28 | +    origin: str  | 
 | 29 | +    sublabel_name: str | None = None  | 
 | 30 | +    sublabel_value: str | None = None  | 
 | 31 | + | 
 | 32 | + | 
21 | 33 | @dataclasses.dataclass  | 
22 | 34 | class LabelStudioNote:  | 
23 | 35 |     """Holds all the data that Label Studio will need for a single note (or a single grouped encounter note)"""  | 
@@ -45,10 +57,8 @@ class LabelStudioNote:  | 
45 | 57 |         default_factory=list  | 
46 | 58 |     )  | 
47 | 59 | 
 
  | 
48 |  | -    # Matches found by word search or csv, as a dict of origins -> labels -> found spans  | 
49 |  | -    highlights: dict[str, dict[str | None, list[ctakesclient.typesystem.Span]]] = dataclasses.field(  | 
50 |  | -        default_factory=dict  | 
51 |  | -    )  | 
 | 60 | +    # Matches found by word search or csv  | 
 | 61 | +    highlights: list[Highlight] = dataclasses.field(default_factory=list)  | 
52 | 62 | 
 
  | 
53 | 63 |     # Matches found by Philter  | 
54 | 64 |     philter_map: dict[int, int] = dataclasses.field(default_factory=dict)  | 
@@ -167,19 +177,49 @@ def _format_task_for_note(self, note: LabelStudioNote) -> dict:  | 
167 | 177 | 
 
  | 
168 | 178 |         return task  | 
169 | 179 | 
 
  | 
170 |  | -    def _format_match(self, begin: int, end: int, text: str, labels: Iterable[str]) -> dict:  | 
171 |  | -        return {  | 
172 |  | -            "from_name": self._labels_name,  | 
173 |  | -            "to_name": self._labels_config["to_name"][0],  | 
174 |  | -            "type": "labels",  | 
 | 180 | +    def _format_match(  | 
 | 181 | +        self,  | 
 | 182 | +        begin: int,  | 
 | 183 | +        end: int,  | 
 | 184 | +        text: str,  | 
 | 185 | +        labels: Iterable[str],  | 
 | 186 | +        from_name: str | None = None,  | 
 | 187 | +        label_id: str | None = None,  | 
 | 188 | +    ) -> dict:  | 
 | 189 | +        from_name = from_name or self._labels_name  | 
 | 190 | +        config = self._project.parsed_label_config.get(from_name)  | 
 | 191 | +        if not config:  | 
 | 192 | +            errors.fatal(f"Unrecognized label name '{from_name}'.", errors.LABEL_UNKNOWN)  | 
 | 193 | + | 
 | 194 | +        match = {  | 
 | 195 | +            "from_name": from_name,  | 
 | 196 | +            "to_name": config["to_name"][0],  | 
 | 197 | +            "type": config["type"].casefold(),  | 
175 | 198 |             "value": {  | 
176 | 199 |                 "start": begin,  | 
177 | 200 |                 "end": end,  | 
178 | 201 |                 "score": 1.0,  | 
179 | 202 |                 "text": text,  | 
180 |  | -                "labels": list(labels),  | 
181 | 203 |             },  | 
182 | 204 |         }  | 
 | 205 | +        if label_id:  | 
 | 206 | +            match["id"] = label_id  | 
 | 207 | + | 
 | 208 | +        match config["type"].casefold():  | 
 | 209 | +            case "labels":  | 
 | 210 | +                field = "labels"  | 
 | 211 | +            case "choices":  | 
 | 212 | +                field = "choices"  | 
 | 213 | +            case "textarea":  | 
 | 214 | +                field = "text"  | 
 | 215 | +            case _:  | 
 | 216 | +                errors.fatal(  | 
 | 217 | +                    f"Unrecognized Label Studio config type '{config['type']}'.",  | 
 | 218 | +                    errors.LABEL_CONFIG_TYPE_UNKNOWN,  | 
 | 219 | +                )  | 
 | 220 | + | 
 | 221 | +        match["value"][field] = list(labels)  | 
 | 222 | +        return match  | 
183 | 223 | 
 
  | 
184 | 224 |     def _format_ctakes_predictions(self, task: dict, note: LabelStudioNote) -> None:  | 
185 | 225 |         if not note.ctakes_matches:  | 
@@ -208,20 +248,46 @@ def _format_ctakes_predictions(self, task: dict, note: LabelStudioNote) -> None:  | 
208 | 248 |         self._update_used_labels(task, used_labels)  | 
209 | 249 | 
 
  | 
210 | 250 |     def _format_highlights_predictions(self, task: dict, note: LabelStudioNote) -> None:  | 
211 |  | -        for source, labels in note.highlights.items():  | 
212 |  | -            prediction = {"model_version": source}  | 
213 |  | -            results = []  | 
214 |  | -            for label, spans in labels.items():  | 
215 |  | -                for span in spans:  | 
216 |  | -                    results.append(  | 
217 |  | -                        self._format_match(  | 
218 |  | -                            span.begin, span.end, note.text[span.begin : span.end], [label]  | 
219 |  | -                        )  | 
 | 251 | +        # Group up the highlights by parent label.  | 
 | 252 | +        # Then we'll see how many sublabels it has.  | 
 | 253 | +        grouped_highlights = {}  # key-tuple -> sublabel name -> sublabel value list  | 
 | 254 | +        for highlight in note.highlights:  | 
 | 255 | +            key = (highlight.label, highlight.span, highlight.origin)  | 
 | 256 | +            sublabels = grouped_highlights.setdefault(key, {})  | 
 | 257 | +            sublabels.setdefault(highlight.sublabel_name, []).append(highlight.sublabel_value)  | 
 | 258 | + | 
 | 259 | +        predictions = {}  # dict of origin -> prediction dict  | 
 | 260 | +        for key, sublabels in grouped_highlights.items():  | 
 | 261 | +            label, span, origin = key  | 
 | 262 | +            default_prediction = {"model_version": origin, "result": []}  | 
 | 263 | +            prediction = predictions.setdefault(origin, default_prediction)  | 
 | 264 | + | 
 | 265 | +            label_id = "__".join(str(k) for k in key)  | 
 | 266 | +            label_id = hashlib.md5(label_id.encode(), usedforsecurity=False).hexdigest()  | 
 | 267 | +            text = note.text[span[0] : span[1]]  | 
 | 268 | + | 
 | 269 | +            # First, add the parent label  | 
 | 270 | +            prediction["result"].append(  | 
 | 271 | +                self._format_match(span[0], span[1], text, [label], label_id=label_id)  | 
 | 272 | +            )  | 
 | 273 | + | 
 | 274 | +            # Now add sublabels  | 
 | 275 | +            for sublabel_name, sublabel_values in sublabels.items():  | 
 | 276 | +                if not sublabel_name:  | 
 | 277 | +                    continue  | 
 | 278 | +                prediction["result"].append(  | 
 | 279 | +                    self._format_match(  | 
 | 280 | +                        span[0],  | 
 | 281 | +                        span[1],  | 
 | 282 | +                        text,  | 
 | 283 | +                        sublabel_values,  | 
 | 284 | +                        label_id=label_id,  | 
 | 285 | +                        from_name=sublabel_name,  | 
220 | 286 |                     )  | 
221 |  | -            prediction["result"] = results  | 
222 |  | -            task["predictions"].append(prediction)  | 
 | 287 | +                )  | 
223 | 288 | 
 
  | 
224 |  | -            self._update_used_labels(task, labels.keys())  | 
 | 289 | +        task["predictions"].extend(predictions.values())  | 
 | 290 | +        self._update_used_labels(task, {x.label for x in note.highlights})  | 
225 | 291 | 
 
  | 
226 | 292 |     def _format_philter_predictions(self, task: dict, note: LabelStudioNote) -> None:  | 
227 | 293 |         """  | 
 | 
0 commit comments