diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index 2df34e40..890a021a 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -19,13 +19,9 @@ def _concat_messages_array(messages): @mlflow.trace() -def _query_genie_as_agent( - input, genie_space_id, genie_agent_name, client: Optional[WorkspaceClient] = None -): +def _query_genie_as_agent(input, genie: Genie, genie_agent_name): from langchain_core.messages import AIMessage - genie = Genie(genie_space_id, client=client) - message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n" # Concatenate messages to form the chat history @@ -44,7 +40,6 @@ def _query_genie_as_agent( def GenieAgent( genie_space_id, genie_agent_name: str = "Genie", - description: str = "", client: Optional["WorkspaceClient"] = None, ): """Create a genie agent that can be used to query the API""" @@ -55,13 +50,15 @@ def GenieAgent( from langchain_core.runnables import RunnableLambda + genie = Genie(genie_space_id, client=client) + # Create a partial function with the genie_space_id pre-filled partial_genie_agent = partial( _query_genie_as_agent, - genie_space_id=genie_space_id, + genie=genie, genie_agent_name=genie_agent_name, - client=client, ) - # Use the partial function in the RunnableLambda - return RunnableLambda(partial_genie_agent) + runnable = RunnableLambda(partial_genie_agent) + runnable.description = genie.description + return runnable diff --git a/integrations/langchain/tests/unit_tests/test_genie.py b/integrations/langchain/tests/unit_tests/test_genie.py index 43b7ff7e..a9f10e3e 100644 --- a/integrations/langchain/tests/unit_tests/test_genie.py +++ b/integrations/langchain/tests/unit_tests/test_genie.py @@ -1,7 +1,8 @@ from unittest.mock import patch import pytest -from databricks_ai_bridge.genie import GenieResponse +from databricks.sdk.service.dashboards import GenieSpace +from databricks_ai_bridge.genie import Genie from langchain_core.messages import AIMessage from databricks_langchain.genie import ( @@ -42,46 +43,59 @@ def __init__(self, role, content): assert result == expected -@patch("databricks_langchain.genie.Genie") -def test_query_genie_as_agent(MockGenie): - # Mock the Genie class and its response - mock_genie = MockGenie.return_value - mock_genie.ask_question.return_value = GenieResponse(result="It is sunny.") +@patch("databricks.sdk.WorkspaceClient") +def test_query_genie_as_agent(MockWorkspaceClient): + mock_space = GenieSpace( + space_id="space-id", + title="Sales Space", + description="description", + ) + MockWorkspaceClient.genie.get_space.return_value = mock_space + MockWorkspaceClient.genie._api.do.side_effect = [ + {"conversation_id": "123", "message_id": "abc"}, + {"status": "COMPLETED", "attachments": [{"text": {"content": "It is sunny."}}]}, + ] input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} - result = _query_genie_as_agent(input_data, "space-id", "Genie", None) + genie = Genie("space-id", MockWorkspaceClient) + result = _query_genie_as_agent(input_data, genie, "Genie") expected_message = {"messages": [AIMessage(content="It is sunny.")]} assert result == expected_message - # Test the case when genie_response is empty - mock_genie.ask_question.return_value = GenieResponse(result=None) - result = _query_genie_as_agent(input_data, "space-id", "Genie", None) - - expected_message = {"messages": [AIMessage(content="")]} - assert result == expected_message - +@patch("databricks.sdk.WorkspaceClient") @patch("langchain_core.runnables.RunnableLambda") -def test_create_genie_agent(MockRunnableLambda): - mock_runnable = MockRunnableLambda.return_value +def test_create_genie_agent(MockRunnableLambda, MockWorkspaceClient): + mock_space = GenieSpace( + space_id="space-id", + title="Sales Space", + description="description", + ) + MockWorkspaceClient.genie.get_space.return_value = mock_space - agent = GenieAgent("space-id", "Genie") - assert agent == mock_runnable + agent = GenieAgent("space-id", "Genie", MockWorkspaceClient) + assert agent.description == "description" - # Check that the partial function is created with the correct arguments - MockRunnableLambda.assert_called() + MockWorkspaceClient.genie.get_space.assert_called_once() + assert agent == MockRunnableLambda.return_value @patch("databricks.sdk.WorkspaceClient") def test_query_genie_with_client(mock_workspace_client): + mock_workspace_client.genie.get_space.return_value = GenieSpace( + space_id="space-id", + title="Sales Space", + description="description", + ) mock_workspace_client.genie._api.do.side_effect = [ {"conversation_id": "123", "message_id": "abc"}, {"status": "COMPLETED", "attachments": [{"text": {"content": "It is sunny."}}]}, ] input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} - result = _query_genie_as_agent(input_data, "space-id", "Genie", mock_workspace_client) + genie = Genie("space-id", mock_workspace_client) + result = _query_genie_as_agent(input_data, genie, "Genie") expected_message = {"messages": [AIMessage(content="It is sunny.")]} assert result == expected_message diff --git a/pyproject.toml b/pyproject.toml index 06f7979a..f68b70ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires-python = ">=3.9" dependencies = [ "typing_extensions", "pydantic", - "databricks-sdk>=0.44.1", + "databricks-sdk>=0.49.0", "pandas", "tiktoken>=0.8.0", "tabulate>=0.9.0", diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py index ded1f908..c30deebb 100644 --- a/src/databricks_ai_bridge/genie.py +++ b/src/databricks_ai_bridge/genie.py @@ -75,6 +75,7 @@ def __init__(self, space_id, client: Optional["WorkspaceClient"] = None): self.space_id = space_id workspace_client = client or WorkspaceClient() self.genie = workspace_client.genie + self.description = self.genie.get_space(space_id).description self.headers = { "Accept": "application/json", "Content-Type": "application/json",