Skip to content

Fix breaking changes from databricks-ai-bridge 0.4.2 #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,52 @@ jobs:
- name: Run tests
run: |
pytest integrations/langchain/tests/unit_tests

langchain_cross_version_test:
runs-on: ubuntu-latest
name: langchain_test (${{ matrix.python-version }}, ${{ matrix.version.name }})
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
version:
- {ref: "databricks-ai-v0.4.0", name: "v0.4.0"}
- {ref: "databricks-ai-v0.3.0", name: "v0.3.0"}
- {ref: "databricks-ai-v0.2.0", name: "v0.2.0"}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests for vector search tool integration don't exist in 0.1.0

timeout-minutes: 20
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .
- name: Checkout langchain version
uses: actions/checkout@v4
with:
ref: ${{ matrix.version.ref }}
fetch-depth: 1
path: older-version
- name: Replace langchain with older version
run: |
# Remove current langchain if it exists to avoid conflicts
rm -rf integrations/langchain

# Copy older version of langchain to the main repo
cp -r older-version/integrations/langchain integrations/
- name: Install langchain dependency
run: |
pip install integrations/langchain[dev]
- name: Run tests
run: |
# Only testing initialization since functionality can change
pytest integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py::test_init
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only testing basic init functionality here, since the functionality of other parts has changed. This still does catch the backwards compatibility issues that Serena saw.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment saying the same? this is really awesome work!

pytest integrations/langchain/tests/unit_tests/test_genie.py
pytest integrations/langchain/tests/unit_tests/test_embeddings.py
pytest integrations/langchain/tests/unit_tests/test_chat_models.py

openai_test:
runs-on: ubuntu-latest
Expand All @@ -92,6 +138,49 @@ jobs:
run: |
pytest integrations/openai/tests/unit_tests

openai_cross_version_test:
runs-on: ubuntu-latest
name: openai_test (${{ matrix.python-version }}, ${{ matrix.version.name }})
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
version:
- {ref: "databricks-ai-v0.4.0", name: "v0.4.0"}
- {ref: "databricks-ai-v0.3.0", name: "v0.3.0"}
Copy link
Contributor Author

@annzhang-db annzhang-db May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

openai integration with the tests we are running doesn't exist for older version of databricks-ai-bridge

timeout-minutes: 20
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .
- name: Checkout openai version
uses: actions/checkout@v4
with:
ref: ${{ matrix.version.ref }}
fetch-depth: 1
path: older-version
- name: Replace openai with older version
run: |
# Remove current openai if it exists to avoid conflicts
rm -rf integrations/openai

# Copy older version of openai to the main repo
cp -r older-version/integrations/openai integrations/
- name: Install openai dependency
run: |
pip install integrations/openai[dev]
- name: Run tests
run: |
# Only testing initialization since functionality can change
pytest integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py::test_vector_search_retriever_tool_init


llamaindex_test:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llamaindex is not published and thus has no cross-version tests rn

runs-on: ubuntu-latest
strategy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def similarity_search_with_score(
)
search_resp = self.index.similarity_search(**kwargs)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=Document
search_resp, retriever_schema=self._retriever_schema, document_class=Document
)

def _select_relevance_score_fn(self) -> Callable[[float], float]:
Expand Down Expand Up @@ -586,7 +586,7 @@ def similarity_search_by_vector_with_score(
**kwargs,
)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=Document
search_resp, retriever_schema=self._retriever_schema, document_class=Document
)

