From 0e613c879cac9a956ac62a1a2c36cf215a5c48ba Mon Sep 17 00:00:00 2001 From: Michael Terry Date: Wed, 22 Oct 2025 08:00:19 -0400 Subject: [PATCH] irae: split one task into two Adds irae__nlp_donor_* tasks, which check for donor characteristics. The old irae__nlp_* task names remain, but now only check for longitudinal characteristics (i.e. non-donor-characteristic ones). --- cumulus_etl/etl/studies/irae/__init__.py | 15 ++-- cumulus_etl/etl/studies/irae/irae_tasks.py | 90 ++++++++++++++++++---- cumulus_etl/etl/tasks/nlp_task.py | 17 ++-- cumulus_etl/etl/tasks/task_factory.py | 22 ++---- cumulus_etl/nlp/__init__.py | 1 + docs/setup/cumulus-aws-template.yaml | 9 +++ tests/data/irae/donor-output.ndjson | 1 + tests/data/irae/longitudinal-output.ndjson | 1 + tests/data/irae/output.ndjson | 1 - tests/nlp/test_irae.py | 81 +++++++++++++++---- tests/nlp/test_models.py | 45 +++++------ 11 files changed, 204 insertions(+), 79 deletions(-) create mode 100644 tests/data/irae/donor-output.ndjson create mode 100644 tests/data/irae/longitudinal-output.ndjson delete mode 100644 tests/data/irae/output.ndjson diff --git a/cumulus_etl/etl/studies/irae/__init__.py b/cumulus_etl/etl/studies/irae/__init__.py index f772e15a..c0a5d66f 100644 --- a/cumulus_etl/etl/studies/irae/__init__.py +++ b/cumulus_etl/etl/studies/irae/__init__.py @@ -1,7 +1,12 @@ """The irae study""" -from .irae_tasks import IraeClaudeSonnet45Task as IraeClaudeSonnet45Task -from .irae_tasks import IraeGpt4oTask as IraeGpt4oTask -from .irae_tasks import IraeGpt5Task as IraeGpt5Task -from .irae_tasks import IraeGptOss120bTask as IraeGptOss120bTask -from .irae_tasks import IraeLlama4ScoutTask as IraeLlama4ScoutTask +from .irae_tasks import IraeDonorClaudeSonnet45Task as IraeDonorClaudeSonnet45Task +from .irae_tasks import IraeDonorGpt4oTask as IraeDonorGpt4oTask +from .irae_tasks import IraeDonorGpt5Task as IraeDonorGpt5Task +from .irae_tasks import IraeDonorGptOss120bTask as IraeDonorGptOss120bTask +from .irae_tasks import IraeDonorLlama4ScoutTask as IraeDonorLlama4ScoutTask +from .irae_tasks import IraeLongitudinalClaudeSonnet45Task as IraeLongitudinalClaudeSonnet45Task +from .irae_tasks import IraeLongitudinalGpt4oTask as IraeLongitudinalGpt4oTask +from .irae_tasks import IraeLongitudinalGpt5Task as IraeLongitudinalGpt5Task +from .irae_tasks import IraeLongitudinalGptOss120bTask as IraeLongitudinalGptOss120bTask +from .irae_tasks import IraeLongitudinalLlama4ScoutTask as IraeLongitudinalLlama4ScoutTask diff --git a/cumulus_etl/etl/studies/irae/irae_tasks.py b/cumulus_etl/etl/studies/irae/irae_tasks.py index bfd3211e..ae5f2ff1 100644 --- a/cumulus_etl/etl/studies/irae/irae_tasks.py +++ b/cumulus_etl/etl/studies/irae/irae_tasks.py @@ -1,6 +1,5 @@ """Define tasks for the irae study""" -import json from enum import StrEnum from pydantic import BaseModel, Field @@ -10,7 +9,7 @@ class SpanAugmentedMention(BaseModel): - has_mention: bool | None # True, False, or None + has_mention: bool # True, False spans: list[str] @@ -23,7 +22,10 @@ class SpanAugmentedMention(BaseModel): # Dates are treated as strings - no enum needed class DonorTransplantDateMention(SpanAugmentedMention): - donor_transplant_date: str | None = Field(None, description="Date of renal transplant") + donor_transplant_date: str | None = Field( + None, + description="Exact date of renal transplant; use YYYY-MM-DD format in your response. Only highlight date mentions with an explicit day, month, and year (e.g. 2020-01-15). All other date mentions, or an absence of a date mention, should be indicated with None.", + ) class DonorType(StrEnum): @@ -51,9 +53,11 @@ class DonorRelationshipMention(SpanAugmentedMention): class DonorHlaMatchQuality(StrEnum): - WELL = "Well matched (0-1 mismatches)" - MODERATE = "Moderately matched (2-4 mismatches)" - POOR = "Poorly matched (5-6 mismatches)" + WELL = "Well matched (0-1 mismatches) OR recipient explicitly documented as not sensitized" + MODERATE = ( + "Moderately matched (2-4 mismatches) OR recipient explicitly documented as sensitized" + ) + POOR = "Poorly matched (5-6 mismatches) OR recipient explicitly documented as highly sensitized" NOT_MENTIONED = "HLA match quality not mentioned" @@ -354,7 +358,7 @@ class DeceasedMention(SpanAugmentedMention): deceased_date: str | None = Field( None, description=( - "If the patient is deceased, include the date the patient became deceased. " + "If the patient is deceased, include the date the patient became deceased. Use YYYY-MM-DD format if possible. " "Use None if there is no date recorded or if the patient is not observed as deceased." ), ) @@ -365,7 +369,7 @@ class DeceasedMention(SpanAugmentedMention): ############################################################################### -class KidneyTransplantAnnotation(BaseModel): +class KidneyTransplantDonorGroupAnnotation(BaseModel): """ An object-model for annotations of immune related adverse event (IRAE) observations found in a patient's chart, relating specifically to kidney @@ -381,6 +385,24 @@ class KidneyTransplantAnnotation(BaseModel): donor_relationship_mention: DonorRelationshipMention donor_hla_match_quality_mention: DonorHlaMatchQualityMention donor_hla_mismatch_count_mention: DonorHlaMismatchCountMention + + +class KidneyTransplantLongitudinalAnnotation(BaseModel): + """ + An object-model for annotations of immune related adverse event (IRAE) + observations found in a patient's chart, relating specifically to kidney + transplants. + + This class only includes longitudinally variable mentions, i.e. those + that can change over time, such as therapeutic status, compliance, infections, + graft rejection/failure, DSA, PTLD, cancer, and deceased status. + + Take care to avoid false positives, like confusing information that only + appears in family history for patient history. Annotations should indicate + the relevant details of the finding, as well as some additional evidence + metadata to validate findings post-hoc. + """ + rx_therapeutic_status_mention: RxTherapeuticStatusMention rx_compliance_mention: RxComplianceMention dsa_mention: DSAMention @@ -396,8 +418,9 @@ class KidneyTransplantAnnotation(BaseModel): class BaseIraeTask(tasks.BaseModelTaskWithSpans): - task_version = 3 + task_version = 4 # Task Version History: + # ** 4 (2025-10): Split into donor & longitudinal models ** # ** 3 (2025-10): New serialized format ** # ** 2 (2025-09): Updated prompt and pydantic models ** # ** 1 (2025-08): Updated prompt ** @@ -419,7 +442,8 @@ class BaseIraeTask(tasks.BaseModelTaskWithSpans): " BIOPSY_PROVEN > CONFIRMED > SUSPECTED > NONE_OF_THE_ABOVE.\n" "5. Always produce structured JSON that conforms to the Pydantic schema provided below.\n" "\n" - "Pydantic Schema:\n" + json.dumps(KidneyTransplantAnnotation.model_json_schema()) + "Pydantic Schema:\n" + "%JSON-SCHEMA%" ) user_prompt = ( "Evaluate the following clinical document for kidney transplant variables and outcomes.\n" @@ -427,29 +451,63 @@ class BaseIraeTask(tasks.BaseModelTaskWithSpans): "\n" "%CLINICAL-NOTE%" ) - response_format = KidneyTransplantAnnotation -class IraeGpt4oTask(BaseIraeTask): +class IraeDonorGpt4oTask(BaseIraeTask): + name = "irae__nlp_donor_gpt4o" + client_class = nlp.Gpt4oModel + response_format = KidneyTransplantDonorGroupAnnotation + + +class IraeLongitudinalGpt4oTask(BaseIraeTask): name = "irae__nlp_gpt4o" client_class = nlp.Gpt4oModel + response_format = KidneyTransplantLongitudinalAnnotation -class IraeGpt5Task(BaseIraeTask): +class IraeDonorGpt5Task(BaseIraeTask): + name = "irae__nlp_donor_gpt5" + client_class = nlp.Gpt5Model + response_format = KidneyTransplantDonorGroupAnnotation + + +class IraeLongitudinalGpt5Task(BaseIraeTask): name = "irae__nlp_gpt5" client_class = nlp.Gpt5Model + response_format = KidneyTransplantLongitudinalAnnotation + + +class IraeDonorGptOss120bTask(BaseIraeTask): + name = "irae__nlp_donor_gpt_oss_120b" + client_class = nlp.GptOss120bModel + response_format = KidneyTransplantDonorGroupAnnotation -class IraeGptOss120bTask(BaseIraeTask): +class IraeLongitudinalGptOss120bTask(BaseIraeTask): name = "irae__nlp_gpt_oss_120b" client_class = nlp.GptOss120bModel + response_format = KidneyTransplantLongitudinalAnnotation + + +class IraeDonorLlama4ScoutTask(BaseIraeTask): + name = "irae__nlp_donor_llama4_scout" + client_class = nlp.Llama4ScoutModel + response_format = KidneyTransplantDonorGroupAnnotation -class IraeLlama4ScoutTask(BaseIraeTask): +class IraeLongitudinalLlama4ScoutTask(BaseIraeTask): name = "irae__nlp_llama4_scout" client_class = nlp.Llama4ScoutModel + response_format = KidneyTransplantLongitudinalAnnotation + + +class IraeDonorClaudeSonnet45Task(BaseIraeTask): + name = "irae__nlp_donor_claude_sonnet45" + client_class = nlp.ClaudeSonnet45Model + response_format = KidneyTransplantDonorGroupAnnotation -class IraeClaudeSonnet45Task(BaseIraeTask): +class IraeLongitudinalClaudeSonnet45Task(BaseIraeTask): name = "irae__nlp_claude_sonnet45" client_class = nlp.ClaudeSonnet45Model + response_format = KidneyTransplantLongitudinalAnnotation diff --git a/cumulus_etl/etl/tasks/nlp_task.py b/cumulus_etl/etl/tasks/nlp_task.py index aa1fdd1e..35d59629 100644 --- a/cumulus_etl/etl/tasks/nlp_task.py +++ b/cumulus_etl/etl/tasks/nlp_task.py @@ -1,6 +1,7 @@ """Base NLP task support""" import copy +import json import logging import os import re @@ -126,10 +127,10 @@ class BaseModelTask(BaseNlpTask): outputs: ClassVar = [tasks.OutputTable(resource_type=None, uniqueness_fields={"note_ref"})] # If you change these prompts, consider updating task_version. - system_prompt: ClassVar = None - user_prompt: ClassVar = None - client_class: ClassVar = None - response_format: ClassVar = None + system_prompt: str = None + user_prompt: str = None + client_class: type[nlp.Model] = None + response_format: type[pydantic.BaseModel] = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -153,7 +154,7 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task try: response = await self.model.prompt( - self.system_prompt, + self.get_system_prompt(), self.get_user_prompt(note_text), schema=self.response_format, cache_dir=self.task_config.dir_phi, @@ -206,6 +207,12 @@ def finish_task(self) -> None: rich.get_console().print(table) + @classmethod + def get_system_prompt(cls) -> str: + return cls.system_prompt.replace( + "%JSON-SCHEMA%", json.dumps(cls.response_format.model_json_schema()) + ) + @classmethod def get_user_prompt(cls, note_text: str) -> str: prompt = cls.user_prompt or "%CLINICAL-NOTE%" diff --git a/cumulus_etl/etl/tasks/task_factory.py b/cumulus_etl/etl/tasks/task_factory.py index b236acf7..b86cb351 100644 --- a/cumulus_etl/etl/tasks/task_factory.py +++ b/cumulus_etl/etl/tasks/task_factory.py @@ -1,5 +1,6 @@ """Finds and creates ETL tasks""" +import inspect import sys from collections.abc import Iterable from typing import TypeVar @@ -25,22 +26,15 @@ def get_all_tasks() -> list[type[AnyTask]]: ] +def get_classes_from_module(module) -> list[type[AnyTask]]: + return [x[1] for x in inspect.getmembers(module, inspect.isclass)] + + def get_nlp_tasks() -> list[type[AnyTask]]: return [ - covid_symptom.CovidSymptomNlpResultsGpt35Task, - covid_symptom.CovidSymptomNlpResultsGpt4Task, - covid_symptom.CovidSymptomNlpResultsTask, - covid_symptom.CovidSymptomNlpResultsTermExistsTask, - example.ExampleGpt4Task, - example.ExampleGpt4oTask, - example.ExampleGpt5Task, - example.ExampleGptOss120bTask, - example.ExampleLlama4ScoutTask, - irae.IraeClaudeSonnet45Task, - irae.IraeGptOss120bTask, - irae.IraeGpt4oTask, - irae.IraeGpt5Task, - irae.IraeLlama4ScoutTask, + *get_classes_from_module(covid_symptom), + *get_classes_from_module(example), + *get_classes_from_module(irae), ] diff --git a/cumulus_etl/nlp/__init__.py b/cumulus_etl/nlp/__init__.py index 8d49748c..011f750d 100644 --- a/cumulus_etl/nlp/__init__.py +++ b/cumulus_etl/nlp/__init__.py @@ -9,6 +9,7 @@ Gpt35Model, GptOss120bModel, Llama4ScoutModel, + Model, TokenStats, set_nlp_provider, ) diff --git a/docs/setup/cumulus-aws-template.yaml b/docs/setup/cumulus-aws-template.yaml index a8dff908..046ffc42 100644 --- a/docs/setup/cumulus-aws-template.yaml +++ b/docs/setup/cumulus-aws-template.yaml @@ -209,6 +209,15 @@ Resources: - !Sub "s3://${S3Bucket}/${EtlSubdir}/covid_symptom__nlp_results_gpt35" - !Sub "s3://${S3Bucket}/${EtlSubdir}/covid_symptom__nlp_results_gpt4" - !Sub "s3://${S3Bucket}/${EtlSubdir}/covid_symptom__nlp_results_term_exists" + CreateNativeDeltaTable: True + WriteManifest: False + - DeltaTables: + - !Sub "s3://${S3Bucket}/${EtlSubdir}/irae__nlp_claude_sonnet45" + - !Sub "s3://${S3Bucket}/${EtlSubdir}/irae__nlp_donor_claude_sonnet45" + - !Sub "s3://${S3Bucket}/${EtlSubdir}/irae__nlp_donor_gpt4o" + - !Sub "s3://${S3Bucket}/${EtlSubdir}/irae__nlp_donor_gpt5" + - !Sub "s3://${S3Bucket}/${EtlSubdir}/irae__nlp_donor_gpt_oss_120b" + - !Sub "s3://${S3Bucket}/${EtlSubdir}/irae__nlp_donor_llama4_scout" - !Sub "s3://${S3Bucket}/${EtlSubdir}/irae__nlp_gpt4o" - !Sub "s3://${S3Bucket}/${EtlSubdir}/irae__nlp_gpt5" - !Sub "s3://${S3Bucket}/${EtlSubdir}/irae__nlp_gpt_oss_120b" diff --git a/tests/data/irae/donor-output.ndjson b/tests/data/irae/donor-output.ndjson new file mode 100644 index 00000000..faadb11e --- /dev/null +++ b/tests/data/irae/donor-output.ndjson @@ -0,0 +1 @@ +{"note_ref": "DocumentReference/c31a3dbf188ed241b2c06b2475cd56159017fa1df1ea882d3fc4beab860fc24d", "encounter_ref": "Encounter/b3d0707624491d8b71a808bd20b63625981af48f526b95214146de2a15f7dd43", "subject_ref": "Patient/00680c7c0e2e1712e9c4a01eb5c6dfb8949871faef6337c5db204d19e1d9ca58", "generated_on": "2021-09-14T21:23:45+00:00", "task_version": 4, "system_fingerprint": "test-fp", "result": {"donor_transplant_date_mention": {"has_mention": false, "spans": []}, "donor_type_mention": {"has_mention": false, "spans": [], "donor_type": "Donor was not mentioned as living or deceased"}, "donor_relationship_mention": {"has_mention": false, "spans": [], "donor_relationship": "Donor relationship status was not mentioned"}, "donor_hla_match_quality_mention": {"has_mention": false, "spans": [], "donor_hla_match_quality": "HLA match quality not mentioned"}, "donor_hla_mismatch_count_mention": {"has_mention": false, "spans": [], "donor_hla_mismatch_count": "HLA mismatch count not mentioned"}}} \ No newline at end of file diff --git a/tests/data/irae/longitudinal-output.ndjson b/tests/data/irae/longitudinal-output.ndjson new file mode 100644 index 00000000..21313286 --- /dev/null +++ b/tests/data/irae/longitudinal-output.ndjson @@ -0,0 +1 @@ +{"note_ref": "DocumentReference/c31a3dbf188ed241b2c06b2475cd56159017fa1df1ea882d3fc4beab860fc24d", "encounter_ref": "Encounter/b3d0707624491d8b71a808bd20b63625981af48f526b95214146de2a15f7dd43", "subject_ref": "Patient/00680c7c0e2e1712e9c4a01eb5c6dfb8949871faef6337c5db204d19e1d9ca58", "generated_on": "2021-09-14T21:23:45+00:00", "task_version": 4, "system_fingerprint": "test-fp", "result": {"rx_therapeutic_status_mention": {"has_mention": false, "spans": [], "rx_therapeutic_status": "None of the above"}, "rx_compliance_mention": {"has_mention": false, "spans": [], "rx_compliance": "None of the above"}, "dsa_mention": {"has_mention": false, "spans": [], "dsa_history": false, "dsa": "None of the above"}, "infection_mention": {"has_mention": false, "spans": [], "infection_history": false, "infection": "None of the above"}, "viral_infection_mention": {"has_mention": false, "spans": [], "viral_infection_history": false, "viral_infection": "None of the above"}, "bacterial_infection_mention": {"has_mention": false, "spans": [], "bacterial_infection_history": false, "bacterial_infection": "None of the above"}, "fungal_infection_mention": {"has_mention": false, "spans": [], "fungal_infection_history": false, "fungal_infection": "None of the above"}, "graft_rejection_mention": {"has_mention": false, "spans": [], "graft_rejection_history": false, "graft_rejection": "None of the above"}, "graft_failure_mention": {"has_mention": false, "spans": [], "graft_failure_history": false, "graft_failure": "None of the above"}, "ptld_mention": {"has_mention": false, "spans": [], "ptld_history": false, "ptld": "None of the above"}, "cancer_mention": {"has_mention": false, "spans": [], "cancer_history": false, "cancer": "None of the above"}, "deceased_mention": {"has_mention": true, "spans": [[5, 9]], "deceased": true, "deceased_date": "2025-10-10"}}} \ No newline at end of file diff --git a/tests/data/irae/output.ndjson b/tests/data/irae/output.ndjson deleted file mode 100644 index 0a3ac68f..00000000 --- a/tests/data/irae/output.ndjson +++ /dev/null @@ -1 +0,0 @@ -{"note_ref": "DocumentReference/c31a3dbf188ed241b2c06b2475cd56159017fa1df1ea882d3fc4beab860fc24d", "encounter_ref": "Encounter/b3d0707624491d8b71a808bd20b63625981af48f526b95214146de2a15f7dd43", "subject_ref": "Patient/00680c7c0e2e1712e9c4a01eb5c6dfb8949871faef6337c5db204d19e1d9ca58", "generated_on": "2021-09-14T21:23:45+00:00", "task_version": 3, "system_fingerprint": "test-fp", "result": {"donor_transplant_date_mention": {"has_mention": false, "spans": []}, "donor_type_mention": {"has_mention": false, "spans": [], "donor_type": "Donor was not mentioned as living or deceased"}, "donor_relationship_mention": {"has_mention": false, "spans": [], "donor_relationship": "Donor relationship status was not mentioned"}, "donor_hla_match_quality_mention": {"has_mention": false, "spans": [], "donor_hla_match_quality": "HLA match quality not mentioned"}, "donor_hla_mismatch_count_mention": {"has_mention": false, "spans": [], "donor_hla_mismatch_count": "HLA mismatch count not mentioned"}, "rx_therapeutic_status_mention": {"has_mention": false, "spans": [], "rx_therapeutic_status": "None of the above"}, "rx_compliance_mention": {"has_mention": false, "spans": [], "rx_compliance": "None of the above"}, "dsa_mention": {"has_mention": false, "spans": [], "dsa_history": false, "dsa": "None of the above"}, "infection_mention": {"has_mention": false, "spans": [], "infection_history": false, "infection": "None of the above"}, "viral_infection_mention": {"has_mention": false, "spans": [], "viral_infection_history": false, "viral_infection": "None of the above"}, "bacterial_infection_mention": {"has_mention": false, "spans": [], "bacterial_infection_history": false, "bacterial_infection": "None of the above"}, "fungal_infection_mention": {"has_mention": false, "spans": [], "fungal_infection_history": false, "fungal_infection": "None of the above"}, "graft_rejection_mention": {"has_mention": false, "spans": [], "graft_rejection_history": false, "graft_rejection": "None of the above"}, "graft_failure_mention": {"has_mention": false, "spans": [], "graft_failure_history": false, "graft_failure": "None of the above"}, "ptld_mention": {"has_mention": false, "spans": [], "ptld_history": false, "ptld": "None of the above"}, "cancer_mention": {"has_mention": false, "spans": [], "cancer_history": false, "cancer": "None of the above"}, "deceased_mention": {"has_mention": true, "spans": [[5, 9]], "deceased": true, "deceased_date": "2025-10-10"}}} \ No newline at end of file diff --git a/tests/nlp/test_irae.py b/tests/nlp/test_irae.py index a2d9f23e..1c679681 100644 --- a/tests/nlp/test_irae.py +++ b/tests/nlp/test_irae.py @@ -4,7 +4,10 @@ import ddt -from cumulus_etl.etl.studies.irae.irae_tasks import KidneyTransplantAnnotation +from cumulus_etl.etl.studies.irae.irae_tasks import ( + KidneyTransplantDonorGroupAnnotation, + KidneyTransplantLongitudinalAnnotation, +) from tests.etl import BaseEtlSimple from tests.nlp.utils import NlpModelTestCase @@ -16,22 +19,28 @@ class TestIraeTask(NlpModelTestCase, BaseEtlSimple): DATA_ROOT = "irae" @ddt.data( - ("irae__nlp_gpt_oss_120b", "gpt-oss-120b"), - ("irae__nlp_gpt4o", "gpt-4o"), - ("irae__nlp_gpt5", "gpt-5"), - ("irae__nlp_llama4_scout", "Llama-4-Scout-17B-16E-Instruct"), + ("gpt_oss_120b", "gpt-oss-120b"), + ("gpt4o", "gpt-4o"), + ("gpt5", "gpt-5"), + ("llama4_scout", "Llama-4-Scout-17B-16E-Instruct"), ) @ddt.unpack - async def test_basic_etl(self, task_name, model_id): + async def test_basic_etl(self, model_slug, model_id): self.mock_azure(model_id) self.mock_response( - content=KidneyTransplantAnnotation.model_validate( + content=KidneyTransplantDonorGroupAnnotation.model_validate( { "donor_transplant_date_mention": {"has_mention": False, "spans": []}, "donor_type_mention": {"has_mention": False, "spans": []}, "donor_relationship_mention": {"has_mention": False, "spans": []}, "donor_hla_match_quality_mention": {"has_mention": False, "spans": []}, "donor_hla_mismatch_count_mention": {"has_mention": False, "spans": []}, + } + ) + ) + self.mock_response( + content=KidneyTransplantLongitudinalAnnotation.model_validate( + { "rx_therapeutic_status_mention": {"has_mention": False, "spans": []}, "rx_compliance_mention": {"has_mention": False, "spans": []}, "dsa_mention": {"has_mention": False, "spans": []}, @@ -54,14 +63,21 @@ async def test_basic_etl(self, task_name, model_id): ) ) - await self.run_etl("--provider=azure", tasks=[task_name]) + donor_task_name = f"irae__nlp_donor_{model_slug}" + longitudinal_task_name = f"irae__nlp_{model_slug}" + + await self.run_etl("--provider=azure", tasks=[donor_task_name, longitudinal_task_name]) self.assert_files_equal( - f"{self.root_path}/output.ndjson", - f"{self.output_path}/{task_name}/{task_name}.000.ndjson", + f"{self.root_path}/donor-output.ndjson", + f"{self.output_path}/{donor_task_name}/{donor_task_name}.000.ndjson", + ) + self.assert_files_equal( + f"{self.root_path}/longitudinal-output.ndjson", + f"{self.output_path}/{longitudinal_task_name}/{longitudinal_task_name}.000.ndjson", ) - self.assertEqual(self.mock_create.call_count, 1) + self.assertEqual(self.mock_create.call_count, 2) self.assertEqual( { "messages": [ @@ -83,7 +99,7 @@ async def test_basic_etl(self, task_name, model_id): "5. Always produce structured JSON that conforms to the Pydantic schema provided below.\n" "\n" "Pydantic Schema:\n" - + json.dumps(KidneyTransplantAnnotation.model_json_schema()), + + json.dumps(KidneyTransplantDonorGroupAnnotation.model_json_schema()), }, { "role": "user", @@ -97,7 +113,46 @@ async def test_basic_etl(self, task_name, model_id): "seed": 12345, "temperature": 0, "timeout": 120, - "response_format": KidneyTransplantAnnotation, + "response_format": KidneyTransplantDonorGroupAnnotation, }, self.mock_create.call_args_list[0][1], ) + self.assertEqual( + { + "messages": [ + { + "role": "system", + "content": "You are a clinical chart reviewer for a kidney transplant outcomes study.\n" + "Your task is to extract patient-specific information from an unstructured clinical " + "document and map it into a predefined Pydantic schema.\n" + "\n" + "Core Rules:\n" + "1. Base all assertions ONLY on patient-specific information in the clinical document.\n" + " - Never negate or exclude information just because it is not mentioned.\n" + " - Never conflate family history or population-level risk with patient findings.\n" + " - Do not count past medical history, prior episodes, or family history.\n" + "2. Do not invent or infer facts beyond what is documented.\n" + "3. Maintain high fidelity to the clinical document language when citing spans.\n" + "4. Answer patient outcomes with strongest available documented evidence:\n" + " BIOPSY_PROVEN > CONFIRMED > SUSPECTED > NONE_OF_THE_ABOVE.\n" + "5. Always produce structured JSON that conforms to the Pydantic schema provided below.\n" + "\n" + "Pydantic Schema:\n" + + json.dumps(KidneyTransplantLongitudinalAnnotation.model_json_schema()), + }, + { + "role": "user", + "content": "Evaluate the following clinical document for kidney " + "transplant variables and outcomes.\n" + "Here is the clinical document for you to analyze:\n\n" + "Test note 1", + }, + ], + "model": model_id, + "seed": 12345, + "temperature": 0, + "timeout": 120, + "response_format": KidneyTransplantLongitudinalAnnotation, + }, + self.mock_create.call_args_list[1][1], + ) diff --git a/tests/nlp/test_models.py b/tests/nlp/test_models.py index 2e72712a..71e2777b 100644 --- a/tests/nlp/test_models.py +++ b/tests/nlp/test_models.py @@ -16,7 +16,7 @@ from cumulus_etl import common, errors, nlp from cumulus_etl.etl.studies import covid_symptom, irae -from cumulus_etl.etl.studies.irae.irae_tasks import KidneyTransplantAnnotation +from cumulus_etl.etl.studies.irae.irae_tasks import KidneyTransplantLongitudinalAnnotation from tests import i2b2_mock_data from tests.nlp.utils import NlpModelTestCase @@ -27,13 +27,8 @@ class TestWithSpansNLPTasks(NlpModelTestCase): MODEL_ID = "openai/gpt-oss-120b" - def default_kidney(self, **kwargs) -> KidneyTransplantAnnotation: + def default_kidney(self, **kwargs) -> KidneyTransplantLongitudinalAnnotation: model_dict = { - "donor_transplant_date_mention": {"has_mention": False, "spans": []}, - "donor_type_mention": {"has_mention": False, "spans": []}, - "donor_relationship_mention": {"has_mention": False, "spans": []}, - "donor_hla_match_quality_mention": {"has_mention": False, "spans": []}, - "donor_hla_mismatch_count_mention": {"has_mention": False, "spans": []}, "rx_therapeutic_status_mention": {"has_mention": False, "spans": []}, "rx_compliance_mention": {"has_mention": False, "spans": []}, "dsa_mention": {"has_mention": False, "spans": []}, @@ -48,7 +43,7 @@ def default_kidney(self, **kwargs) -> KidneyTransplantAnnotation: "deceased_mention": {"has_mention": False, "spans": []}, } model_dict.update(kwargs) - return KidneyTransplantAnnotation.model_validate(model_dict) + return KidneyTransplantLongitudinalAnnotation.model_validate(model_dict) def default_content(self) -> pydantic.BaseModel: return self.default_kidney() @@ -60,7 +55,7 @@ def prep_docs(self, docref: dict | None = None): self.make_json("DocumentReference", "2", **i2b2_mock_data.documentreference("bar")) async def assert_failed_doc(self, msg: str): - task = irae.IraeGptOss120bTask(self.job_config, self.scrubber) + task = irae.IraeLongitudinalGptOss120bTask(self.job_config, self.scrubber) with self.assertLogs(level="WARN") as cm: await task.run() @@ -115,10 +110,10 @@ async def test_caching(self): self.assertFalse(os.path.exists(f"{self.phi_dir}/nlp-cache")) self.mock_response() - await irae.IraeGptOss120bTask(self.job_config, self.scrubber).run() + await irae.IraeLongitudinalGptOss120bTask(self.job_config, self.scrubber).run() self.assertEqual(self.mock_create.call_count, 1) - cache_dir = f"{self.phi_dir}/nlp-cache/irae__nlp_gpt_oss_120b_v3/06ee" + cache_dir = f"{self.phi_dir}/nlp-cache/irae__nlp_gpt_oss_120b_v4/06ee" cache_file = f"{cache_dir}/sha256-06ee538c626fbf4bdcec2199b7225c8034f26e2b46a7b5cb7ab385c8e8c00efa.cache" self.assertEqual( common.read_json(cache_file), @@ -128,13 +123,13 @@ async def test_caching(self): }, ) - await irae.IraeGptOss120bTask(self.job_config, self.scrubber).run() + await irae.IraeLongitudinalGptOss120bTask(self.job_config, self.scrubber).run() self.assertEqual(self.mock_create.call_count, 1) # Confirm that if we remove the cache file, we call the endpoint again self.mock_response() os.remove(cache_file) - await irae.IraeGptOss120bTask(self.job_config, self.scrubber).run() + await irae.IraeLongitudinalGptOss120bTask(self.job_config, self.scrubber).run() self.assertEqual(self.mock_create.call_count, 2) async def test_init_check_unreachable(self): @@ -142,28 +137,28 @@ async def test_init_check_unreachable(self): self.mock_client.models.list = self.mock_model_list(error=True) with self.assertRaises(SystemExit) as cm: - await irae.IraeGptOss120bTask.init_check() + await irae.IraeLongitudinalGptOss120bTask.init_check() self.assertEqual(errors.SERVICE_MISSING, cm.exception.code) async def test_init_check_config(self): """Verify we check the server properties""" # Happy path - await irae.IraeGptOss120bTask.init_check() + await irae.IraeLongitudinalGptOss120bTask.init_check() # Random error bubbles up self.mock_client.models.list = mock.MagicMock(side_effect=SystemExit) with self.assertRaises(SystemExit): - await irae.IraeGptOss120bTask.init_check() + await irae.IraeLongitudinalGptOss120bTask.init_check() # Bad model ID self.mock_client.models.list = self.mock_model_list("bogus-model") with self.assert_fatal_exit(errors.SERVICE_MISSING): - await irae.IraeGptOss120bTask.init_check() + await irae.IraeLongitudinalGptOss120bTask.init_check() async def test_output_fields(self): self.make_json("DocumentReference", "1", **i2b2_mock_data.documentreference("foo")) self.mock_response() - await irae.IraeGptOss120bTask(self.job_config, self.scrubber).run() + await irae.IraeLongitudinalGptOss120bTask(self.job_config, self.scrubber).run() self.assertEqual(self.format.write_records.call_count, 1) batch = self.format.write_records.call_args[0][0] @@ -180,7 +175,7 @@ async def test_output_fields(self): "6beb306dc5b91513f353ecdb6aaedee8a9864b3a2f20d91f0d5b27510152acf2", "generated_on": "2021-09-14T21:23:45+00:00", "system_fingerprint": "test-fp", - "task_version": 3, + "task_version": 4, }, ) @@ -191,7 +186,7 @@ async def test_trailing_whitespace_removed(self): **i2b2_mock_data.documentreference("Test \n lines "), ) self.mock_response() - await irae.IraeGptOss120bTask(self.job_config, self.scrubber).run() + await irae.IraeLongitudinalGptOss120bTask(self.job_config, self.scrubber).run() self.assertEqual(self.mock_create.call_count, 1) kwargs = self.mock_create.call_args.kwargs @@ -219,7 +214,7 @@ async def test_span_conversion(self): } ) ) - await irae.IraeGptOss120bTask(self.job_config, self.scrubber).run() + await irae.IraeLongitudinalGptOss120bTask(self.job_config, self.scrubber).run() self.assertEqual(self.format.write_records.call_count, 1) batch = self.format.write_records.call_args[0][0] @@ -229,7 +224,7 @@ async def test_span_conversion(self): ) async def test_span_conversion_in_schema(self): - schema = irae.IraeGptOss120bTask.get_schema(None, []) + schema = irae.IraeLongitudinalGptOss120bTask.get_schema(None, []) result_index = schema.get_field_index("result") result_type = schema.field(result_index).type dsa_index = result_type.get_field_index("dsa_mention") # spot check one of the structs @@ -277,7 +272,7 @@ async def test_bedrock_text_parsing(self, text): content_text = content.model_dump_json() self.mock_response(content=text.replace("%CONTENT%", content_text)) - await irae.IraeClaudeSonnet45Task(self.job_config, self.scrubber).run() + await irae.IraeLongitudinalClaudeSonnet45Task(self.job_config, self.scrubber).run() self.assertEqual(self.format.write_records.call_count, 1) batch = self.format.write_records.call_args[0][0] @@ -299,7 +294,7 @@ async def test_bedrock_tool_use_parsing(self): # Also test that we handle Claude's injected parameter parent self.mock_response(content={"parameter": self.default_kidney().model_dump()}) - await irae.IraeGptOss120bTask(self.job_config, self.scrubber).run() + await irae.IraeLongitudinalGptOss120bTask(self.job_config, self.scrubber).run() self.assertEqual(self.format.write_records.call_count, 1) batch = self.format.write_records.call_args[0][0] @@ -340,7 +335,7 @@ async def test_usage_recorded(self, provider): self.mock_response(usage=(1, 2, 4, 8)) self.mock_response(usage=(16, 32, 64, 128)) - task = irae.IraeGptOss120bTask(self.job_config, self.scrubber) + task = irae.IraeLongitudinalGptOss120bTask(self.job_config, self.scrubber) await task.run() self.assertEqual(