Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cumulus_etl/etl/tasks/nlp_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ class BaseOpenAiTask(BaseNlpTask):

@classmethod
async def init_check(cls) -> None:
await cls.client_class.pre_init_check()
await cls.client_class().post_init_check()

async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator:
Expand Down
135 changes: 63 additions & 72 deletions cumulus_etl/nlp/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Abstraction layer for Hugging Face's inference API"""

import abc
import os
from collections.abc import Iterable

import openai
from openai.types import chat
Expand All @@ -10,37 +10,62 @@
from cumulus_etl import errors


class OpenAIModel(abc.ABC):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm changing away from an abstract class that is inherited by specific-use models like AzureModel and LocalModel. And switching that into a class that looks for all supported methods (azure, bedrock, local) and finds whichever mode is configured. This better supports stuff like llama4, which can do any of the three modes depending on config.

USER_ID = None # name in compose file or brand name
MODEL_NAME = None # which model to request via the API
class OpenAIModel:
AZURE_ID = None # model name in MS Azure
BEDROCK_ID = None # model name in AWS Bedrock
COMPOSE_ID = None # docker service name in compose.yaml
VLLM_INFO = None # tuple of vLLM model name, env var stem use for URL, plus default port

@abc.abstractmethod
def make_client(self) -> openai.AsyncOpenAI:
"""Creates an NLP client"""
AZURE_ENV = ("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")
BEDROCK_ENV = ("BEDROCK_OPENAI_API_KEY", "BEDROCK_OPENAI_ENDPOINT")

@staticmethod
def _env_defined(env_keys: Iterable[str]) -> bool:
return all(os.environ.get(key) for key in env_keys)

def __init__(self):
self.client = self.make_client()
self.is_vllm = False

# override to add your own checks
@classmethod
async def pre_init_check(cls) -> None:
pass
if self.AZURE_ID and self._env_defined(self.AZURE_ENV):
self.model_name = self.AZURE_ID
self.client = openai.AsyncAzureOpenAI(api_version="2024-10-21")

elif self.BEDROCK_ID and self._env_defined(self.BEDROCK_ENV):
self.model_name = self.BEDROCK_ID
self.client = openai.AsyncOpenAI(
base_url=os.environ["BEDROCK_OPENAI_ENDPOINT"],
api_key=os.environ["BEDROCK_OPENAI_API_KEY"],
)

elif self.COMPOSE_ID:
self.model_name = self.VLLM_INFO[0]
url = os.environ.get(f"CUMULUS_{self.VLLM_INFO[1]}_URL") # set by compose.yaml
url = url or f"http://localhost:{self.VLLM_INFO[2]}/v1" # offer non-docker fallback
self.client = openai.AsyncOpenAI(base_url=url, api_key="")
self.is_vllm = True

else:
errors.fatal(
"Missing Azure or Bedrock environment variables. "
"Set AZURE_OPENAI_API_KEY & AZURE_OPENAI_ENDPOINT or "
"BEDROCK_OPENAI_API_KEY & BEDROCK_OPENAI_ENDPOINT.",
errors.ARGS_INVALID,
)

# override to add your own checks
async def post_init_check(self) -> None:
try:
models = self.client.models.list()
names = {model.id async for model in models}
except openai.APIError as exc:
errors.fatal(
f"NLP server '{self.USER_ID}' is unreachable: {exc}.\n"
f"If it's a local server, try running 'docker compose up {self.USER_ID} --wait'.",
errors.SERVICE_MISSING,
)
message = f"NLP server is unreachable: {exc}."
if self.is_vllm:
message += f"\nTry running 'docker compose up {self.COMPOSE_ID} --wait'."
errors.fatal(message, errors.SERVICE_MISSING)

if self.MODEL_NAME not in names:
if self.model_name not in names:
errors.fatal(
f"NLP server '{self.USER_ID}' is using an unexpected model setup.",
f"NLP server does not have model ID '{self.model_name}'.",
errors.SERVICE_MISSING,
)

Expand All @@ -49,7 +74,7 @@ async def prompt(self, system: str, user: str, schema: BaseModel) -> chat.Parsed

async def _parse_prompt(self, system: str, user: str, schema) -> chat.ParsedChatCompletion:
return await self.client.chat.completions.parse(
model=self.MODEL_NAME,
model=self.model_name,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
Expand All @@ -61,28 +86,8 @@ async def _parse_prompt(self, system: str, user: str, schema) -> chat.ParsedChat
)


class AzureModel(OpenAIModel):
USER_ID = "Azure"

@classmethod
async def pre_init_check(cls) -> None:
await super().pre_init_check()

messages = []
if not os.environ.get("AZURE_OPENAI_API_KEY"):
messages.append("The AZURE_OPENAI_API_KEY environment variable is not set.")
if not os.environ.get("AZURE_OPENAI_ENDPOINT"):
messages.append("The AZURE_OPENAI_ENDPOINT environment variable is not set.")

