Skip to content

Expose description in Genie integrations #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions integrations/langchain/src/databricks_langchain/genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,9 @@ def _concat_messages_array(messages):


@mlflow.trace()
def _query_genie_as_agent(
input, genie_space_id, genie_agent_name, client: Optional[WorkspaceClient] = None
):
def _query_genie_as_agent(input, genie: Genie, genie_agent_name):
from langchain_core.messages import AIMessage

genie = Genie(genie_space_id, client=client)

message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n"

# Concatenate messages to form the chat history
Expand All @@ -44,7 +40,6 @@ def _query_genie_as_agent(
def GenieAgent(
genie_space_id,
genie_agent_name: str = "Genie",
description: str = "",
client: Optional["WorkspaceClient"] = None,
):
"""Create a genie agent that can be used to query the API"""
Expand All @@ -55,13 +50,15 @@ def GenieAgent(

from langchain_core.runnables import RunnableLambda

genie = Genie(genie_space_id, client=client)

# Create a partial function with the genie_space_id pre-filled
partial_genie_agent = partial(
_query_genie_as_agent,
genie_space_id=genie_space_id,
genie=genie,
genie_agent_name=genie_agent_name,
client=client,
)

# Use the partial function in the RunnableLambda
return RunnableLambda(partial_genie_agent)
runnable = RunnableLambda(partial_genie_agent)
runnable.description = genie.description
return runnable
56 changes: 35 additions & 21 deletions integrations/langchain/tests/unit_tests/test_genie.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from unittest.mock import patch

import pytest
from databricks_ai_bridge.genie import GenieResponse
from databricks.sdk.service.dashboards import GenieSpace
from databricks_ai_bridge.genie import Genie
from langchain_core.messages import AIMessage

from databricks_langchain.genie import (
Expand Down Expand Up @@ -42,46 +43,59 @@ def __init__(self, role, content):
assert result == expected


@patch("databricks_langchain.genie.Genie")
def test_query_genie_as_agent(MockGenie):
# Mock the Genie class and its response
mock_genie = MockGenie.return_value
mock_genie.ask_question.return_value = GenieResponse(result="It is sunny.")
@patch("databricks.sdk.WorkspaceClient")
def test_query_genie_as_agent(MockWorkspaceClient):
mock_space = GenieSpace(
space_id="space-id",
title="Sales Space",
description="description",
)
MockWorkspaceClient.genie.get_space.return_value = mock_space
MockWorkspaceClient.genie._api.do.side_effect = [
{"conversation_id": "123", "message_id": "abc"},
{"status": "COMPLETED", "attachments": [{"text": {"content": "It is sunny."}}]},
]

input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]}
result = _query_genie_as_agent(input_data, "space-id", "Genie", None)
genie = Genie("space-id", MockWorkspaceClient)
result = _query_genie_as_agent(input_data, genie, "Genie")

expected_message = {"messages": [AIMessage(content="It is sunny.")]}
assert result == expected_message

# Test the case when genie_response is empty
mock_genie.ask_question.return_value = GenieResponse(result=None)
result = _query_genie_as_agent(input_data, "space-id", "Genie", None)

expected_message = {"messages": [AIMessage(content="")]}
assert result == expected_message


@patch("databricks.sdk.WorkspaceClient")
@patch("langchain_core.runnables.RunnableLambda")
def test_create_genie_agent(MockRunnableLambda):
mock_runnable = MockRunnableLambda.return_value
def test_create_genie_agent(MockRunnableLambda, MockWorkspaceClient):
mock_space = GenieSpace(
space_id="space-id",
title="Sales Space",
description="description",
)
MockWorkspaceClient.genie.get_space.return_value = mock_space

agent = GenieAgent("space-id", "Genie")
assert agent == mock_runnable
agent = GenieAgent("space-id", "Genie", MockWorkspaceClient)
assert agent.description == "description"

# Check that the partial function is created with the correct arguments
MockRunnableLambda.assert_called()
MockWorkspaceClient.genie.get_space.assert_called_once()
assert agent == MockRunnableLambda.return_value


@patch("databricks.sdk.WorkspaceClient")
def test_query_genie_with_client(mock_workspace_client):
mock_workspace_client.genie.get_space.return_value = GenieSpace(
space_id="space-id",
title="Sales Space",
description="description",
)
mock_workspace_client.genie._api.do.side_effect = [
{"conversation_id": "123", "message_id": "abc"},
{"status": "COMPLETED", "attachments": [{"text": {"content": "It is sunny."}}]},
]

input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]}
result = _query_genie_as_agent(input_data, "space-id", "Genie", mock_workspace_client)
genie = Genie("space-id", mock_workspace_client)
result = _query_genie_as_agent(input_data, genie, "Genie")

expected_message = {"messages": [AIMessage(content="It is sunny.")]}
assert result == expected_message
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ requires-python = ">=3.9"
dependencies = [
"typing_extensions",
"pydantic",
"databricks-sdk>=0.44.1",
"databricks-sdk>=0.49.0",
"pandas",
"tiktoken>=0.8.0",
"tabulate>=0.9.0",
Expand Down
1 change: 1 addition & 0 deletions src/databricks_ai_bridge/genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, space_id, client: Optional["WorkspaceClient"] = None):
self.space_id = space_id
workspace_client = client or WorkspaceClient()
self.genie = workspace_client.genie
self.description = self.genie.get_space(space_id).description
self.headers = {
"Accept": "application/json",
"Content-Type": "application/json",
Expand Down