11"""Abstraction layer for Hugging Face's inference API"""
22
3- import abc
43import os
4+ from collections .abc import Iterable
55
66import openai
77from openai .types import chat
1010from cumulus_etl import errors
1111
1212
13- class OpenAIModel (abc .ABC ):
14- USER_ID = None # name in compose file or brand name
15- MODEL_NAME = None # which model to request via the API
13+ class OpenAIModel :
14+ AZURE_ID = None # model name in MS Azure
15+ BEDROCK_ID = None # model name in AWS Bedrock
16+ COMPOSE_ID = None # docker service name in compose.yaml
17+ VLLM_INFO = None # tuple of vLLM model name, env var stem use for URL, plus default port
1618
17- @abc .abstractmethod
18- def make_client (self ) -> openai .AsyncOpenAI :
19- """Creates an NLP client"""
19+ AZURE_ENV = ("AZURE_OPENAI_API_KEY" , "AZURE_OPENAI_ENDPOINT" )
20+ BEDROCK_ENV = ("BEDROCK_OPENAI_API_KEY" , "BEDROCK_OPENAI_ENDPOINT" )
21+
22+ @staticmethod
23+ def _env_defined (env_keys : Iterable [str ]) -> bool :
24+ return all (os .environ .get (key ) for key in env_keys )
2025
2126 def __init__ (self ):
22- self .client = self . make_client ()
27+ self .is_vllm = False
2328
24- # override to add your own checks
25- @classmethod
26- async def pre_init_check (cls ) -> None :
27- pass
29+ if self .AZURE_ID and self ._env_defined (self .AZURE_ENV ):
30+ self .model_name = self .AZURE_ID
31+ self .client = openai .AsyncAzureOpenAI (api_version = "2024-10-21" )
32+
33+ elif self .BEDROCK_ID and self ._env_defined (self .BEDROCK_ENV ):
34+ self .model_name = self .BEDROCK_ID
35+ self .client = openai .AsyncOpenAI (
36+ base_url = os .environ ["BEDROCK_OPENAI_ENDPOINT" ],
37+ api_key = os .environ ["BEDROCK_OPENAI_API_KEY" ],
38+ )
39+
40+ elif self .COMPOSE_ID :
41+ self .model_name = self .VLLM_INFO [0 ]
42+ url = os .environ .get (f"CUMULUS_{ self .VLLM_INFO [1 ]} _URL" ) # set by compose.yaml
43+ url = url or f"http://localhost:{ self .VLLM_INFO [2 ]} /v1" # offer non-docker fallback
44+ self .client = openai .AsyncOpenAI (base_url = url , api_key = "" )
45+ self .is_vllm = True
46+
47+ else :
48+ errors .fatal (
49+ "Missing Azure or Bedrock environment variables. "
50+ "Set AZURE_OPENAI_API_KEY & AZURE_OPENAI_ENDPOINT or "
51+ "BEDROCK_OPENAI_API_KEY & BEDROCK_OPENAI_ENDPOINT." ,
52+ errors .ARGS_INVALID ,
53+ )
2854
2955 # override to add your own checks
3056 async def post_init_check (self ) -> None :
3157 try :
3258 models = self .client .models .list ()
3359 names = {model .id async for model in models }
3460 except openai .APIError as exc :
35- errors .fatal (
36- f"NLP server '{ self .USER_ID } ' is unreachable: { exc } .\n "
37- f"If it's a local server, try running 'docker compose up { self .USER_ID } --wait'." ,
38- errors .SERVICE_MISSING ,
39- )
61+ message = f"NLP server is unreachable: { exc } ."
62+ if self .is_vllm :
63+ message += f"\n Try running 'docker compose up { self .COMPOSE_ID } --wait'."
64+ errors .fatal (message , errors .SERVICE_MISSING )
4065
41- if self .MODEL_NAME not in names :
66+ if self .model_name not in names :
4267 errors .fatal (
43- f"NLP server '{ self .USER_ID } ' is using an unexpected model setup ." ,
68+ f"NLP server does not have model ID '{ self .model_name } ' ." ,
4469 errors .SERVICE_MISSING ,
4570 )
4671
@@ -49,7 +74,7 @@ async def prompt(self, system: str, user: str, schema: BaseModel) -> chat.Parsed
4974
5075 async def _parse_prompt (self , system : str , user : str , schema ) -> chat .ParsedChatCompletion :
5176 return await self .client .chat .completions .parse (
52- model = self .MODEL_NAME ,
77+ model = self .model_name ,
5378 messages = [
5479 {"role" : "system" , "content" : system },
5580 {"role" : "user" , "content" : user },
@@ -61,28 +86,8 @@ async def _parse_prompt(self, system: str, user: str, schema) -> chat.ParsedChat
6186 )
6287
6388
64- class AzureModel (OpenAIModel ):
65- USER_ID = "Azure"
66-
67- @classmethod
68- async def pre_init_check (cls ) -> None :
69- await super ().pre_init_check ()
70-
71- messages = []
72- if not os .environ .get ("AZURE_OPENAI_API_KEY" ):
73- messages .append ("The AZURE_OPENAI_API_KEY environment variable is not set." )
74- if not os .environ .get ("AZURE_OPENAI_ENDPOINT" ):
75- messages .append ("The AZURE_OPENAI_ENDPOINT environment variable is not set." )
76-
77- if messages :
78- errors .fatal ("\n " .join (messages ), errors .ARGS_INVALID )
79-
80- def make_client (self ) -> openai .AsyncOpenAI :
81- return openai .AsyncAzureOpenAI (api_version = "2024-10-21" )
82-
83-
84- class Gpt35Model (AzureModel ): # deprecated, do not use in new code (doesn't support JSON schemas)
85- MODEL_NAME = "gpt-35-turbo-0125"
89+ class Gpt35Model (OpenAIModel ): # deprecated, do not use in new code (doesn't support JSON schemas)
90+ AZURE_ID = "gpt-35-turbo-0125"
8691
8792 # 3.5 doesn't support a pydantic JSON schema, so we do some work to keep it using the same API
8893 # as the rest of our code.
@@ -93,41 +98,27 @@ async def prompt(self, system: str, user: str, schema: BaseModel) -> chat.Parsed
9398 return response
9499
95100
96- class Gpt4Model (AzureModel ):
97- MODEL_NAME = "gpt-4"
98-
99-
100- class Gpt4oModel (AzureModel ):
101- MODEL_NAME = "gpt-4o"
102-
103-
104- class Gpt5Model (AzureModel ):
105- MODEL_NAME = "gpt-5"
106-
101+ class Gpt4Model (OpenAIModel ):
102+ AZURE_ID = "gpt-4"
107103
108- class LocalModel (OpenAIModel , abc .ABC ):
109- @property
110- @abc .abstractmethod
111- def url (self ) -> str :
112- """The OpenAI compatible URL to talk to (where's the server?)"""
113104
114- def make_client ( self ) -> openai . AsyncOpenAI :
115- return openai . AsyncOpenAI ( base_url = self . url , api_key = "" )
105+ class Gpt4oModel ( OpenAIModel ) :
106+ AZURE_ID = "gpt-4o"
116107
117108
118- class GptOss120bModel (LocalModel ):
119- USER_ID = "gpt-oss-120b"
120- MODEL_NAME = "openai/gpt-oss-120b"
109+ class Gpt5Model (OpenAIModel ):
110+ AZURE_ID = "gpt-5"
121111
122- @property
123- def url (self ) -> str :
124- return os .environ .get ("CUMULUS_GPT_OSS_120B_URL" ) or "http://localhost:8086/v1"
125112
113+ class GptOss120bModel (OpenAIModel ):
114+ AZURE_ID = "gpt-oss-120b"
115+ BEDROCK_ID = "openai.gpt-oss-120b-1:0"
116+ COMPOSE_ID = "gpt-oss-120b"
117+ VLLM_INFO = ("openai/gpt-oss-120b" , "GPT_OSS_120B" , 8086 )
126118
127- class Llama4ScoutModel (LocalModel ):
128- USER_ID = "llama4-scout"
129- MODEL_NAME = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
130119
131- @property
132- def url (self ) -> str :
133- return os .environ .get ("CUMULUS_LLAMA4_SCOUT_URL" ) or "http://localhost:8087/v1"
120+ class Llama4ScoutModel (OpenAIModel ):
121+ AZURE_ID = "Llama-4-Scout-17B-16E-Instruct"
122+ BEDROCK_ID = "meta.llama4-scout-17b-instruct-v1:0"
123+ COMPOSE_ID = "llama4-scout"
124+ VLLM_INFO = ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" , "LLAMA4_SCOUT" , 8087 )
0 commit comments