Skip to content

Commit f1c80ff

Browse files
committed
nlp: add bedrock support (and allow Azure for gpt-oss and llama4)
1 parent 6a712e3 commit f1c80ff

File tree

7 files changed

+116
-87
lines changed

7 files changed

+116
-87
lines changed

cumulus_etl/etl/tasks/nlp_task.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ class BaseOpenAiTask(BaseNlpTask):
134134

135135
@classmethod
136136
async def init_check(cls) -> None:
137-
await cls.client_class.pre_init_check()
138137
await cls.client_class().post_init_check()
139138

140139
async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator:

cumulus_etl/nlp/openai.py

Lines changed: 63 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Abstraction layer for Hugging Face's inference API"""
22

3-
import abc
43
import os
4+
from collections.abc import Iterable
55

66
import openai
77
from openai.types import chat
@@ -10,37 +10,62 @@
1010
from cumulus_etl import errors
1111

1212

13-
class OpenAIModel(abc.ABC):
14-
USER_ID = None # name in compose file or brand name
15-
MODEL_NAME = None # which model to request via the API
13+
class OpenAIModel:
14+
AZURE_ID = None # model name in MS Azure
15+
BEDROCK_ID = None # model name in AWS Bedrock
16+
COMPOSE_ID = None # docker service name in compose.yaml
17+
VLLM_INFO = None # tuple of vLLM model name, env var stem use for URL, plus default port
1618

17-
@abc.abstractmethod
18-
def make_client(self) -> openai.AsyncOpenAI:
19-
"""Creates an NLP client"""
19+
AZURE_ENV = ("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")
20+
BEDROCK_ENV = ("BEDROCK_OPENAI_API_KEY", "BEDROCK_OPENAI_ENDPOINT")
21+
22+
@staticmethod
23+
def _env_defined(env_keys: Iterable[str]) -> bool:
24+
return all(os.environ.get(key) for key in env_keys)
2025

2126
def __init__(self):
22-
self.client = self.make_client()
27+
self.is_vllm = False
2328

24-
# override to add your own checks
25-
@classmethod
26-
async def pre_init_check(cls) -> None:
27-
pass
29+
if self.AZURE_ID and self._env_defined(self.AZURE_ENV):
30+
self.model_name = self.AZURE_ID
31+
self.client = openai.AsyncAzureOpenAI(api_version="2024-10-21")
32+
33+
elif self.BEDROCK_ID and self._env_defined(self.BEDROCK_ENV):
34+
self.model_name = self.BEDROCK_ID
35+
self.client = openai.AsyncOpenAI(
36+
base_url=os.environ["BEDROCK_OPENAI_ENDPOINT"],
37+
api_key=os.environ["BEDROCK_OPENAI_API_KEY"],
38+
)
39+
40+
elif self.COMPOSE_ID:
41+
self.model_name = self.VLLM_INFO[0]
42+
url = os.environ.get(f"CUMULUS_{self.VLLM_INFO[1]}_URL") # set by compose.yaml
43+
url = url or f"http://localhost:{self.VLLM_INFO[2]}/v1" # offer non-docker fallback
44+
self.client = openai.AsyncOpenAI(base_url=url, api_key="")
45+
self.is_vllm = True
46+
47+
else:
48+
errors.fatal(
49+
"Missing Azure or Bedrock environment variables. "
50+
"Set AZURE_OPENAI_API_KEY & AZURE_OPENAI_ENDPOINT or "
51+
"BEDROCK_OPENAI_API_KEY & BEDROCK_OPENAI_ENDPOINT.",
52+
errors.ARGS_INVALID,
53+
)
2854

2955
# override to add your own checks
3056
async def post_init_check(self) -> None:
3157
try:
3258
models = self.client.models.list()
3359
names = {model.id async for model in models}
3460
except openai.APIError as exc:
35-
errors.fatal(
36-
f"NLP server '{self.USER_ID}' is unreachable: {exc}.\n"
37-
f"If it's a local server, try running 'docker compose up {self.USER_ID} --wait'.",
38-
errors.SERVICE_MISSING,
39-
)
61+
message = f"NLP server is unreachable: {exc}."
62+
if self.is_vllm:
63+
message += f"\nTry running 'docker compose up {self.COMPOSE_ID} --wait'."
64+
errors.fatal(message, errors.SERVICE_MISSING)
4065

