diff --git a/integrations/langchain/src/databricks_langchain/chat_models.py b/integrations/langchain/src/databricks_langchain/chat_models.py index c714686d..8a28ddc1 100644 --- a/integrations/langchain/src/databricks_langchain/chat_models.py +++ b/integrations/langchain/src/databricks_langchain/chat_models.py @@ -214,7 +214,7 @@ class GetPopulation(BaseModel): """Name of Databricks Model Serving endpoint to query.""" target_uri: str = "databricks" """The target URI to use. Defaults to ``databricks``.""" - temperature: float = 0.0 + temperature: Optional[float] = 0.0 """Sampling temperature. Higher values make the model more creative.""" n: int = 1 """The number of completion choices to generate.""" @@ -252,20 +252,32 @@ def endpoint(self, value: str) -> None: def __init__(self, **kwargs: Any): super().__init__(**kwargs) + if "temperature" not in kwargs: + warnings.warn( + "Currently, temperature defaults to 0.0 if not specified. " + "In the next release, temperature will need to be explicitly set. " + "Please update your code to specify a temperature value.", + DeprecationWarning, + stacklevel=2, + ) self.client = get_deployment_client(self.target_uri) self.extra_params = self.extra_params or {} @property def _default_params(self) -> Dict[str, Any]: - params: Dict[str, Any] = { - "target_uri": self.target_uri, - "model": self.model, + exclude_if_none = { "temperature": self.temperature, "n": self.n, "stop": self.stop, "max_tokens": self.max_tokens, "extra_params": self.extra_params, } + + params = { + "model": self.model, + "target_uri": self.target_uri, + **{k: v for k, v in exclude_if_none.items() if v is not None}, + } return params def _generate( @@ -287,11 +299,12 @@ def _prepare_inputs( ) -> Dict[str, Any]: data: Dict[str, Any] = { "messages": [_convert_message_to_dict(msg) for msg in messages], - "temperature": self.temperature, "n": self.n, **self.extra_params, # type: ignore **kwargs, } + if self.temperature is not None: + data["temperature"] = self.temperature if stop := self.stop or stop: data["stop"] = stop if self.max_tokens is not None: @@ -620,7 +633,7 @@ class AnswerWithJustification(BaseModel): if method == "function_calling": if schema is None: raise ValueError( - "schema must be specified when method is 'function_calling'. " "Received None." + "schema must be specified when method is 'function_calling'. Received None." ) tool_name = convert_to_openai_tool(schema)["function"]["name"] llm = self.bind_tools([schema], tool_choice=tool_name) @@ -641,7 +654,7 @@ class AnswerWithJustification(BaseModel): elif method == "json_schema": if schema is None: raise ValueError( - "schema must be specified when method is 'json_schema'. " "Received None." + "schema must be specified when method is 'json_schema'. Received None." ) response_format = { "type": "json_schema",