diff --git a/libs/ibm/langchain_ibm/chat_models.py b/libs/ibm/langchain_ibm/chat_models.py
index 783b3c3..bfad13c 100644
--- a/libs/ibm/langchain_ibm/chat_models.py
+++ b/libs/ibm/langchain_ibm/chat_models.py
@@ -2,8 +2,6 @@
import json
import logging
-import warnings
-from datetime import datetime
from operator import itemgetter
from typing import (
Any,
@@ -24,7 +22,10 @@
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models import ModelInference # type: ignore
-from langchain_core._api import LangChainDeprecationWarning
+from ibm_watsonx_ai.foundation_models.schema import ( # type: ignore
+ BaseSchema,
+ TextChatParameters,
+)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
@@ -49,22 +50,25 @@
ToolCall,
ToolMessage,
ToolMessageChunk,
- convert_to_messages,
)
+from langchain_core.messages.ai import UsageMetadata
+from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
+ make_invalid_tool_call,
+ parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
-from langchain_core.prompt_values import ChatPromptValue
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
+from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
@@ -85,48 +89,53 @@ def _convert_dict_to_message(_dict: Mapping[str, Any], call_id: str) -> BaseMess
The LangChain message.
"""
role = _dict.get("role")
+ name = _dict.get("name")
+ id_ = call_id
if role == "user":
- return HumanMessage(content=_dict.get("generated_text", ""))
- else:
+ return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
+ elif role == "assistant":
+ content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
+ if function_call := _dict.get("function_call"):
+ additional_kwargs["function_call"] = dict(function_call)
tool_calls = []
- invalid_tool_calls: List[InvalidToolCall] = []
- content = ""
-
- raw_tool_calls = _dict.get("generated_text", "")
-
- if "json" in raw_tool_calls:
- try:
- split_raw_tool_calls = raw_tool_calls.split("\n\n")
- for raw_tool_call in split_raw_tool_calls:
- if "json" in raw_tool_call:
- json_parts = JsonOutputParser().parse(raw_tool_call)
-
- if json_parts["function"]["name"] == "Final Answer":
- content = json_parts["function"]["arguments"]["output"]
- break
-
- additional_kwargs["tool_calls"] = json_parts
-
- parsed = {
- "name": json_parts["function"]["name"] or "",
- "args": json_parts["function"]["arguments"] or {},
- "id": call_id,
- }
- tool_calls.append(parsed)
-
- except: # noqa: E722
- content = _dict.get("generated_text", "") or ""
-
- else:
- content = _dict.get("generated_text", "") or ""
-
+ invalid_tool_calls = []
+ if raw_tool_calls := _dict.get("tool_calls"):
+ additional_kwargs["tool_calls"] = raw_tool_calls
+ for raw_tool_call in raw_tool_calls:
+ try:
+ tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
+ except Exception as e:
+ invalid_tool_calls.append(
+ make_invalid_tool_call(raw_tool_call, str(e))
+ )
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
+ name=name,
+ id=id_,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
+ elif role == "system":
+ return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
+ elif role == "function":
+ return FunctionMessage(
+ content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
+ )
+ elif role == "tool":
+ additional_kwargs = {}
+ if "name" in _dict:
+ additional_kwargs["name"] = _dict["name"]
+ return ToolMessage(
+ content=_dict.get("content", ""),
+ tool_call_id=cast(str, _dict.get("tool_call_id")),
+ additional_kwargs=additional_kwargs,
+ name=name,
+ id=id_,
+ )
+ else:
+ return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type]
def _format_message_content(content: Any) -> Any:
@@ -149,30 +158,6 @@ def _format_message_content(content: Any) -> Any:
return formatted_content
-def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
- return {
- "type": "function",
- "id": tool_call["id"],
- "function": {
- "name": tool_call["name"],
- "arguments": json.dumps(tool_call["args"]),
- },
- }
-
-
-def _lc_invalid_tool_call_to_openai_tool_call(
- invalid_tool_call: InvalidToolCall,
-) -> dict:
- return {
- "type": "function",
- "id": invalid_tool_call["id"],
- "function": {
- "name": invalid_tool_call["name"],
- "arguments": invalid_tool_call["args"],
- },
- }
-
-
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
@@ -197,9 +182,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict["function_call"] = message.additional_kwargs["function_call"]
if message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
- _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
+ _lc_tool_call_to_watsonx_tool_call(tc) for tc in message.tool_calls
] + [
- _lc_invalid_tool_call_to_openai_tool_call(tc)
+ _lc_invalid_tool_call_to_watsonx_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif "tool_calls" in message.additional_kwargs:
@@ -213,14 +198,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
pass
# If tool calls present, content null value should be None not empty string.
if "function_call" in message_dict or "tool_calls" in message_dict:
- message_dict["content"] = message_dict["content"] or ""
- message_dict["tool_calls"][0]["name"] = message_dict["tool_calls"][0][
- "function"
- ]["name"]
- message_dict["tool_calls"][0]["args"] = json.loads(
- message_dict["tool_calls"][0]["function"]["arguments"]
- )
-
+ message_dict["content"] = message_dict["content"] or None
elif isinstance(message, SystemMessage):
message_dict["role"] = "system"
elif isinstance(message, FunctionMessage):
@@ -237,11 +215,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_delta_to_message_chunk(
- _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
+ _dict: Mapping[str, Any],
+ default_class: Type[BaseMessageChunk],
+ call_id: str,
+ finish_reason: str,
) -> BaseMessageChunk:
- id_ = "sample_id"
+ id_ = call_id
role = cast(str, _dict.get("role"))
- content = cast(str, _dict.get("generated_text") or "")
+ content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
@@ -253,12 +234,14 @@ def _convert_delta_to_message_chunk(
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
- {
- "name": rtc["function"].get("name"),
- "args": rtc["function"].get("arguments"),
- "id": rtc.get("id"),
- "index": rtc["index"],
- }
+ tool_call_chunk(
+ name=rtc["function"].get("name")
+ if finish_reason is not None
+ else None,
+ args=rtc["function"].get("arguments"),
+ id=call_id if finish_reason is not None else None,
+ index=rtc["index"],
+ )
for rtc in raw_tool_calls
]
except KeyError:
@@ -287,6 +270,58 @@ def _convert_delta_to_message_chunk(
return default_class(content=content, id=id_) # type: ignore
+def _convert_chunk_to_generation_chunk(
+ chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict]
+) -> Optional[ChatGenerationChunk]:
+ token_usage = chunk.get("usage")
+ choices = chunk.get("choices", [])
+
+ usage_metadata: Optional[UsageMetadata] = (
+ UsageMetadata(
+ input_tokens=token_usage.get("prompt_tokens", 0),
+ output_tokens=token_usage.get("completion_tokens", 0),
+ total_tokens=token_usage.get("total_tokens", 0),
+ )
+ if token_usage
+ else None
+ )
+
+ if len(choices) == 0:
+ # logprobs is implicitly None
+ generation_chunk = ChatGenerationChunk(
+ message=default_chunk_class(content="", usage_metadata=usage_metadata)
+ )
+ return generation_chunk
+
+ choice = choices[0]
+ if choice["delta"] is None:
+ return None
+
+ message_chunk = _convert_delta_to_message_chunk(
+ choice["delta"], default_chunk_class, chunk["id"], choice["finish_reason"]
+ )
+ generation_info = {**base_generation_info} if base_generation_info else {}
+
+ if finish_reason := choice.get("finish_reason"):
+ generation_info["finish_reason"] = finish_reason
+ if model_name := chunk.get("model"):
+ generation_info["model_name"] = model_name
+ if system_fingerprint := chunk.get("system_fingerprint"):
+ generation_info["system_fingerprint"] = system_fingerprint
+
+ logprobs = choice.get("logprobs")
+ if logprobs:
+ generation_info["logprobs"] = logprobs
+
+ if usage_metadata and isinstance(message_chunk, AIMessageChunk):
+ message_chunk.usage_metadata = usage_metadata
+
+ generation_chunk = ChatGenerationChunk(
+ message=message_chunk, generation_info=generation_info or None
+ )
+ return generation_chunk
+
+
class _FunctionCall(TypedDict):
name: str
@@ -378,7 +413,7 @@ class ChatWatsonx(BaseChatModel):
version: Optional[SecretStr] = None
"""Version of the CPD instance."""
- params: Optional[dict] = None
+ params: Optional[Union[dict, TextChatParameters]] = None
"""Model parameters to use during request generation."""
verify: Union[str, bool, None] = None
@@ -523,77 +558,9 @@ def _generate(
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
- if message_dicts[-1].get("role") == "tool":
- chat_prompt = (
- "User: Please summarize given sentences into "
- "JSON containing Final Answer: '"
- )
- for message in message_dicts:
- if message["content"]:
- chat_prompt += message["content"] + "\n"
- chat_prompt += "'"
- else:
- chat_prompt = self._create_chat_prompt(message_dicts)
-
- tools = kwargs.get("tools")
-
- if tools:
- chat_prompt = f"""
-You are Mixtral Chat function calling, an AI language model developed by Mistral AI.
-You are a cautious assistant. You carefully follow instructions. You are helpful and
-harmless and you follow ethical guidelines and promote positive behavior. Here are a
-few of the tools available to you:
-[AVAILABLE_TOOLS]
-{json.dumps(tools[0], indent=2)}
-[/AVAILABLE_TOOLS]
-To use these tools you must always respond in JSON format containing `"type"` and
-`"function"` key-value pairs. Also `"function"` key-value pair always containing
-`"name"` and `"arguments"` key-value pairs. For example, to answer the question,
-"What is a length of word think?" you must use the get_word_length tool like so:
-
-```json
-{{
- "type": "function",
- "function": {{
- "name": "get_word_length",
- "arguments": {{
- "word": "think"
- }}
- }}
-}}
-```
-
-
-Remember, even when answering to the user, you must still use this JSON format!
-If you'd like to ask how the user is doing you must write:
-
-```json
-{{
- "type": "function",
- "function": {{
- "name": "Final Answer",
- "arguments": {{
- "output": "How are you today?"
- }}
- }}
-}}
-```
-
-
-Remember to end your response with ''
-
-{chat_prompt}
-(reminder to respond in a JSON blob no matter what and use tools only if necessary)"""
-
- params = params | {"stop_sequences": [""]}
-
- if "tools" in kwargs:
- del kwargs["tools"]
- if "tool_choice" in kwargs:
- del kwargs["tool_choice"]
-
- response = self.watsonx_model.generate(
- prompt=chat_prompt, **(kwargs | {"params": params})
+
+ response = self.watsonx_model.chat(
+ messages=message_dicts, **(kwargs | {"params": params})
)
return self._create_chat_result(response)
@@ -605,140 +572,54 @@ def _stream(
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
- if message_dicts[-1].get("role") == "tool":
- chat_prompt = (
- "User: Please summarize given sentences into JSON "
- "containing Final Answer: '"
- )
- for message in message_dicts:
- if message["content"]:
- chat_prompt += message["content"] + "\n"
- chat_prompt += "'"
- else:
- chat_prompt = self._create_chat_prompt(message_dicts)
-
- tools = kwargs.get("tools")
-
- if tools:
- chat_prompt = f"""
-You are Mixtral Chat function calling, an AI language model developed by Mistral AI.
-You are a cautious assistant. You carefully follow instructions. You are helpful and
-harmless and you follow ethical guidelines and promote positive behavior. Here are a
-few of the tools available to you:
-[AVAILABLE_TOOLS]
-{json.dumps(tools[0], indent=2)}
-[/AVAILABLE_TOOLS]
-To use these tools you must always respond in JSON format containing `"type"` and
-`"function"` key-value pairs. Also `"function"` key-value pair always containing
-`"name"` and `"arguments"` key-value pairs. For example, to answer the question,
-"What is a length of word think?" you must use the get_word_length tool like so:
-
-```json
-{{
- "type": "function",
- "function": {{
- "name": "get_word_length",
- "arguments": {{
- "word": "think"
- }}
- }}
-}}
-```
-
-
-Remember, even when answering to the user, you must still use this JSON format!
-If you'd like to ask how the user is doing you must write:
-
-```json
-{{
- "type": "function",
- "function": {{
- "name": "Final Answer",
- "arguments": {{
- "output": "How are you today?"
- }}
- }}
-}}
-```
-
-
-Remember to end your response with ''
-
-{chat_prompt[:-5]}
-(reminder to respond in a JSON blob no matter what and use tools only if necessary)"""
-
- params = params | {"stop_sequences": [""]}
-
- if "tools" in kwargs:
- del kwargs["tools"]
- if "tool_choice" in kwargs:
- del kwargs["tool_choice"]
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
+ base_generation_info: dict = {}
+
+ is_first_chunk = True
- for chunk in self.watsonx_model.generate_text_stream(
- prompt=chat_prompt, raw_response=True, **(kwargs | {"params": params})
+ for chunk in self.watsonx_model.chat_stream(
+ messages=message_dicts, **(kwargs | {"params": params})
):
if not isinstance(chunk, dict):
- chunk = chunk.dict()
- if len(chunk["results"]) == 0:
- continue
- choice = chunk["results"][0]
-
- message_chunk = _convert_delta_to_message_chunk(choice, default_chunk_class)
- generation_info = {}
- if (finish_reason := choice.get("stop_reason")) != "not_finished":
- generation_info["finish_reason"] = finish_reason
- chunk = ChatGenerationChunk(
- message=message_chunk, generation_info=generation_info or None
+ chunk = chunk.model_dump()
+ generation_chunk = _convert_chunk_to_generation_chunk(
+ chunk,
+ default_chunk_class,
+ base_generation_info if is_first_chunk else {},
)
+ if generation_chunk is None:
+ continue
+ default_chunk_class = generation_chunk.message.__class__
+ logprobs = (generation_chunk.generation_info or {}).get("logprobs")
if run_manager:
- run_manager.on_llm_new_token(chunk.text, chunk=chunk)
-
- yield chunk
-
- def _create_chat_prompt(self, messages: List[Dict[str, Any]]) -> str:
- prompt = ""
-
- if self.model_id in ["ibm/granite-13b-chat-v1", "ibm/granite-13b-chat-v2"]:
- for message in messages:
- if message["role"] == "system":
- prompt += "<|system|>\n" + message["content"] + "\n\n"
- elif message["role"] == "assistant":
- prompt += "<|assistant|>\n" + message["content"] + "\n\n"
- elif message["role"] == "function":
- prompt += "<|function|>\n" + message["content"] + "\n\n"
- elif message["role"] == "tool":
- prompt += "<|tool|>\n" + message["content"] + "\n\n"
- else:
- prompt += "<|user|>:\n" + message["content"] + "\n\n"
-
- prompt += "<|assistant|>\n"
-
- elif self.model_id in [
- "meta-llama/llama-2-13b-chat",
- "meta-llama/llama-2-70b-chat",
- ]:
- for message in messages:
- if message["role"] == "system":
- prompt += "[INST] <>\n" + message["content"] + "<>\n\n"
- elif message["role"] == "assistant":
- prompt += message["content"] + "\n[INST]\n\n"
- else:
- prompt += message["content"] + "\n[/INST]\n"
-
- else:
- prompt = ChatPromptValue(
- messages=convert_to_messages(messages) + [AIMessage(content="")]
- ).to_string()
-
- return prompt
+ run_manager.on_llm_new_token(
+ generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
+ )
+ is_first_chunk = False
+ yield generation_chunk
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]], **kwargs: Any
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
- params = {**self.params} if self.params else {}
- params = params | {**kwargs.get("params", {})}
+ params = (
+ {
+ **(
+ self.params.to_dict()
+ if isinstance(self.params, BaseSchema)
+ else self.params
+ )
+ }
+ if self.params
+ else {}
+ )
+ params = params | {
+ **(
+ kwargs.get("params", {}).to_dict()
+ if isinstance(kwargs.get("params", {}), BaseSchema)
+ else kwargs.get("params", {})
+ )
+ }
if stop is not None:
if params and "stop_sequences" in params:
raise ValueError(
@@ -748,47 +629,41 @@ def _create_message_dicts(
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
- def _create_chat_result(self, response: Union[dict]) -> ChatResult:
+ def _create_chat_result(
+ self, response: dict, generation_info: Optional[Dict] = None
+ ) -> ChatResult:
generations = []
- sum_of_total_generated_tokens = 0
- sum_of_total_input_tokens = 0
- call_id = ""
- date_string = response.get("created_at")
- if date_string:
- date_object = datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ")
- call_id = str(date_object.timestamp())
if response.get("error"):
raise ValueError(response.get("error"))
- for res in response["results"]:
- message = _convert_dict_to_message(res, call_id)
- generation_info = dict(finish_reason=res.get("stop_reason"))
- if "generated_token_count" in res:
- sum_of_total_generated_tokens += res["generated_token_count"]
- if "input_token_count" in res:
- sum_of_total_input_tokens += res["input_token_count"]
- total_token = sum_of_total_generated_tokens + sum_of_total_input_tokens
- if total_token and isinstance(message, AIMessage):
+ token_usage = response.get("usage", {})
+
+ for res in response["choices"]:
+ message = _convert_dict_to_message(res["message"], response["id"])
+
+ if token_usage and isinstance(message, AIMessage):
message.usage_metadata = {
- "input_tokens": sum_of_total_input_tokens,
- "output_tokens": sum_of_total_generated_tokens,
- "total_tokens": total_token,
+ "input_tokens": token_usage.get("prompt_tokens", 0),
+ "output_tokens": token_usage.get("completion_tokens", 0),
+ "total_tokens": token_usage.get("total_tokens", 0),
}
- gen = ChatGeneration(
- message=message,
- generation_info=generation_info,
+ generation_info = generation_info or {}
+ generation_info["finish_reason"] = (
+ res.get("finish_reason")
+ if res.get("finish_reason") is not None
+ else generation_info.get("finish_reason")
)
+ if "logprobs" in res:
+ generation_info["logprobs"] = res["logprobs"]
+ gen = ChatGeneration(message=message, generation_info=generation_info)
generations.append(gen)
- token_usage = {
- "generated_token_count": sum_of_total_generated_tokens,
- "input_token_count": sum_of_total_input_tokens,
- }
llm_output = {
"token_usage": token_usage,
- "model_name": self.model_id,
+ "model_name": response.get("model_id", self.model_id),
"system_fingerprint": response.get("system_fingerprint", ""),
}
+
return ChatResult(generations=generations, llm_output=llm_output)
def bind_functions(
@@ -845,7 +720,11 @@ def bind_functions(
def bind_tools(
self,
- tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
+ tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
+ *,
+ tool_choice: Optional[
+ Union[dict, str, Literal["auto", "none", "required", "any"], bool]
+ ] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
@@ -855,30 +734,63 @@ def bind_tools(
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
- **kwargs: Any additional parameters to pass to the
- :class:`~langchain.runnable.Runnable` constructor.
- """
- bind_tools_supported_models = ["mistralai/mixtral-8x7b-instruct-v01"]
- if self.model_id not in bind_tools_supported_models:
- raise Warning(
- f"bind_tools() method for ChatWatsonx support only "
- f"following models: {bind_tools_supported_models}"
- )
- else:
- warnings.warn(
- "The `mistralai/mixtral-8x7b-instruct-v01` model, which "
- "supports the `bind_tools()` method, is deprecated and will be "
- "removed in version 0.3.x of `langchain_ibm`.",
- LangChainDeprecationWarning,
- )
-
+ tool_choice: Which tool to require the model to call.
+ Options are:
+ - str of the form ``"<>"``: calls <> tool.
+ - ``"auto"``: automatically selects a tool (including no tool).
+ - ``"none"``: does not call a tool.
+ - ``"any"`` or ``"required"`` or ``True``: force at least one tool to be called.
+ - dict of the form ``{"type": "function", "function": {"name": <>}}``: calls <> tool.
+ - ``False`` or ``None``: no effect, default OpenAI behavior.
+
+ kwargs: Any additional parameters are passed directly to
+ ``self.bind(**kwargs)``.
+ """ # noqa: E501
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
+ if tool_choice:
+ if isinstance(tool_choice, str):
+ # tool_choice is a tool/function name
+ if tool_choice not in ("auto", "none", "any", "required"):
+ tool_choice = {
+ "type": "function",
+ "function": {"name": tool_choice},
+ }
+ # We support 'any' since other models use this instead of 'required'.
+ if tool_choice == "any":
+ tool_choice = "required"
+ elif isinstance(tool_choice, bool):
+ tool_choice = "required"
+ elif isinstance(tool_choice, dict):
+ tool_names = [
+ formatted_tool["function"]["name"]
+ for formatted_tool in formatted_tools
+ ]
+ if not any(
+ tool_name == tool_choice["function"]["name"]
+ for tool_name in tool_names
+ ):
+ raise ValueError(
+ f"Tool choice {tool_choice} was specified, but the only "
+ f"provided tools were {tool_names}."
+ )
+ else:
+ raise ValueError(
+ f"Unrecognized tool_choice type. Expected str, bool or dict. "
+ f"Received: {tool_choice}"
+ )
+
+ if isinstance(tool_choice, str):
+ kwargs["tool_choice_option"] = tool_choice
+ else:
+ kwargs["tool_choice"] = tool_choice
+ else:
+ kwargs["tool_choice_option"] = "auto"
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
- schema: Optional[Union[Dict, Type[BaseModel]]] = None,
+ schema: Optional[Union[Dict, Type]] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: bool = False,
@@ -1037,14 +949,15 @@ class AnswerWithJustification(BaseModel):
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
- is_pydantic_schema = _is_pydantic_class(schema)
+ is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
- llm = self.bind_tools([schema], tool_choice=True)
+ # specifying a tool.
+ llm = self.bind_tools([schema], tool_choice="auto")
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
@@ -1062,12 +975,6 @@ class AnswerWithJustification(BaseModel):
if is_pydantic_schema
else JsonOutputParser()
)
- else:
- raise ValueError(
- f"Unrecognized method argument. Expected one of 'function_calling' or "
- f"'json_format'. Received: '{method}'"
- )
-
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
@@ -1083,3 +990,27 @@ class AnswerWithJustification(BaseModel):
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)
+
+
+def _lc_tool_call_to_watsonx_tool_call(tool_call: ToolCall) -> dict:
+ return {
+ "type": "function",
+ "id": tool_call["id"],
+ "function": {
+ "name": tool_call["name"],
+ "arguments": json.dumps(tool_call["args"]),
+ },
+ }
+
+
+def _lc_invalid_tool_call_to_watsonx_tool_call(
+ invalid_tool_call: InvalidToolCall,
+) -> dict:
+ return {
+ "type": "function",
+ "id": invalid_tool_call["id"],
+ "function": {
+ "name": invalid_tool_call["name"],
+ "arguments": invalid_tool_call["args"],
+ },
+ }
diff --git a/libs/ibm/poetry.lock b/libs/ibm/poetry.lock
index 6e1eaf1..6ef4ff9 100644
--- a/libs/ibm/poetry.lock
+++ b/libs/ibm/poetry.lock
@@ -345,57 +345,57 @@ zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "ibm-cos-sdk"
-version = "2.13.6"
+version = "2.13.5"
description = "IBM SDK for Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "ibm-cos-sdk-2.13.6.tar.gz", hash = "sha256:171cf2ae4ab662a4b8ab58dcf4ac994b0577d6c92d78490295fd7704a83978f6"},
+ {file = "ibm-cos-sdk-2.13.5.tar.gz", hash = "sha256:1aff7f9863ac9072a3db2f0053bec99478b26f3fb5fa797ce96a15bbb13cd40e"},
]
[package.dependencies]
-ibm-cos-sdk-core = "2.13.6"
-ibm-cos-sdk-s3transfer = "2.13.6"
+ibm-cos-sdk-core = "2.13.5"
+ibm-cos-sdk-s3transfer = "2.13.5"
jmespath = ">=0.10.0,<=1.0.1"
[[package]]
name = "ibm-cos-sdk-core"
-version = "2.13.6"
+version = "2.13.5"
description = "Low-level, data-driven core of IBM SDK for Python"
optional = false
python-versions = ">=3.6"
files = [
- {file = "ibm-cos-sdk-core-2.13.6.tar.gz", hash = "sha256:dd41fb789eeb65546501afabcd50e78846ab4513b6ad4042e410b6a14ff88413"},
+ {file = "ibm-cos-sdk-core-2.13.5.tar.gz", hash = "sha256:d3a99d8b06b3f8c00b1a9501f85538d592463e63ddf8cec32672ab5a0b107b83"},
]
[package.dependencies]
jmespath = ">=0.10.0,<=1.0.1"
python-dateutil = ">=2.9.0,<3.0.0"
-requests = ">=2.32.0,<2.32.3"
-urllib3 = ">=1.26.18,<3"
+requests = ">=2.32.3,<3.0"
+urllib3 = {version = ">=1.26.18,<2.2", markers = "python_version >= \"3.10\""}
[[package]]
name = "ibm-cos-sdk-s3transfer"
-version = "2.13.6"
+version = "2.13.5"
description = "IBM S3 Transfer Manager"
optional = false
python-versions = ">=3.8"
files = [
- {file = "ibm-cos-sdk-s3transfer-2.13.6.tar.gz", hash = "sha256:e0acce6f380c47d11e07c6765b684b4ababbf5c66cc0503bc246469a1e2b9790"},
+ {file = "ibm-cos-sdk-s3transfer-2.13.5.tar.gz", hash = "sha256:9649b1f2201c6de96ff5a6b5a3686de3a809e6ef3b8b12c7c4f2f7ce72da7749"},
]
[package.dependencies]
-ibm-cos-sdk-core = "2.13.6"
+ibm-cos-sdk-core = "2.13.5"
[[package]]
name = "ibm-watsonx-ai"
-version = "1.1.11"
+version = "1.1.14"
description = "IBM watsonx.ai API Client"
optional = false
python-versions = ">=3.10"
files = [
- {file = "ibm_watsonx_ai-1.1.11-py3-none-any.whl", hash = "sha256:0b2c8b9abbe18acba3f987e2cb27cf0efcf0a7ba2373310afad6e3955b967a74"},
- {file = "ibm_watsonx_ai-1.1.11.tar.gz", hash = "sha256:47b25c927acacdcceb148cf0a2ebc75a965805bf89c818e9460dc9e67b895da8"},
+ {file = "ibm_watsonx_ai-1.1.14-py3-none-any.whl", hash = "sha256:3b711dd4eb96a67ebfa406d5de115bf51e74b8bd5b58e0d80bd7cb7aebc11155"},
+ {file = "ibm_watsonx_ai-1.1.14.tar.gz", hash = "sha256:746d838370b5c07e2591082530ff8ed232a2ef01a530b214da32367483deafa8"},
]
[package.dependencies]
@@ -415,7 +415,7 @@ fl-crypto = ["pyhelayers (==1.5.0.3)"]
fl-crypto-rt24-1 = ["pyhelayers (==1.5.3.1)"]
fl-rt23-1-py3-10 = ["GPUtil", "cryptography (==42.0.5)", "ddsketch (==2.0.4)", "diffprivlib (==0.5.1)", "environs (==9.5.0)", "gym", "image (==1.5.33)", "joblib (==1.1.1)", "lz4", "msgpack (==1.0.7)", "msgpack-numpy (==0.4.8)", "numcompress (==0.1.2)", "numpy (==1.23.5)", "pandas (==1.5.3)", "parse (==1.19.0)", "pathlib2 (==2.3.6)", "protobuf (==4.22.1)", "psutil", "pyYAML (==6.0.1)", "pytest (==6.2.5)", "requests (==2.32.3)", "scikit-learn (==1.1.1)", "scipy (==1.10.1)", "setproctitle", "skops (==0.9.0)", "skorch (==0.12.0)", "tabulate (==0.8.9)", "tensorflow (==2.12.0)", "torch (==2.0.1)", "websockets (==10.1)"]
fl-rt24-1-py3-11 = ["GPUtil", "cryptography (==42.0.5)", "ddsketch (==2.0.4)", "diffprivlib (==0.5.1)", "environs (==9.5.0)", "gym", "image (==1.5.33)", "joblib (==1.3.2)", "lz4", "msgpack (==1.0.7)", "msgpack-numpy (==0.4.8)", "numcompress (==0.1.2)", "numpy (==1.26.4)", "pandas (==2.1.4)", "parse (==1.19.0)", "pathlib2 (==2.3.6)", "protobuf (==4.22.1)", "psutil", "pyYAML (==6.0.1)", "pytest (==6.2.5)", "requests (==2.32.3)", "scikit-learn (==1.3.0)", "scipy (==1.11.4)", "setproctitle", "skops (==0.9.0)", "skorch (==0.12.0)", "tabulate (==0.8.9)", "tensorflow (==2.14.1)", "torch (==2.1.2)", "websockets (==10.1)"]
-rag = ["beautifulsoup4 (==4.12.3)", "grpcio (>=1.60.0)", "langchain (==0.2.15)", "langchain-chroma (==0.1.1)", "langchain-core (==0.2.37)", "langchain-elasticsearch (==0.2.2)", "langchain-ibm", "langchain-milvus (==0.1.1)", "pypdf (==4.2.0)", "python-docx (==1.1.2)"]
+rag = ["beautifulsoup4 (==4.12.3)", "grpcio (>=1.60.0)", "langchain (>=0.2.15,<0.3)", "langchain-chroma (==0.1.1)", "langchain-community (>=0.2.4,<0.3)", "langchain-core (>=0.2.37,<0.3)", "langchain-elasticsearch (==0.2.2)", "langchain-ibm", "langchain-milvus (==0.1.1)", "pypdf (==4.2.0)", "python-docx (==1.1.2)"]
[[package]]
name = "idna"
@@ -526,18 +526,40 @@ typing-extensions = ">=4.7"
type = "git"
url = "https://github.com/langchain-ai/langchain.git"
reference = "HEAD"
-resolved_reference = "7a07196df683582c783edf164bfb6fe813135169"
+resolved_reference = "16f5fdb38b307ee3f3b5065f671b28c90661b698"
subdirectory = "libs/core"
+[[package]]
+name = "langchain-standard-tests"
+version = "0.1.1"
+description = "Standard tests for LangChain implementations"
+optional = false
+python-versions = ">=3.9,<4.0"
+files = []
+develop = false
+
+[package.dependencies]
+httpx = "^0.27.0"
+langchain-core = "^0.3.0"
+pytest = ">=7,<9"
+syrupy = "^4"
+
+[package.source]
+type = "git"
+url = "https://github.com/langchain-ai/langchain.git"
+reference = "HEAD"
+resolved_reference = "16f5fdb38b307ee3f3b5065f671b28c90661b698"
+subdirectory = "libs/standard-tests"
+
[[package]]
name = "langsmith"
-version = "0.1.131"
+version = "0.1.132"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
- {file = "langsmith-0.1.131-py3-none-any.whl", hash = "sha256:80c106b1c42307195cc0bb3a596472c41ef91b79d15bcee9938307800336c563"},
- {file = "langsmith-0.1.131.tar.gz", hash = "sha256:626101a3bf3ca481e5110d5155ace8aa066e4e9cc2fa7d96c8290ade0fbff797"},
+ {file = "langsmith-0.1.132-py3-none-any.whl", hash = "sha256:2320894203675c1c292b818cbecf68b69e47a9f7814d4e950237d1faaafd5dee"},
+ {file = "langsmith-0.1.132.tar.gz", hash = "sha256:007b8fac469138abdba89db931900a26c5d316640e27ff4660d28c92a766aae1"},
]
[package.dependencies]
@@ -1130,13 +1152,13 @@ files = [
[[package]]
name = "requests"
-version = "2.32.2"
+version = "2.32.3"
description = "Python HTTP for Humans."
optional = false
python-versions = ">=3.8"
files = [
- {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"},
- {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"},
+ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
+ {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
]
[package.dependencies]
@@ -1304,18 +1326,17 @@ files = [
[[package]]
name = "urllib3"
-version = "2.2.3"
+version = "2.1.0"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
python-versions = ">=3.8"
files = [
- {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"},
- {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"},
+ {file = "urllib3-2.1.0-py3-none-any.whl", hash = "sha256:55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3"},
+ {file = "urllib3-2.1.0.tar.gz", hash = "sha256:df7aa8afb0148fa78488e7899b2c59b5f4ffcfa82e6c54ccb9dd37c1d7b52d54"},
]
[package.extras]
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
-h2 = ["h2 (>=4,<5)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
@@ -1383,4 +1404,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<4.0"
-content-hash = "84b71aacace41099670cb831f589df0a2a9b27aef7dc6a88792873f5abe6265e"
+content-hash = "33a6cea354dc8afa5d40c16c29e339a844adf6649610266d09340f46c696fe8b"
diff --git a/libs/ibm/pyproject.toml b/libs/ibm/pyproject.toml
index c41c51f..ca92174 100644
--- a/libs/ibm/pyproject.toml
+++ b/libs/ibm/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-ibm"
-version = "0.2.2"
+version = "0.3.0"
description = "An integration package connecting IBM watsonx.ai and LangChain"
authors = ["IBM"]
readme = "README.md"
@@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
langchain-core = ">=0.3.0,<0.4"
-ibm-watsonx-ai = "^1.1.9"
+ibm-watsonx-ai = "^1.1.14"
[tool.poetry.group.test]
optional = true
@@ -27,6 +27,7 @@ pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
pytest-cov = "^4.1.0"
langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
+langchain-standard-tests = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/standard-tests" }
[tool.poetry.group.codespell]
optional = true
diff --git a/libs/ibm/tests/integration_tests/test_chat_models.py b/libs/ibm/tests/integration_tests/test_chat_models.py
index d8e284b..4372fbe 100644
--- a/libs/ibm/tests/integration_tests/test_chat_models.py
+++ b/libs/ibm/tests/integration_tests/test_chat_models.py
@@ -1,17 +1,20 @@
import json
import os
-from typing import Any
+from typing import Any, Optional
import pytest
+from ibm_watsonx_ai.foundation_models.schema import TextChatParameters # type: ignore
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames # type: ignore
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
+ BaseMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain_core.prompts import ChatPromptTemplate
+from langchain_core.tools import tool
from pydantic import BaseModel
from langchain_ibm import ChatWatsonx
@@ -20,13 +23,15 @@
WX_PROJECT_ID = os.environ.get("WATSONX_PROJECT_ID", "")
URL = "https://us-south.ml.cloud.ibm.com"
-MODEL_ID = "mistralai/mixtral-8x7b-instruct-v01"
+
+MODEL_ID = "ibm/granite-34b-code-instruct"
+MODEL_ID_TOOL = "mistralai/mistral-large"
def test_01_generate_chat() -> None:
chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type]
messages = [
- ("system", "You are a helpful assistant that translates English to French."),
+ ("user", "You are a helpful assistant that translates English to French."),
(
"human",
"Translate this sentence from English to French. I love programming.",
@@ -37,13 +42,24 @@ def test_01_generate_chat() -> None:
assert response.content
-def test_01a_generate_chat_with_invoke_params() -> None:
- from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
+def test_01a_generate_chat_with_params_as_dict_in_invoke() -> None:
+ params = {"max_tokens": 10}
+ chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type]
+ messages = [
+ ("system", "You are a helpful assistant that translates English to French."),
+ (
+ "human",
+ "Translate this sentence from English to French. I love programming.",
+ ),
+ ]
+ response = chat.invoke(messages, params=params)
+ assert response
+ assert response.content
+ print(response.content)
- params = {
- GenTextParamsMetaNames.MIN_NEW_TOKENS: 1,
- GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
- }
+
+def test_01a_generate_chat_with_params_as_object_in_invoke() -> None:
+ params = TextChatParameters(max_tokens=10)
chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type]
messages = [
("system", "You are a helpful assistant that translates English to French."),
@@ -55,11 +71,31 @@ def test_01a_generate_chat_with_invoke_params() -> None:
response = chat.invoke(messages, params=params)
assert response
assert response.content
+ print(response.content)
-def test_01b_generate_chat_with_invoke_params() -> None:
- from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
+def test_01a_generate_chat_with_params_as_object_in_constructor() -> None:
+ params = TextChatParameters(max_tokens=10)
+ chat = ChatWatsonx(
+ model_id=MODEL_ID,
+ url=URL, # type: ignore[arg-type]
+ project_id=WX_PROJECT_ID,
+ params=params,
+ )
+ messages = [
+ ("system", "You are a helpful assistant that translates English to French."),
+ (
+ "human",
+ "Translate this sentence from English to French. I love programming.",
+ ),
+ ]
+ response = chat.invoke(messages)
+ assert response
+ assert response.content
+ print(response.content)
+
+def test_01b_generate_chat_with_invoke_params() -> None:
parameters_1 = {
GenTextParamsMetaNames.DECODING_METHOD: "sample",
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
@@ -124,12 +160,7 @@ def test_05a_invoke_chat_with_streaming() -> None:
def test_05_generate_chat_with_stream_with_param() -> None:
- from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
-
- params = {
- GenTextParamsMetaNames.MIN_NEW_TOKENS: 1,
- GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
- }
+ params = TextChatParameters(max_tokens=10)
chat = ChatWatsonx(
model_id=MODEL_ID,
url=URL, # type: ignore[arg-type]
@@ -142,12 +173,7 @@ def test_05_generate_chat_with_stream_with_param() -> None:
def test_05_generate_chat_with_stream_with_param_v2() -> None:
- from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
-
- params = {
- GenTextParamsMetaNames.MIN_NEW_TOKENS: 1,
- GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
- }
+ params = TextChatParameters(max_tokens=10)
chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type]
response = chat.stream("What's the weather in san francisco", params=params)
for chunk in response:
@@ -232,16 +258,150 @@ def test_11_chaining_with_params() -> None:
assert response.content
-def test_20_tool_choice() -> None:
- """Test that tool choice is respected."""
- from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
+def test_20_bind_tools() -> None:
+ chat = ChatWatsonx(
+ model_id=MODEL_ID_TOOL,
+ url=URL, # type: ignore[arg-type]
+ project_id=WX_PROJECT_ID,
+ )
- params = {GenTextParamsMetaNames.MAX_NEW_TOKENS: 500}
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get weather report for a city",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ },
+ },
+ },
+ ]
+
+ llm_with_tools = chat.bind_tools(tools=tools)
+
+ response = llm_with_tools.invoke("what's the weather in san francisco, ca")
+ assert isinstance(response, AIMessage)
+ assert not response.content
+ assert isinstance(response.tool_calls, list)
+ assert len(response.tool_calls) == 1
+ tool_call = response.tool_calls[0]
+ assert tool_call["name"] == "get_weather"
+ assert isinstance(tool_call["args"], dict)
+ assert "location" in tool_call["args"]
+
+
+def test_21a_bind_tools_tool_choice_auto() -> None:
chat = ChatWatsonx(
- model_id=MODEL_ID,
+ model_id=MODEL_ID_TOOL,
+ url=URL, # type: ignore[arg-type]
+ project_id=WX_PROJECT_ID,
+ )
+
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get weather report for a city",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ },
+ },
+ },
+ ]
+
+ llm_with_tools = chat.bind_tools(tools=tools, tool_choice="auto")
+
+ response = llm_with_tools.invoke("what's the weather in san francisco, ca")
+ assert isinstance(response, AIMessage)
+ assert not response.content
+ assert isinstance(response.tool_calls, list)
+ assert len(response.tool_calls) == 1
+ tool_call = response.tool_calls[0]
+ assert tool_call["name"] == "get_weather"
+ assert isinstance(tool_call["args"], dict)
+ assert "location" in tool_call["args"]
+
+
+@pytest.mark.skip(reason="Not supported yet")
+def test_21b_bind_tools_tool_choice_none() -> None:
+ chat = ChatWatsonx(
+ model_id=MODEL_ID_TOOL,
+ url=URL, # type: ignore[arg-type]
+ project_id=WX_PROJECT_ID,
+ )
+
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get weather report for a city",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ },
+ },
+ },
+ ]
+
+ llm_with_tools = chat.bind_tools(tools=tools, tool_choice="none")
+
+ response = llm_with_tools.invoke("what's the weather in san francisco, ca")
+ assert isinstance(response, AIMessage)
+ assert not response.content
+ assert isinstance(response.tool_calls, list)
+ assert len(response.tool_calls) == 1
+ tool_call = response.tool_calls[0]
+ assert tool_call["name"] == "get_weather"
+ assert isinstance(tool_call["args"], dict)
+ assert "location" in tool_call["args"]
+
+
+@pytest.mark.skip(reason="Not supported yet")
+def test_21c_bind_tools_tool_choice_required() -> None:
+ chat = ChatWatsonx(
+ model_id=MODEL_ID_TOOL,
+ url=URL, # type: ignore[arg-type]
+ project_id=WX_PROJECT_ID,
+ )
+
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get weather report for a city",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ },
+ },
+ },
+ ]
+
+ llm_with_tools = chat.bind_tools(tools=tools, tool_choice="required")
+
+ response = llm_with_tools.invoke("what's the weather in san francisco, ca")
+ assert isinstance(response, AIMessage)
+ assert not response.content
+ assert isinstance(response.tool_calls, list)
+ assert len(response.tool_calls) == 1
+ tool_call = response.tool_calls[0]
+ assert tool_call["name"] == "get_weather"
+ assert isinstance(tool_call["args"], dict)
+ assert "location" in tool_call["args"]
+
+
+def test_22a_bind_tools_tool_choice_as_class() -> None:
+ """Test that tool choice is respected."""
+ chat = ChatWatsonx(
+ model_id=MODEL_ID_TOOL,
url=URL, # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
- params=params,
)
class Person(BaseModel):
@@ -262,23 +422,21 @@ class Person(BaseModel):
}
-def test_21_tool_choice_bool() -> None:
+def test_22b_bind_tools_tool_choice_as_dict() -> None:
"""Test that tool choice is respected just passing in True."""
- from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
-
- params = {GenTextParamsMetaNames.MAX_NEW_TOKENS: 500}
chat = ChatWatsonx(
- model_id=MODEL_ID,
+ model_id=MODEL_ID_TOOL,
url=URL, # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
- params=params,
)
class Person(BaseModel):
name: str
age: int
- with_tool = chat.bind_tools([Person], tool_choice=True)
+ tool_choice = {"type": "function", "function": {"name": "Person"}}
+
+ with_tool = chat.bind_tools([Person], tool_choice=tool_choice)
result = with_tool.invoke("Erick, 27 years old")
assert isinstance(result, AIMessage)
@@ -291,18 +449,13 @@ class Person(BaseModel):
}
-def test_22_tool_invoke() -> None:
+def test_23a_bind_tools_list_tool_choice_dict() -> None:
"""Test that tool choice is respected just passing in True."""
- from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
-
- params = {GenTextParamsMetaNames.MAX_NEW_TOKENS: 500}
chat = ChatWatsonx(
- model_id=MODEL_ID,
+ model_id=MODEL_ID_TOOL,
url=URL, # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
- params=params,
)
- from langchain_core.tools import tool
@tool
def add(a: int, b: int) -> int:
@@ -321,42 +474,137 @@ def get_word_length(word: str) -> int:
tools = [add, multiply, get_word_length]
- chat_with_tools = chat.bind_tools(tools)
+ tool_choice = {
+ "type": "function",
+ "function": {
+ "name": "add",
+ },
+ }
+
+ chat_with_tools = chat.bind_tools(tools, tool_choice=tool_choice)
- query = "What is 3 + 12? What is 3 + 10?"
+ query = "What is 3 + 12? "
resp = chat_with_tools.invoke(query)
assert resp.content == ""
- query = "Who was the famous painter from Italy?"
+
+def test_23_bind_tools_list_tool_choice_auto() -> None:
+ """Test that tool choice is respected just passing in True."""
+ chat = ChatWatsonx(
+ model_id=MODEL_ID_TOOL,
+ url=URL, # type: ignore[arg-type]
+ project_id=WX_PROJECT_ID,
+ )
+
+ @tool
+ def add(a: int, b: int) -> int:
+ """Adds a and b."""
+ return a + b
+
+ @tool
+ def multiply(a: int, b: int) -> int:
+ """Multiplies a and b."""
+ return a * b
+
+ @tool
+ def get_word_length(word: str) -> int:
+ """Get word length."""
+ return len(word)
+
+ tools = [add, multiply, get_word_length]
+ chat_with_tools = chat.bind_tools(tools, tool_choice="auto")
+
+ query = "What is 3 + 12? "
resp = chat_with_tools.invoke(query)
+ assert resp.content == ""
+ assert len(resp.tool_calls) == 1 # type: ignore
+ tool_call = resp.tool_calls[0] # type: ignore
+ assert tool_call["name"] == "add"
+
+ query = "What is 3 * 12? "
+ resp = chat_with_tools.invoke(query)
+ assert resp.content == ""
+ assert len(resp.tool_calls) == 1 # type: ignore
+ tool_call = resp.tool_calls[0] # type: ignore
+ assert tool_call["name"] == "multiply"
+ query = "Who was the famous painter from Italy?"
+ resp = chat_with_tools.invoke(query)
assert resp.content
+ assert len(resp.tool_calls) == 0 # type: ignore
+
+
+def test_json_mode() -> None:
+ llm = ChatWatsonx(
+ model_id=MODEL_ID_TOOL,
+ url=URL, # type: ignore[arg-type]
+ project_id=WX_PROJECT_ID,
+ )
+ response = llm.invoke(
+ "Return this as json: {'a': 1}",
+ params={"response_format": {"type": "json_object"}},
+ )
+ assert isinstance(response.content, str)
+ assert json.loads(response.content) == {"a": 1}
+
+ # Test streaming
+ full: Optional[BaseMessageChunk] = None
+ for chunk in llm.stream(
+ "Return this as json: {'a': 1}",
+ params={"response_format": {"type": "json_object"}},
+ ):
+ full = chunk if full is None else full + chunk
+ assert isinstance(full, AIMessageChunk)
+ assert isinstance(full.content, str)
+ assert json.loads(full.content) == {"a": 1}
+
+
+async def test_json_mode_async() -> None:
+ llm = ChatWatsonx(
+ model_id=MODEL_ID_TOOL,
+ url=URL, # type: ignore[arg-type]
+ project_id=WX_PROJECT_ID,
+ )
+ response = await llm.ainvoke(
+ "Return this as json: {'a': 1}",
+ params={"response_format": {"type": "json_object"}},
+ )
+ assert isinstance(response.content, str)
+ assert json.loads(response.content) == {"a": 1}
+
+ # Test streaming
+ full: Optional[BaseMessageChunk] = None
+ async for chunk in llm.astream(
+ "Return this as json: {'a': 1}",
+ params={"response_format": {"type": "json_object"}},
+ ):
+ full = chunk if full is None else full + chunk
+ assert isinstance(full, AIMessageChunk)
+ assert isinstance(full.content, str)
+ assert json.loads(full.content) == {"a": 1}
@pytest.mark.skip(reason="Not implemented")
def test_streaming_tool_call() -> None:
- from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
-
- params = {GenTextParamsMetaNames.MAX_NEW_TOKENS: 500}
chat = ChatWatsonx(
- model_id=MODEL_ID,
+ model_id=MODEL_ID_TOOL,
url=URL, # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
- params=params,
)
class Person(BaseModel):
name: str
age: int
- tool_llm = chat.bind_tools([Person])
+ tool_choice = {"type": "function", "function": {"name": "Person"}}
+
+ tool_llm = chat.bind_tools([Person], tool_choice=tool_choice)
- # where it calls the tool
- strm = tool_llm.stream("Erick, 27 years old")
+ stream_response = tool_llm.stream("Erick, 27 years old")
additional_kwargs = None
- for chunk in strm:
+ for chunk in stream_response:
assert isinstance(chunk, AIMessageChunk)
assert chunk.content == ""
additional_kwargs = chunk.additional_kwargs
@@ -384,3 +632,51 @@ class Person(BaseModel):
acc = chunk if acc is None else acc + chunk
assert acc.content != ""
assert "tool_calls" not in acc.additional_kwargs
+
+
+def test_structured_output() -> None:
+ chat = ChatWatsonx(
+ model_id=MODEL_ID_TOOL,
+ url=URL, # type: ignore[arg-type]
+ project_id=WX_PROJECT_ID,
+ )
+ schema = {
+ "title": "AnswerWithJustification",
+ "description": (
+ "An answer to the user question along with justification for the answer."
+ ),
+ "type": "object",
+ "properties": {
+ "answer": {"title": "Answer", "type": "string"},
+ "justification": {"title": "Justification", "type": "string"},
+ },
+ "required": ["answer", "justification"],
+ }
+ structured_llm = chat.with_structured_output(schema)
+ result = structured_llm.invoke(
+ "What weighs more a pound of bricks or a pound of feathers"
+ )
+ assert isinstance(result, dict)
+
+
+@pytest.mark.skip(reason="Not implemented")
+def test_streaming_structured_output() -> None:
+ chat = ChatWatsonx(
+ model_id=MODEL_ID_TOOL,
+ url=URL, # type: ignore[arg-type]
+ project_id=WX_PROJECT_ID,
+ )
+
+ class Person(BaseModel):
+ name: str
+ age: int
+
+ structured_llm = chat.with_structured_output(Person)
+ strm_response = structured_llm.stream("Erick, 27 years old")
+ chunk_num = 0
+ for chunk in strm_response:
+ assert chunk_num == 0, "should only have one chunk with model"
+ assert isinstance(chunk, Person)
+ assert chunk.name == "Erick"
+ assert chunk.age == 27
+ chunk_num += 1