41-
if self.MODEL_NAME not in names:
66+
if self.model_name not in names:
4267
errors.fatal(
43-
f"NLP server '{self.USER_ID}' is using an unexpected model setup.",
68+
f"NLP server does not have model ID '{self.model_name}'.",
4469
errors.SERVICE_MISSING,
4570
)
4671

@@ -49,7 +74,7 @@ async def prompt(self, system: str, user: str, schema: BaseModel) -> chat.Parsed
4974

5075
async def _parse_prompt(self, system: str, user: str, schema) -> chat.ParsedChatCompletion:
5176
return await self.client.chat.completions.parse(
52-
model=self.MODEL_NAME,
77+
model=self.model_name,
5378
messages=[
5479
{"role": "system", "content": system},
5580
{"role": "user", "content": user},
@@ -61,28 +86,8 @@ async def _parse_prompt(self, system: str, user: str, schema) -> chat.ParsedChat
6186
)
6287

6388

64-
class AzureModel(OpenAIModel):
65-
USER_ID = "Azure"
66-
67-
@classmethod
68-
async def pre_init_check(cls) -> None:
69-
await super().pre_init_check()
70-
71-
messages = []
72-
if not os.environ.get("AZURE_OPENAI_API_KEY"):
73-
messages.append("The AZURE_OPENAI_API_KEY environment variable is not set.")
74-
if not os.environ.get("AZURE_OPENAI_ENDPOINT"):
75-
messages.append("The AZURE_OPENAI_ENDPOINT environment variable is not set.")
76-
77-
if messages:
78-
errors.fatal("\n".join(messages), errors.ARGS_INVALID)
79-
80-
def make_client(self) -> openai.AsyncOpenAI:
81-
return openai.AsyncAzureOpenAI(api_version="2024-10-21")
82-
83-
84-
class Gpt35Model(AzureModel): # deprecated, do not use in new code (doesn't support JSON schemas)
85-
MODEL_NAME = "gpt-35-turbo-0125"
89+
class Gpt35Model(OpenAIModel): # deprecated, do not use in new code (doesn't support JSON schemas)
90+
AZURE_ID = "gpt-35-turbo-0125"
8691

8792
# 3.5 doesn't support a pydantic JSON schema, so we do some work to keep it using the same API
8893
# as the rest of our code.
@@ -93,41 +98,27 @@ async def prompt(self, system: str, user: str, schema: BaseModel) -> chat.Parsed
9398
return response
9499

95100

96-
class Gpt4Model(AzureModel):
97-
MODEL_NAME = "gpt-4"
98-
99-
100-
class Gpt4oModel(AzureModel):
101-
MODEL_NAME = "gpt-4o"
102-
103-
104-
class Gpt5Model(AzureModel):
105-
MODEL_NAME = "gpt-5"
106-
101+
class Gpt4Model(OpenAIModel):
102+
AZURE_ID = "gpt-4"
107103

108-
class LocalModel(OpenAIModel, abc.ABC):
109-
@property
110-
@abc.abstractmethod
111-
def url(self) -> str:
112-
"""The OpenAI compatible URL to talk to (where's the server?)"""
113104

114-
def make_client(self) -> openai.AsyncOpenAI:
115-
return openai.AsyncOpenAI(base_url=self.url, api_key="")
105+
class Gpt4oModel(OpenAIModel):
106+
AZURE_ID = "gpt-4o"
116107

117108

118-
class GptOss120bModel(LocalModel):
119-
USER_ID = "gpt-oss-120b"
120-
MODEL_NAME = "openai/gpt-oss-120b"
109+
class Gpt5Model(OpenAIModel):
110+
AZURE_ID = "gpt-5"
121111

122-
@property
123-
def url(self) -> str:
124-
return os.environ.get("CUMULUS_GPT_OSS_120B_URL") or "http://localhost:8086/v1"
125112

113+
class GptOss120bModel(OpenAIModel):
114+
AZURE_ID = "gpt-oss-120b"
115+
BEDROCK_ID = "openai.gpt-oss-120b-1:0"
116+
COMPOSE_ID = "gpt-oss-120b"
117+
VLLM_INFO = ("openai/gpt-oss-120b", "GPT_OSS_120B", 8086)
126118

127-
class Llama4ScoutModel(LocalModel):
128-
USER_ID = "llama4-scout"
129-
MODEL_NAME = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
130119

131-
@property
132-
def url(self) -> str:
133-
return os.environ.get("CUMULUS_LLAMA4_SCOUT_URL") or "http://localhost:8087/v1"
120+
class Llama4ScoutModel(OpenAIModel):
121+
AZURE_ID = "Llama-4-Scout-17B-16E-Instruct"
122+
BEDROCK_ID = "meta.llama4-scout-17b-instruct-v1:0"
123+
COMPOSE_ID = "llama4-scout"
124+
VLLM_INFO = ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", "LLAMA4_SCOUT", 8087)

tests/covid_symptom/test_covid_gpt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ async def test_gpt4_changes(self):
4242
"""
4343
self.make_json("DocumentReference", "doc", **i2b2_mock_data.documentreference())
4444
self.mock_response()
45+
self.mock_azure()
4546

4647
task = covid_symptom.CovidSymptomNlpResultsGpt4Task(self.job_config, self.scrubber)
4748
await task.run()
@@ -54,6 +55,7 @@ async def test_gpt4_changes(self):
5455
async def test_happy_path(self):
5556
self.make_json("DocumentReference", "1", **i2b2_mock_data.documentreference("foo"))
5657
self.mock_response(parsed=False)
58+
self.mock_azure()
5759

5860
task = covid_symptom.CovidSymptomNlpResultsGpt35Task(self.job_config, self.scrubber)
5961
await task.run()

tests/nlp/test_example.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def default_content(self) -> pydantic.BaseModel:
2323
"example_nlp__nlp_llama4_scout",
2424
)
2525
async def test_basic_etl(self, task_name):
26+
self.mock_azure()
2627
for _ in range(8):
2728
self.mock_response()
2829
await self.run_etl(tasks=[task_name], input_path="%EXAMPLE-NLP%")

tests/nlp/test_irae.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ class TestIraeTask(OpenAITestCase, BaseEtlSimple):
1616
DATA_ROOT = "irae"
1717

1818
@ddt.data(
19-
("irae__nlp_gpt_oss_120b", "openai/gpt-oss-120b"),
19+
("irae__nlp_gpt_oss_120b", "gpt-oss-120b"),
2020
("irae__nlp_gpt4o", "gpt-4o"),
2121
("irae__nlp_gpt5", "gpt-5"),
22-
("irae__nlp_llama4_scout", "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"),
22+
("irae__nlp_llama4_scout", "Llama-4-Scout-17B-16E-Instruct"),
2323
)
2424
@ddt.unpack
2525
async def test_basic_etl(self, task_name, model_id):
26+
self.mock_azure()
2627
self.mock_response(
2728
content=DSAMention(
2829
spans=["note"],

tests/nlp/test_openai.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,26 @@ async def assert_failed_doc(self, msg: str):
6767
async def test_gpt_oss_120_env_url_override(self):
6868
"""Verify we can override the default URL."""
6969
self.patch_dict(os.environ, {"CUMULUS_GPT_OSS_120B_URL": ""})
70-
self.assertEqual(nlp.GptOss120bModel().url, "http://localhost:8086/v1")
70+
nlp.GptOss120bModel()
71+
self.assertEqual(
72+
self.mock_client_factory.call_args[1]["base_url"], "http://localhost:8086/v1"
73+
)
7174

7275
self.patch_dict(os.environ, {"CUMULUS_GPT_OSS_120B_URL": "https://blarg/"})
73-
self.assertEqual(nlp.GptOss120bModel().url, "https://blarg/")
76+
nlp.GptOss120bModel()
77+
self.assertEqual(self.mock_client_factory.call_args[1]["base_url"], "https://blarg/")
7478

7579
async def test_llama4_scout_env_url_override(self):
7680
"""Verify we can override the default URL."""
7781
self.patch_dict(os.environ, {"CUMULUS_LLAMA4_SCOUT_URL": ""})
78-
self.assertEqual(nlp.Llama4ScoutModel().url, "http://localhost:8087/v1")
82+
nlp.Llama4ScoutModel()
83+
self.assertEqual(
84+
self.mock_client_factory.call_args[1]["base_url"], "http://localhost:8087/v1"
85+
)
7986

8087
self.patch_dict(os.environ, {"CUMULUS_LLAMA4_SCOUT_URL": "https://blarg/"})
81-
self.assertEqual(nlp.Llama4ScoutModel().url, "https://blarg/")
88+
nlp.Llama4ScoutModel()
89+
self.assertEqual(self.mock_client_factory.call_args[1]["base_url"], "https://blarg/")
8290

8391
async def test_caching(self):
8492
"""Verify we cache results"""
@@ -141,9 +149,8 @@ async def test_init_check_config(self):
141149

142150
# Bad model ID
143151
self.mock_client.models.list = self.mock_model_list("bogus-model")
144-
with self.assertRaises(SystemExit) as cm:
152+
with self.assert_fatal_exit(errors.SERVICE_MISSING):
145153
await irae.IraeGptOss120bTask.init_check()
146-
self.assertEqual(errors.SERVICE_MISSING, cm.exception.code)
147154

148155
async def test_output_fields(self):
149156
self.make_json("DocumentReference", "1", **i2b2_mock_data.documentreference("foo"))
@@ -271,6 +278,7 @@ class TestAzureNLPTasks(OpenAITestCase):
271278
)
272279
@ddt.unpack
273280
async def test_requires_env(self, names, success):
281+
self.mock_azure()
274282
task = covid_symptom.CovidSymptomNlpResultsGpt35Task(self.job_config, self.scrubber)
275283
env = {name: "content" for name in names}
276284
self.patch_dict(os.environ, env, clear=True)
@@ -279,3 +287,27 @@ async def test_requires_env(self, names, success):
279287
else:
280288
with self.assertRaises(SystemExit):
281289
await task.init_check()
290+
291+
292+
@ddt.ddt
293+
class TestBedrockNLPTasks(OpenAITestCase):
294+
"""Tests the Bedrock specific code"""
295+
296+
MODEL_ID = "meta.llama4-scout-17b-instruct-v1:0"
297+
298+
@ddt.data(
299+
# env vars to set, success
300+
(["BEDROCK_OPENAI_API_KEY", "BEDROCK_OPENAI_ENDPOINT"], True),
301+
(["BEDROCK_OPENAI_API_KEY"], False),
302+
(["BEDROCK_OPENAI_ENDPOINT"], False),
303+
)
304+
@ddt.unpack
305+
async def test_requires_env(self, names, success):
306+
task = irae.IraeLlama4ScoutTask(self.job_config, self.scrubber)
307+
env = {name: "content" for name in names}
308+
self.patch_dict(os.environ, env, clear=True)
309+
if success:
310+
await task.init_check()
311+
else:
312+
with self.assertRaises(SystemExit):
313+
await task.init_check()

tests/nlp/utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@ def setUp(self):
2020
self.mock_create = mock.AsyncMock()
2121
self.mock_client.chat.completions.parse = self.mock_create
2222
self.mock_client.models.list = self.mock_model_list(self.MODEL_ID)
23-
mock_client_factory = self.patch("openai.AsyncOpenAI")
24-
mock_client_factory.return_value = self.mock_client
23+
self.mock_client_factory = self.patch("openai.AsyncOpenAI")
24+
self.mock_client_factory.return_value = self.mock_client
2525

26-
# Also set up azure mocks, which have a different entry point
26+
self.responses = []
27+
28+
def mock_azure(self):
2729
self.patch_dict(os.environ, {"AZURE_OPENAI_API_KEY": "?", "AZURE_OPENAI_ENDPOINT": "?"})
28-
mock_azure_factory = self.patch("openai.AsyncAzureOpenAI")
29-
mock_azure_factory.return_value = self.mock_client
30+
self.mock_azure_factory = self.patch("openai.AsyncAzureOpenAI")
31+
self.mock_azure_factory.return_value = self.mock_client
3032

31-
self.responses = []
33+
def mock_bedrock(self):
34+
self.patch_dict(os.environ, {"BEDROCK_OPENAI_API_KEY": "?", "BEDROCK_OPENAI_ENDPOINT": "?"})
3235

3336
@staticmethod
3437
def mock_model_list(models: str | list[str] = "", *, error: bool = False):

0 commit comments

Comments
 (0)