@@ -88,13 +88,17 @@ class LangChainModelClient(ModelClient):
88
88
def __init__ (self , config : dict , ** kwargs ) -> None :
89
89
super ().__init__ ()
90
90
logger .info ("LangChain model client config: %s" , str (config ))
91
- # Make a copy of the config since we are popping the keys
91
+ # Make a copy of the config since we are popping some keys
92
92
config = copy .deepcopy (config )
93
+ # model_client_cls will always be LangChainModelClient
93
94
self .client_class = config .pop ("model_client_cls" )
94
95
95
- self .function_call_params = config .pop ("function_call_params" , {})
96
+ # model_name is used in constructing the response.
97
+ self .model_name = config .get ("model" , "" )
96
98
97
- self .model_name = config .get ("model" )
99
+ # If the config specified function_call_params,
100
+ # Pop the params and use them only for tool calling.
101
+ self .function_call_params = config .pop ("function_call_params" , {})
98
102
99
103
# Import the LangChain class
100
104
if "langchain_cls" not in config :
@@ -104,8 +108,13 @@ def __init__(self, config: dict, **kwargs) -> None:
104
108
langchain_module = importlib .import_module (module_name )
105
109
langchain_cls = getattr (langchain_module , cls_name )
106
110
111
+ # If the config specified client_params,
112
+ # Only use the client_params to initialize the LangChain model.
113
+ # Otherwise, use the config
114
+ self .client_params = config .get ("client_params" , config )
115
+
107
116
# Initialize the LangChain client
108
- self .model = langchain_cls (** config )
117
+ self .model = langchain_cls (** self . client_params )
109
118
110
119
def create (self , params ) -> ModelClient .ModelClientResponseProtocol :
111
120
"""Creates a LLM completion for a given config.
0 commit comments