def max_marginal_relevance_search(
Expand Down Expand Up @@ -723,7 +723,7 @@ def max_marginal_relevance_search_by_vector(
ignore_cols: List = [embedding_column] if embedding_column not in self._columns else []
candidates = parse_vector_search_response(
search_resp,
self._retriever_schema,
retriever_schema=self._retriever_schema,
ignore_cols=ignore_cols,
document_class=Document,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa
)
search_resp = self._index.similarity_search(**kwargs)
return parse_vector_search_response(
search_resp, self._retriever_schema, document_class=dict
search_resp, retriever_schema=self._retriever_schema, document_class=dict
)

# Create tool metadata
Expand Down
55 changes: 32 additions & 23 deletions src/databricks_ai_bridge/utils/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,32 +87,40 @@ def get_metadata(columns: List[str], result: List[Any], retriever_schema, ignore
"""
metadata = {}

# Skipping the last column, which is always the score
for col, value in zip(columns[:-1], result[:-1]):
if col == retriever_schema.doc_uri:
metadata["doc_uri"] = value
elif col == retriever_schema.primary_key:
metadata["chunk_id"] = value
elif col == "doc_uri" and retriever_schema.doc_uri:
# Prioritize retriever_schema.doc_uri, don't override with the actual "doc_uri" column
continue
elif col == "chunk_id" and retriever_schema.primary_key:
# Prioritize retriever_schema.primary_key, don't override with the actual "chunk_id" column
continue
elif col in ignore_cols:
# ignore_cols has precedence over other_columns
continue
elif retriever_schema.other_columns is not None:
if col in retriever_schema.other_columns:
if retriever_schema:
# Skipping the last column, which is always the score
for col, value in zip(columns[:-1], result[:-1]):
if col == retriever_schema.doc_uri:
metadata["doc_uri"] = value
elif col == retriever_schema.primary_key:
metadata["chunk_id"] = value
elif col == "doc_uri" and retriever_schema.doc_uri:
# Prioritize retriever_schema.doc_uri, don't override with the actual "doc_uri" column
continue
elif col == "chunk_id" and retriever_schema.primary_key:
# Prioritize retriever_schema.primary_key, don't override with the actual "chunk_id" column
continue
elif col in ignore_cols:
# ignore_cols has precedence over other_columns
continue
elif retriever_schema.other_columns is not None:
if col in retriever_schema.other_columns:
metadata[col] = value
else:
metadata[col] = value
else:
for col, value in zip(columns[:-1], result[:-1]):
if col not in ignore_cols:
metadata[col] = value
else:
metadata[col] = value
return metadata


def parse_vector_search_response(
search_resp: Dict,
retriever_schema: RetrieverSchema,
index_details: IndexDetails = None, # deprecated
text_column: str = None, # deprecated
*,
retriever_schema: RetrieverSchema = None,
ignore_cols: Optional[List[str]] = None,
document_class: Any = dict,
) -> List[Tuple[Dict, float]]:
Expand All @@ -123,7 +131,8 @@ def parse_vector_search_response(
if ignore_cols is None:
ignore_cols = []

text_column = retriever_schema.text_column
if retriever_schema:
text_column = retriever_schema.text_column
ignore_cols.extend([text_column])

columns = [col["name"] for col in search_resp.get("manifest", dict()).get("columns", [])]
Expand Down Expand Up @@ -161,8 +170,8 @@ def validate_and_get_return_columns(
columns: List[str],
text_column: str,
index_details: IndexDetails,
doc_uri: str,
primary_key: str,
doc_uri: str = None,
primary_key: str = None,
) -> List[str]:
"""
Get a list of columns to retrieve from the index.
Expand Down
12 changes: 8 additions & 4 deletions src/databricks_ai_bridge/vector_search_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,20 @@ def _get_default_tool_description(self, index_details: IndexDetails) -> str:
return description

def _get_resources(
self, index_name: str, embedding_endpoint: str, index_details: IndexDetails
self, index_name: str, embedding_endpoint: str, index_details: IndexDetails = None
) -> List[Resource]:
resources = []
if index_name:
resources.append(DatabricksVectorSearchIndex(index_name=index_name))
if embedding_endpoint:
resources.append(DatabricksServingEndpoint(endpoint_name=embedding_endpoint))
if index_details.is_databricks_managed_embeddings and (
managed_embedding := index_details.embedding_source_column.get(
"embedding_model_endpoint_name", None
if (
index_details
and index_details.is_databricks_managed_embeddings
and (
managed_embedding := index_details.embedding_source_column.get(
"embedding_model_endpoint_name", None
)
)
):
if managed_embedding != embedding_endpoint:
Expand Down
16 changes: 15 additions & 1 deletion tests/databricks_ai_bridge/utils/test_vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,19 @@ def make_document(row_index: int, score: float):
)
def test_parse_vector_search_response(retriever_schema, ignore_cols, docs_with_score):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we also add a test for what an old version of databricks-langchain would've done?

assert (
parse_vector_search_response(search_resp, retriever_schema, ignore_cols) == docs_with_score
parse_vector_search_response(
search_resp, retriever_schema=retriever_schema, ignore_cols=ignore_cols
)
== docs_with_score
)


def test_parse_vector_search_response_without_retriever_schema():
assert (
parse_vector_search_response(search_resp, text_column="column_1", ignore_cols=["column_2"])
== construct_docs_with_score(
page_content_column="column_2",
column_3="column_3",
column_4="column_4",
),
)