diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 521c55b0..ce17cf1c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -70,6 +70,52 @@ jobs: - name: Run tests run: | pytest integrations/langchain/tests/unit_tests + + langchain_cross_version_test: + runs-on: ubuntu-latest + name: langchain_test (${{ matrix.python-version }}, ${{ matrix.version.name }}) + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] + version: + - {ref: "databricks-ai-v0.4.0", name: "v0.4.0"} + - {ref: "databricks-ai-v0.3.0", name: "v0.3.0"} + - {ref: "databricks-ai-v0.2.0", name: "v0.2.0"} + timeout-minutes: 20 + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install . + - name: Checkout langchain version + uses: actions/checkout@v4 + with: + ref: ${{ matrix.version.ref }} + fetch-depth: 1 + path: older-version + - name: Replace langchain with older version + run: | + # Remove current langchain if it exists to avoid conflicts + rm -rf integrations/langchain + + # Copy older version of langchain to the main repo + cp -r older-version/integrations/langchain integrations/ + - name: Install langchain dependency + run: | + pip install integrations/langchain[dev] + - name: Run tests + run: | + # Only testing initialization since functionality can change + pytest integrations/langchain/tests/unit_tests/test_vector_search_retriever_tool.py::test_init + pytest integrations/langchain/tests/unit_tests/test_genie.py + pytest integrations/langchain/tests/unit_tests/test_embeddings.py + pytest integrations/langchain/tests/unit_tests/test_chat_models.py openai_test: runs-on: ubuntu-latest @@ -92,6 +138,49 @@ jobs: run: | pytest integrations/openai/tests/unit_tests + openai_cross_version_test: + runs-on: ubuntu-latest + name: openai_test (${{ matrix.python-version }}, ${{ matrix.version.name }}) + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] + version: + - {ref: "databricks-ai-v0.4.0", name: "v0.4.0"} + - {ref: "databricks-ai-v0.3.0", name: "v0.3.0"} + timeout-minutes: 20 + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install . + - name: Checkout openai version + uses: actions/checkout@v4 + with: + ref: ${{ matrix.version.ref }} + fetch-depth: 1 + path: older-version + - name: Replace openai with older version + run: | + # Remove current openai if it exists to avoid conflicts + rm -rf integrations/openai + + # Copy older version of openai to the main repo + cp -r older-version/integrations/openai integrations/ + - name: Install openai dependency + run: | + pip install integrations/openai[dev] + - name: Run tests + run: | + # Only testing initialization since functionality can change + pytest integrations/openai/tests/unit_tests/test_vector_search_retriever_tool.py::test_vector_search_retriever_tool_init + + llamaindex_test: runs-on: ubuntu-latest strategy: diff --git a/integrations/langchain/src/databricks_langchain/vectorstores.py b/integrations/langchain/src/databricks_langchain/vectorstores.py index c0d35081..e4d64617 100644 --- a/integrations/langchain/src/databricks_langchain/vectorstores.py +++ b/integrations/langchain/src/databricks_langchain/vectorstores.py @@ -473,7 +473,7 @@ def similarity_search_with_score( ) search_resp = self.index.similarity_search(**kwargs) return parse_vector_search_response( - search_resp, self._retriever_schema, document_class=Document + search_resp, retriever_schema=self._retriever_schema, document_class=Document ) def _select_relevance_score_fn(self) -> Callable[[float], float]: @@ -586,7 +586,7 @@ def similarity_search_by_vector_with_score( **kwargs, ) return parse_vector_search_response( - search_resp, self._retriever_schema, document_class=Document + search_resp, retriever_schema=self._retriever_schema, document_class=Document ) def max_marginal_relevance_search( @@ -723,7 +723,7 @@ def max_marginal_relevance_search_by_vector( ignore_cols: List = [embedding_column] if embedding_column not in self._columns else [] candidates = parse_vector_search_response( search_resp, - self._retriever_schema, + retriever_schema=self._retriever_schema, ignore_cols=ignore_cols, document_class=Document, ) diff --git a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py index c7ae3243..f9b86769 100644 --- a/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py +++ b/integrations/llamaindex/src/databricks_llamaindex/vector_search_retriever_tool.py @@ -127,7 +127,7 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa ) search_resp = self._index.similarity_search(**kwargs) return parse_vector_search_response( - search_resp, self._retriever_schema, document_class=dict + search_resp, retriever_schema=self._retriever_schema, document_class=dict ) # Create tool metadata diff --git a/src/databricks_ai_bridge/utils/vector_search.py b/src/databricks_ai_bridge/utils/vector_search.py index 9fba15c1..9137682e 100644 --- a/src/databricks_ai_bridge/utils/vector_search.py +++ b/src/databricks_ai_bridge/utils/vector_search.py @@ -87,32 +87,40 @@ def get_metadata(columns: List[str], result: List[Any], retriever_schema, ignore """ metadata = {} - # Skipping the last column, which is always the score - for col, value in zip(columns[:-1], result[:-1]): - if col == retriever_schema.doc_uri: - metadata["doc_uri"] = value - elif col == retriever_schema.primary_key: - metadata["chunk_id"] = value - elif col == "doc_uri" and retriever_schema.doc_uri: - # Prioritize retriever_schema.doc_uri, don't override with the actual "doc_uri" column - continue - elif col == "chunk_id" and retriever_schema.primary_key: - # Prioritize retriever_schema.primary_key, don't override with the actual "chunk_id" column - continue - elif col in ignore_cols: - # ignore_cols has precedence over other_columns - continue - elif retriever_schema.other_columns is not None: - if col in retriever_schema.other_columns: + if retriever_schema: + # Skipping the last column, which is always the score + for col, value in zip(columns[:-1], result[:-1]): + if col == retriever_schema.doc_uri: + metadata["doc_uri"] = value + elif col == retriever_schema.primary_key: + metadata["chunk_id"] = value + elif col == "doc_uri" and retriever_schema.doc_uri: + # Prioritize retriever_schema.doc_uri, don't override with the actual "doc_uri" column + continue + elif col == "chunk_id" and retriever_schema.primary_key: + # Prioritize retriever_schema.primary_key, don't override with the actual "chunk_id" column + continue + elif col in ignore_cols: + # ignore_cols has precedence over other_columns + continue + elif retriever_schema.other_columns is not None: + if col in retriever_schema.other_columns: + metadata[col] = value + else: + metadata[col] = value + else: + for col, value in zip(columns[:-1], result[:-1]): + if col not in ignore_cols: metadata[col] = value - else: - metadata[col] = value return metadata def parse_vector_search_response( search_resp: Dict, - retriever_schema: RetrieverSchema, + index_details: IndexDetails = None, # deprecated + text_column: str = None, # deprecated + *, + retriever_schema: RetrieverSchema = None, ignore_cols: Optional[List[str]] = None, document_class: Any = dict, ) -> List[Tuple[Dict, float]]: @@ -123,7 +131,8 @@ def parse_vector_search_response( if ignore_cols is None: ignore_cols = [] - text_column = retriever_schema.text_column + if retriever_schema: + text_column = retriever_schema.text_column ignore_cols.extend([text_column]) columns = [col["name"] for col in search_resp.get("manifest", dict()).get("columns", [])] @@ -161,8 +170,8 @@ def validate_and_get_return_columns( columns: List[str], text_column: str, index_details: IndexDetails, - doc_uri: str, - primary_key: str, + doc_uri: str = None, + primary_key: str = None, ) -> List[str]: """ Get a list of columns to retrieve from the index. diff --git a/src/databricks_ai_bridge/vector_search_retriever_tool.py b/src/databricks_ai_bridge/vector_search_retriever_tool.py index a289ef6a..8fab74cd 100644 --- a/src/databricks_ai_bridge/vector_search_retriever_tool.py +++ b/src/databricks_ai_bridge/vector_search_retriever_tool.py @@ -143,16 +143,20 @@ def _get_default_tool_description(self, index_details: IndexDetails) -> str: return description def _get_resources( - self, index_name: str, embedding_endpoint: str, index_details: IndexDetails + self, index_name: str, embedding_endpoint: str, index_details: IndexDetails = None ) -> List[Resource]: resources = [] if index_name: resources.append(DatabricksVectorSearchIndex(index_name=index_name)) if embedding_endpoint: resources.append(DatabricksServingEndpoint(endpoint_name=embedding_endpoint)) - if index_details.is_databricks_managed_embeddings and ( - managed_embedding := index_details.embedding_source_column.get( - "embedding_model_endpoint_name", None + if ( + index_details + and index_details.is_databricks_managed_embeddings + and ( + managed_embedding := index_details.embedding_source_column.get( + "embedding_model_endpoint_name", None + ) ) ): if managed_embedding != embedding_endpoint: diff --git a/tests/databricks_ai_bridge/utils/test_vector_search.py b/tests/databricks_ai_bridge/utils/test_vector_search.py index e22b79da..399eef64 100644 --- a/tests/databricks_ai_bridge/utils/test_vector_search.py +++ b/tests/databricks_ai_bridge/utils/test_vector_search.py @@ -111,5 +111,19 @@ def make_document(row_index: int, score: float): ) def test_parse_vector_search_response(retriever_schema, ignore_cols, docs_with_score): assert ( - parse_vector_search_response(search_resp, retriever_schema, ignore_cols) == docs_with_score + parse_vector_search_response( + search_resp, retriever_schema=retriever_schema, ignore_cols=ignore_cols + ) + == docs_with_score + ) + + +def test_parse_vector_search_response_without_retriever_schema(): + assert ( + parse_vector_search_response(search_resp, text_column="column_1", ignore_cols=["column_2"]) + == construct_docs_with_score( + page_content_column="column_2", + column_3="column_3", + column_4="column_4", + ), )