|
1 | 1 | """Define tasks for the irae study""" |
2 | 2 |
|
| 3 | +import datetime |
| 4 | +import logging |
| 5 | +from collections.abc import Generator, Iterator |
3 | 6 | from enum import StrEnum |
4 | 7 |
|
| 8 | +import cumulus_fhir_support as cfs |
5 | 9 | from pydantic import BaseModel, Field |
6 | 10 |
|
7 | | -from cumulus_etl import nlp |
| 11 | +from cumulus_etl import common, nlp, store |
8 | 12 | from cumulus_etl.etl import tasks |
9 | 13 |
|
10 | 14 |
|
@@ -453,61 +457,124 @@ class BaseIraeTask(tasks.BaseModelTaskWithSpans): |
453 | 457 | ) |
454 | 458 |
|
455 | 459 |
|
456 | | -class IraeDonorGpt4oTask(BaseIraeTask): |
| 460 | +class BaseDonorIraeTask(BaseIraeTask): |
| 461 | + response_format = KidneyTransplantDonorGroupAnnotation |
| 462 | + |
| 463 | + |
| 464 | +class BaseLongitudinalIraeTask(BaseIraeTask): |
| 465 | + response_format = KidneyTransplantLongitudinalAnnotation |
| 466 | + |
| 467 | + def __init__(self, *args, **kwargs): |
| 468 | + super().__init__(*args, **kwargs) |
| 469 | + self.subject_refs_to_skip = set() |
| 470 | + |
| 471 | + @staticmethod |
| 472 | + def ndjson_in_order(input_root: store.Root, resource: str) -> Generator[dict]: |
| 473 | + # To avoid loading all the notes into memory, we'll first go through each note, and keep |
| 474 | + # track of their byte offset on disk and their date. Then we'll grab each from disk in |
| 475 | + # order. |
| 476 | + |
| 477 | + # Get a list of all files we're going to be working with here |
| 478 | + filenames = common.ls_resources(input_root, {resource}) |
| 479 | + |
| 480 | + # Go through all files, keeping a record of each line's dates and offsets. |
| 481 | + note_info = [] |
| 482 | + for file_index, path in enumerate(filenames): |
| 483 | + for row in cfs.read_multiline_json_with_details(path, fsspec_fs=input_root.fs): |
| 484 | + date = nlp.get_note_date(row["json"]) or datetime.datetime.max |
| 485 | + note_info.append((date, file_index, row["byte_offset"])) |
| 486 | + |
| 487 | + # Now yield each note again in order, reading each from disk |
| 488 | + note_info.sort() |
| 489 | + for _date, file_index, offset in note_info: |
| 490 | + rows = cfs.read_multiline_json_with_details( |
| 491 | + filenames[file_index], |
| 492 | + offset=offset, |
| 493 | + fsspec_fs=input_root.fs, |
| 494 | + ) |
| 495 | + # StopIteration errors shouldn't happen here, because we just went through these |
| 496 | + # files above, but just to be safe, we'll gracefully intercept it. |
| 497 | + try: |
| 498 | + yield next(rows)["json"] |
| 499 | + except StopIteration: # pragma: no cover |
| 500 | + logging.warning( |
| 501 | + f"File '{filenames[file_index]}' changed while reading, skipping some notes." |
| 502 | + ) |
| 503 | + continue |
| 504 | + |
| 505 | + # Override the read-from-disk portion, so we can order notes in oldest-to-newest order |
| 506 | + def read_ndjson_from_disk(self, input_root: store.Root, resource: str) -> Iterator[dict]: |
| 507 | + yield from self.ndjson_in_order(input_root, resource) |
| 508 | + |
| 509 | + def should_skip(self, orig_note: dict) -> bool: |
| 510 | + subject_ref = nlp.get_note_subject_ref(orig_note) |
| 511 | + return subject_ref in self.subject_refs_to_skip or super().should_skip(orig_note) |
| 512 | + |
| 513 | + def post_process(self, parsed: dict, orig_note_text: str, orig_note: dict) -> None: |
| 514 | + super().post_process(parsed, orig_note_text, orig_note) |
| 515 | + |
| 516 | + # If we have an annotation that asserts a graft failure or deceased, |
| 517 | + # we can stop processing charts for that patient, to avoid pointless NLP requests. |
| 518 | + |
| 519 | + graft_failure = parsed.get("graft_failure_mention", {}) |
| 520 | + is_failed = ( |
| 521 | + graft_failure.get("has_mention") |
| 522 | + and graft_failure.get("graft_failure") == GraftFailurePresent.CONFIRMED |
| 523 | + ) |
| 524 | + |
| 525 | + deceased = parsed.get("deceased_mention", {}) |
| 526 | + is_deceased = deceased.get("has_mention") and deceased.get("deceased") |
| 527 | + |
| 528 | + if is_failed or is_deceased: |
| 529 | + if subject_ref := nlp.get_note_subject_ref(orig_note): |
| 530 | + self.subject_refs_to_skip.add(subject_ref) |
| 531 | + |
| 532 | + |
| 533 | +class IraeDonorGpt4oTask(BaseDonorIraeTask): |
457 | 534 | name = "irae__nlp_donor_gpt4o" |
458 | 535 | client_class = nlp.Gpt4oModel |
459 | | - response_format = KidneyTransplantDonorGroupAnnotation |
460 | 536 |
|
461 | 537 |
|
462 | | -class IraeLongitudinalGpt4oTask(BaseIraeTask): |
| 538 | +class IraeLongitudinalGpt4oTask(BaseLongitudinalIraeTask): |
463 | 539 | name = "irae__nlp_gpt4o" |
464 | 540 | client_class = nlp.Gpt4oModel |
465 | | - response_format = KidneyTransplantLongitudinalAnnotation |
466 | 541 |
|
467 | 542 |
|
468 | | -class IraeDonorGpt5Task(BaseIraeTask): |
| 543 | +class IraeDonorGpt5Task(BaseDonorIraeTask): |
469 | 544 | name = "irae__nlp_donor_gpt5" |
470 | 545 | client_class = nlp.Gpt5Model |
471 | | - response_format = KidneyTransplantDonorGroupAnnotation |
472 | 546 |
|
473 | 547 |
|
474 | | -class IraeLongitudinalGpt5Task(BaseIraeTask): |
| 548 | +class IraeLongitudinalGpt5Task(BaseLongitudinalIraeTask): |
475 | 549 | name = "irae__nlp_gpt5" |
476 | 550 | client_class = nlp.Gpt5Model |
477 | | - response_format = KidneyTransplantLongitudinalAnnotation |
478 | 551 |
|
479 | 552 |
|
480 | | -class IraeDonorGptOss120bTask(BaseIraeTask): |
| 553 | +class IraeDonorGptOss120bTask(BaseDonorIraeTask): |
481 | 554 | name = "irae__nlp_donor_gpt_oss_120b" |
482 | 555 | client_class = nlp.GptOss120bModel |
483 | | - response_format = KidneyTransplantDonorGroupAnnotation |
484 | 556 |
|
485 | 557 |
|
486 | | -class IraeLongitudinalGptOss120bTask(BaseIraeTask): |
| 558 | +class IraeLongitudinalGptOss120bTask(BaseLongitudinalIraeTask): |
487 | 559 | name = "irae__nlp_gpt_oss_120b" |
488 | 560 | client_class = nlp.GptOss120bModel |
489 | | - response_format = KidneyTransplantLongitudinalAnnotation |
490 | 561 |
|
491 | 562 |
|
492 | | -class IraeDonorLlama4ScoutTask(BaseIraeTask): |
| 563 | +class IraeDonorLlama4ScoutTask(BaseDonorIraeTask): |
493 | 564 | name = "irae__nlp_donor_llama4_scout" |
494 | 565 | client_class = nlp.Llama4ScoutModel |
495 | | - response_format = KidneyTransplantDonorGroupAnnotation |
496 | 566 |
|
497 | 567 |
|
498 | | -class IraeLongitudinalLlama4ScoutTask(BaseIraeTask): |
| 568 | +class IraeLongitudinalLlama4ScoutTask(BaseLongitudinalIraeTask): |
499 | 569 | name = "irae__nlp_llama4_scout" |
500 | 570 | client_class = nlp.Llama4ScoutModel |
501 | | - response_format = KidneyTransplantLongitudinalAnnotation |
502 | 571 |
|
503 | 572 |
|
504 | | -class IraeDonorClaudeSonnet45Task(BaseIraeTask): |
| 573 | +class IraeDonorClaudeSonnet45Task(BaseDonorIraeTask): |
505 | 574 | name = "irae__nlp_donor_claude_sonnet45" |
506 | 575 | client_class = nlp.ClaudeSonnet45Model |
507 | | - response_format = KidneyTransplantDonorGroupAnnotation |
508 | 576 |
|
509 | 577 |
|
510 | | -class IraeLongitudinalClaudeSonnet45Task(BaseIraeTask): |
| 578 | +class IraeLongitudinalClaudeSonnet45Task(BaseLongitudinalIraeTask): |
511 | 579 | name = "irae__nlp_claude_sonnet45" |
512 | 580 | client_class = nlp.ClaudeSonnet45Model |
513 | | - response_format = KidneyTransplantLongitudinalAnnotation |
|
0 commit comments