if messages:
errors.fatal("\n".join(messages), errors.ARGS_INVALID)

def make_client(self) -> openai.AsyncOpenAI:
return openai.AsyncAzureOpenAI(api_version="2024-10-21")


class Gpt35Model(AzureModel): # deprecated, do not use in new code (doesn't support JSON schemas)
MODEL_NAME = "gpt-35-turbo-0125"
class Gpt35Model(OpenAIModel): # deprecated, do not use in new code (doesn't support JSON schemas)
AZURE_ID = "gpt-35-turbo-0125"

# 3.5 doesn't support a pydantic JSON schema, so we do some work to keep it using the same API
# as the rest of our code.
Expand All @@ -93,41 +98,27 @@ async def prompt(self, system: str, user: str, schema: BaseModel) -> chat.Parsed
return response


class Gpt4Model(AzureModel):
MODEL_NAME = "gpt-4"


class Gpt4oModel(AzureModel):
MODEL_NAME = "gpt-4o"


class Gpt5Model(AzureModel):
MODEL_NAME = "gpt-5"

class Gpt4Model(OpenAIModel):
AZURE_ID = "gpt-4"

class LocalModel(OpenAIModel, abc.ABC):
@property
@abc.abstractmethod
def url(self) -> str:
"""The OpenAI compatible URL to talk to (where's the server?)"""

def make_client(self) -> openai.AsyncOpenAI:
return openai.AsyncOpenAI(base_url=self.url, api_key="")
class Gpt4oModel(OpenAIModel):
AZURE_ID = "gpt-4o"


class GptOss120bModel(LocalModel):
USER_ID = "gpt-oss-120b"
MODEL_NAME = "openai/gpt-oss-120b"
class Gpt5Model(OpenAIModel):
AZURE_ID = "gpt-5"

@property
def url(self) -> str:
return os.environ.get("CUMULUS_GPT_OSS_120B_URL") or "http://localhost:8086/v1"

class GptOss120bModel(OpenAIModel):
AZURE_ID = "gpt-oss-120b"
BEDROCK_ID = "openai.gpt-oss-120b-1:0"
COMPOSE_ID = "gpt-oss-120b"
VLLM_INFO = ("openai/gpt-oss-120b", "GPT_OSS_120B", 8086)

class Llama4ScoutModel(LocalModel):
USER_ID = "llama4-scout"
MODEL_NAME = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"

