From 8443e901adb212770bd211fc8388e6986d33d318 Mon Sep 17 00:00:00 2001 From: Bryan Qiu Date: Tue, 1 Apr 2025 16:19:52 -0700 Subject: [PATCH 1/3] . Signed-off-by: Bryan Qiu --- .../src/databricks_langchain/genie.py | 17 +++--- .../langchain/tests/unit_tests/test_genie.py | 60 ++++++++++++------- src/databricks_ai_bridge/genie.py | 1 + 3 files changed, 48 insertions(+), 30 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index 2df34e40..890a021a 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -19,13 +19,9 @@ def _concat_messages_array(messages): @mlflow.trace() -def _query_genie_as_agent( - input, genie_space_id, genie_agent_name, client: Optional[WorkspaceClient] = None -): +def _query_genie_as_agent(input, genie: Genie, genie_agent_name): from langchain_core.messages import AIMessage - genie = Genie(genie_space_id, client=client) - message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n" # Concatenate messages to form the chat history @@ -44,7 +40,6 @@ def _query_genie_as_agent( def GenieAgent( genie_space_id, genie_agent_name: str = "Genie", - description: str = "", client: Optional["WorkspaceClient"] = None, ): """Create a genie agent that can be used to query the API""" @@ -55,13 +50,15 @@ def GenieAgent( from langchain_core.runnables import RunnableLambda + genie = Genie(genie_space_id, client=client) + # Create a partial function with the genie_space_id pre-filled partial_genie_agent = partial( _query_genie_as_agent, - genie_space_id=genie_space_id, + genie=genie, genie_agent_name=genie_agent_name, - client=client, ) - # Use the partial function in the RunnableLambda - return RunnableLambda(partial_genie_agent) + runnable = RunnableLambda(partial_genie_agent) + runnable.description = genie.description + return runnable diff --git a/integrations/langchain/tests/unit_tests/test_genie.py b/integrations/langchain/tests/unit_tests/test_genie.py index 43b7ff7e..8d0928cf 100644 --- a/integrations/langchain/tests/unit_tests/test_genie.py +++ b/integrations/langchain/tests/unit_tests/test_genie.py @@ -1,7 +1,8 @@ from unittest.mock import patch import pytest -from databricks_ai_bridge.genie import GenieResponse +from databricks.sdk.service.dashboards import GenieSpace +from databricks_ai_bridge.genie import Genie from langchain_core.messages import AIMessage from databricks_langchain.genie import ( @@ -42,46 +43,65 @@ def __init__(self, role, content): assert result == expected -@patch("databricks_langchain.genie.Genie") -def test_query_genie_as_agent(MockGenie): - # Mock the Genie class and its response - mock_genie = MockGenie.return_value - mock_genie.ask_question.return_value = GenieResponse(result="It is sunny.") +@patch("databricks.sdk.WorkspaceClient") +def test_query_genie_as_agent(MockWorkspaceClient): + mock_space = GenieSpace( + space_id="space-id", + title="Sales Space", + description="description", + ) + MockWorkspaceClient.genie.get_space.return_value = mock_space + MockWorkspaceClient.genie._api.do.side_effect = [ + {"conversation_id": "123", "message_id": "abc"}, + {"status": "COMPLETED", "attachments": [{"text": {"content": "It is sunny."}}]}, + ] input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} - result = _query_genie_as_agent(input_data, "space-id", "Genie", None) + genie = Genie("space-id", MockWorkspaceClient) + result = _query_genie_as_agent(input_data, genie, "Genie") expected_message = {"messages": [AIMessage(content="It is sunny.")]} assert result == expected_message - # Test the case when genie_response is empty - mock_genie.ask_question.return_value = GenieResponse(result=None) - result = _query_genie_as_agent(input_data, "space-id", "Genie", None) - - expected_message = {"messages": [AIMessage(content="")]} - assert result == expected_message - +@patch("databricks.sdk.WorkspaceClient") @patch("langchain_core.runnables.RunnableLambda") -def test_create_genie_agent(MockRunnableLambda): +def test_create_genie_agent(MockRunnableLambda, MockWorkspaceClient): + # Create a mock space with a description + mock_space = GenieSpace( + space_id="space-id", + title="Sales Space", + description="description", + ) + MockWorkspaceClient.genie.get_space.return_value = mock_space + + # Create the agent + agent = GenieAgent("space-id", "Genie", MockWorkspaceClient) + assert agent.description == "description" + + # Verify RunnableLambda was created with correct arguments + MockWorkspaceClient.genie.get_space.assert_called_once() mock_runnable = MockRunnableLambda.return_value - agent = GenieAgent("space-id", "Genie") + # Verify the agent is the same as the mock runnable assert agent == mock_runnable - # Check that the partial function is created with the correct arguments - MockRunnableLambda.assert_called() - @patch("databricks.sdk.WorkspaceClient") def test_query_genie_with_client(mock_workspace_client): + mock_workspace_client.genie.get_space.return_value = GenieSpace( + space_id="space-id", + title="Sales Space", + description="description", + ) mock_workspace_client.genie._api.do.side_effect = [ {"conversation_id": "123", "message_id": "abc"}, {"status": "COMPLETED", "attachments": [{"text": {"content": "It is sunny."}}]}, ] input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} - result = _query_genie_as_agent(input_data, "space-id", "Genie", mock_workspace_client) + genie = Genie("space-id", mock_workspace_client) + result = _query_genie_as_agent(input_data, genie, "Genie") expected_message = {"messages": [AIMessage(content="It is sunny.")]} assert result == expected_message diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py index ded1f908..c30deebb 100644 --- a/src/databricks_ai_bridge/genie.py +++ b/src/databricks_ai_bridge/genie.py @@ -75,6 +75,7 @@ def __init__(self, space_id, client: Optional["WorkspaceClient"] = None): self.space_id = space_id workspace_client = client or WorkspaceClient() self.genie = workspace_client.genie + self.description = self.genie.get_space(space_id).description self.headers = { "Accept": "application/json", "Content-Type": "application/json", From a2dd79a4835d3e75f9c7107634c7780c0058424f Mon Sep 17 00:00:00 2001 From: Bryan Qiu Date: Tue, 1 Apr 2025 16:21:08 -0700 Subject: [PATCH 2/3] . Signed-off-by: Bryan Qiu --- integrations/langchain/tests/unit_tests/test_genie.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/integrations/langchain/tests/unit_tests/test_genie.py b/integrations/langchain/tests/unit_tests/test_genie.py index 8d0928cf..a9f10e3e 100644 --- a/integrations/langchain/tests/unit_tests/test_genie.py +++ b/integrations/langchain/tests/unit_tests/test_genie.py @@ -67,7 +67,6 @@ def test_query_genie_as_agent(MockWorkspaceClient): @patch("databricks.sdk.WorkspaceClient") @patch("langchain_core.runnables.RunnableLambda") def test_create_genie_agent(MockRunnableLambda, MockWorkspaceClient): - # Create a mock space with a description mock_space = GenieSpace( space_id="space-id", title="Sales Space", @@ -75,16 +74,11 @@ def test_create_genie_agent(MockRunnableLambda, MockWorkspaceClient): ) MockWorkspaceClient.genie.get_space.return_value = mock_space - # Create the agent agent = GenieAgent("space-id", "Genie", MockWorkspaceClient) assert agent.description == "description" - # Verify RunnableLambda was created with correct arguments MockWorkspaceClient.genie.get_space.assert_called_once() - mock_runnable = MockRunnableLambda.return_value - - # Verify the agent is the same as the mock runnable - assert agent == mock_runnable + assert agent == MockRunnableLambda.return_value @patch("databricks.sdk.WorkspaceClient") From 095090480b3c7a2ebe4990dce84735b68b945358 Mon Sep 17 00:00:00 2001 From: Bryan Qiu Date: Tue, 1 Apr 2025 16:29:45 -0700 Subject: [PATCH 3/3] . Signed-off-by: Bryan Qiu --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 06f7979a..f68b70ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires-python = ">=3.9" dependencies = [ "typing_extensions", "pydantic", - "databricks-sdk>=0.44.1", + "databricks-sdk>=0.49.0", "pandas", "tiktoken>=0.8.0", "tabulate>=0.9.0",