diff --git a/libs/ibm/langchain_ibm/chat_models.py b/libs/ibm/langchain_ibm/chat_models.py index 783b3c3..bfad13c 100644 --- a/libs/ibm/langchain_ibm/chat_models.py +++ b/libs/ibm/langchain_ibm/chat_models.py @@ -2,8 +2,6 @@ import json import logging -import warnings -from datetime import datetime from operator import itemgetter from typing import ( Any, @@ -24,7 +22,10 @@ from ibm_watsonx_ai import APIClient, Credentials # type: ignore from ibm_watsonx_ai.foundation_models import ModelInference # type: ignore -from langchain_core._api import LangChainDeprecationWarning +from ibm_watsonx_ai.foundation_models.schema import ( # type: ignore + BaseSchema, + TextChatParameters, +) from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( @@ -49,22 +50,25 @@ ToolCall, ToolMessage, ToolMessageChunk, - convert_to_messages, ) +from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, PydanticToolsParser, + make_invalid_tool_call, + parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.prompt_values import ChatPromptValue from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import ( convert_to_openai_function, convert_to_openai_tool, ) +from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.utils import secret_from_env from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self @@ -85,48 +89,53 @@ def _convert_dict_to_message(_dict: Mapping[str, Any], call_id: str) -> BaseMess The LangChain message. """ role = _dict.get("role") + name = _dict.get("name") + id_ = call_id if role == "user": - return HumanMessage(content=_dict.get("generated_text", "")) - else: + return HumanMessage(content=_dict.get("content", ""), id=id_, name=name) + elif role == "assistant": + content = _dict.get("content", "") or "" additional_kwargs: Dict = {} + if function_call := _dict.get("function_call"): + additional_kwargs["function_call"] = dict(function_call) tool_calls = [] - invalid_tool_calls: List[InvalidToolCall] = [] - content = "" - - raw_tool_calls = _dict.get("generated_text", "") - - if "json" in raw_tool_calls: - try: - split_raw_tool_calls = raw_tool_calls.split("\n\n") - for raw_tool_call in split_raw_tool_calls: - if "json" in raw_tool_call: - json_parts = JsonOutputParser().parse(raw_tool_call) - - if json_parts["function"]["name"] == "Final Answer": - content = json_parts["function"]["arguments"]["output"] - break - - additional_kwargs["tool_calls"] = json_parts - - parsed = { - "name": json_parts["function"]["name"] or "", - "args": json_parts["function"]["arguments"] or {}, - "id": call_id, - } - tool_calls.append(parsed) - - except: # noqa: E722 - content = _dict.get("generated_text", "") or "" - - else: - content = _dict.get("generated_text", "") or "" - + invalid_tool_calls = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in raw_tool_calls: + try: + tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) + except Exception as e: + invalid_tool_calls.append( + make_invalid_tool_call(raw_tool_call, str(e)) + ) return AIMessage( content=content, additional_kwargs=additional_kwargs, + name=name, + id=id_, tool_calls=tool_calls, invalid_tool_calls=invalid_tool_calls, ) + elif role == "system": + return SystemMessage(content=_dict.get("content", ""), name=name, id=id_) + elif role == "function": + return FunctionMessage( + content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_ + ) + elif role == "tool": + additional_kwargs = {} + if "name" in _dict: + additional_kwargs["name"] = _dict["name"] + return ToolMessage( + content=_dict.get("content", ""), + tool_call_id=cast(str, _dict.get("tool_call_id")), + additional_kwargs=additional_kwargs, + name=name, + id=id_, + ) + else: + return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type] def _format_message_content(content: Any) -> Any: @@ -149,30 +158,6 @@ def _format_message_content(content: Any) -> Any: return formatted_content -def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict: - return { - "type": "function", - "id": tool_call["id"], - "function": { - "name": tool_call["name"], - "arguments": json.dumps(tool_call["args"]), - }, - } - - -def _lc_invalid_tool_call_to_openai_tool_call( - invalid_tool_call: InvalidToolCall, -) -> dict: - return { - "type": "function", - "id": invalid_tool_call["id"], - "function": { - "name": invalid_tool_call["name"], - "arguments": invalid_tool_call["args"], - }, - } - - def _convert_message_to_dict(message: BaseMessage) -> dict: """Convert a LangChain message to a dictionary. @@ -197,9 +182,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict["function_call"] = message.additional_kwargs["function_call"] if message.tool_calls or message.invalid_tool_calls: message_dict["tool_calls"] = [ - _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls + _lc_tool_call_to_watsonx_tool_call(tc) for tc in message.tool_calls ] + [ - _lc_invalid_tool_call_to_openai_tool_call(tc) + _lc_invalid_tool_call_to_watsonx_tool_call(tc) for tc in message.invalid_tool_calls ] elif "tool_calls" in message.additional_kwargs: @@ -213,14 +198,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: pass # If tool calls present, content null value should be None not empty string. if "function_call" in message_dict or "tool_calls" in message_dict: - message_dict["content"] = message_dict["content"] or "" - message_dict["tool_calls"][0]["name"] = message_dict["tool_calls"][0][ - "function" - ]["name"] - message_dict["tool_calls"][0]["args"] = json.loads( - message_dict["tool_calls"][0]["function"]["arguments"] - ) - + message_dict["content"] = message_dict["content"] or None elif isinstance(message, SystemMessage): message_dict["role"] = "system" elif isinstance(message, FunctionMessage): @@ -237,11 +215,14 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_delta_to_message_chunk( - _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] + _dict: Mapping[str, Any], + default_class: Type[BaseMessageChunk], + call_id: str, + finish_reason: str, ) -> BaseMessageChunk: - id_ = "sample_id" + id_ = call_id role = cast(str, _dict.get("role")) - content = cast(str, _dict.get("generated_text") or "") + content = cast(str, _dict.get("content") or "") additional_kwargs: Dict = {} if _dict.get("function_call"): function_call = dict(_dict["function_call"]) @@ -253,12 +234,14 @@ def _convert_delta_to_message_chunk( additional_kwargs["tool_calls"] = raw_tool_calls try: tool_call_chunks = [ - { - "name": rtc["function"].get("name"), - "args": rtc["function"].get("arguments"), - "id": rtc.get("id"), - "index": rtc["index"], - } + tool_call_chunk( + name=rtc["function"].get("name") + if finish_reason is not None + else None, + args=rtc["function"].get("arguments"), + id=call_id if finish_reason is not None else None, + index=rtc["index"], + ) for rtc in raw_tool_calls ] except KeyError: @@ -287,6 +270,58 @@ def _convert_delta_to_message_chunk( return default_class(content=content, id=id_) # type: ignore +def _convert_chunk_to_generation_chunk( + chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict] +) -> Optional[ChatGenerationChunk]: + token_usage = chunk.get("usage") + choices = chunk.get("choices", []) + + usage_metadata: Optional[UsageMetadata] = ( + UsageMetadata( + input_tokens=token_usage.get("prompt_tokens", 0), + output_tokens=token_usage.get("completion_tokens", 0), + total_tokens=token_usage.get("total_tokens", 0), + ) + if token_usage + else None + ) + + if len(choices) == 0: + # logprobs is implicitly None + generation_chunk = ChatGenerationChunk( + message=default_chunk_class(content="", usage_metadata=usage_metadata) + ) + return generation_chunk + + choice = choices[0] + if choice["delta"] is None: + return None + + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class, chunk["id"], choice["finish_reason"] + ) + generation_info = {**base_generation_info} if base_generation_info else {} + + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + if model_name := chunk.get("model"): + generation_info["model_name"] = model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint + + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + + if usage_metadata and isinstance(message_chunk, AIMessageChunk): + message_chunk.usage_metadata = usage_metadata + + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None + ) + return generation_chunk + + class _FunctionCall(TypedDict): name: str @@ -378,7 +413,7 @@ class ChatWatsonx(BaseChatModel): version: Optional[SecretStr] = None """Version of the CPD instance.""" - params: Optional[dict] = None + params: Optional[Union[dict, TextChatParameters]] = None """Model parameters to use during request generation.""" verify: Union[str, bool, None] = None @@ -523,77 +558,9 @@ def _generate( return generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop, **kwargs) - if message_dicts[-1].get("role") == "tool": - chat_prompt = ( - "User: Please summarize given sentences into " - "JSON containing Final Answer: '" - ) - for message in message_dicts: - if message["content"]: - chat_prompt += message["content"] + "\n" - chat_prompt += "'" - else: - chat_prompt = self._create_chat_prompt(message_dicts) - - tools = kwargs.get("tools") - - if tools: - chat_prompt = f""" -You are Mixtral Chat function calling, an AI language model developed by Mistral AI. -You are a cautious assistant. You carefully follow instructions. You are helpful and -harmless and you follow ethical guidelines and promote positive behavior. Here are a -few of the tools available to you: -[AVAILABLE_TOOLS] -{json.dumps(tools[0], indent=2)} -[/AVAILABLE_TOOLS] -To use these tools you must always respond in JSON format containing `"type"` and -`"function"` key-value pairs. Also `"function"` key-value pair always containing -`"name"` and `"arguments"` key-value pairs. For example, to answer the question, -"What is a length of word think?" you must use the get_word_length tool like so: - -```json -{{ - "type": "function", - "function": {{ - "name": "get_word_length", - "arguments": {{ - "word": "think" - }} - }} -}} -``` - - -Remember, even when answering to the user, you must still use this JSON format! -If you'd like to ask how the user is doing you must write: - -```json -{{ - "type": "function", - "function": {{ - "name": "Final Answer", - "arguments": {{ - "output": "How are you today?" - }} - }} -}} -``` - - -Remember to end your response with '' - -{chat_prompt} -(reminder to respond in a JSON blob no matter what and use tools only if necessary)""" - - params = params | {"stop_sequences": [""]} - - if "tools" in kwargs: - del kwargs["tools"] - if "tool_choice" in kwargs: - del kwargs["tool_choice"] - - response = self.watsonx_model.generate( - prompt=chat_prompt, **(kwargs | {"params": params}) + + response = self.watsonx_model.chat( + messages=message_dicts, **(kwargs | {"params": params}) ) return self._create_chat_result(response) @@ -605,140 +572,54 @@ def _stream( **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop, **kwargs) - if message_dicts[-1].get("role") == "tool": - chat_prompt = ( - "User: Please summarize given sentences into JSON " - "containing Final Answer: '" - ) - for message in message_dicts: - if message["content"]: - chat_prompt += message["content"] + "\n" - chat_prompt += "'" - else: - chat_prompt = self._create_chat_prompt(message_dicts) - - tools = kwargs.get("tools") - - if tools: - chat_prompt = f""" -You are Mixtral Chat function calling, an AI language model developed by Mistral AI. -You are a cautious assistant. You carefully follow instructions. You are helpful and -harmless and you follow ethical guidelines and promote positive behavior. Here are a -few of the tools available to you: -[AVAILABLE_TOOLS] -{json.dumps(tools[0], indent=2)} -[/AVAILABLE_TOOLS] -To use these tools you must always respond in JSON format containing `"type"` and -`"function"` key-value pairs. Also `"function"` key-value pair always containing -`"name"` and `"arguments"` key-value pairs. For example, to answer the question, -"What is a length of word think?" you must use the get_word_length tool like so: - -```json -{{ - "type": "function", - "function": {{ - "name": "get_word_length", - "arguments": {{ - "word": "think" - }} - }} -}} -``` - - -Remember, even when answering to the user, you must still use this JSON format! -If you'd like to ask how the user is doing you must write: - -```json -{{ - "type": "function", - "function": {{ - "name": "Final Answer", - "arguments": {{ - "output": "How are you today?" - }} - }} -}} -``` - - -Remember to end your response with '' - -{chat_prompt[:-5]} -(reminder to respond in a JSON blob no matter what and use tools only if necessary)""" - - params = params | {"stop_sequences": [""]} - - if "tools" in kwargs: - del kwargs["tools"] - if "tool_choice" in kwargs: - del kwargs["tool_choice"] default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk + base_generation_info: dict = {} + + is_first_chunk = True - for chunk in self.watsonx_model.generate_text_stream( - prompt=chat_prompt, raw_response=True, **(kwargs | {"params": params}) + for chunk in self.watsonx_model.chat_stream( + messages=message_dicts, **(kwargs | {"params": params}) ): if not isinstance(chunk, dict): - chunk = chunk.dict() - if len(chunk["results"]) == 0: - continue - choice = chunk["results"][0] - - message_chunk = _convert_delta_to_message_chunk(choice, default_chunk_class) - generation_info = {} - if (finish_reason := choice.get("stop_reason")) != "not_finished": - generation_info["finish_reason"] = finish_reason - chunk = ChatGenerationChunk( - message=message_chunk, generation_info=generation_info or None + chunk = chunk.model_dump() + generation_chunk = _convert_chunk_to_generation_chunk( + chunk, + default_chunk_class, + base_generation_info if is_first_chunk else {}, ) + if generation_chunk is None: + continue + default_chunk_class = generation_chunk.message.__class__ + logprobs = (generation_chunk.generation_info or {}).get("logprobs") if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) - - yield chunk - - def _create_chat_prompt(self, messages: List[Dict[str, Any]]) -> str: - prompt = "" - - if self.model_id in ["ibm/granite-13b-chat-v1", "ibm/granite-13b-chat-v2"]: - for message in messages: - if message["role"] == "system": - prompt += "<|system|>\n" + message["content"] + "\n\n" - elif message["role"] == "assistant": - prompt += "<|assistant|>\n" + message["content"] + "\n\n" - elif message["role"] == "function": - prompt += "<|function|>\n" + message["content"] + "\n\n" - elif message["role"] == "tool": - prompt += "<|tool|>\n" + message["content"] + "\n\n" - else: - prompt += "<|user|>:\n" + message["content"] + "\n\n" - - prompt += "<|assistant|>\n" - - elif self.model_id in [ - "meta-llama/llama-2-13b-chat", - "meta-llama/llama-2-70b-chat", - ]: - for message in messages: - if message["role"] == "system": - prompt += "[INST] <>\n" + message["content"] + "<>\n\n" - elif message["role"] == "assistant": - prompt += message["content"] + "\n[INST]\n\n" - else: - prompt += message["content"] + "\n[/INST]\n" - - else: - prompt = ChatPromptValue( - messages=convert_to_messages(messages) + [AIMessage(content="")] - ).to_string() - - return prompt + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk, logprobs=logprobs + ) + is_first_chunk = False + yield generation_chunk def _create_message_dicts( self, messages: List[BaseMessage], stop: Optional[List[str]], **kwargs: Any ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: - params = {**self.params} if self.params else {} - params = params | {**kwargs.get("params", {})} + params = ( + { + **( + self.params.to_dict() + if isinstance(self.params, BaseSchema) + else self.params + ) + } + if self.params + else {} + ) + params = params | { + **( + kwargs.get("params", {}).to_dict() + if isinstance(kwargs.get("params", {}), BaseSchema) + else kwargs.get("params", {}) + ) + } if stop is not None: if params and "stop_sequences" in params: raise ValueError( @@ -748,47 +629,41 @@ def _create_message_dicts( message_dicts = [_convert_message_to_dict(m) for m in messages] return message_dicts, params - def _create_chat_result(self, response: Union[dict]) -> ChatResult: + def _create_chat_result( + self, response: dict, generation_info: Optional[Dict] = None + ) -> ChatResult: generations = [] - sum_of_total_generated_tokens = 0 - sum_of_total_input_tokens = 0 - call_id = "" - date_string = response.get("created_at") - if date_string: - date_object = datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ") - call_id = str(date_object.timestamp()) if response.get("error"): raise ValueError(response.get("error")) - for res in response["results"]: - message = _convert_dict_to_message(res, call_id) - generation_info = dict(finish_reason=res.get("stop_reason")) - if "generated_token_count" in res: - sum_of_total_generated_tokens += res["generated_token_count"] - if "input_token_count" in res: - sum_of_total_input_tokens += res["input_token_count"] - total_token = sum_of_total_generated_tokens + sum_of_total_input_tokens - if total_token and isinstance(message, AIMessage): + token_usage = response.get("usage", {}) + + for res in response["choices"]: + message = _convert_dict_to_message(res["message"], response["id"]) + + if token_usage and isinstance(message, AIMessage): message.usage_metadata = { - "input_tokens": sum_of_total_input_tokens, - "output_tokens": sum_of_total_generated_tokens, - "total_tokens": total_token, + "input_tokens": token_usage.get("prompt_tokens", 0), + "output_tokens": token_usage.get("completion_tokens", 0), + "total_tokens": token_usage.get("total_tokens", 0), } - gen = ChatGeneration( - message=message, - generation_info=generation_info, + generation_info = generation_info or {} + generation_info["finish_reason"] = ( + res.get("finish_reason") + if res.get("finish_reason") is not None + else generation_info.get("finish_reason") ) + if "logprobs" in res: + generation_info["logprobs"] = res["logprobs"] + gen = ChatGeneration(message=message, generation_info=generation_info) generations.append(gen) - token_usage = { - "generated_token_count": sum_of_total_generated_tokens, - "input_token_count": sum_of_total_input_tokens, - } llm_output = { "token_usage": token_usage, - "model_name": self.model_id, + "model_name": response.get("model_id", self.model_id), "system_fingerprint": response.get("system_fingerprint", ""), } + return ChatResult(generations=generations, llm_output=llm_output) def bind_functions( @@ -845,7 +720,11 @@ def bind_functions( def bind_tools( self, - tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], + *, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. @@ -855,30 +734,63 @@ def bind_tools( Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic models, callables, and BaseTools will be automatically converted to their schema dictionary representation. - **kwargs: Any additional parameters to pass to the - :class:`~langchain.runnable.Runnable` constructor. - """ - bind_tools_supported_models = ["mistralai/mixtral-8x7b-instruct-v01"] - if self.model_id not in bind_tools_supported_models: - raise Warning( - f"bind_tools() method for ChatWatsonx support only " - f"following models: {bind_tools_supported_models}" - ) - else: - warnings.warn( - "The `mistralai/mixtral-8x7b-instruct-v01` model, which " - "supports the `bind_tools()` method, is deprecated and will be " - "removed in version 0.3.x of `langchain_ibm`.", - LangChainDeprecationWarning, - ) - + tool_choice: Which tool to require the model to call. + Options are: + - str of the form ``"<>"``: calls <> tool. + - ``"auto"``: automatically selects a tool (including no tool). + - ``"none"``: does not call a tool. + - ``"any"`` or ``"required"`` or ``True``: force at least one tool to be called. + - dict of the form ``{"type": "function", "function": {"name": <>}}``: calls <> tool. + - ``False`` or ``None``: no effect, default OpenAI behavior. + + kwargs: Any additional parameters are passed directly to + ``self.bind(**kwargs)``. + """ # noqa: E501 formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + if tool_choice: + if isinstance(tool_choice, str): + # tool_choice is a tool/function name + if tool_choice not in ("auto", "none", "any", "required"): + tool_choice = { + "type": "function", + "function": {"name": tool_choice}, + } + # We support 'any' since other models use this instead of 'required'. + if tool_choice == "any": + tool_choice = "required" + elif isinstance(tool_choice, bool): + tool_choice = "required" + elif isinstance(tool_choice, dict): + tool_names = [ + formatted_tool["function"]["name"] + for formatted_tool in formatted_tools + ] + if not any( + tool_name == tool_choice["function"]["name"] + for tool_name in tool_names + ): + raise ValueError( + f"Tool choice {tool_choice} was specified, but the only " + f"provided tools were {tool_names}." + ) + else: + raise ValueError( + f"Unrecognized tool_choice type. Expected str, bool or dict. " + f"Received: {tool_choice}" + ) + + if isinstance(tool_choice, str): + kwargs["tool_choice_option"] = tool_choice + else: + kwargs["tool_choice"] = tool_choice + else: + kwargs["tool_choice_option"] = "auto" return super().bind(tools=formatted_tools, **kwargs) def with_structured_output( self, - schema: Optional[Union[Dict, Type[BaseModel]]] = None, + schema: Optional[Union[Dict, Type]] = None, *, method: Literal["function_calling", "json_mode"] = "function_calling", include_raw: bool = False, @@ -1037,14 +949,15 @@ class AnswerWithJustification(BaseModel): """ # noqa: E501 if kwargs: raise ValueError(f"Received unsupported arguments {kwargs}") - is_pydantic_schema = _is_pydantic_class(schema) + is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema) if method == "function_calling": if schema is None: raise ValueError( "schema must be specified when method is 'function_calling'. " "Received None." ) - llm = self.bind_tools([schema], tool_choice=True) + # specifying a tool. + llm = self.bind_tools([schema], tool_choice="auto") if is_pydantic_schema: output_parser: OutputParserLike = PydanticToolsParser( tools=[schema], # type: ignore[list-item] @@ -1062,12 +975,6 @@ class AnswerWithJustification(BaseModel): if is_pydantic_schema else JsonOutputParser() ) - else: - raise ValueError( - f"Unrecognized method argument. Expected one of 'function_calling' or " - f"'json_format'. Received: '{method}'" - ) - if include_raw: parser_assign = RunnablePassthrough.assign( parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None @@ -1083,3 +990,27 @@ class AnswerWithJustification(BaseModel): def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and issubclass(obj, BaseModel) + + +def _lc_tool_call_to_watsonx_tool_call(tool_call: ToolCall) -> dict: + return { + "type": "function", + "id": tool_call["id"], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["args"]), + }, + } + + +def _lc_invalid_tool_call_to_watsonx_tool_call( + invalid_tool_call: InvalidToolCall, +) -> dict: + return { + "type": "function", + "id": invalid_tool_call["id"], + "function": { + "name": invalid_tool_call["name"], + "arguments": invalid_tool_call["args"], + }, + } diff --git a/libs/ibm/poetry.lock b/libs/ibm/poetry.lock index 6e1eaf1..6ef4ff9 100644 --- a/libs/ibm/poetry.lock +++ b/libs/ibm/poetry.lock @@ -345,57 +345,57 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "ibm-cos-sdk" -version = "2.13.6" +version = "2.13.5" description = "IBM SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "ibm-cos-sdk-2.13.6.tar.gz", hash = "sha256:171cf2ae4ab662a4b8ab58dcf4ac994b0577d6c92d78490295fd7704a83978f6"}, + {file = "ibm-cos-sdk-2.13.5.tar.gz", hash = "sha256:1aff7f9863ac9072a3db2f0053bec99478b26f3fb5fa797ce96a15bbb13cd40e"}, ] [package.dependencies] -ibm-cos-sdk-core = "2.13.6" -ibm-cos-sdk-s3transfer = "2.13.6" +ibm-cos-sdk-core = "2.13.5" +ibm-cos-sdk-s3transfer = "2.13.5" jmespath = ">=0.10.0,<=1.0.1" [[package]] name = "ibm-cos-sdk-core" -version = "2.13.6" +version = "2.13.5" description = "Low-level, data-driven core of IBM SDK for Python" optional = false python-versions = ">=3.6" files = [ - {file = "ibm-cos-sdk-core-2.13.6.tar.gz", hash = "sha256:dd41fb789eeb65546501afabcd50e78846ab4513b6ad4042e410b6a14ff88413"}, + {file = "ibm-cos-sdk-core-2.13.5.tar.gz", hash = "sha256:d3a99d8b06b3f8c00b1a9501f85538d592463e63ddf8cec32672ab5a0b107b83"}, ] [package.dependencies] jmespath = ">=0.10.0,<=1.0.1" python-dateutil = ">=2.9.0,<3.0.0" -requests = ">=2.32.0,<2.32.3" -urllib3 = ">=1.26.18,<3" +requests = ">=2.32.3,<3.0" +urllib3 = {version = ">=1.26.18,<2.2", markers = "python_version >= \"3.10\""} [[package]] name = "ibm-cos-sdk-s3transfer" -version = "2.13.6" +version = "2.13.5" description = "IBM S3 Transfer Manager" optional = false python-versions = ">=3.8" files = [ - {file = "ibm-cos-sdk-s3transfer-2.13.6.tar.gz", hash = "sha256:e0acce6f380c47d11e07c6765b684b4ababbf5c66cc0503bc246469a1e2b9790"}, + {file = "ibm-cos-sdk-s3transfer-2.13.5.tar.gz", hash = "sha256:9649b1f2201c6de96ff5a6b5a3686de3a809e6ef3b8b12c7c4f2f7ce72da7749"}, ] [package.dependencies] -ibm-cos-sdk-core = "2.13.6" +ibm-cos-sdk-core = "2.13.5" [[package]] name = "ibm-watsonx-ai" -version = "1.1.11" +version = "1.1.14" description = "IBM watsonx.ai API Client" optional = false python-versions = ">=3.10" files = [ - {file = "ibm_watsonx_ai-1.1.11-py3-none-any.whl", hash = "sha256:0b2c8b9abbe18acba3f987e2cb27cf0efcf0a7ba2373310afad6e3955b967a74"}, - {file = "ibm_watsonx_ai-1.1.11.tar.gz", hash = "sha256:47b25c927acacdcceb148cf0a2ebc75a965805bf89c818e9460dc9e67b895da8"}, + {file = "ibm_watsonx_ai-1.1.14-py3-none-any.whl", hash = "sha256:3b711dd4eb96a67ebfa406d5de115bf51e74b8bd5b58e0d80bd7cb7aebc11155"}, + {file = "ibm_watsonx_ai-1.1.14.tar.gz", hash = "sha256:746d838370b5c07e2591082530ff8ed232a2ef01a530b214da32367483deafa8"}, ] [package.dependencies] @@ -415,7 +415,7 @@ fl-crypto = ["pyhelayers (==1.5.0.3)"] fl-crypto-rt24-1 = ["pyhelayers (==1.5.3.1)"] fl-rt23-1-py3-10 = ["GPUtil", "cryptography (==42.0.5)", "ddsketch (==2.0.4)", "diffprivlib (==0.5.1)", "environs (==9.5.0)", "gym", "image (==1.5.33)", "joblib (==1.1.1)", "lz4", "msgpack (==1.0.7)", "msgpack-numpy (==0.4.8)", "numcompress (==0.1.2)", "numpy (==1.23.5)", "pandas (==1.5.3)", "parse (==1.19.0)", "pathlib2 (==2.3.6)", "protobuf (==4.22.1)", "psutil", "pyYAML (==6.0.1)", "pytest (==6.2.5)", "requests (==2.32.3)", "scikit-learn (==1.1.1)", "scipy (==1.10.1)", "setproctitle", "skops (==0.9.0)", "skorch (==0.12.0)", "tabulate (==0.8.9)", "tensorflow (==2.12.0)", "torch (==2.0.1)", "websockets (==10.1)"] fl-rt24-1-py3-11 = ["GPUtil", "cryptography (==42.0.5)", "ddsketch (==2.0.4)", "diffprivlib (==0.5.1)", "environs (==9.5.0)", "gym", "image (==1.5.33)", "joblib (==1.3.2)", "lz4", "msgpack (==1.0.7)", "msgpack-numpy (==0.4.8)", "numcompress (==0.1.2)", "numpy (==1.26.4)", "pandas (==2.1.4)", "parse (==1.19.0)", "pathlib2 (==2.3.6)", "protobuf (==4.22.1)", "psutil", "pyYAML (==6.0.1)", "pytest (==6.2.5)", "requests (==2.32.3)", "scikit-learn (==1.3.0)", "scipy (==1.11.4)", "setproctitle", "skops (==0.9.0)", "skorch (==0.12.0)", "tabulate (==0.8.9)", "tensorflow (==2.14.1)", "torch (==2.1.2)", "websockets (==10.1)"] -rag = ["beautifulsoup4 (==4.12.3)", "grpcio (>=1.60.0)", "langchain (==0.2.15)", "langchain-chroma (==0.1.1)", "langchain-core (==0.2.37)", "langchain-elasticsearch (==0.2.2)", "langchain-ibm", "langchain-milvus (==0.1.1)", "pypdf (==4.2.0)", "python-docx (==1.1.2)"] +rag = ["beautifulsoup4 (==4.12.3)", "grpcio (>=1.60.0)", "langchain (>=0.2.15,<0.3)", "langchain-chroma (==0.1.1)", "langchain-community (>=0.2.4,<0.3)", "langchain-core (>=0.2.37,<0.3)", "langchain-elasticsearch (==0.2.2)", "langchain-ibm", "langchain-milvus (==0.1.1)", "pypdf (==4.2.0)", "python-docx (==1.1.2)"] [[package]] name = "idna" @@ -526,18 +526,40 @@ typing-extensions = ">=4.7" type = "git" url = "https://github.com/langchain-ai/langchain.git" reference = "HEAD" -resolved_reference = "7a07196df683582c783edf164bfb6fe813135169" +resolved_reference = "16f5fdb38b307ee3f3b5065f671b28c90661b698" subdirectory = "libs/core" +[[package]] +name = "langchain-standard-tests" +version = "0.1.1" +description = "Standard tests for LangChain implementations" +optional = false +python-versions = ">=3.9,<4.0" +files = [] +develop = false + +[package.dependencies] +httpx = "^0.27.0" +langchain-core = "^0.3.0" +pytest = ">=7,<9" +syrupy = "^4" + +[package.source] +type = "git" +url = "https://github.com/langchain-ai/langchain.git" +reference = "HEAD" +resolved_reference = "16f5fdb38b307ee3f3b5065f671b28c90661b698" +subdirectory = "libs/standard-tests" + [[package]] name = "langsmith" -version = "0.1.131" +version = "0.1.132" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.131-py3-none-any.whl", hash = "sha256:80c106b1c42307195cc0bb3a596472c41ef91b79d15bcee9938307800336c563"}, - {file = "langsmith-0.1.131.tar.gz", hash = "sha256:626101a3bf3ca481e5110d5155ace8aa066e4e9cc2fa7d96c8290ade0fbff797"}, + {file = "langsmith-0.1.132-py3-none-any.whl", hash = "sha256:2320894203675c1c292b818cbecf68b69e47a9f7814d4e950237d1faaafd5dee"}, + {file = "langsmith-0.1.132.tar.gz", hash = "sha256:007b8fac469138abdba89db931900a26c5d316640e27ff4660d28c92a766aae1"}, ] [package.dependencies] @@ -1130,13 +1152,13 @@ files = [ [[package]] name = "requests" -version = "2.32.2" +version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" files = [ - {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, - {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, ] [package.dependencies] @@ -1304,18 +1326,17 @@ files = [ [[package]] name = "urllib3" -version = "2.2.3" +version = "2.1.0" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, - {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, + {file = "urllib3-2.1.0-py3-none-any.whl", hash = "sha256:55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3"}, + {file = "urllib3-2.1.0.tar.gz", hash = "sha256:df7aa8afb0148fa78488e7899b2c59b5f4ffcfa82e6c54ccb9dd37c1d7b52d54"}, ] [package.extras] brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -1383,4 +1404,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "84b71aacace41099670cb831f589df0a2a9b27aef7dc6a88792873f5abe6265e" +content-hash = "33a6cea354dc8afa5d40c16c29e339a844adf6649610266d09340f46c696fe8b" diff --git a/libs/ibm/pyproject.toml b/libs/ibm/pyproject.toml index c41c51f..ca92174 100644 --- a/libs/ibm/pyproject.toml +++ b/libs/ibm/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-ibm" -version = "0.2.2" +version = "0.3.0" description = "An integration package connecting IBM watsonx.ai and LangChain" authors = ["IBM"] readme = "README.md" @@ -13,7 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.10,<4.0" langchain-core = ">=0.3.0,<0.4" -ibm-watsonx-ai = "^1.1.9" +ibm-watsonx-ai = "^1.1.14" [tool.poetry.group.test] optional = true @@ -27,6 +27,7 @@ pytest-watcher = "^0.3.4" pytest-asyncio = "^0.21.1" pytest-cov = "^4.1.0" langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" } +langchain-standard-tests = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/standard-tests" } [tool.poetry.group.codespell] optional = true diff --git a/libs/ibm/tests/integration_tests/test_chat_models.py b/libs/ibm/tests/integration_tests/test_chat_models.py index d8e284b..4372fbe 100644 --- a/libs/ibm/tests/integration_tests/test_chat_models.py +++ b/libs/ibm/tests/integration_tests/test_chat_models.py @@ -1,17 +1,20 @@ import json import os -from typing import Any +from typing import Any, Optional import pytest +from ibm_watsonx_ai.foundation_models.schema import TextChatParameters # type: ignore from ibm_watsonx_ai.metanames import GenTextParamsMetaNames # type: ignore from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, + BaseMessageChunk, HumanMessage, SystemMessage, ) from langchain_core.prompts import ChatPromptTemplate +from langchain_core.tools import tool from pydantic import BaseModel from langchain_ibm import ChatWatsonx @@ -20,13 +23,15 @@ WX_PROJECT_ID = os.environ.get("WATSONX_PROJECT_ID", "") URL = "https://us-south.ml.cloud.ibm.com" -MODEL_ID = "mistralai/mixtral-8x7b-instruct-v01" + +MODEL_ID = "ibm/granite-34b-code-instruct" +MODEL_ID_TOOL = "mistralai/mistral-large" def test_01_generate_chat() -> None: chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type] messages = [ - ("system", "You are a helpful assistant that translates English to French."), + ("user", "You are a helpful assistant that translates English to French."), ( "human", "Translate this sentence from English to French. I love programming.", @@ -37,13 +42,24 @@ def test_01_generate_chat() -> None: assert response.content -def test_01a_generate_chat_with_invoke_params() -> None: - from ibm_watsonx_ai.metanames import GenTextParamsMetaNames +def test_01a_generate_chat_with_params_as_dict_in_invoke() -> None: + params = {"max_tokens": 10} + chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type] + messages = [ + ("system", "You are a helpful assistant that translates English to French."), + ( + "human", + "Translate this sentence from English to French. I love programming.", + ), + ] + response = chat.invoke(messages, params=params) + assert response + assert response.content + print(response.content) - params = { - GenTextParamsMetaNames.MIN_NEW_TOKENS: 1, - GenTextParamsMetaNames.MAX_NEW_TOKENS: 10, - } + +def test_01a_generate_chat_with_params_as_object_in_invoke() -> None: + params = TextChatParameters(max_tokens=10) chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type] messages = [ ("system", "You are a helpful assistant that translates English to French."), @@ -55,11 +71,31 @@ def test_01a_generate_chat_with_invoke_params() -> None: response = chat.invoke(messages, params=params) assert response assert response.content + print(response.content) -def test_01b_generate_chat_with_invoke_params() -> None: - from ibm_watsonx_ai.metanames import GenTextParamsMetaNames +def test_01a_generate_chat_with_params_as_object_in_constructor() -> None: + params = TextChatParameters(max_tokens=10) + chat = ChatWatsonx( + model_id=MODEL_ID, + url=URL, # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + params=params, + ) + messages = [ + ("system", "You are a helpful assistant that translates English to French."), + ( + "human", + "Translate this sentence from English to French. I love programming.", + ), + ] + response = chat.invoke(messages) + assert response + assert response.content + print(response.content) + +def test_01b_generate_chat_with_invoke_params() -> None: parameters_1 = { GenTextParamsMetaNames.DECODING_METHOD: "sample", GenTextParamsMetaNames.MAX_NEW_TOKENS: 10, @@ -124,12 +160,7 @@ def test_05a_invoke_chat_with_streaming() -> None: def test_05_generate_chat_with_stream_with_param() -> None: - from ibm_watsonx_ai.metanames import GenTextParamsMetaNames - - params = { - GenTextParamsMetaNames.MIN_NEW_TOKENS: 1, - GenTextParamsMetaNames.MAX_NEW_TOKENS: 10, - } + params = TextChatParameters(max_tokens=10) chat = ChatWatsonx( model_id=MODEL_ID, url=URL, # type: ignore[arg-type] @@ -142,12 +173,7 @@ def test_05_generate_chat_with_stream_with_param() -> None: def test_05_generate_chat_with_stream_with_param_v2() -> None: - from ibm_watsonx_ai.metanames import GenTextParamsMetaNames - - params = { - GenTextParamsMetaNames.MIN_NEW_TOKENS: 1, - GenTextParamsMetaNames.MAX_NEW_TOKENS: 10, - } + params = TextChatParameters(max_tokens=10) chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type] response = chat.stream("What's the weather in san francisco", params=params) for chunk in response: @@ -232,16 +258,150 @@ def test_11_chaining_with_params() -> None: assert response.content -def test_20_tool_choice() -> None: - """Test that tool choice is respected.""" - from ibm_watsonx_ai.metanames import GenTextParamsMetaNames +def test_20_bind_tools() -> None: + chat = ChatWatsonx( + model_id=MODEL_ID_TOOL, + url=URL, # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) - params = {GenTextParamsMetaNames.MAX_NEW_TOKENS: 500} + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather report for a city", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + }, + ] + + llm_with_tools = chat.bind_tools(tools=tools) + + response = llm_with_tools.invoke("what's the weather in san francisco, ca") + assert isinstance(response, AIMessage) + assert not response.content + assert isinstance(response.tool_calls, list) + assert len(response.tool_calls) == 1 + tool_call = response.tool_calls[0] + assert tool_call["name"] == "get_weather" + assert isinstance(tool_call["args"], dict) + assert "location" in tool_call["args"] + + +def test_21a_bind_tools_tool_choice_auto() -> None: chat = ChatWatsonx( - model_id=MODEL_ID, + model_id=MODEL_ID_TOOL, + url=URL, # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather report for a city", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + }, + ] + + llm_with_tools = chat.bind_tools(tools=tools, tool_choice="auto") + + response = llm_with_tools.invoke("what's the weather in san francisco, ca") + assert isinstance(response, AIMessage) + assert not response.content + assert isinstance(response.tool_calls, list) + assert len(response.tool_calls) == 1 + tool_call = response.tool_calls[0] + assert tool_call["name"] == "get_weather" + assert isinstance(tool_call["args"], dict) + assert "location" in tool_call["args"] + + +@pytest.mark.skip(reason="Not supported yet") +def test_21b_bind_tools_tool_choice_none() -> None: + chat = ChatWatsonx( + model_id=MODEL_ID_TOOL, + url=URL, # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather report for a city", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + }, + ] + + llm_with_tools = chat.bind_tools(tools=tools, tool_choice="none") + + response = llm_with_tools.invoke("what's the weather in san francisco, ca") + assert isinstance(response, AIMessage) + assert not response.content + assert isinstance(response.tool_calls, list) + assert len(response.tool_calls) == 1 + tool_call = response.tool_calls[0] + assert tool_call["name"] == "get_weather" + assert isinstance(tool_call["args"], dict) + assert "location" in tool_call["args"] + + +@pytest.mark.skip(reason="Not supported yet") +def test_21c_bind_tools_tool_choice_required() -> None: + chat = ChatWatsonx( + model_id=MODEL_ID_TOOL, + url=URL, # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather report for a city", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + }, + ] + + llm_with_tools = chat.bind_tools(tools=tools, tool_choice="required") + + response = llm_with_tools.invoke("what's the weather in san francisco, ca") + assert isinstance(response, AIMessage) + assert not response.content + assert isinstance(response.tool_calls, list) + assert len(response.tool_calls) == 1 + tool_call = response.tool_calls[0] + assert tool_call["name"] == "get_weather" + assert isinstance(tool_call["args"], dict) + assert "location" in tool_call["args"] + + +def test_22a_bind_tools_tool_choice_as_class() -> None: + """Test that tool choice is respected.""" + chat = ChatWatsonx( + model_id=MODEL_ID_TOOL, url=URL, # type: ignore[arg-type] project_id=WX_PROJECT_ID, - params=params, ) class Person(BaseModel): @@ -262,23 +422,21 @@ class Person(BaseModel): } -def test_21_tool_choice_bool() -> None: +def test_22b_bind_tools_tool_choice_as_dict() -> None: """Test that tool choice is respected just passing in True.""" - from ibm_watsonx_ai.metanames import GenTextParamsMetaNames - - params = {GenTextParamsMetaNames.MAX_NEW_TOKENS: 500} chat = ChatWatsonx( - model_id=MODEL_ID, + model_id=MODEL_ID_TOOL, url=URL, # type: ignore[arg-type] project_id=WX_PROJECT_ID, - params=params, ) class Person(BaseModel): name: str age: int - with_tool = chat.bind_tools([Person], tool_choice=True) + tool_choice = {"type": "function", "function": {"name": "Person"}} + + with_tool = chat.bind_tools([Person], tool_choice=tool_choice) result = with_tool.invoke("Erick, 27 years old") assert isinstance(result, AIMessage) @@ -291,18 +449,13 @@ class Person(BaseModel): } -def test_22_tool_invoke() -> None: +def test_23a_bind_tools_list_tool_choice_dict() -> None: """Test that tool choice is respected just passing in True.""" - from ibm_watsonx_ai.metanames import GenTextParamsMetaNames - - params = {GenTextParamsMetaNames.MAX_NEW_TOKENS: 500} chat = ChatWatsonx( - model_id=MODEL_ID, + model_id=MODEL_ID_TOOL, url=URL, # type: ignore[arg-type] project_id=WX_PROJECT_ID, - params=params, ) - from langchain_core.tools import tool @tool def add(a: int, b: int) -> int: @@ -321,42 +474,137 @@ def get_word_length(word: str) -> int: tools = [add, multiply, get_word_length] - chat_with_tools = chat.bind_tools(tools) + tool_choice = { + "type": "function", + "function": { + "name": "add", + }, + } + + chat_with_tools = chat.bind_tools(tools, tool_choice=tool_choice) - query = "What is 3 + 12? What is 3 + 10?" + query = "What is 3 + 12? " resp = chat_with_tools.invoke(query) assert resp.content == "" - query = "Who was the famous painter from Italy?" + +def test_23_bind_tools_list_tool_choice_auto() -> None: + """Test that tool choice is respected just passing in True.""" + chat = ChatWatsonx( + model_id=MODEL_ID_TOOL, + url=URL, # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + + @tool + def add(a: int, b: int) -> int: + """Adds a and b.""" + return a + b + + @tool + def multiply(a: int, b: int) -> int: + """Multiplies a and b.""" + return a * b + + @tool + def get_word_length(word: str) -> int: + """Get word length.""" + return len(word) + + tools = [add, multiply, get_word_length] + chat_with_tools = chat.bind_tools(tools, tool_choice="auto") + + query = "What is 3 + 12? " resp = chat_with_tools.invoke(query) + assert resp.content == "" + assert len(resp.tool_calls) == 1 # type: ignore + tool_call = resp.tool_calls[0] # type: ignore + assert tool_call["name"] == "add" + + query = "What is 3 * 12? " + resp = chat_with_tools.invoke(query) + assert resp.content == "" + assert len(resp.tool_calls) == 1 # type: ignore + tool_call = resp.tool_calls[0] # type: ignore + assert tool_call["name"] == "multiply" + query = "Who was the famous painter from Italy?" + resp = chat_with_tools.invoke(query) assert resp.content + assert len(resp.tool_calls) == 0 # type: ignore + + +def test_json_mode() -> None: + llm = ChatWatsonx( + model_id=MODEL_ID_TOOL, + url=URL, # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + response = llm.invoke( + "Return this as json: {'a': 1}", + params={"response_format": {"type": "json_object"}}, + ) + assert isinstance(response.content, str) + assert json.loads(response.content) == {"a": 1} + + # Test streaming + full: Optional[BaseMessageChunk] = None + for chunk in llm.stream( + "Return this as json: {'a': 1}", + params={"response_format": {"type": "json_object"}}, + ): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert isinstance(full.content, str) + assert json.loads(full.content) == {"a": 1} + + +async def test_json_mode_async() -> None: + llm = ChatWatsonx( + model_id=MODEL_ID_TOOL, + url=URL, # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + response = await llm.ainvoke( + "Return this as json: {'a': 1}", + params={"response_format": {"type": "json_object"}}, + ) + assert isinstance(response.content, str) + assert json.loads(response.content) == {"a": 1} + + # Test streaming + full: Optional[BaseMessageChunk] = None + async for chunk in llm.astream( + "Return this as json: {'a': 1}", + params={"response_format": {"type": "json_object"}}, + ): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert isinstance(full.content, str) + assert json.loads(full.content) == {"a": 1} @pytest.mark.skip(reason="Not implemented") def test_streaming_tool_call() -> None: - from ibm_watsonx_ai.metanames import GenTextParamsMetaNames - - params = {GenTextParamsMetaNames.MAX_NEW_TOKENS: 500} chat = ChatWatsonx( - model_id=MODEL_ID, + model_id=MODEL_ID_TOOL, url=URL, # type: ignore[arg-type] project_id=WX_PROJECT_ID, - params=params, ) class Person(BaseModel): name: str age: int - tool_llm = chat.bind_tools([Person]) + tool_choice = {"type": "function", "function": {"name": "Person"}} + + tool_llm = chat.bind_tools([Person], tool_choice=tool_choice) - # where it calls the tool - strm = tool_llm.stream("Erick, 27 years old") + stream_response = tool_llm.stream("Erick, 27 years old") additional_kwargs = None - for chunk in strm: + for chunk in stream_response: assert isinstance(chunk, AIMessageChunk) assert chunk.content == "" additional_kwargs = chunk.additional_kwargs @@ -384,3 +632,51 @@ class Person(BaseModel): acc = chunk if acc is None else acc + chunk assert acc.content != "" assert "tool_calls" not in acc.additional_kwargs + + +def test_structured_output() -> None: + chat = ChatWatsonx( + model_id=MODEL_ID_TOOL, + url=URL, # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + schema = { + "title": "AnswerWithJustification", + "description": ( + "An answer to the user question along with justification for the answer." + ), + "type": "object", + "properties": { + "answer": {"title": "Answer", "type": "string"}, + "justification": {"title": "Justification", "type": "string"}, + }, + "required": ["answer", "justification"], + } + structured_llm = chat.with_structured_output(schema) + result = structured_llm.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) + assert isinstance(result, dict) + + +@pytest.mark.skip(reason="Not implemented") +def test_streaming_structured_output() -> None: + chat = ChatWatsonx( + model_id=MODEL_ID_TOOL, + url=URL, # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + + class Person(BaseModel): + name: str + age: int + + structured_llm = chat.with_structured_output(Person) + strm_response = structured_llm.stream("Erick, 27 years old") + chunk_num = 0 + for chunk in strm_response: + assert chunk_num == 0, "should only have one chunk with model" + assert isinstance(chunk, Person) + assert chunk.name == "Erick" + assert chunk.age == 27 + chunk_num += 1