@property
def url(self) -> str:
return os.environ.get("CUMULUS_LLAMA4_SCOUT_URL") or "http://localhost:8087/v1"
class Llama4ScoutModel(OpenAIModel):
AZURE_ID = "Llama-4-Scout-17B-16E-Instruct"
BEDROCK_ID = "meta.llama4-scout-17b-instruct-v1:0"
COMPOSE_ID = "llama4-scout"
VLLM_INFO = ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", "LLAMA4_SCOUT", 8087)
2 changes: 2 additions & 0 deletions tests/covid_symptom/test_covid_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def test_gpt4_changes(self):
"""
self.make_json("DocumentReference", "doc", **i2b2_mock_data.documentreference())
self.mock_response()
self.mock_azure()

task = covid_symptom.CovidSymptomNlpResultsGpt4Task(self.job_config, self.scrubber)
await task.run()
Expand All @@ -54,6 +55,7 @@ async def test_gpt4_changes(self):
async def test_happy_path(self):
self.make_json("DocumentReference", "1", **i2b2_mock_data.documentreference("foo"))
self.mock_response(parsed=False)
self.mock_azure()

task = covid_symptom.CovidSymptomNlpResultsGpt35Task(self.job_config, self.scrubber)
await task.run()
Expand Down
1 change: 1 addition & 0 deletions tests/nlp/test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def default_content(self) -> pydantic.BaseModel:
"example_nlp__nlp_llama4_scout",
)
async def test_basic_etl(self, task_name):
self.mock_azure()
for _ in range(8):
self.mock_response()
await self.run_etl(tasks=[task_name], input_path="%EXAMPLE-NLP%")
Expand Down
5 changes: 3 additions & 2 deletions tests/nlp/test_irae.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ class TestIraeTask(OpenAITestCase, BaseEtlSimple):
DATA_ROOT = "irae"

@ddt.data(
("irae__nlp_gpt_oss_120b", "openai/gpt-oss-120b"),
("irae__nlp_gpt_oss_120b", "gpt-oss-120b"),
("irae__nlp_gpt4o", "gpt-4o"),
("irae__nlp_gpt5", "gpt-5"),
("irae__nlp_llama4_scout", "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"),
("irae__nlp_llama4_scout", "Llama-4-Scout-17B-16E-Instruct"),
)
@ddt.unpack
async def test_basic_etl(self, task_name, model_id):
self.mock_azure()
self.mock_response(
content=DSAMention(
spans=["note"],
Expand Down
44 changes: 38 additions & 6 deletions tests/nlp/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,26 @@ async def assert_failed_doc(self, msg: str):
async def test_gpt_oss_120_env_url_override(self):
"""Verify we can override the default URL."""
self.patch_dict(os.environ, {"CUMULUS_GPT_OSS_120B_URL": ""})
self.assertEqual(nlp.GptOss120bModel().url, "http://localhost:8086/v1")
nlp.GptOss120bModel()
self.assertEqual(
self.mock_client_factory.call_args[1]["base_url"], "http://localhost:8086/v1"
)

self.patch_dict(os.environ, {"CUMULUS_GPT_OSS_120B_URL": "https://blarg/"})
self.assertEqual(nlp.GptOss120bModel().url, "https://blarg/")
nlp.GptOss120bModel()
self.assertEqual(self.mock_client_factory.call_args[1]["base_url"], "https://blarg/")

async def test_llama4_scout_env_url_override(self):
"""Verify we can override the default URL."""
self.patch_dict(os.environ, {"CUMULUS_LLAMA4_SCOUT_URL": ""})
self.assertEqual(nlp.Llama4ScoutModel().url, "http://localhost:8087/v1")
nlp.Llama4ScoutModel()
self.assertEqual(
self.mock_client_factory.call_args[1]["base_url"], "http://localhost:8087/v1"
)

self.patch_dict(os.environ, {"CUMULUS_LLAMA4_SCOUT_URL": "https://blarg/"})
self.assertEqual(nlp.Llama4ScoutModel().url, "https://blarg/")
nlp.Llama4ScoutModel()
self.assertEqual(self.mock_client_factory.call_args[1]["base_url"], "https://blarg/")

async def test_caching(self):
"""Verify we cache results"""
Expand Down Expand Up @@ -141,9 +149,8 @@ async def test_init_check_config(self):

# Bad model ID
self.mock_client.models.list = self.mock_model_list("bogus-model")
with self.assertRaises(SystemExit) as cm:
with self.assert_fatal_exit(errors.SERVICE_MISSING):
await irae.IraeGptOss120bTask.init_check()
self.assertEqual(errors.SERVICE_MISSING, cm.exception.code)

async def test_output_fields(self):
self.make_json("DocumentReference", "1", **i2b2_mock_data.documentreference("foo"))
Expand Down Expand Up @@ -271,6 +278,7 @@ class TestAzureNLPTasks(OpenAITestCase):
)
@ddt.unpack
async def test_requires_env(self, names, success):
self.mock_azure()
task = covid_symptom.CovidSymptomNlpResultsGpt35Task(self.job_config, self.scrubber)
env = {name: "content" for name in names}
self.patch_dict(os.environ, env, clear=True)
Expand All @@ -279,3 +287,27 @@ async def test_requires_env(self, names, success):
else:
with self.assertRaises(SystemExit):
await task.init_check()


@ddt.ddt
class TestBedrockNLPTasks(OpenAITestCase):
"""Tests the Bedrock specific code"""

MODEL_ID = "meta.llama4-scout-17b-instruct-v1:0"

@ddt.data(
# env vars to set, success
(["BEDROCK_OPENAI_API_KEY", "BEDROCK_OPENAI_ENDPOINT"], True),
(["BEDROCK_OPENAI_API_KEY"], False),
(["BEDROCK_OPENAI_ENDPOINT"], False),
)
@ddt.unpack
async def test_requires_env(self, names, success):
task = irae.IraeLlama4ScoutTask(self.job_config, self.scrubber)
env = {name: "content" for name in names}
self.patch_dict(os.environ, env, clear=True)
if success:
await task.init_check()
else:
with self.assertRaises(SystemExit):
await task.init_check()
15 changes: 9 additions & 6 deletions tests/nlp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@ def setUp(self):
self.mock_create = mock.AsyncMock()
self.mock_client.chat.completions.parse = self.mock_create
self.mock_client.models.list = self.mock_model_list(self.MODEL_ID)
mock_client_factory = self.patch("openai.AsyncOpenAI")
mock_client_factory.return_value = self.mock_client
self.mock_client_factory = self.patch("openai.AsyncOpenAI")
self.mock_client_factory.return_value = self.mock_client

# Also set up azure mocks, which have a different entry point
self.responses = []

def mock_azure(self):
self.patch_dict(os.environ, {"AZURE_OPENAI_API_KEY": "?", "AZURE_OPENAI_ENDPOINT": "?"})
mock_azure_factory = self.patch("openai.AsyncAzureOpenAI")
mock_azure_factory.return_value = self.mock_client
self.mock_azure_factory = self.patch("openai.AsyncAzureOpenAI")
self.mock_azure_factory.return_value = self.mock_client

self.responses = []
def mock_bedrock(self):
self.patch_dict(os.environ, {"BEDROCK_OPENAI_API_KEY": "?", "BEDROCK_OPENAI_ENDPOINT": "?"})

@staticmethod
def mock_model_list(models: str | list[str] = "", *, error: bool = False):
Expand Down