7
7
import uuid
8
8
import warnings
9
9
from pydantic import BaseModel
10
- from langchain_community .chat_models import ErnieBotChat
11
- from langchain_nvidia_ai_endpoints import ChatNVIDIA
12
10
from langchain .chat_models import init_chat_model
13
11
from ..helpers import models_tokens
14
12
from ..models import (
@@ -147,16 +145,17 @@ def handle_model(model_name, provider, token_key, default_token=8192):
147
145
warnings .simplefilter ("ignore" )
148
146
return init_chat_model (** llm_params )
149
147
150
- known_models = ["chatgpt" ,"gpt" ,"openai" , "azure_openai" , "google_genai" , "ollama" , "oneapi" , "nvidia" , "groq" , "google_vertexai" , "bedrock" , "mistralai" , "hugging_face" , "deepseek" , "ernie" , "fireworks" ]
148
+ known_models = ["chatgpt" ,"gpt" ,"openai" , "azure_openai" , "google_genai" ,
149
+ "ollama" , "oneapi" , "nvidia" , "groq" , "google_vertexai" ,
150
+ "bedrock" , "mistralai" , "hugging_face" , "deepseek" , "ernie" , "fireworks" ]
151
151
152
152
if llm_params ["model" ].split ("/" )[0 ] not in known_models and llm_params ["model" ].split ("-" )[0 ] not in known_models :
153
153
raise ValueError (f"Model '{ llm_params ['model' ]} ' is not supported" )
154
154
155
155
try :
156
156
if "azure" in llm_params ["model" ]:
157
157
model_name = llm_params ["model" ].split ("/" )[- 1 ]
158
- return handle_model (model_name , "azure_openai" , model_name )
159
-
158
+ return handle_model (model_name , "azure_openai" , model_name )
160
159
if "fireworks" in llm_params ["model" ]:
161
160
model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
162
161
token_key = llm_params ["model" ].split ("/" )[- 1 ]
@@ -188,7 +187,6 @@ def handle_model(model_name, provider, token_key, default_token=8192):
188
187
model_name = llm_params ["model" ].split ("/" )[- 1 ]
189
188
return handle_model (model_name , "mistralai" , model_name )
190
189
191
- # Instantiate the language model based on the model name (models that do not use the common interface)
192
190
elif "deepseek" in llm_params ["model" ]:
193
191
try :
194
192
self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
@@ -198,6 +196,8 @@ def handle_model(model_name, provider, token_key, default_token=8192):
198
196
return DeepSeek (llm_params )
199
197
200
198
elif "ernie" in llm_params ["model" ]:
199
+ from langchain_community .chat_models import ErnieBotChat
200
+
201
201
try :
202
202
self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
203
203
except KeyError :
@@ -215,6 +215,8 @@ def handle_model(model_name, provider, token_key, default_token=8192):
215
215
return OneApi (llm_params )
216
216
217
217
elif "nvidia" in llm_params ["model" ]:
218
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
219
+
218
220
try :
219
221
self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
220
222
llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
0 commit comments