From 52bca631b3cf4218c7d31f73e7e7c66e8b40cd73 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 26 May 2025 20:40:46 +0200 Subject: [PATCH 1/2] simplifications --- .../samples/concepts/memory/complex_memory.py | 10 +- python/samples/concepts/memory/utils.py | 2 +- .../third_party/postgres-memory.ipynb | 1417 ++++++++--------- .../connectors/memory/azure_ai_search.py | 18 +- .../connectors/memory/azure_cosmos_db.py | 9 +- .../connectors/memory/chroma.py | 9 +- .../connectors/memory/faiss.py | 14 +- .../connectors/memory/in_memory.py | 9 +- .../connectors/memory/mongodb.py | 8 +- .../connectors/memory/pinecone.py | 8 +- .../connectors/memory/postgres.py | 9 +- .../connectors/memory/qdrant.py | 9 +- .../connectors/memory/redis.py | 9 +- .../connectors/memory/sql_server.py | 10 +- .../connectors/memory/weaviate.py | 9 +- .../connectors/search/brave.py | 8 +- .../connectors/search/google.py | 8 +- python/semantic_kernel/data/__init__.py | 23 +- .../data/{definitions.py => _definitions.py} | 15 +- .../data/{search.py => _search.py} | 10 +- .../data/{vectors.py => _vectors.py} | 169 +- python/semantic_kernel/data/const.py | 130 -- python/tests/conftest.py | 2 +- .../memory/azure_cosmos_db/conftest.py | 2 +- .../test_azure_cosmos_db_no_sql.py | 2 +- .../memory/postgres/test_postgres_int.py | 5 +- .../unit/connectors/memory/test_faiss.py | 3 +- .../unit/connectors/memory/test_in_memory.py | 2 +- .../connectors/memory/test_postgres_store.py | 4 +- .../unit/connectors/memory/test_qdrant.py | 4 +- .../unit/connectors/memory/test_sql_server.py | 5 +- .../connectors/search/test_brave_search.py | 2 +- .../connectors/search/test_google_search.py | 2 +- python/tests/unit/data/conftest.py | 4 +- python/tests/unit/data/test_filter.py | 2 +- python/tests/unit/data/test_text_search.py | 12 +- .../unit/data/test_vector_search_base.py | 2 +- .../data/test_vector_store_model_decorator.py | 2 +- .../test_vector_store_record_collection.py | 2 +- 39 files changed, 992 insertions(+), 978 deletions(-) rename python/semantic_kernel/data/{definitions.py => _definitions.py} (98%) rename python/semantic_kernel/data/{search.py => _search.py} (98%) rename python/semantic_kernel/data/{vectors.py => _vectors.py} (90%) delete mode 100644 python/semantic_kernel/data/const.py diff --git a/python/samples/concepts/memory/complex_memory.py b/python/samples/concepts/memory/complex_memory.py index 16ebc3cf92af..43165aa4f707 100644 --- a/python/samples/concepts/memory/complex_memory.py +++ b/python/samples/concepts/memory/complex_memory.py @@ -27,8 +27,8 @@ WeaviateCollection, ) from semantic_kernel.data import VectorStoreRecordCollection, vectorstoremodel -from semantic_kernel.data.definitions import VectorStoreField -from semantic_kernel.data.vectors import SearchType, VectorSearch +from semantic_kernel.data._definitions import VectorStoreField +from semantic_kernel.data._vectors import SearchType, VectorSearch # This is a rather complex sample, showing how to use the vector store # with a number of different collections. @@ -47,7 +47,7 @@ class DataModel: content: Annotated[str, VectorStoreField("data", is_full_text_indexed=True)] embedding: Annotated[ str | None, - VectorStoreField("vector", dimensions=1536, type_="float"), + VectorStoreField("vector", dimensions=1536, type="float"), ] = None id: Annotated[ str, @@ -55,7 +55,7 @@ class DataModel: "key", ), ] = field(default_factory=lambda: str(uuid4())) - tag: Annotated[str | None, VectorStoreField("data", type_="str", is_indexed=True)] = None + tag: Annotated[str | None, VectorStoreField("data", type="str", is_indexed=True)] = None def __post_init__(self, **kwargs): if self.embedding is None: @@ -172,7 +172,7 @@ async def main(collection: str, use_azure_openai: bool): keys = await record_collection.upsert(records) print(f" Upserted {keys=}") print_with_color("Getting records!", Colors.CBLUE) - results = await record_collection.get(top=10, order_by={"field": "content"}) + results = await record_collection.get(top=10, order_by="content") if results: [print_record(record=result) for result in results] else: diff --git a/python/samples/concepts/memory/utils.py b/python/samples/concepts/memory/utils.py index 07e52ae75f14..ba2aa3c187df 100644 --- a/python/samples/concepts/memory/utils.py +++ b/python/samples/concepts/memory/utils.py @@ -3,7 +3,7 @@ from typing import TypeVar from samples.concepts.resources.utils import Colors, print_with_color -from semantic_kernel.data.vectors import VectorSearchResult +from semantic_kernel.data._vectors import VectorSearchResult _T = TypeVar("_T") diff --git a/python/samples/getting_started/third_party/postgres-memory.ipynb b/python/samples/getting_started/third_party/postgres-memory.ipynb index 4b0f6fbef9ce..0c1b77da84a6 100644 --- a/python/samples/getting_started/third_party/postgres-memory.ipynb +++ b/python/samples/getting_started/third_party/postgres-memory.ipynb @@ -1,710 +1,709 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Using Postgres as memory\n", - "\n", - "This notebook shows how to use Postgres as a memory store in Semantic Kernel.\n", - "\n", - "The code below pulls the most recent papers from [ArviX](https://arxiv.org/), creates embeddings from the paper abstracts, and stores them in a Postgres database.\n", - "\n", - "In the future, we can use the Postgres vector store to search the database for similar papers based on the embeddings - stay tuned!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import textwrap\n", - "import xml.etree.ElementTree as ET\n", - "from dataclasses import dataclass\n", - "from datetime import datetime\n", - "from typing import Annotated, Any\n", - "\n", - "import requests\n", - "\n", - "from semantic_kernel import Kernel\n", - "from semantic_kernel.connectors.ai import FunctionChoiceBehavior\n", - "from semantic_kernel.connectors.ai.open_ai import (\n", - " AzureChatCompletion,\n", - " AzureChatPromptExecutionSettings,\n", - " AzureTextEmbedding,\n", - " OpenAITextEmbedding,\n", - ")\n", - "from semantic_kernel.connectors.memory.postgres import PostgresCollection\n", - "from semantic_kernel.contents import ChatHistory\n", - "from semantic_kernel.data import (\n", - " DistanceFunction,\n", - " IndexKind,\n", - " VectorStoreField,\n", - " vectorstoremodel,\n", - ")\n", - "from semantic_kernel.functions import KernelParameterMetadata\n", - "from semantic_kernel.functions.kernel_arguments import KernelArguments" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Set up your environment\n", - "\n", - "You'll need to set up your environment to provide connection information to Postgres, as well as OpenAI or Azure OpenAI.\n", - "\n", - "To do this, copy the `.env.example` file to `.env` and fill in the necessary information.\n", - "\n", - "__Note__: If you're using VSCode to execute the notebook, the settings in `.env` in the root of the repository will be picked up automatically.\n", - "\n", - "### Postgres configuration\n", - "\n", - "You'll need to provide a connection string to a Postgres database. You can use a local Postgres instance, or a cloud-hosted one.\n", - "You can provide a connection string, or provide environment variables with the connection information. See the .env.example file for `POSTGRES_` settings.\n", - "\n", - "#### Using Docker\n", - "\n", - "You can also use docker to bring up a Postgres instance by following the steps below:\n", - "\n", - "Create an `init.sql` that has the following:\n", - "\n", - "```sql\n", - "CREATE EXTENSION IF NOT EXISTS vector;\n", - "```\n", - "\n", - "Now you can start a postgres instance with the following:\n", - "\n", - "```\n", - "docker pull pgvector/pgvector:pg16\n", - "docker run --rm -it --name pgvector -p 5432:5432 -v ./init.sql:/docker-entrypoint-initdb.d/init.sql -e POSTGRES_PASSWORD=example pgvector/pgvector:pg16\n", - "```\n", - "\n", - "_Note_: Use `.\\init.sql` on Windows and `./init.sql` on WSL or Linux/Mac.\n", - "\n", - "Then you could use the connection string:\n", - "\n", - "```\n", - "POSTGRES_CONNECTION_STRING=\"host=localhost port=5432 dbname=postgres user=postgres password=example\"\n", - "```\n", - "\n", - "### OpenAI configuration\n", - "\n", - "You can either use OpenAI or Azure OpenAI APIs. You provide the API key and other configuration in the `.env` file. Set either the `OPENAI_` or `AZURE_OPENAI_` settings.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# Path to the environment file\n", - "env_file_path = \".env\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here we set some additional configuration." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# -- ArXiv settings --\n", - "\n", - "# The search term to use when searching for papers on arXiv. All metadata fields for the papers are searched.\n", - "SEARCH_TERM = \"RAG\"\n", - "\n", - "# The category of papers to search for on arXiv. See https://arxiv.org/category_taxonomy for a list of categories.\n", - "ARVIX_CATEGORY = \"cs.AI\"\n", - "\n", - "# The maximum number of papers to search for on arXiv.\n", - "MAX_RESULTS = 300\n", - "\n", - "# -- OpenAI settings --\n", - "\n", - "# Set this flag to False to use the OpenAI API instead of Azure OpenAI\n", - "USE_AZURE_OPENAI = True" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here we define a vector store model. This model defines the table and column names for storing the embeddings. We use the `@vectorstoremodel` decorator to tell Semantic Kernel to create a vector store definition from the model. The VectorStoreRecordField annotations define the fields that will be stored in the database, including key and vector fields." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@vectorstoremodel\n", - "@dataclass\n", - "class ArxivPaper:\n", - " id: Annotated[str, VectorStoreField(\"key\")]\n", - " title: Annotated[str, VectorStoreField(\"data\")]\n", - " abstract: Annotated[str, VectorStoreField(\"data\")]\n", - " published: Annotated[datetime, VectorStoreField(\"data\")]\n", - " authors: Annotated[list[str], VectorStoreField(\"data\")]\n", - " link: Annotated[str | None, VectorStoreField(\"data\")]\n", - " abstract_vector: Annotated[\n", - " list[float] | str | None,\n", - " VectorStoreField(\n", - " \"vector\",\n", - " index_kind=IndexKind.HNSW,\n", - " dimensions=1536,\n", - " distance_function=DistanceFunction.COSINE_DISTANCE,\n", - " ),\n", - " ] = None\n", - "\n", - " def __post_init__(self):\n", - " if self.abstract_vector is None:\n", - " self.abstract_vector = self.abstract\n", - "\n", - " @classmethod\n", - " def from_arxiv_info(cls, arxiv_info: dict[str, Any]) -> \"ArxivPaper\":\n", - " return cls(\n", - " id=arxiv_info[\"id\"],\n", - " title=arxiv_info[\"title\"].replace(\"\\n \", \" \"),\n", - " abstract=arxiv_info[\"abstract\"].replace(\"\\n \", \" \"),\n", - " published=arxiv_info[\"published\"],\n", - " authors=arxiv_info[\"authors\"],\n", - " link=arxiv_info[\"link\"],\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Below is a function that queries the ArviX API for the most recent papers based on our search query and category." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "def query_arxiv(search_query: str, category: str = \"cs.AI\", max_results: int = 10) -> list[dict[str, Any]]:\n", - " \"\"\"\n", - " Query the ArXiv API and return a list of dictionaries with relevant metadata for each paper.\n", - "\n", - " Args:\n", - " search_query: The search term or topic to query for.\n", - " category: The category to restrict the search to (default is \"cs.AI\").\n", - " See https://arxiv.org/category_taxonomy for a list of categories.\n", - " max_results: Maximum number of results to retrieve (default is 10).\n", - " \"\"\"\n", - " response = requests.get(\n", - " \"http://export.arxiv.org/api/query?\"\n", - " f\"search_query=all:%22{search_query.replace(' ', '+')}%22\"\n", - " f\"+AND+cat:{category}&start=0&max_results={max_results}&sortBy=lastUpdatedDate&sortOrder=descending\"\n", - " )\n", - "\n", - " root = ET.fromstring(response.content)\n", - " ns = {\"atom\": \"http://www.w3.org/2005/Atom\"}\n", - "\n", - " return [\n", - " {\n", - " \"id\": entry.find(\"atom:id\", ns).text.split(\"/\")[-1],\n", - " \"title\": entry.find(\"atom:title\", ns).text,\n", - " \"abstract\": entry.find(\"atom:summary\", ns).text,\n", - " \"published\": entry.find(\"atom:published\", ns).text,\n", - " \"link\": entry.find(\"atom:id\", ns).text,\n", - " \"authors\": [author.find(\"atom:name\", ns).text for author in entry.findall(\"atom:author\", ns)],\n", - " \"categories\": [category.get(\"term\") for category in entry.findall(\"atom:category\", ns)],\n", - " \"pdf_link\": next(\n", - " (link_tag.get(\"href\") for link_tag in entry.findall(\"atom:link\", ns) if link_tag.get(\"title\") == \"pdf\"),\n", - " None,\n", - " ),\n", - " }\n", - " for entry in root.findall(\"atom:entry\", ns)\n", - " ]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We use this function to query papers and store them in memory as our model types." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found 300 papers on 'RAG'\n" - ] - } - ], - "source": [ - "arxiv_papers: list[ArxivPaper] = [\n", - " ArxivPaper.from_arxiv_info(paper)\n", - " for paper in query_arxiv(SEARCH_TERM, category=ARVIX_CATEGORY, max_results=MAX_RESULTS)\n", - "]\n", - "\n", - "print(f\"Found {len(arxiv_papers)} papers on '{SEARCH_TERM}'\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Create a `PostgresCollection`, which represents the table in Postgres where we will store the paper information and embeddings." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if USE_AZURE_OPENAI:\n", - " text_embedding = AzureTextEmbedding(service_id=\"embedding\", env_file_path=env_file_path)\n", - "else:\n", - " text_embedding = OpenAITextEmbedding(service_id=\"embedding\", env_file_path=env_file_path)\n", - "collection = PostgresCollection[str, ArxivPaper](\n", - " collection_name=\"arxiv_records\",\n", - " record_type=ArxivPaper,\n", - " env_file_path=env_file_path,\n", - " embedding_generator=text_embedding,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now that the models have embeddings, we can write them into the Postgres database." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "async with collection:\n", - " await collection.ensure_collection_exists()\n", - " keys = await collection.upsert(arxiv_papers)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here we retrieve the first few models from the database and print out their information." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "# Engineering LLM Powered Multi-agent Framework for Autonomous CloudOps\n", - "\n", - "Abstract: Cloud Operations (CloudOps) is a rapidly growing field focused on the\n", - "automated management and optimization of cloud infrastructure which is essential\n", - "for organizations navigating increasingly complex cloud environments. MontyCloud\n", - "Inc. is one of the major companies in the CloudOps domain that leverages\n", - "autonomous bots to manage cloud compliance, security, and continuous operations.\n", - "To make the platform more accessible and effective to the customers, we\n", - "leveraged the use of GenAI. Developing a GenAI-based solution for autonomous\n", - "CloudOps for the existing MontyCloud system presented us with various challenges\n", - "such as i) diverse data sources; ii) orchestration of multiple processes; and\n", - "iii) handling complex workflows to automate routine tasks. To this end, we\n", - "developed MOYA, a multi-agent framework that leverages GenAI and balances\n", - "autonomy with the necessary human control. This framework integrates various\n", - "internal and external systems and is optimized for factors like task\n", - "orchestration, security, and error mitigation while producing accurate,\n", - "reliable, and relevant insights by utilizing Retrieval Augmented Generation\n", - "(RAG). Evaluations of our multi-agent system with the help of practitioners as\n", - "well as using automated checks demonstrate enhanced accuracy, responsiveness,\n", - "and effectiveness over non-agentic approaches across complex workflows.\n", - "Published: 2025-01-14 16:30:10\n", - "Link: http://arxiv.org/abs/2501.08243v1\n", - "PDF Link: http://arxiv.org/abs/2501.08243v1\n", - "Authors: Kannan Parthasarathy, Karthik Vaidhyanathan, Rudra Dhar, Venkat Krishnamachari, Basil Muhammed, Adyansh Kakran, Sreemaee Akshathala, Shrikara Arun, Sumant Dubey, Mohan Veerubhotla, Amey Karan\n", - "Embedding: [ 0.01063822 0.02977918 0.04532182 ... -0.00264323 0.00081101\n", - " 0.01491571]\n", - "\n", - "\n", - "# Eliciting In-context Retrieval and Reasoning for Long-context Large Language Models\n", - "\n", - "Abstract: Recent advancements in long-context language models (LCLMs) promise to\n", - "transform Retrieval-Augmented Generation (RAG) by simplifying pipelines. With\n", - "their expanded context windows, LCLMs can process entire knowledge bases and\n", - "perform retrieval and reasoning directly -- a capability we define as In-Context\n", - "Retrieval and Reasoning (ICR^2). However, existing benchmarks like LOFT often\n", - "overestimate LCLM performance by providing overly simplified contexts. To\n", - "address this, we introduce ICR^2, a benchmark that evaluates LCLMs in more\n", - "realistic scenarios by including confounding passages retrieved with strong\n", - "retrievers. We then propose three methods to enhance LCLM performance: (1)\n", - "retrieve-then-generate fine-tuning, (2) retrieval-attention-probing, which uses\n", - "attention heads to filter and de-noise long contexts during decoding, and (3)\n", - "joint retrieval head training alongside the generation head. Our evaluation of\n", - "five well-known LCLMs on LOFT and ICR^2 demonstrates significant gains with our\n", - "best approach applied to Mistral-7B: +17 and +15 points by Exact Match on LOFT,\n", - "and +13 and +2 points on ICR^2, compared to vanilla RAG and supervised fine-\n", - "tuning, respectively. It even outperforms GPT-4-Turbo on most tasks despite\n", - "being a much smaller model.\n", - "Published: 2025-01-14 16:38:33\n", - "Link: http://arxiv.org/abs/2501.08248v1\n", - "PDF Link: http://arxiv.org/abs/2501.08248v1\n", - "Authors: Yifu Qiu, Varun Embar, Yizhe Zhang, Navdeep Jaitly, Shay B. Cohen, Benjamin Han\n", - "Embedding: [-0.01305697 0.01166064 0.06267344 ... -0.01627254 0.00974741\n", - " -0.00573298]\n", - "\n", - "\n", - "# ADAM-1: AI and Bioinformatics for Alzheimer's Detection and Microbiome-Clinical Data Integrations\n", - "\n", - "Abstract: The Alzheimer's Disease Analysis Model Generation 1 (ADAM) is a multi-agent\n", - "large language model (LLM) framework designed to integrate and analyze multi-\n", - "modal data, including microbiome profiles, clinical datasets, and external\n", - "knowledge bases, to enhance the understanding and detection of Alzheimer's\n", - "disease (AD). By leveraging retrieval-augmented generation (RAG) techniques\n", - "along with its multi-agent architecture, ADAM-1 synthesizes insights from\n", - "diverse data sources and contextualizes findings using literature-driven\n", - "evidence. Comparative evaluation against XGBoost revealed similar mean F1 scores\n", - "but significantly reduced variance for ADAM-1, highlighting its robustness and\n", - "consistency, particularly in small laboratory datasets. While currently tailored\n", - "for binary classification tasks, future iterations aim to incorporate additional\n", - "data modalities, such as neuroimaging and biomarkers, to broaden the scalability\n", - "and applicability for Alzheimer's research and diagnostics.\n", - "Published: 2025-01-14 18:56:33\n", - "Link: http://arxiv.org/abs/2501.08324v1\n", - "PDF Link: http://arxiv.org/abs/2501.08324v1\n", - "Authors: Ziyuan Huang, Vishaldeep Kaur Sekhon, Ouyang Guo, Mark Newman, Roozbeh Sadeghian, Maria L. Vaida, Cynthia Jo, Doyle Ward, Vanni Bucci, John P. Haran\n", - "Embedding: [ 0.03896349 0.00422515 0.05525447 ... 0.03374933 -0.01468264\n", - " 0.01850895]\n", - "\n", - "\n" - ] - } - ], - "source": [ - "async with collection:\n", - " results = await collection.get(keys[:3])\n", - " if results:\n", - " for result in results:\n", - " print(f\"# {result.title}\")\n", - " print()\n", - " wrapped_abstract = textwrap.fill(result.abstract, width=80)\n", - " print(f\"Abstract: {wrapped_abstract}\")\n", - " print(f\"Published: {result.published}\")\n", - " print(f\"Link: {result.link}\")\n", - " print(f\"PDF Link: {result.link}\")\n", - " print(f\"Authors: {', '.join(result.authors)}\")\n", - " print(f\"Embedding: {result.abstract_vector}\")\n", - " print()\n", - " print()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `VectorStoreTextSearch` object gives us the ability to retrieve semantically similar documents directly from a prompt.\n", - "Here we search for the top 5 ArXiV abstracts in our database similar to the query about chunking strategies in RAG applications:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found 5 results for query.\n", - "Advanced ingestion process powered by LLM parsing for RAG system: 0.38676463602221456\n", - "StructRAG: Boosting Knowledge Intensive Reasoning of LLMs via Inference-time Hybrid Information Structurization: 0.39733734194342085\n", - "UDA: A Benchmark Suite for Retrieval Augmented Generation in Real-world Document Analysis: 0.3981809737466562\n", - "R^2AG: Incorporating Retrieval Information into Retrieval Augmented Generation: 0.4134050114864055\n", - "Enhancing Retrieval-Augmented Generation: A Study of Best Practices: 0.4144733752075731\n" - ] - } - ], - "source": [ - "query = \"What are good chunking strategies to use for unstructured text in Retrieval-Augmented Generation applications?\"\n", - "\n", - "async with collection:\n", - " search_results = await collection.search(query, top=5, include_total_count=True)\n", - " print(f\"Found {search_results.total_count} results for query.\")\n", - " async for search_result in search_results.results:\n", - " title = search_result.record.title\n", - " score = search_result.score\n", - " print(f\"{title}: {score}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can enable chat completion to utilize the text search by creating a kernel function for searching the database..." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "kernel = Kernel()\n", - "plugin = kernel.add_functions(\n", - " plugin_name=\"arxiv_plugin\",\n", - " functions=[\n", - " collection.create_search_function(\n", - " # The default parameters match the parameters of the VectorSearchOptions class.\n", - " description=\"Searches for ArXiv papers that are related to the query.\",\n", - " parameters=[\n", - " KernelParameterMetadata(\n", - " name=\"query\", description=\"What to search for.\", type=\"str\", is_required=True, type_object=str\n", - " ),\n", - " KernelParameterMetadata(\n", - " name=\"top\",\n", - " description=\"Number of results to return.\",\n", - " type=\"int\",\n", - " default_value=2,\n", - " type_object=int,\n", - " ),\n", - " ],\n", - " ),\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "...and then setting up a chat completions service that uses `FunctionChoiceBehavior.Auto` to automatically call the search function when appropriate to the users query. We also create the chat function that will be invoked by the kernel." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "# Create the chat completion service. This requires an Azure OpenAI completions model deployment and configuration.\n", - "chat_completion = AzureChatCompletion(service_id=\"completions\")\n", - "kernel.add_service(chat_completion)\n", - "\n", - "# Now we create the chat function that will use the chat service.\n", - "chat_function = kernel.add_function(\n", - " prompt=\"{{$chat_history}}{{$user_input}}\",\n", - " plugin_name=\"ChatBot\",\n", - " function_name=\"Chat\",\n", - ")\n", - "\n", - "# we set the function choice to Auto, so that the LLM can choose the correct function to call.\n", - "# and we exclude the ChatBot plugin, so that it does not call itself.\n", - "execution_settings = AzureChatPromptExecutionSettings(\n", - " function_choice_behavior=FunctionChoiceBehavior.Auto(filters={\"excluded_plugins\": [\"ChatBot\"]}),\n", - " service_id=\"chat\",\n", - " max_tokens=7000,\n", - " temperature=0.7,\n", - " top_p=0.8,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here we create a chat history with a system message and some initial context:" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "history = ChatHistory()\n", - "system_message = \"\"\"\n", - "You are a chat bot. Your name is Archie and\n", - "you have one goal: help people find answers\n", - "to technical questions by relying on the latest\n", - "research papers published on ArXiv.\n", - "You communicate effectively in the style of a helpful librarian. \n", - "You always make sure to include the\n", - "ArXiV paper references in your responses.\n", - "If you cannot find the answer in the papers,\n", - "you will let the user know, but also provide the papers\n", - "you did find to be most relevant. If the abstract of the \n", - "paper does not specifically reference the user's inquiry,\n", - "but you believe it might be relevant, you can still include it\n", - "BUT you must make sure to mention that the paper might not directly\n", - "address the user's inquiry. Make certain that the papers you link are\n", - "from a specific search result.\n", - "\"\"\"\n", - "history.add_system_message(system_message)\n", - "history.add_user_message(\"Hi there, who are you?\")\n", - "history.add_assistant_message(\n", - " \"I am Archie, the ArXiV chat bot. I'm here to help you find the latest research papers from ArXiv that relate to your inquiries.\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can now invoke the chat function via the Kernel to get chat completions:" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "arguments = KernelArguments(\n", - " user_input=query,\n", - " chat_history=history,\n", - " settings=execution_settings,\n", - ")\n", - "\n", - "result = await kernel.invoke(chat_function, arguments=arguments)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Printing the result shows that the chat completion service used our text search to locate relevant ArXiV papers based on the query:" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Archie:>\n", - "What an excellent and timely question! Chunking strategies for unstructured text are\n", - "critical for optimizing Retrieval-Augmented Generation (RAG) systems since they\n", - "significantly affect how effectively a RAG model can retrieve and generate contextually\n", - "relevant information. Let me consult the latest papers on this topic from ArXiv and\n", - "provide you with relevant insights.\n", - "---\n", - "Here are some recent papers that dive into chunking strategies or similar concepts for\n", - "retrieval-augmented frameworks:\n", - "1. **\"Post-training optimization of retrieval-augmented generation models\"**\n", - " *Authors*: Vibhor Agarwal et al.\n", - " *Abstract*: While the paper discusses optimization strategies for retrieval-augmented\n", - "generation models, there is a discussion on handling unstructured text that could apply to\n", - "chunking methodologies. Chunking isn't always explicitly mentioned as \"chunking\" but may\n", - "be referred to in contexts like splitting data for retrieval.\n", - " *ArXiv link*: [arXiv:2308.10701](https://arxiv.org/abs/2308.10701)\n", - " *Note*: This paper may not focus entirely on chunking strategies but might discuss\n", - "relevant downstream considerations. It could still provide a foundation for you to explore\n", - "how chunking integrates with retrievers.\n", - "2. **\"Beyond Text: Retrieval-Augmented Reranking for Open-Domain Tasks\"**\n", - " *Authors*: Younggyo Seo et al.\n", - " *Abstract*: Although primarily focused on retrieval augmentation for reranking, there\n", - "are reflections on how document structure impacts task performance. Chunking unstructured\n", - "text to improve retrievability for such tasks could indirectly relate to this work.\n", - " *ArXiv link*: [arXiv:2310.03714](https://arxiv.org/abs/2310.03714)\n", - "3. **\"ALMA: Alignment of Generative and Retrieval Models for Long Documents\"**\n", - " *Authors*: Yao Fu et al.\n", - " *Abstract excerpt*: \"Our approach is designed to handle retrieval and generation for\n", - "long documents by aligning the retrieval and generation models more effectively.\"\n", - "Strategies to divide and process long documents into smaller chunks for efficient\n", - "alignment are explicitly discussed. A focus on handling unstructured long-form content\n", - "makes this paper highly relevant.\n", - " *ArXiv link*: [arXiv:2308.05467](https://arxiv.org/abs/2308.05467)\n", - "4. **\"Enhancing Context-aware Question Generation with Multi-modal Knowledge\"**\n", - " *Authors*: Jialong Han et al.\n", - " *Abstract excerpt*: \"Proposed techniques focus on improving retrievals through better\n", - "division of available knowledge.\" It doesn’t focus solely on text chunking in the RAG\n", - "framework but might be interesting since contextual awareness often relates to\n", - "preprocessing unstructured input into structured chunks.\n", - " *ArXiv link*: [arXiv:2307.12345](https://arxiv.org/abs/2307.12345)\n", - "---\n", - "### Practical Approaches Discussed in Literature:\n", - "From my broad understanding of RAG systems and some of the details in these papers, here\n", - "are common chunking strategies discussed in the research community:\n", - "1. **Sliding Window Approach**: Divide the text into overlapping chunks of fixed lengths\n", - "(e.g., 512 tokens with an overlap of 128 tokens). This helps ensure no important context\n", - "is left behind when chunks are created.\n", - "\n", - "2. **Semantic Chunking**: Use sentence embeddings or clustering techniques (e.g., via Bi-\n", - "Encoders or Sentence Transformers) to ensure chunks align semantically rather than naively\n", - "by token count.\n", - "3. **Dynamic Partitioning**: Implement chunking based on higher-order structure in the\n", - "text, such as splitting at sentence boundaries, paragraph breaks, or logical sections.\n", - "4. **Content-aware Chunking**: Experiment with LLMs to pre-identify contextual relevance\n", - "of different parts of the text and chunk accordingly.\n", - "---\n", - "If you'd like, I can search more specifically on a sub-part of chunking strategies or\n", - "related RAG optimizations. Let me know!\n" - ] - } - ], - "source": [ - "def wrap_text(text, width=90):\n", - " paragraphs = text.split(\"\\n\\n\") # Split the text into paragraphs\n", - " wrapped_paragraphs = [\n", - " \"\\n\".join(textwrap.fill(part, width=width) for paragraph in paragraphs for part in paragraph.split(\"\\n\"))\n", - " ] # Wrap each paragraph, split by newlines\n", - " return \"\\n\\n\".join(wrapped_paragraphs) # Join the wrapped paragraphs back together\n", - "\n", - "\n", - "print(f\"Archie:>\\n{wrap_text(str(result))}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.15" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} \ No newline at end of file + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using Postgres as memory\n", + "\n", + "This notebook shows how to use Postgres as a memory store in Semantic Kernel.\n", + "\n", + "The code below pulls the most recent papers from [ArviX](https://arxiv.org/), creates embeddings from the paper abstracts, and stores them in a Postgres database.\n", + "\n", + "In the future, we can use the Postgres vector store to search the database for similar papers based on the embeddings - stay tuned!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import textwrap\n", + "import xml.etree.ElementTree as ET\n", + "from dataclasses import dataclass\n", + "from datetime import datetime\n", + "from typing import Annotated, Any\n", + "\n", + "import requests\n", + "\n", + "from semantic_kernel import Kernel\n", + "from semantic_kernel.connectors.ai import FunctionChoiceBehavior\n", + "from semantic_kernel.connectors.ai.open_ai import (\n", + " AzureChatCompletion,\n", + " AzureChatPromptExecutionSettings,\n", + " AzureTextEmbedding,\n", + " OpenAITextEmbedding,\n", + ")\n", + "from semantic_kernel.connectors.memory.postgres import PostgresCollection\n", + "from semantic_kernel.contents import ChatHistory\n", + "from semantic_kernel.data import (\n", + " VectorStoreField,\n", + " vectorstoremodel,\n", + ")\n", + "from semantic_kernel.data._vectors import DistanceFunction, IndexKind\n", + "from semantic_kernel.functions import KernelParameterMetadata\n", + "from semantic_kernel.functions.kernel_arguments import KernelArguments" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set up your environment\n", + "\n", + "You'll need to set up your environment to provide connection information to Postgres, as well as OpenAI or Azure OpenAI.\n", + "\n", + "To do this, copy the `.env.example` file to `.env` and fill in the necessary information.\n", + "\n", + "__Note__: If you're using VSCode to execute the notebook, the settings in `.env` in the root of the repository will be picked up automatically.\n", + "\n", + "### Postgres configuration\n", + "\n", + "You'll need to provide a connection string to a Postgres database. You can use a local Postgres instance, or a cloud-hosted one.\n", + "You can provide a connection string, or provide environment variables with the connection information. See the .env.example file for `POSTGRES_` settings.\n", + "\n", + "#### Using Docker\n", + "\n", + "You can also use docker to bring up a Postgres instance by following the steps below:\n", + "\n", + "Create an `init.sql` that has the following:\n", + "\n", + "```sql\n", + "CREATE EXTENSION IF NOT EXISTS vector;\n", + "```\n", + "\n", + "Now you can start a postgres instance with the following:\n", + "\n", + "```\n", + "docker pull pgvector/pgvector:pg16\n", + "docker run --rm -it --name pgvector -p 5432:5432 -v ./init.sql:/docker-entrypoint-initdb.d/init.sql -e POSTGRES_PASSWORD=example pgvector/pgvector:pg16\n", + "```\n", + "\n", + "_Note_: Use `.\\init.sql` on Windows and `./init.sql` on WSL or Linux/Mac.\n", + "\n", + "Then you could use the connection string:\n", + "\n", + "```\n", + "POSTGRES_CONNECTION_STRING=\"host=localhost port=5432 dbname=postgres user=postgres password=example\"\n", + "```\n", + "\n", + "### OpenAI configuration\n", + "\n", + "You can either use OpenAI or Azure OpenAI APIs. You provide the API key and other configuration in the `.env` file. Set either the `OPENAI_` or `AZURE_OPENAI_` settings.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Path to the environment file\n", + "env_file_path = \".env\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we set some additional configuration." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# -- ArXiv settings --\n", + "\n", + "# The search term to use when searching for papers on arXiv. All metadata fields for the papers are searched.\n", + "SEARCH_TERM = \"RAG\"\n", + "\n", + "# The category of papers to search for on arXiv. See https://arxiv.org/category_taxonomy for a list of categories.\n", + "ARVIX_CATEGORY = \"cs.AI\"\n", + "\n", + "# The maximum number of papers to search for on arXiv.\n", + "MAX_RESULTS = 300\n", + "\n", + "# -- OpenAI settings --\n", + "\n", + "# Set this flag to False to use the OpenAI API instead of Azure OpenAI\n", + "USE_AZURE_OPENAI = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we define a vector store model. This model defines the table and column names for storing the embeddings. We use the `@vectorstoremodel` decorator to tell Semantic Kernel to create a vector store definition from the model. The VectorStoreRecordField annotations define the fields that will be stored in the database, including key and vector fields." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@vectorstoremodel\n", + "@dataclass\n", + "class ArxivPaper:\n", + " id: Annotated[str, VectorStoreField(\"key\")]\n", + " title: Annotated[str, VectorStoreField(\"data\")]\n", + " abstract: Annotated[str, VectorStoreField(\"data\")]\n", + " published: Annotated[datetime, VectorStoreField(\"data\")]\n", + " authors: Annotated[list[str], VectorStoreField(\"data\")]\n", + " link: Annotated[str | None, VectorStoreField(\"data\")]\n", + " abstract_vector: Annotated[\n", + " list[float] | str | None,\n", + " VectorStoreField(\n", + " \"vector\",\n", + " index_kind=IndexKind.HNSW,\n", + " dimensions=1536,\n", + " distance_function=DistanceFunction.COSINE_DISTANCE,\n", + " ),\n", + " ] = None\n", + "\n", + " def __post_init__(self):\n", + " if self.abstract_vector is None:\n", + " self.abstract_vector = self.abstract\n", + "\n", + " @classmethod\n", + " def from_arxiv_info(cls, arxiv_info: dict[str, Any]) -> \"ArxivPaper\":\n", + " return cls(\n", + " id=arxiv_info[\"id\"],\n", + " title=arxiv_info[\"title\"].replace(\"\\n \", \" \"),\n", + " abstract=arxiv_info[\"abstract\"].replace(\"\\n \", \" \"),\n", + " published=arxiv_info[\"published\"],\n", + " authors=arxiv_info[\"authors\"],\n", + " link=arxiv_info[\"link\"],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Below is a function that queries the ArviX API for the most recent papers based on our search query and category." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def query_arxiv(search_query: str, category: str = \"cs.AI\", max_results: int = 10) -> list[dict[str, Any]]:\n", + " \"\"\"\n", + " Query the ArXiv API and return a list of dictionaries with relevant metadata for each paper.\n", + "\n", + " Args:\n", + " search_query: The search term or topic to query for.\n", + " category: The category to restrict the search to (default is \"cs.AI\").\n", + " See https://arxiv.org/category_taxonomy for a list of categories.\n", + " max_results: Maximum number of results to retrieve (default is 10).\n", + " \"\"\"\n", + " response = requests.get(\n", + " \"http://export.arxiv.org/api/query?\"\n", + " f\"search_query=all:%22{search_query.replace(' ', '+')}%22\"\n", + " f\"+AND+cat:{category}&start=0&max_results={max_results}&sortBy=lastUpdatedDate&sortOrder=descending\"\n", + " )\n", + "\n", + " root = ET.fromstring(response.content)\n", + " ns = {\"atom\": \"http://www.w3.org/2005/Atom\"}\n", + "\n", + " return [\n", + " {\n", + " \"id\": entry.find(\"atom:id\", ns).text.split(\"/\")[-1],\n", + " \"title\": entry.find(\"atom:title\", ns).text,\n", + " \"abstract\": entry.find(\"atom:summary\", ns).text,\n", + " \"published\": entry.find(\"atom:published\", ns).text,\n", + " \"link\": entry.find(\"atom:id\", ns).text,\n", + " \"authors\": [author.find(\"atom:name\", ns).text for author in entry.findall(\"atom:author\", ns)],\n", + " \"categories\": [category.get(\"term\") for category in entry.findall(\"atom:category\", ns)],\n", + " \"pdf_link\": next(\n", + " (link_tag.get(\"href\") for link_tag in entry.findall(\"atom:link\", ns) if link_tag.get(\"title\") == \"pdf\"),\n", + " None,\n", + " ),\n", + " }\n", + " for entry in root.findall(\"atom:entry\", ns)\n", + " ]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use this function to query papers and store them in memory as our model types." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 300 papers on 'RAG'\n" + ] + } + ], + "source": [ + "arxiv_papers: list[ArxivPaper] = [\n", + " ArxivPaper.from_arxiv_info(paper)\n", + " for paper in query_arxiv(SEARCH_TERM, category=ARVIX_CATEGORY, max_results=MAX_RESULTS)\n", + "]\n", + "\n", + "print(f\"Found {len(arxiv_papers)} papers on '{SEARCH_TERM}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create a `PostgresCollection`, which represents the table in Postgres where we will store the paper information and embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if USE_AZURE_OPENAI:\n", + " text_embedding = AzureTextEmbedding(service_id=\"embedding\", env_file_path=env_file_path)\n", + "else:\n", + " text_embedding = OpenAITextEmbedding(service_id=\"embedding\", env_file_path=env_file_path)\n", + "collection = PostgresCollection[str, ArxivPaper](\n", + " collection_name=\"arxiv_records\",\n", + " record_type=ArxivPaper,\n", + " env_file_path=env_file_path,\n", + " embedding_generator=text_embedding,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that the models have embeddings, we can write them into the Postgres database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async with collection:\n", + " await collection.ensure_collection_exists()\n", + " keys = await collection.upsert(arxiv_papers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we retrieve the first few models from the database and print out their information." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# Engineering LLM Powered Multi-agent Framework for Autonomous CloudOps\n", + "\n", + "Abstract: Cloud Operations (CloudOps) is a rapidly growing field focused on the\n", + "automated management and optimization of cloud infrastructure which is essential\n", + "for organizations navigating increasingly complex cloud environments. MontyCloud\n", + "Inc. is one of the major companies in the CloudOps domain that leverages\n", + "autonomous bots to manage cloud compliance, security, and continuous operations.\n", + "To make the platform more accessible and effective to the customers, we\n", + "leveraged the use of GenAI. Developing a GenAI-based solution for autonomous\n", + "CloudOps for the existing MontyCloud system presented us with various challenges\n", + "such as i) diverse data sources; ii) orchestration of multiple processes; and\n", + "iii) handling complex workflows to automate routine tasks. To this end, we\n", + "developed MOYA, a multi-agent framework that leverages GenAI and balances\n", + "autonomy with the necessary human control. This framework integrates various\n", + "internal and external systems and is optimized for factors like task\n", + "orchestration, security, and error mitigation while producing accurate,\n", + "reliable, and relevant insights by utilizing Retrieval Augmented Generation\n", + "(RAG). Evaluations of our multi-agent system with the help of practitioners as\n", + "well as using automated checks demonstrate enhanced accuracy, responsiveness,\n", + "and effectiveness over non-agentic approaches across complex workflows.\n", + "Published: 2025-01-14 16:30:10\n", + "Link: http://arxiv.org/abs/2501.08243v1\n", + "PDF Link: http://arxiv.org/abs/2501.08243v1\n", + "Authors: Kannan Parthasarathy, Karthik Vaidhyanathan, Rudra Dhar, Venkat Krishnamachari, Basil Muhammed, Adyansh Kakran, Sreemaee Akshathala, Shrikara Arun, Sumant Dubey, Mohan Veerubhotla, Amey Karan\n", + "Embedding: [ 0.01063822 0.02977918 0.04532182 ... -0.00264323 0.00081101\n", + " 0.01491571]\n", + "\n", + "\n", + "# Eliciting In-context Retrieval and Reasoning for Long-context Large Language Models\n", + "\n", + "Abstract: Recent advancements in long-context language models (LCLMs) promise to\n", + "transform Retrieval-Augmented Generation (RAG) by simplifying pipelines. With\n", + "their expanded context windows, LCLMs can process entire knowledge bases and\n", + "perform retrieval and reasoning directly -- a capability we define as In-Context\n", + "Retrieval and Reasoning (ICR^2). However, existing benchmarks like LOFT often\n", + "overestimate LCLM performance by providing overly simplified contexts. To\n", + "address this, we introduce ICR^2, a benchmark that evaluates LCLMs in more\n", + "realistic scenarios by including confounding passages retrieved with strong\n", + "retrievers. We then propose three methods to enhance LCLM performance: (1)\n", + "retrieve-then-generate fine-tuning, (2) retrieval-attention-probing, which uses\n", + "attention heads to filter and de-noise long contexts during decoding, and (3)\n", + "joint retrieval head training alongside the generation head. Our evaluation of\n", + "five well-known LCLMs on LOFT and ICR^2 demonstrates significant gains with our\n", + "best approach applied to Mistral-7B: +17 and +15 points by Exact Match on LOFT,\n", + "and +13 and +2 points on ICR^2, compared to vanilla RAG and supervised fine-\n", + "tuning, respectively. It even outperforms GPT-4-Turbo on most tasks despite\n", + "being a much smaller model.\n", + "Published: 2025-01-14 16:38:33\n", + "Link: http://arxiv.org/abs/2501.08248v1\n", + "PDF Link: http://arxiv.org/abs/2501.08248v1\n", + "Authors: Yifu Qiu, Varun Embar, Yizhe Zhang, Navdeep Jaitly, Shay B. Cohen, Benjamin Han\n", + "Embedding: [-0.01305697 0.01166064 0.06267344 ... -0.01627254 0.00974741\n", + " -0.00573298]\n", + "\n", + "\n", + "# ADAM-1: AI and Bioinformatics for Alzheimer's Detection and Microbiome-Clinical Data Integrations\n", + "\n", + "Abstract: The Alzheimer's Disease Analysis Model Generation 1 (ADAM) is a multi-agent\n", + "large language model (LLM) framework designed to integrate and analyze multi-\n", + "modal data, including microbiome profiles, clinical datasets, and external\n", + "knowledge bases, to enhance the understanding and detection of Alzheimer's\n", + "disease (AD). By leveraging retrieval-augmented generation (RAG) techniques\n", + "along with its multi-agent architecture, ADAM-1 synthesizes insights from\n", + "diverse data sources and contextualizes findings using literature-driven\n", + "evidence. Comparative evaluation against XGBoost revealed similar mean F1 scores\n", + "but significantly reduced variance for ADAM-1, highlighting its robustness and\n", + "consistency, particularly in small laboratory datasets. While currently tailored\n", + "for binary classification tasks, future iterations aim to incorporate additional\n", + "data modalities, such as neuroimaging and biomarkers, to broaden the scalability\n", + "and applicability for Alzheimer's research and diagnostics.\n", + "Published: 2025-01-14 18:56:33\n", + "Link: http://arxiv.org/abs/2501.08324v1\n", + "PDF Link: http://arxiv.org/abs/2501.08324v1\n", + "Authors: Ziyuan Huang, Vishaldeep Kaur Sekhon, Ouyang Guo, Mark Newman, Roozbeh Sadeghian, Maria L. Vaida, Cynthia Jo, Doyle Ward, Vanni Bucci, John P. Haran\n", + "Embedding: [ 0.03896349 0.00422515 0.05525447 ... 0.03374933 -0.01468264\n", + " 0.01850895]\n", + "\n", + "\n" + ] + } + ], + "source": [ + "async with collection:\n", + " results = await collection.get(keys[:3])\n", + " if results:\n", + " for result in results:\n", + " print(f\"# {result.title}\")\n", + " print()\n", + " wrapped_abstract = textwrap.fill(result.abstract, width=80)\n", + " print(f\"Abstract: {wrapped_abstract}\")\n", + " print(f\"Published: {result.published}\")\n", + " print(f\"Link: {result.link}\")\n", + " print(f\"PDF Link: {result.link}\")\n", + " print(f\"Authors: {', '.join(result.authors)}\")\n", + " print(f\"Embedding: {result.abstract_vector}\")\n", + " print()\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `VectorStoreTextSearch` object gives us the ability to retrieve semantically similar documents directly from a prompt.\n", + "Here we search for the top 5 ArXiV abstracts in our database similar to the query about chunking strategies in RAG applications:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 5 results for query.\n", + "Advanced ingestion process powered by LLM parsing for RAG system: 0.38676463602221456\n", + "StructRAG: Boosting Knowledge Intensive Reasoning of LLMs via Inference-time Hybrid Information Structurization: 0.39733734194342085\n", + "UDA: A Benchmark Suite for Retrieval Augmented Generation in Real-world Document Analysis: 0.3981809737466562\n", + "R^2AG: Incorporating Retrieval Information into Retrieval Augmented Generation: 0.4134050114864055\n", + "Enhancing Retrieval-Augmented Generation: A Study of Best Practices: 0.4144733752075731\n" + ] + } + ], + "source": [ + "query = \"What are good chunking strategies to use for unstructured text in Retrieval-Augmented Generation applications?\"\n", + "\n", + "async with collection:\n", + " search_results = await collection.search(query, top=5, include_total_count=True)\n", + " print(f\"Found {search_results.total_count} results for query.\")\n", + " async for search_result in search_results.results:\n", + " title = search_result.record.title\n", + " score = search_result.score\n", + " print(f\"{title}: {score}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can enable chat completion to utilize the text search by creating a kernel function for searching the database..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "kernel = Kernel()\n", + "plugin = kernel.add_functions(\n", + " plugin_name=\"arxiv_plugin\",\n", + " functions=[\n", + " collection.create_search_function(\n", + " # The default parameters match the parameters of the VectorSearchOptions class.\n", + " description=\"Searches for ArXiv papers that are related to the query.\",\n", + " parameters=[\n", + " KernelParameterMetadata(\n", + " name=\"query\", description=\"What to search for.\", type=\"str\", is_required=True, type_object=str\n", + " ),\n", + " KernelParameterMetadata(\n", + " name=\"top\",\n", + " description=\"Number of results to return.\",\n", + " type=\"int\",\n", + " default_value=2,\n", + " type_object=int,\n", + " ),\n", + " ],\n", + " ),\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "...and then setting up a chat completions service that uses `FunctionChoiceBehavior.Auto` to automatically call the search function when appropriate to the users query. We also create the chat function that will be invoked by the kernel." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the chat completion service. This requires an Azure OpenAI completions model deployment and configuration.\n", + "chat_completion = AzureChatCompletion(service_id=\"completions\")\n", + "kernel.add_service(chat_completion)\n", + "\n", + "# Now we create the chat function that will use the chat service.\n", + "chat_function = kernel.add_function(\n", + " prompt=\"{{$chat_history}}{{$user_input}}\",\n", + " plugin_name=\"ChatBot\",\n", + " function_name=\"Chat\",\n", + ")\n", + "\n", + "# we set the function choice to Auto, so that the LLM can choose the correct function to call.\n", + "# and we exclude the ChatBot plugin, so that it does not call itself.\n", + "execution_settings = AzureChatPromptExecutionSettings(\n", + " function_choice_behavior=FunctionChoiceBehavior.Auto(filters={\"excluded_plugins\": [\"ChatBot\"]}),\n", + " service_id=\"chat\",\n", + " max_tokens=7000,\n", + " temperature=0.7,\n", + " top_p=0.8,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we create a chat history with a system message and some initial context:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "history = ChatHistory()\n", + "system_message = \"\"\"\n", + "You are a chat bot. Your name is Archie and\n", + "you have one goal: help people find answers\n", + "to technical questions by relying on the latest\n", + "research papers published on ArXiv.\n", + "You communicate effectively in the style of a helpful librarian. \n", + "You always make sure to include the\n", + "ArXiV paper references in your responses.\n", + "If you cannot find the answer in the papers,\n", + "you will let the user know, but also provide the papers\n", + "you did find to be most relevant. If the abstract of the \n", + "paper does not specifically reference the user's inquiry,\n", + "but you believe it might be relevant, you can still include it\n", + "BUT you must make sure to mention that the paper might not directly\n", + "address the user's inquiry. Make certain that the papers you link are\n", + "from a specific search result.\n", + "\"\"\"\n", + "history.add_system_message(system_message)\n", + "history.add_user_message(\"Hi there, who are you?\")\n", + "history.add_assistant_message(\n", + " \"I am Archie, the ArXiV chat bot. I'm here to help you find the latest research papers from ArXiv that relate to your inquiries.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now invoke the chat function via the Kernel to get chat completions:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "arguments = KernelArguments(\n", + " user_input=query,\n", + " chat_history=history,\n", + " settings=execution_settings,\n", + ")\n", + "\n", + "result = await kernel.invoke(chat_function, arguments=arguments)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Printing the result shows that the chat completion service used our text search to locate relevant ArXiV papers based on the query:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Archie:>\n", + "What an excellent and timely question! Chunking strategies for unstructured text are\n", + "critical for optimizing Retrieval-Augmented Generation (RAG) systems since they\n", + "significantly affect how effectively a RAG model can retrieve and generate contextually\n", + "relevant information. Let me consult the latest papers on this topic from ArXiv and\n", + "provide you with relevant insights.\n", + "---\n", + "Here are some recent papers that dive into chunking strategies or similar concepts for\n", + "retrieval-augmented frameworks:\n", + "1. **\"Post-training optimization of retrieval-augmented generation models\"**\n", + " *Authors*: Vibhor Agarwal et al.\n", + " *Abstract*: While the paper discusses optimization strategies for retrieval-augmented\n", + "generation models, there is a discussion on handling unstructured text that could apply to\n", + "chunking methodologies. Chunking isn't always explicitly mentioned as \"chunking\" but may\n", + "be referred to in contexts like splitting data for retrieval.\n", + " *ArXiv link*: [arXiv:2308.10701](https://arxiv.org/abs/2308.10701)\n", + " *Note*: This paper may not focus entirely on chunking strategies but might discuss\n", + "relevant downstream considerations. It could still provide a foundation for you to explore\n", + "how chunking integrates with retrievers.\n", + "2. **\"Beyond Text: Retrieval-Augmented Reranking for Open-Domain Tasks\"**\n", + " *Authors*: Younggyo Seo et al.\n", + " *Abstract*: Although primarily focused on retrieval augmentation for reranking, there\n", + "are reflections on how document structure impacts task performance. Chunking unstructured\n", + "text to improve retrievability for such tasks could indirectly relate to this work.\n", + " *ArXiv link*: [arXiv:2310.03714](https://arxiv.org/abs/2310.03714)\n", + "3. **\"ALMA: Alignment of Generative and Retrieval Models for Long Documents\"**\n", + " *Authors*: Yao Fu et al.\n", + " *Abstract excerpt*: \"Our approach is designed to handle retrieval and generation for\n", + "long documents by aligning the retrieval and generation models more effectively.\"\n", + "Strategies to divide and process long documents into smaller chunks for efficient\n", + "alignment are explicitly discussed. A focus on handling unstructured long-form content\n", + "makes this paper highly relevant.\n", + " *ArXiv link*: [arXiv:2308.05467](https://arxiv.org/abs/2308.05467)\n", + "4. **\"Enhancing Context-aware Question Generation with Multi-modal Knowledge\"**\n", + " *Authors*: Jialong Han et al.\n", + " *Abstract excerpt*: \"Proposed techniques focus on improving retrievals through better\n", + "division of available knowledge.\" It doesn’t focus solely on text chunking in the RAG\n", + "framework but might be interesting since contextual awareness often relates to\n", + "preprocessing unstructured input into structured chunks.\n", + " *ArXiv link*: [arXiv:2307.12345](https://arxiv.org/abs/2307.12345)\n", + "---\n", + "### Practical Approaches Discussed in Literature:\n", + "From my broad understanding of RAG systems and some of the details in these papers, here\n", + "are common chunking strategies discussed in the research community:\n", + "1. **Sliding Window Approach**: Divide the text into overlapping chunks of fixed lengths\n", + "(e.g., 512 tokens with an overlap of 128 tokens). This helps ensure no important context\n", + "is left behind when chunks are created.\n", + "\n", + "2. **Semantic Chunking**: Use sentence embeddings or clustering techniques (e.g., via Bi-\n", + "Encoders or Sentence Transformers) to ensure chunks align semantically rather than naively\n", + "by token count.\n", + "3. **Dynamic Partitioning**: Implement chunking based on higher-order structure in the\n", + "text, such as splitting at sentence boundaries, paragraph breaks, or logical sections.\n", + "4. **Content-aware Chunking**: Experiment with LLMs to pre-identify contextual relevance\n", + "of different parts of the text and chunk accordingly.\n", + "---\n", + "If you'd like, I can search more specifically on a sub-part of chunking strategies or\n", + "related RAG optimizations. Let me know!\n" + ] + } + ], + "source": [ + "def wrap_text(text, width=90):\n", + " paragraphs = text.split(\"\\n\\n\") # Split the text into paragraphs\n", + " wrapped_paragraphs = [\n", + " \"\\n\".join(textwrap.fill(part, width=width) for paragraph in paragraphs for part in paragraph.split(\"\\n\"))\n", + " ] # Wrap each paragraph, split by newlines\n", + " return \"\\n\\n\".join(wrapped_paragraphs) # Join the wrapped paragraphs back together\n", + "\n", + "\n", + "print(f\"Archie:>\\n{wrap_text(str(result))}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/semantic_kernel/connectors/memory/azure_ai_search.py b/python/semantic_kernel/connectors/memory/azure_ai_search.py index de041b5c502c..702778051ca2 100644 --- a/python/semantic_kernel/connectors/memory/azure_ai_search.py +++ b/python/semantic_kernel/connectors/memory/azure_ai_search.py @@ -29,11 +29,12 @@ from pydantic import SecretStr, ValidationError from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data.const import DistanceFunction, IndexKind -from semantic_kernel.data.definitions import FieldTypes, VectorStoreCollectionDefinition -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import ( +from semantic_kernel.data._definitions import FieldTypes, VectorStoreCollectionDefinition +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DistanceFunction, GetFilteredRecordOptions, + IndexKind, SearchType, TModel, VectorSearch, @@ -432,12 +433,11 @@ async def _inner_get( if options is not None: ordering = [] if options.order_by: - order_by = options.order_by if isinstance(options.order_by, Sequence) else [options.order_by] - for order in order_by: - if order.field not in self.definition.storage_names: - logger.warning(f"Field {order.field} not in data model, skipping.") + for field, asc_flag in options.order_by.items(): + if field not in self.definition.storage_names: + logger.warning(f"Field {field} not in data model, skipping.") continue - ordering.append(order.field if order.ascending else f"{order.field} desc") + ordering.append(field if asc_flag else f"{field} desc") result = await client.search( search_text="*", diff --git a/python/semantic_kernel/connectors/memory/azure_cosmos_db.py b/python/semantic_kernel/connectors/memory/azure_cosmos_db.py index 6ce200dc3b4f..f90ff5989170 100644 --- a/python/semantic_kernel/connectors/memory/azure_cosmos_db.py +++ b/python/semantic_kernel/connectors/memory/azure_cosmos_db.py @@ -23,11 +23,12 @@ MongoDBAtlasCollection, MongoDBAtlasStore, ) -from semantic_kernel.data.const import DistanceFunction, IndexKind -from semantic_kernel.data.definitions import FieldTypes, VectorStoreCollectionDefinition -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import ( +from semantic_kernel.data._definitions import FieldTypes, VectorStoreCollectionDefinition +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DistanceFunction, GetFilteredRecordOptions, + IndexKind, SearchType, TModel, VectorSearch, diff --git a/python/semantic_kernel/connectors/memory/chroma.py b/python/semantic_kernel/connectors/memory/chroma.py index cbe86665a87b..83e664f1e2cc 100644 --- a/python/semantic_kernel/connectors/memory/chroma.py +++ b/python/semantic_kernel/connectors/memory/chroma.py @@ -13,11 +13,12 @@ from chromadb.config import Settings from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data.const import DistanceFunction, IndexKind -from semantic_kernel.data.definitions import VectorStoreCollectionDefinition -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import ( +from semantic_kernel.data._definitions import VectorStoreCollectionDefinition +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DistanceFunction, GetFilteredRecordOptions, + IndexKind, SearchType, TModel, VectorSearch, diff --git a/python/semantic_kernel/connectors/memory/faiss.py b/python/semantic_kernel/connectors/memory/faiss.py index d9aaa7038efa..c665e3f3a5d3 100644 --- a/python/semantic_kernel/connectors/memory/faiss.py +++ b/python/semantic_kernel/connectors/memory/faiss.py @@ -10,10 +10,16 @@ from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase from semantic_kernel.connectors.memory.in_memory import IN_MEMORY_SCORE_KEY, InMemoryCollection, InMemoryStore, TKey -from semantic_kernel.data.const import DistanceFunction, IndexKind -from semantic_kernel.data.definitions import VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import SearchType, TModel, VectorSearchOptions, VectorSearchResult +from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DistanceFunction, + IndexKind, + SearchType, + TModel, + VectorSearchOptions, + VectorSearchResult, +) from semantic_kernel.exceptions import VectorStoreInitializationException, VectorStoreOperationException from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreModelException diff --git a/python/semantic_kernel/connectors/memory/in_memory.py b/python/semantic_kernel/connectors/memory/in_memory.py index 636f0526ed53..d5ce2ca41458 100644 --- a/python/semantic_kernel/connectors/memory/in_memory.py +++ b/python/semantic_kernel/connectors/memory/in_memory.py @@ -11,10 +11,11 @@ from typing_extensions import override from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data.const import DISTANCE_FUNCTION_DIRECTION_HELPER, DistanceFunction -from semantic_kernel.data.definitions import VectorStoreCollectionDefinition -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import ( +from semantic_kernel.data._definitions import VectorStoreCollectionDefinition +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DISTANCE_FUNCTION_DIRECTION_HELPER, + DistanceFunction, GetFilteredRecordOptions, SearchType, TModel, diff --git a/python/semantic_kernel/connectors/memory/mongodb.py b/python/semantic_kernel/connectors/memory/mongodb.py index 102e96008282..1d2ed3a16e34 100644 --- a/python/semantic_kernel/connectors/memory/mongodb.py +++ b/python/semantic_kernel/connectors/memory/mongodb.py @@ -15,10 +15,10 @@ from pymongo.operations import SearchIndexModel from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data.const import DistanceFunction -from semantic_kernel.data.definitions import VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import ( +from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DistanceFunction, GetFilteredRecordOptions, SearchType, TModel, diff --git a/python/semantic_kernel/connectors/memory/pinecone.py b/python/semantic_kernel/connectors/memory/pinecone.py index d0387b65bf02..edeea0922fb8 100644 --- a/python/semantic_kernel/connectors/memory/pinecone.py +++ b/python/semantic_kernel/connectors/memory/pinecone.py @@ -13,10 +13,10 @@ from pydantic import SecretStr, ValidationError from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data.const import DistanceFunction -from semantic_kernel.data.definitions import VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import ( +from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DistanceFunction, GetFilteredRecordOptions, SearchType, TModel, diff --git a/python/semantic_kernel/connectors/memory/postgres.py b/python/semantic_kernel/connectors/memory/postgres.py index 2e2edb9187be..9156c38bf8da 100644 --- a/python/semantic_kernel/connectors/memory/postgres.py +++ b/python/semantic_kernel/connectors/memory/postgres.py @@ -17,11 +17,12 @@ from pydantic_settings import SettingsConfigDict from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data.const import DistanceFunction, IndexKind -from semantic_kernel.data.definitions import FieldTypes, VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import ( +from semantic_kernel.data._definitions import FieldTypes, VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DistanceFunction, GetFilteredRecordOptions, + IndexKind, SearchType, TModel, VectorSearch, diff --git a/python/semantic_kernel/connectors/memory/qdrant.py b/python/semantic_kernel/connectors/memory/qdrant.py index bd49427331de..cfe81f59d7fd 100644 --- a/python/semantic_kernel/connectors/memory/qdrant.py +++ b/python/semantic_kernel/connectors/memory/qdrant.py @@ -28,11 +28,12 @@ from typing_extensions import override from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data.const import DistanceFunction, IndexKind -from semantic_kernel.data.definitions import VectorStoreCollectionDefinition -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import ( +from semantic_kernel.data._definitions import VectorStoreCollectionDefinition +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DistanceFunction, GetFilteredRecordOptions, + IndexKind, SearchType, TModel, VectorSearch, diff --git a/python/semantic_kernel/connectors/memory/redis.py b/python/semantic_kernel/connectors/memory/redis.py index 61a7472529ff..9c0de7c7fa35 100644 --- a/python/semantic_kernel/connectors/memory/redis.py +++ b/python/semantic_kernel/connectors/memory/redis.py @@ -24,11 +24,12 @@ from redisvl.schema import StorageType from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data.const import DistanceFunction, IndexKind -from semantic_kernel.data.definitions import FieldTypes, VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import ( +from semantic_kernel.data._definitions import FieldTypes, VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DistanceFunction, GetFilteredRecordOptions, + IndexKind, SearchType, TModel, VectorSearch, diff --git a/python/semantic_kernel/connectors/memory/sql_server.py b/python/semantic_kernel/connectors/memory/sql_server.py index 79de3cd3d6a4..5b1ccf796934 100644 --- a/python/semantic_kernel/connectors/memory/sql_server.py +++ b/python/semantic_kernel/connectors/memory/sql_server.py @@ -17,11 +17,13 @@ from pydantic import SecretStr, ValidationError, field_validator from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data.const import DISTANCE_FUNCTION_DIRECTION_HELPER, DistanceFunction, IndexKind -from semantic_kernel.data.definitions import VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import ( +from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DISTANCE_FUNCTION_DIRECTION_HELPER, + DistanceFunction, GetFilteredRecordOptions, + IndexKind, SearchType, VectorSearch, VectorSearchOptions, diff --git a/python/semantic_kernel/connectors/memory/weaviate.py b/python/semantic_kernel/connectors/memory/weaviate.py index e86bdfa75437..431c7a318d19 100644 --- a/python/semantic_kernel/connectors/memory/weaviate.py +++ b/python/semantic_kernel/connectors/memory/weaviate.py @@ -20,11 +20,12 @@ from weaviate.exceptions import WeaviateClosedClientError, WeaviateConnectionError from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data.const import DistanceFunction, IndexKind -from semantic_kernel.data.definitions import VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data.search import KernelSearchResults -from semantic_kernel.data.vectors import ( +from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data._search import KernelSearchResults +from semantic_kernel.data._vectors import ( + DistanceFunction, GetFilteredRecordOptions, + IndexKind, SearchType, TModel, VectorSearch, diff --git a/python/semantic_kernel/connectors/search/brave.py b/python/semantic_kernel/connectors/search/brave.py index bf8bb2b8d157..bf0858a369c4 100644 --- a/python/semantic_kernel/connectors/search/brave.py +++ b/python/semantic_kernel/connectors/search/brave.py @@ -11,13 +11,7 @@ from pydantic import Field, SecretStr, ValidationError from semantic_kernel.connectors.search.utils import SearchLambdaVisitor -from semantic_kernel.data.search import ( - KernelSearchResults, - SearchOptions, - TextSearch, - TextSearchResult, - TSearchResult, -) +from semantic_kernel.data._search import KernelSearchResults, SearchOptions, TextSearch, TextSearchResult, TSearchResult from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidRequestError from semantic_kernel.kernel_pydantic import KernelBaseModel, KernelBaseSettings from semantic_kernel.kernel_types import OptionalOneOrList diff --git a/python/semantic_kernel/connectors/search/google.py b/python/semantic_kernel/connectors/search/google.py index 57c42085fd70..736972798aa0 100644 --- a/python/semantic_kernel/connectors/search/google.py +++ b/python/semantic_kernel/connectors/search/google.py @@ -12,13 +12,7 @@ from pydantic import Field, SecretStr, ValidationError from semantic_kernel.connectors.search.utils import SearchLambdaVisitor -from semantic_kernel.data.search import ( - KernelSearchResults, - SearchOptions, - TextSearch, - TextSearchResult, - TSearchResult, -) +from semantic_kernel.data._search import KernelSearchResults, SearchOptions, TextSearch, TextSearchResult, TSearchResult from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidRequestError from semantic_kernel.kernel_pydantic import KernelBaseModel, KernelBaseSettings from semantic_kernel.kernel_types import OptionalOneOrList diff --git a/python/semantic_kernel/data/__init__.py b/python/semantic_kernel/data/__init__.py index f7c5854f66e0..5797e85c30ff 100644 --- a/python/semantic_kernel/data/__init__.py +++ b/python/semantic_kernel/data/__init__.py @@ -1,20 +1,15 @@ # Copyright (c) Microsoft. All rights reserved. -from semantic_kernel.data.const import ( - DEFAULT_DESCRIPTION, - DEFAULT_FUNCTION_NAME, - DISTANCE_FUNCTION_DIRECTION_HELPER, - DistanceFunction, - IndexKind, -) -from semantic_kernel.data.definitions import ( +from semantic_kernel.data._definitions import ( FieldTypes, VectorStoreCollectionDefinition, VectorStoreField, vectorstoremodel, ) -from semantic_kernel.data.search import ( +from semantic_kernel.data._search import ( + DEFAULT_DESCRIPTION, + DEFAULT_FUNCTION_NAME, DynamicFilterFunction, KernelSearchResults, TextSearch, @@ -22,7 +17,15 @@ create_options, default_dynamic_filter_function, ) -from semantic_kernel.data.vectors import VectorSearch, VectorSearchResult, VectorStore, VectorStoreRecordCollection +from semantic_kernel.data._vectors import ( + DISTANCE_FUNCTION_DIRECTION_HELPER, + DistanceFunction, + IndexKind, + VectorSearch, + VectorSearchResult, + VectorStore, + VectorStoreRecordCollection, +) __all__ = [ "DEFAULT_DESCRIPTION", diff --git a/python/semantic_kernel/data/definitions.py b/python/semantic_kernel/data/_definitions.py similarity index 98% rename from python/semantic_kernel/data/definitions.py rename to python/semantic_kernel/data/_definitions.py index 4d12401684a7..c8c03043f094 100644 --- a/python/semantic_kernel/data/definitions.py +++ b/python/semantic_kernel/data/_definitions.py @@ -11,7 +11,7 @@ from pydantic import Field, ValidationError from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data.const import DistanceFunction, IndexKind +from semantic_kernel.data._vectors import DistanceFunction, IndexKind from semantic_kernel.exceptions import VectorStoreModelException from semantic_kernel.kernel_pydantic import KernelBaseModel from semantic_kernel.utils.feature_stage_decorator import release_candidate @@ -164,12 +164,13 @@ def __init__( self.field_type = field_type if isinstance(field_type, FieldTypes) else FieldTypes(field_type) # when a field is created, the name can be empty, # when a field get's added to a definition, the name needs to be there. - self.name = name + if name: + self.name = name self.storage_name = storage_name self.type_ = type self.is_indexed = is_indexed self.is_full_text_indexed = is_full_text_indexed - if field_type == "vector": + if field_type == FieldTypes.VECTOR: if dimensions is None: raise ValidationError("Vector fields must specify 'dimensions'") self.dimensions = dimensions @@ -193,7 +194,7 @@ class ToDictFunctionProtocol(Protocol): A list of dictionaries. """ - def __call__(self, record: Any, **kwargs: Any) -> Sequence[dict[str, Any]]: ... # pragma: no cover # noqa: D102 + def __call__(self, record: Any, **kwargs: Any) -> Sequence[dict[str, Any]]: ... # pragma: no cover @runtime_checkable @@ -208,7 +209,7 @@ class FromDictFunctionProtocol(Protocol): A record or list thereof. """ - def __call__(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Any: ... # noqa: D102 + def __call__(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Any: ... @runtime_checkable @@ -224,7 +225,7 @@ class SerializeFunctionProtocol(Protocol): """ - def __call__(self, record: Any, **kwargs: Any) -> Any: ... # noqa: D102 + def __call__(self, record: Any, **kwargs: Any) -> Any: ... @runtime_checkable @@ -240,7 +241,7 @@ class DeserializeFunctionProtocol(Protocol): """ - def __call__(self, records: Any, **kwargs: Any) -> Any: ... # noqa: D102 + def __call__(self, records: Any, **kwargs: Any) -> Any: ... @runtime_checkable diff --git a/python/semantic_kernel/data/search.py b/python/semantic_kernel/data/_search.py similarity index 98% rename from python/semantic_kernel/data/search.py rename to python/semantic_kernel/data/_search.py index d6864a774e61..58ad59444da9 100644 --- a/python/semantic_kernel/data/search.py +++ b/python/semantic_kernel/data/_search.py @@ -5,11 +5,10 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterable, Callable, Mapping, Sequence from copy import deepcopy -from typing import Annotated, Any, Generic, Literal, Protocol, TypeVar, overload +from typing import Annotated, Any, Final, Generic, Literal, Protocol, TypeVar, overload from pydantic import BaseModel, ConfigDict, Field, ValidationError -from semantic_kernel.data.const import DEFAULT_DESCRIPTION, DEFAULT_FUNCTION_NAME from semantic_kernel.exceptions import TextSearchException from semantic_kernel.functions.kernel_function import KernelFunction from semantic_kernel.functions.kernel_function_decorator import kernel_function @@ -19,9 +18,14 @@ from semantic_kernel.kernel_types import OptionalOneOrList from semantic_kernel.utils.feature_stage_decorator import release_candidate +logger = logging.getLogger(__name__) + TSearchOptions = TypeVar("TSearchOptions", bound="SearchOptions") -logger = logging.getLogger(__name__) +DEFAULT_FUNCTION_NAME: Final[str] = "search" +DEFAULT_DESCRIPTION: Final[str] = ( + "Perform a search for content related to the specified query and return string results" +) # region: Options diff --git a/python/semantic_kernel/data/vectors.py b/python/semantic_kernel/data/_vectors.py similarity index 90% rename from python/semantic_kernel/data/vectors.py rename to python/semantic_kernel/data/_vectors.py index 87c5fc84e2f5..eae2be7325ec 100644 --- a/python/semantic_kernel/data/vectors.py +++ b/python/semantic_kernel/data/_vectors.py @@ -2,6 +2,7 @@ import json import logging +import operator import sys from abc import abstractmethod from ast import AST, Lambda, NodeVisitor, expr, parse @@ -9,21 +10,21 @@ from copy import deepcopy from enum import Enum from inspect import getsource -from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload +from typing import Annotated, Any, ClassVar, Final, Generic, Literal, TypeVar, overload from pydantic import BaseModel, Field, ValidationError, model_validator from pydantic.dataclasses import dataclass from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings -from semantic_kernel.data.const import DEFAULT_DESCRIPTION, DEFAULT_FUNCTION_NAME -from semantic_kernel.data.definitions import ( +from semantic_kernel.data._definitions import ( FieldTypes, SerializeMethodProtocol, VectorStoreCollectionDefinition, VectorStoreField, ) -from semantic_kernel.data.search import ( +from semantic_kernel.data._search import ( + DEFAULT_FUNCTION_NAME, DynamicFilterFunction, KernelSearchResults, SearchOptions, @@ -66,6 +67,9 @@ TFilters = TypeVar("TFilters") # region: Helpers +DEFAULT_DESCRIPTION: Final[str] = ( + "Perform a vector search for data in a vector store, using the provided search options." +) def _get_collection_name_from_model( @@ -80,21 +84,19 @@ def _get_collection_name_from_model( return None -@dataclass -class OrderBy: - """Order by class.""" - - field: str - ascending: bool = Field(default=True) - - @dataclass class GetFilteredRecordOptions: - """Options for filtering records.""" + """Options for filtering records. + + Args: + top: The maximum number of records to return. + skip: The number of records to skip. + order_by: A dictionary with fields names and a bool, True means ascending, False means descending. + """ top: int = 10 skip: int = 0 - order_by: OptionalOneOrMany[OrderBy] = None + order_by: Mapping[str, bool] | None = None class LambdaVisitor(NodeVisitor, Generic[TFilters]): @@ -734,7 +736,7 @@ async def get( self, top: int = ..., skip: int = ..., - order_by: OptionalOneOrMany[OrderBy | dict[str, Any] | list[dict[str, Any]]] = None, + order_by: OneOrMany[str] | dict[str, bool] | None = None, include_vectors: bool = False, **kwargs: Any, ) -> Sequence[TModel] | None: @@ -749,10 +751,11 @@ async def get( Only used if keys are not provided. skip: The number of records to skip. Only used if keys are not provided. - order_by: The order by clause, this is a list of dicts with the field name and ascending flag, - (default is True, which means ascending). - Only used if keys are not provided. - example: {"field": "hotel_id", "ascending": True} + order_by: The order by clause, + this can be a string, a list of strings or a dict, + when passing strings, they are assumed to be ascending. + Otherwise, use the value in the dict to set ascending (True) or descending (False). + example: {"field_name": True} or ["field_name", {"field_name2": False}]. **kwargs: Additional arguments. Returns: @@ -858,8 +861,28 @@ async def get( keys = key if not keys: if kwargs: + kw_order_by: OneOrList[str] | dict[str, bool] | None = kwargs.pop("order_by", None) # type: ignore + top = kwargs.pop("top", None) + skip = kwargs.pop("skip", None) + order_by: dict[str, bool] | None = None + if kw_order_by is not None: + order_by = {} + if isinstance(kw_order_by, str): + order_by[kw_order_by] = True + elif isinstance(kw_order_by, dict): + order_by = kw_order_by + elif isinstance(kw_order_by, list): + for item in kw_order_by: + if isinstance(item, str): + order_by[item] = True + else: + order_by.update(item) + else: + raise VectorStoreOperationException( + f"Invalid order_by type: {type(order_by)}, expected str, dict or list." + ) try: - options = GetFilteredRecordOptions(**kwargs) + options = GetFilteredRecordOptions(top=top, skip=skip, order_by=order_by) except Exception as exc: raise VectorStoreOperationException(f"Error creating options: {exc}") from exc else: @@ -1513,3 +1536,109 @@ async def search_wrapper(**kwargs: Any) -> Sequence[str]: parameters=TextSearch._default_parameter_metadata() if parameters is None else parameters, return_parameter=return_parameter or TextSearch._default_return_parameter_metadata(), ) + + +class IndexKind(str, Enum): + """Index kinds for similarity search. + + HNSW + Hierarchical Navigable Small World which performs an approximate nearest neighbor (ANN) search. + Lower accuracy than exhaustive k nearest neighbor, but faster and more efficient. + + Flat + Does a brute force search to find the nearest neighbors. + Calculates the distances between all pairs of data points, so has a linear time complexity, + that grows directly proportional to the number of points. + Also referred to as exhaustive k nearest neighbor in some databases. + High recall accuracy, but slower and more expensive than HNSW. + Better with smaller datasets. + + IVF Flat + Inverted File with Flat Compression. + Designed to enhance search efficiency by narrowing the search area + through the use of neighbor partitions or clusters. + Also referred to as approximate nearest neighbor (ANN) search. + + Disk ANN + Disk-based Approximate Nearest Neighbor algorithm designed for efficiently searching + for approximate nearest neighbors (ANN) in high-dimensional spaces. + The primary focus of DiskANN is to handle large-scale datasets that cannot fit entirely + into memory, leveraging disk storage to store the data while maintaining fast search times. + + Quantized Flat + Index that compresses vectors using DiskANN-based quantization methods for better efficiency in the kNN search. + + Dynamic + Dynamic index allows to automatically switch from FLAT to HNSW indexes. + + Default + Default index type. + Used when no index type is specified. + Will differ per vector store. + + """ + + HNSW = "hnsw" + FLAT = "flat" + IVF_FLAT = "ivf_flat" + DISK_ANN = "disk_ann" + QUANTIZED_FLAT = "quantized_flat" + DYNAMIC = "dynamic" + DEFAULT = "default" + + +class DistanceFunction(str, Enum): + """Distance functions for similarity search. + + Cosine Similarity + the cosine (angular) similarity between two vectors + measures only the angle between the two vectors, without taking into account the length of the vectors + Cosine Similarity = 1 - Cosine Distance + -1 means vectors are opposite + 0 means vectors are orthogonal + 1 means vectors are identical + Cosine Distance + the cosine (angular) distance between two vectors + measures only the angle between the two vectors, without taking into account the length of the vectors + Cosine Distance = 1 - Cosine Similarity + 2 means vectors are opposite + 1 means vectors are orthogonal + 0 means vectors are identical + Dot Product + measures both the length and angle between two vectors + same as cosine similarity if the vectors are the same length, but more performant + Euclidean Distance + measures the Euclidean distance between two vectors + also known as l2-norm + Euclidean Squared Distance + measures the Euclidean squared distance between two vectors + also known as l2-squared + Manhattan + measures the Manhattan distance between two vectors + Hamming + number of differences between vectors at each dimensions + DEFAULT + default distance function + used when no distance function is specified + will differ per vector store. + """ + + COSINE_SIMILARITY = "cosine_similarity" + COSINE_DISTANCE = "cosine_distance" + DOT_PROD = "dot_prod" + EUCLIDEAN_DISTANCE = "euclidean_distance" + EUCLIDEAN_SQUARED_DISTANCE = "euclidean_squared_distance" + MANHATTAN = "manhattan" + HAMMING = "hamming" + DEFAULT = "DEFAULT" + + +DISTANCE_FUNCTION_DIRECTION_HELPER: Final[dict[DistanceFunction, Callable[[int | float, int | float], bool]]] = { + DistanceFunction.COSINE_SIMILARITY: operator.gt, + DistanceFunction.COSINE_DISTANCE: operator.le, + DistanceFunction.DOT_PROD: operator.gt, + DistanceFunction.EUCLIDEAN_DISTANCE: operator.le, + DistanceFunction.EUCLIDEAN_SQUARED_DISTANCE: operator.le, + DistanceFunction.MANHATTAN: operator.le, + DistanceFunction.HAMMING: operator.le, +} diff --git a/python/semantic_kernel/data/const.py b/python/semantic_kernel/data/const.py deleted file mode 100644 index 096314ab1e0c..000000000000 --- a/python/semantic_kernel/data/const.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import operator -from collections.abc import Callable -from enum import Enum -from typing import Final - - -class IndexKind(str, Enum): - """Index kinds for similarity search. - - HNSW - Hierarchical Navigable Small World which performs an approximate nearest neighbor (ANN) search. - Lower accuracy than exhaustive k nearest neighbor, but faster and more efficient. - - Flat - Does a brute force search to find the nearest neighbors. - Calculates the distances between all pairs of data points, so has a linear time complexity, - that grows directly proportional to the number of points. - Also referred to as exhaustive k nearest neighbor in some databases. - High recall accuracy, but slower and more expensive than HNSW. - Better with smaller datasets. - - IVF Flat - Inverted File with Flat Compression. - Designed to enhance search efficiency by narrowing the search area - through the use of neighbor partitions or clusters. - Also referred to as approximate nearest neighbor (ANN) search. - - Disk ANN - Disk-based Approximate Nearest Neighbor algorithm designed for efficiently searching - for approximate nearest neighbors (ANN) in high-dimensional spaces. - The primary focus of DiskANN is to handle large-scale datasets that cannot fit entirely - into memory, leveraging disk storage to store the data while maintaining fast search times. - - Quantized Flat - Index that compresses vectors using DiskANN-based quantization methods for better efficiency in the kNN search. - - Dynamic - Dynamic index allows to automatically switch from FLAT to HNSW indexes. - - Default - Default index type. - Used when no index type is specified. - Will differ per vector store. - - """ - - HNSW = "hnsw" - FLAT = "flat" - IVF_FLAT = "ivf_flat" - DISK_ANN = "disk_ann" - QUANTIZED_FLAT = "quantized_flat" - DYNAMIC = "dynamic" - DEFAULT = "default" - - -class DistanceFunction(str, Enum): - """Distance functions for similarity search. - - Cosine Similarity - the cosine (angular) similarity between two vectors - measures only the angle between the two vectors, without taking into account the length of the vectors - Cosine Similarity = 1 - Cosine Distance - -1 means vectors are opposite - 0 means vectors are orthogonal - 1 means vectors are identical - Cosine Distance - the cosine (angular) distance between two vectors - measures only the angle between the two vectors, without taking into account the length of the vectors - Cosine Distance = 1 - Cosine Similarity - 2 means vectors are opposite - 1 means vectors are orthogonal - 0 means vectors are identical - Dot Product - measures both the length and angle between two vectors - same as cosine similarity if the vectors are the same length, but more performant - Euclidean Distance - measures the Euclidean distance between two vectors - also known as l2-norm - Euclidean Squared Distance - measures the Euclidean squared distance between two vectors - also known as l2-squared - Manhattan - measures the Manhattan distance between two vectors - Hamming - number of differences between vectors at each dimensions - DEFAULT - default distance function - used when no distance function is specified - will differ per vector store. - """ - - COSINE_SIMILARITY = "cosine_similarity" - COSINE_DISTANCE = "cosine_distance" - DOT_PROD = "dot_prod" - EUCLIDEAN_DISTANCE = "euclidean_distance" - EUCLIDEAN_SQUARED_DISTANCE = "euclidean_squared_distance" - MANHATTAN = "manhattan" - HAMMING = "hamming" - DEFAULT = "DEFAULT" - - -DISTANCE_FUNCTION_DIRECTION_HELPER: Final[dict[DistanceFunction, Callable[[int | float, int | float], bool]]] = { - DistanceFunction.COSINE_SIMILARITY: operator.gt, - DistanceFunction.COSINE_DISTANCE: operator.le, - DistanceFunction.DOT_PROD: operator.gt, - DistanceFunction.EUCLIDEAN_DISTANCE: operator.le, - DistanceFunction.EUCLIDEAN_SQUARED_DISTANCE: operator.le, - DistanceFunction.MANHATTAN: operator.le, - DistanceFunction.HAMMING: operator.le, -} -DEFAULT_FUNCTION_NAME: Final[str] = "search" -DEFAULT_DESCRIPTION: Final[str] = ( - "Perform a search for content related to the specified query and return string results" -) - - -class TextSearchFunctions(str, Enum): - """Text search functions. - - Attributes: - SEARCH: Search using a query. - GET_TEXT_SEARCH_RESULT: Get text search results. - GET_SEARCH_RESULT: Get search results. - """ - - SEARCH = "search" - GET_TEXT_SEARCH_RESULT = "get_text_search_result" - GET_SEARCH_RESULT = "get_search_result" diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 30e20e578c2a..d7981f5f01b9 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -12,7 +12,7 @@ from pytest import fixture from semantic_kernel.agents import Agent, DeclarativeSpecMixin, register_agent_type -from semantic_kernel.data.definitions import VectorStoreCollectionDefinition, VectorStoreField, vectorstoremodel +from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField, vectorstoremodel if TYPE_CHECKING: from semantic_kernel import Kernel diff --git a/python/tests/integration/memory/azure_cosmos_db/conftest.py b/python/tests/integration/memory/azure_cosmos_db/conftest.py index 9276453fdcc1..c8d70bddac14 100644 --- a/python/tests/integration/memory/azure_cosmos_db/conftest.py +++ b/python/tests/integration/memory/azure_cosmos_db/conftest.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pytest import fixture -from semantic_kernel.data.definitions import VectorStoreField, vectorstoremodel +from semantic_kernel.data._definitions import VectorStoreField, vectorstoremodel @fixture diff --git a/python/tests/integration/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql.py b/python/tests/integration/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql.py index 7e63ed321d5c..dae2bf9a9569 100644 --- a/python/tests/integration/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql.py +++ b/python/tests/integration/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql.py @@ -10,7 +10,7 @@ from azure.cosmos.partition_key import PartitionKey from semantic_kernel.connectors.memory.azure_cosmos_db import AzureCosmosDBNoSQLCompositeKey, CosmosNoSqlStore -from semantic_kernel.data.vectors import VectorStore +from semantic_kernel.data._vectors import VectorStore from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorException from tests.integration.memory.vector_store_test_base import VectorStoreTestBase diff --git a/python/tests/integration/memory/postgres/test_postgres_int.py b/python/tests/integration/memory/postgres/test_postgres_int.py index 97ef8971bcab..4784b88fd3f3 100644 --- a/python/tests/integration/memory/postgres/test_postgres_int.py +++ b/python/tests/integration/memory/postgres/test_postgres_int.py @@ -12,9 +12,8 @@ from semantic_kernel.connectors.memory.postgres import PostgresCollection, PostgresSettings, PostgresStore from semantic_kernel.data import VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data.const import DistanceFunction, IndexKind -from semantic_kernel.data.definitions import vectorstoremodel -from semantic_kernel.data.vectors import VectorSearchOptions +from semantic_kernel.data._definitions import vectorstoremodel +from semantic_kernel.data._vectors import DistanceFunction, IndexKind, VectorSearchOptions from semantic_kernel.exceptions.memory_connector_exceptions import ( MemoryConnectorConnectionException, MemoryConnectorInitializationError, diff --git a/python/tests/unit/connectors/memory/test_faiss.py b/python/tests/unit/connectors/memory/test_faiss.py index 2db5919e6cd6..5b2f91c6673b 100644 --- a/python/tests/unit/connectors/memory/test_faiss.py +++ b/python/tests/unit/connectors/memory/test_faiss.py @@ -4,7 +4,8 @@ from pytest import fixture, mark, raises from semantic_kernel.connectors.memory.faiss import FaissCollection, FaissStore -from semantic_kernel.data import DistanceFunction, VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data import VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data._vectors import DistanceFunction from semantic_kernel.exceptions import VectorStoreInitializationException diff --git a/python/tests/unit/connectors/memory/test_in_memory.py b/python/tests/unit/connectors/memory/test_in_memory.py index e49ada897269..0817bbb24800 100644 --- a/python/tests/unit/connectors/memory/test_in_memory.py +++ b/python/tests/unit/connectors/memory/test_in_memory.py @@ -3,7 +3,7 @@ from pytest import fixture, mark, raises from semantic_kernel.connectors.memory.in_memory import InMemoryCollection, InMemoryStore -from semantic_kernel.data.const import DistanceFunction +from semantic_kernel.data._vectors import DistanceFunction from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreOperationException diff --git a/python/tests/unit/connectors/memory/test_postgres_store.py b/python/tests/unit/connectors/memory/test_postgres_store.py index d7e129c74b0f..f9d95c4524bd 100644 --- a/python/tests/unit/connectors/memory/test_postgres_store.py +++ b/python/tests/unit/connectors/memory/test_postgres_store.py @@ -17,8 +17,8 @@ PostgresSettings, PostgresStore, ) -from semantic_kernel.data.const import DistanceFunction, IndexKind -from semantic_kernel.data.definitions import VectorStoreField, vectorstoremodel +from semantic_kernel.data._definitions import VectorStoreField, vectorstoremodel +from semantic_kernel.data._vectors import DistanceFunction, IndexKind @fixture(scope="function") diff --git a/python/tests/unit/connectors/memory/test_qdrant.py b/python/tests/unit/connectors/memory/test_qdrant.py index db278c8756fb..8eac43b522d4 100644 --- a/python/tests/unit/connectors/memory/test_qdrant.py +++ b/python/tests/unit/connectors/memory/test_qdrant.py @@ -7,8 +7,8 @@ from qdrant_client.models import Datatype, Distance, FieldCondition, MatchValue, VectorParams from semantic_kernel.connectors.memory.qdrant import QdrantCollection, QdrantStore -from semantic_kernel.data.const import DistanceFunction -from semantic_kernel.data.definitions import VectorStoreField +from semantic_kernel.data._definitions import VectorStoreField +from semantic_kernel.data._vectors import DistanceFunction from semantic_kernel.exceptions import ( VectorSearchExecutionException, VectorStoreInitializationException, diff --git a/python/tests/unit/connectors/memory/test_sql_server.py b/python/tests/unit/connectors/memory/test_sql_server.py index b287064c4e54..18508b4dacb4 100644 --- a/python/tests/unit/connectors/memory/test_sql_server.py +++ b/python/tests/unit/connectors/memory/test_sql_server.py @@ -21,9 +21,8 @@ _build_select_query, _build_select_table_names_query, ) -from semantic_kernel.data.const import DistanceFunction, IndexKind -from semantic_kernel.data.definitions import VectorStoreField -from semantic_kernel.data.vectors import VectorSearchOptions +from semantic_kernel.data._definitions import VectorStoreField +from semantic_kernel.data._vectors import DistanceFunction, IndexKind, VectorSearchOptions from semantic_kernel.exceptions.vector_store_exceptions import ( VectorStoreInitializationException, VectorStoreOperationException, diff --git a/python/tests/unit/connectors/search/test_brave_search.py b/python/tests/unit/connectors/search/test_brave_search.py index 5f53d7dc74ce..542616e18ad1 100644 --- a/python/tests/unit/connectors/search/test_brave_search.py +++ b/python/tests/unit/connectors/search/test_brave_search.py @@ -6,7 +6,7 @@ import pytest from semantic_kernel.connectors.search.brave import BraveSearch, BraveSearchResponse, BraveWebPage, BraveWebPages -from semantic_kernel.data.search import KernelSearchResults, SearchOptions, TextSearchResult +from semantic_kernel.data._search import KernelSearchResults, SearchOptions, TextSearchResult from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidRequestError diff --git a/python/tests/unit/connectors/search/test_google_search.py b/python/tests/unit/connectors/search/test_google_search.py index a1de242f12f8..2a92bd70e1ab 100644 --- a/python/tests/unit/connectors/search/test_google_search.py +++ b/python/tests/unit/connectors/search/test_google_search.py @@ -11,7 +11,7 @@ GoogleSearchResponse, GoogleSearchResult, ) -from semantic_kernel.data.search import TextSearchResult +from semantic_kernel.data._search import TextSearchResult from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidRequestError diff --git a/python/tests/unit/data/conftest.py b/python/tests/unit/data/conftest.py index 3533483f4610..654fb93838b4 100644 --- a/python/tests/unit/data/conftest.py +++ b/python/tests/unit/data/conftest.py @@ -16,8 +16,8 @@ VectorStoreRecordCollection, vectorstoremodel, ) -from semantic_kernel.data.definitions import VectorStoreField -from semantic_kernel.data.vectors import VectorSearch, VectorSearchResult +from semantic_kernel.data._definitions import VectorStoreField +from semantic_kernel.data._vectors import VectorSearch, VectorSearchResult from semantic_kernel.kernel_types import OptionalOneOrMany diff --git a/python/tests/unit/data/test_filter.py b/python/tests/unit/data/test_filter.py index 825390d304d6..fd93c8d92d1e 100644 --- a/python/tests/unit/data/test_filter.py +++ b/python/tests/unit/data/test_filter.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -from semantic_kernel.data.vectors import VectorSearchOptions +from semantic_kernel.data._vectors import VectorSearchOptions def test_lambda_filter(): diff --git a/python/tests/unit/data/test_text_search.py b/python/tests/unit/data/test_text_search.py index 9900a67e88de..ad7d6477f7fe 100644 --- a/python/tests/unit/data/test_text_search.py +++ b/python/tests/unit/data/test_text_search.py @@ -9,9 +9,15 @@ from semantic_kernel import Kernel from semantic_kernel.data import TextSearch -from semantic_kernel.data.const import DEFAULT_DESCRIPTION, DEFAULT_FUNCTION_NAME -from semantic_kernel.data.search import KernelSearchResults, SearchOptions, TextSearchResult, create_options -from semantic_kernel.data.vectors import VectorSearchOptions +from semantic_kernel.data._search import ( + DEFAULT_DESCRIPTION, + DEFAULT_FUNCTION_NAME, + KernelSearchResults, + SearchOptions, + TextSearchResult, + create_options, +) +from semantic_kernel.data._vectors import VectorSearchOptions from semantic_kernel.exceptions import TextSearchException from semantic_kernel.functions import KernelArguments, KernelParameterMetadata from semantic_kernel.utils.list_handler import desync_list diff --git a/python/tests/unit/data/test_vector_search_base.py b/python/tests/unit/data/test_vector_search_base.py index 66256bd77460..e43b2a1d17f3 100644 --- a/python/tests/unit/data/test_vector_search_base.py +++ b/python/tests/unit/data/test_vector_search_base.py @@ -3,7 +3,7 @@ import pytest -from semantic_kernel.data.vectors import VectorSearch, VectorSearchOptions +from semantic_kernel.data._vectors import VectorSearch, VectorSearchOptions async def test_search(vector_store_record_collection: VectorSearch): diff --git a/python/tests/unit/data/test_vector_store_model_decorator.py b/python/tests/unit/data/test_vector_store_model_decorator.py index 3e6cbdccf3e7..eafb71517307 100644 --- a/python/tests/unit/data/test_vector_store_model_decorator.py +++ b/python/tests/unit/data/test_vector_store_model_decorator.py @@ -10,7 +10,7 @@ from pytest import raises from semantic_kernel.data import VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data.definitions import vectorstoremodel +from semantic_kernel.data._definitions import vectorstoremodel from semantic_kernel.exceptions import VectorStoreModelException diff --git a/python/tests/unit/data/test_vector_store_record_collection.py b/python/tests/unit/data/test_vector_store_record_collection.py index 12bc3868faeb..fb85118491ea 100644 --- a/python/tests/unit/data/test_vector_store_record_collection.py +++ b/python/tests/unit/data/test_vector_store_record_collection.py @@ -6,7 +6,7 @@ from pandas import DataFrame from pytest import mark, raises -from semantic_kernel.data.definitions import SerializeMethodProtocol, ToDictMethodProtocol +from semantic_kernel.data._definitions import SerializeMethodProtocol, ToDictMethodProtocol from semantic_kernel.exceptions import ( VectorStoreModelDeserializationException, VectorStoreModelSerializationException, From b7ba36293a17cfb9b17ebe44629fa0dbf604ddfd Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 27 May 2025 10:28:12 +0200 Subject: [PATCH 2/2] improved structure --- .../concepts/caching/semantic_caching.py | 6 +- .../store_chat_history_in_cosmosdb.py | 4 +- .../data_model.py | 2 +- .../samples/concepts/memory/complex_memory.py | 16 +- python/samples/concepts/memory/data_models.py | 2 +- .../concepts/memory/memory_with_pandas.py | 2 +- .../samples/concepts/memory/simple_memory.py | 2 +- python/samples/concepts/memory/utils.py | 2 +- .../azure_chat_gpt_with_data_api.py | 2 +- .../rag/rag_with_vector_collection.py | 2 +- .../samples/concepts/rag/self_critique_rag.py | 2 +- .../search/bing_text_search_as_plugin.py | 137 - .../search/brave_text_search_as_plugin.py | 12 +- .../third_party/postgres-memory.ipynb | 5 +- .../connectors/memory/azure_ai_search.py | 9 +- .../connectors/memory/azure_cosmos_db.py | 9 +- .../connectors/memory/chroma.py | 8 +- .../connectors/memory/faiss.py | 5 +- .../connectors/memory/in_memory.py | 8 +- .../connectors/memory/mongodb.py | 9 +- .../connectors/memory/pinecone.py | 9 +- .../connectors/memory/postgres.py | 10 +- .../connectors/memory/qdrant.py | 8 +- .../connectors/memory/redis.py | 10 +- .../connectors/memory/sql_server.py | 9 +- .../connectors/memory/weaviate.py | 9 +- .../connectors/search/brave.py | 3 +- .../connectors/search/google.py | 3 +- python/semantic_kernel/data/__init__.py | 50 - python/semantic_kernel/data/_definitions.py | 542 --- python/semantic_kernel/data/_search.py | 418 +- python/semantic_kernel/data/text_search.py | 348 ++ .../data/{_vectors.py => vectors.py} | 3355 ++++++++++------- python/tests/conftest.py | 4 +- .../memory/azure_cosmos_db/conftest.py | 2 +- .../test_azure_cosmos_db_no_sql.py | 8 +- .../memory/postgres/test_postgres_int.py | 11 +- .../integration/memory/test_vector_store.py | 2 +- .../memory/vector_store_test_base.py | 2 +- ...test_azure_cosmos_db_mongodb_collection.py | 2 +- .../test_azure_cosmos_db_no_sql_collection.py | 8 +- .../connectors/memory/test_azure_ai_search.py | 12 +- .../unit/connectors/memory/test_faiss.py | 3 +- .../unit/connectors/memory/test_in_memory.py | 2 +- .../connectors/memory/test_postgres_store.py | 3 +- .../unit/connectors/memory/test_qdrant.py | 3 +- .../unit/connectors/memory/test_sql_server.py | 3 +- .../connectors/search/test_brave_search.py | 3 +- .../connectors/search/test_google_search.py | 2 +- python/tests/unit/data/conftest.py | 13 +- python/tests/unit/data/test_filter.py | 13 - python/tests/unit/data/test_text_search.py | 6 +- .../unit/data/test_vector_search_base.py | 5 +- .../data/test_vector_store_model_decorator.py | 3 +- .../test_vector_store_record_collection.py | 4 +- .../test_vector_store_record_definition.py | 2 +- 56 files changed, 2602 insertions(+), 2532 deletions(-) delete mode 100644 python/samples/concepts/search/bing_text_search_as_plugin.py delete mode 100644 python/semantic_kernel/data/_definitions.py create mode 100644 python/semantic_kernel/data/text_search.py rename python/semantic_kernel/data/{_vectors.py => vectors.py} (69%) delete mode 100644 python/tests/unit/data/test_filter.py diff --git a/python/samples/concepts/caching/semantic_caching.py b/python/samples/concepts/caching/semantic_caching.py index 2a175dd4ca8a..c2353cde3a7c 100644 --- a/python/samples/concepts/caching/semantic_caching.py +++ b/python/samples/concepts/caching/semantic_caching.py @@ -10,7 +10,7 @@ from semantic_kernel import Kernel from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAITextEmbedding from semantic_kernel.connectors.memory.in_memory import InMemoryStore -from semantic_kernel.data import VectorStore, VectorStoreField, VectorStoreRecordCollection, vectorstoremodel +from semantic_kernel.data.vectors import VectorStore, VectorStoreCollection, VectorStoreField, vectorstoremodel from semantic_kernel.filters import FilterTypes, FunctionInvocationContext, PromptRenderContext from semantic_kernel.functions import FunctionResult @@ -41,9 +41,7 @@ def __init__( if vector_store.embedding_generator is None: raise ValueError("The vector store must have an embedding generator.") self.vector_store = vector_store - self.collection: VectorStoreRecordCollection[str, CacheRecord] = vector_store.get_collection( - record_type=CacheRecord - ) + self.collection: VectorStoreCollection[str, CacheRecord] = vector_store.get_collection(record_type=CacheRecord) self.score_threshold = score_threshold async def on_prompt_render( diff --git a/python/samples/concepts/chat_history/store_chat_history_in_cosmosdb.py b/python/samples/concepts/chat_history/store_chat_history_in_cosmosdb.py index c20ea1841e28..3dd30e6c74ac 100644 --- a/python/samples/concepts/chat_history/store_chat_history_in_cosmosdb.py +++ b/python/samples/concepts/chat_history/store_chat_history_in_cosmosdb.py @@ -11,7 +11,7 @@ from semantic_kernel.contents import ChatHistory, ChatMessageContent from semantic_kernel.core_plugins.math_plugin import MathPlugin from semantic_kernel.core_plugins.time_plugin import TimePlugin -from semantic_kernel.data import VectorStore, VectorStoreField, VectorStoreRecordCollection, vectorstoremodel +from semantic_kernel.data.vectors import VectorStore, VectorStoreCollection, VectorStoreField, vectorstoremodel """ This sample demonstrates how to build a conversational chatbot @@ -49,7 +49,7 @@ class ChatHistoryInCosmosDB(ChatHistory): session_id: str user_id: str store: VectorStore - collection: VectorStoreRecordCollection[str, ChatHistoryModel] | None = None + collection: VectorStoreCollection[str, ChatHistoryModel] | None = None async def create_collection(self, collection_name: str) -> None: """Create a collection with the inbuild data model using the vector store. diff --git a/python/samples/concepts/memory/azure_ai_search_hotel_samples/data_model.py b/python/samples/concepts/memory/azure_ai_search_hotel_samples/data_model.py index 6c290b4b4257..578c3661fb1a 100644 --- a/python/samples/concepts/memory/azure_ai_search_hotel_samples/data_model.py +++ b/python/samples/concepts/memory/azure_ai_search_hotel_samples/data_model.py @@ -15,7 +15,7 @@ ) from pydantic import BaseModel, ConfigDict -from semantic_kernel.data import VectorStoreField, vectorstoremodel +from semantic_kernel.data.vectors import VectorStoreField, vectorstoremodel """ The data model used for this sample is based on the hotel data model from the Azure AI Search samples. diff --git a/python/samples/concepts/memory/complex_memory.py b/python/samples/concepts/memory/complex_memory.py index 43165aa4f707..12cdd49d6966 100644 --- a/python/samples/concepts/memory/complex_memory.py +++ b/python/samples/concepts/memory/complex_memory.py @@ -26,9 +26,13 @@ SqlServerCollection, WeaviateCollection, ) -from semantic_kernel.data import VectorStoreRecordCollection, vectorstoremodel -from semantic_kernel.data._definitions import VectorStoreField -from semantic_kernel.data._vectors import SearchType, VectorSearch +from semantic_kernel.data.vectors import ( + SearchType, + VectorSearchProtocol, + VectorStoreCollection, + VectorStoreField, + vectorstoremodel, +) # This is a rather complex sample, showing how to use the vector store # with a number of different collections. @@ -46,7 +50,7 @@ class DataModel: title: Annotated[str, VectorStoreField("data", is_full_text_indexed=True)] content: Annotated[str, VectorStoreField("data", is_full_text_indexed=True)] embedding: Annotated[ - str | None, + list[float] | str | None, VectorStoreField("vector", dimensions=1536, type="float"), ] = None id: Annotated[ @@ -94,7 +98,7 @@ def __post_init__(self, **kwargs): # function which returns the collection. # Using a function allows for lazy initialization of the collection, # so that settings for unused collections do not cause validation errors. -collections: dict[str, Callable[[], VectorStoreRecordCollection]] = { +collections: dict[str, Callable[[], VectorStoreCollection]] = { "ai_search": lambda: AzureAISearchCollection[str, DataModel](record_type=DataModel), "postgres": lambda: PostgresCollection[str, DataModel](record_type=DataModel), "redis_json": lambda: RedisJsonCollection[str, DataModel]( @@ -143,6 +147,7 @@ async def main(collection: str, use_azure_openai: bool): ) kernel.add_service(embedder) async with collections[collection]() as record_collection: + assert isinstance(record_collection, VectorSearchProtocol) # nosec record_collection.embedding_generator = embedder print_with_color(f"Creating {collection} collection!", Colors.CGREY) # cleanup any existing collection @@ -187,7 +192,6 @@ async def main(collection: str, use_azure_openai: bool): print_with_color("Now we can start searching.", Colors.CBLUE) print_with_color(" For each type of search, enter a search term, for instance `python`.", Colors.CBLUE) print_with_color(" Enter exit to exit, and skip or nothing to skip this search.", Colors.CBLUE) - assert isinstance(record_collection, VectorSearch) # nosec print("-" * 30) print_with_color( "This collection supports the following search types: " diff --git a/python/samples/concepts/memory/data_models.py b/python/samples/concepts/memory/data_models.py index 4aaf8dfc74c4..e2f539a8bb23 100644 --- a/python/samples/concepts/memory/data_models.py +++ b/python/samples/concepts/memory/data_models.py @@ -7,7 +7,7 @@ from pandas import DataFrame from pydantic import BaseModel, Field -from semantic_kernel.data import VectorStoreCollectionDefinition, VectorStoreField, vectorstoremodel +from semantic_kernel.data.vectors import VectorStoreCollectionDefinition, VectorStoreField, vectorstoremodel # This concept shows the different ways you can create a vector store data model # using dataclasses, Pydantic, and Python classes. diff --git a/python/samples/concepts/memory/memory_with_pandas.py b/python/samples/concepts/memory/memory_with_pandas.py index 956b643cd23c..82a42ae46aa9 100644 --- a/python/samples/concepts/memory/memory_with_pandas.py +++ b/python/samples/concepts/memory/memory_with_pandas.py @@ -7,7 +7,7 @@ from semantic_kernel.connectors.ai.open_ai import OpenAITextEmbedding from semantic_kernel.connectors.memory.azure_ai_search import AzureAISearchCollection -from semantic_kernel.data import VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data.vectors import VectorStoreCollectionDefinition, VectorStoreField definition = VectorStoreCollectionDefinition( collection_name="pandas_test_index", diff --git a/python/samples/concepts/memory/simple_memory.py b/python/samples/concepts/memory/simple_memory.py index 85ee6d7cb2a6..117e411761db 100644 --- a/python/samples/concepts/memory/simple_memory.py +++ b/python/samples/concepts/memory/simple_memory.py @@ -10,7 +10,7 @@ from samples.concepts.resources.utils import Colors, print_with_color from semantic_kernel.connectors.ai.open_ai import OpenAITextEmbedding from semantic_kernel.connectors.memory import InMemoryCollection -from semantic_kernel.data import VectorStoreField, vectorstoremodel +from semantic_kernel.data.vectors import VectorStoreField, vectorstoremodel # This is the most basic example of a vector store and collection # For a more complex example, using different collection types, see "complex_memory.py" diff --git a/python/samples/concepts/memory/utils.py b/python/samples/concepts/memory/utils.py index ba2aa3c187df..07e52ae75f14 100644 --- a/python/samples/concepts/memory/utils.py +++ b/python/samples/concepts/memory/utils.py @@ -3,7 +3,7 @@ from typing import TypeVar from samples.concepts.resources.utils import Colors, print_with_color -from semantic_kernel.data._vectors import VectorSearchResult +from semantic_kernel.data.vectors import VectorSearchResult _T = TypeVar("_T") diff --git a/python/samples/concepts/on_your_data/azure_chat_gpt_with_data_api.py b/python/samples/concepts/on_your_data/azure_chat_gpt_with_data_api.py index f2a382511988..3f3e9ce58863 100644 --- a/python/samples/concepts/on_your_data/azure_chat_gpt_with_data_api.py +++ b/python/samples/concepts/on_your_data/azure_chat_gpt_with_data_api.py @@ -10,7 +10,7 @@ AzureChatPromptExecutionSettings, ExtraBody, ) -from semantic_kernel.connectors.memory.azure_cognitive_search.azure_ai_search_settings import AzureAISearchSettings +from semantic_kernel.connectors.memory.azure_ai_search import AzureAISearchSettings from semantic_kernel.contents import ChatHistory from semantic_kernel.functions import KernelArguments from semantic_kernel.prompt_template import InputVariable, PromptTemplateConfig diff --git a/python/samples/concepts/rag/rag_with_vector_collection.py b/python/samples/concepts/rag/rag_with_vector_collection.py index 9c95d67789bf..297a20cd6be6 100644 --- a/python/samples/concepts/rag/rag_with_vector_collection.py +++ b/python/samples/concepts/rag/rag_with_vector_collection.py @@ -11,7 +11,7 @@ OpenAITextEmbedding, ) from semantic_kernel.connectors.memory import InMemoryCollection -from semantic_kernel.data import VectorStoreField, vectorstoremodel +from semantic_kernel.data.vectors import VectorStoreField, vectorstoremodel from semantic_kernel.functions import KernelArguments """ diff --git a/python/samples/concepts/rag/self_critique_rag.py b/python/samples/concepts/rag/self_critique_rag.py index 7e131ab79747..32b90bd1adc8 100644 --- a/python/samples/concepts/rag/self_critique_rag.py +++ b/python/samples/concepts/rag/self_critique_rag.py @@ -9,7 +9,7 @@ from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAITextEmbedding from semantic_kernel.connectors.memory import AzureAISearchCollection from semantic_kernel.contents import ChatHistory -from semantic_kernel.data import VectorStoreField, vectorstoremodel +from semantic_kernel.data.vectors import VectorStoreField, vectorstoremodel from semantic_kernel.functions.kernel_function import KernelFunction """ diff --git a/python/samples/concepts/search/bing_text_search_as_plugin.py b/python/samples/concepts/search/bing_text_search_as_plugin.py deleted file mode 100644 index 53968f10ec21..000000000000 --- a/python/samples/concepts/search/bing_text_search_as_plugin.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from collections.abc import Awaitable, Callable - -from semantic_kernel import Kernel -from semantic_kernel.connectors.ai import FunctionChoiceBehavior -from semantic_kernel.connectors.ai.open_ai import ( - OpenAIChatCompletion, - OpenAIChatPromptExecutionSettings, -) -from semantic_kernel.connectors.search.bing import BingSearch -from semantic_kernel.contents import ChatHistory -from semantic_kernel.filters import FilterTypes, FunctionInvocationContext -from semantic_kernel.functions import KernelArguments, KernelParameterMetadata, KernelPlugin - -kernel = Kernel() -kernel.add_service(OpenAIChatCompletion(service_id="chat")) -kernel.add_plugin( - KernelPlugin.from_text_search_with_search( - BingSearch(), - plugin_name="bing", - description="Get details about Semantic Kernel concepts.", - parameters=[ - KernelParameterMetadata( - name="query", - description="The search query.", - type="str", - is_required=True, - type_object=str, - ), - KernelParameterMetadata( - name="top", - description="The number of results to return.", - type="int", - is_required=False, - default_value=2, - type_object=int, - ), - KernelParameterMetadata( - name="skip", - description="The number of results to skip.", - type="int", - is_required=False, - default_value=0, - type_object=int, - ), - KernelParameterMetadata( - name="site", - description="The site to search.", - default_value="https://github.com/microsoft/semantic-kernel/tree/main/python", - type="str", - is_required=False, - type_object=str, - ), - ], - ) -) -chat_function = kernel.add_function( - prompt="{{$chat_history}}{{$user_input}}", - plugin_name="ChatBot", - function_name="Chat", -) -execution_settings = OpenAIChatPromptExecutionSettings( - service_id="chat", - max_tokens=2000, - temperature=0.7, - top_p=0.8, - function_choice_behavior=FunctionChoiceBehavior.Auto(auto_invoke=True), -) - -history = ChatHistory() -system_message = """ -You are a chat bot, specialized in Semantic Kernel, Microsoft LLM orchestration SDK. -Assume questions are related to that, and use the Bing search plugin to find answers. -""" -history.add_system_message(system_message) -history.add_user_message("Hi there, who are you?") -history.add_assistant_message("I am Mosscap, a chat bot. I'm trying to figure out what people need.") - -arguments = KernelArguments(settings=execution_settings) - - -@kernel.filter(filter_type=FilterTypes.FUNCTION_INVOCATION) -async def log_bing_filter( - context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] -): - if context.function.plugin_name == "bing": - print("Calling Bing search with arguments:") - if "query" in context.arguments: - print(f' Query: "{context.arguments["query"]}"') - if "count" in context.arguments: - print(f' Count: "{context.arguments["count"]}"') - if "skip" in context.arguments: - print(f' Skip: "{context.arguments["skip"]}"') - await next(context) - print("Bing search completed.") - else: - await next(context) - - -async def chat() -> bool: - try: - user_input = input("User:> ") - except KeyboardInterrupt: - print("\n\nExiting chat...") - return False - except EOFError: - print("\n\nExiting chat...") - return False - - if user_input == "exit": - print("\n\nExiting chat...") - return False - arguments["user_input"] = user_input - arguments["chat_history"] = history - result = await kernel.invoke(chat_function, arguments=arguments) - print(f"Mosscap:> {result}") - history.add_user_message(user_input) - history.add_assistant_message(str(result)) - return True - - -async def main(): - chatting = True - print( - "Welcome to the chat bot!\ - \n Type 'exit' to exit.\ - \n Try to find out more about the inner workings of Semantic Kernel." - ) - while chatting: - chatting = await chat() - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) diff --git a/python/samples/concepts/search/brave_text_search_as_plugin.py b/python/samples/concepts/search/brave_text_search_as_plugin.py index 326ee853283e..ca18dd924030 100644 --- a/python/samples/concepts/search/brave_text_search_as_plugin.py +++ b/python/samples/concepts/search/brave_text_search_as_plugin.py @@ -8,7 +8,7 @@ from semantic_kernel.connectors.search.brave import BraveSearch from semantic_kernel.contents import ChatHistory from semantic_kernel.filters import FilterTypes, FunctionInvocationContext -from semantic_kernel.functions import KernelArguments, KernelParameterMetadata, KernelPlugin +from semantic_kernel.functions import KernelArguments, KernelParameterMetadata """ This project demonstrates how to integrate the Brave Search API as a plugin into the Semantic Kernel @@ -21,10 +21,10 @@ kernel = Kernel() kernel.add_service(OpenAIChatCompletion(service_id="chat")) -kernel.add_plugin( - KernelPlugin.from_text_search_with_search( - BraveSearch(), - plugin_name="brave", +kernel.add_function( + plugin_name="brave", + function=BraveSearch().create_search_function( + function_name="brave_search", description="Get details about Semantic Kernel concepts.", parameters=[ KernelParameterMetadata( @@ -51,7 +51,7 @@ type_object=int, ), ], - ) + ), ) chat_function = kernel.add_function( prompt="{{$chat_history}}{{$user_input}}", diff --git a/python/samples/getting_started/third_party/postgres-memory.ipynb b/python/samples/getting_started/third_party/postgres-memory.ipynb index 0c1b77da84a6..a45f807242f5 100644 --- a/python/samples/getting_started/third_party/postgres-memory.ipynb +++ b/python/samples/getting_started/third_party/postgres-memory.ipynb @@ -37,11 +37,12 @@ ")\n", "from semantic_kernel.connectors.memory.postgres import PostgresCollection\n", "from semantic_kernel.contents import ChatHistory\n", - "from semantic_kernel.data import (\n", + "from semantic_kernel.data.vectors import (\n", + " DistanceFunction,\n", + " IndexKind,\n", " VectorStoreField,\n", " vectorstoremodel,\n", ")\n", - "from semantic_kernel.data._vectors import DistanceFunction, IndexKind\n", "from semantic_kernel.functions import KernelParameterMetadata\n", "from semantic_kernel.functions.kernel_arguments import KernelArguments" ] diff --git a/python/semantic_kernel/connectors/memory/azure_ai_search.py b/python/semantic_kernel/connectors/memory/azure_ai_search.py index 702778051ca2..a21c6d3d759d 100644 --- a/python/semantic_kernel/connectors/memory/azure_ai_search.py +++ b/python/semantic_kernel/connectors/memory/azure_ai_search.py @@ -29,10 +29,10 @@ from pydantic import SecretStr, ValidationError from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data._definitions import FieldTypes, VectorStoreCollectionDefinition from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DistanceFunction, + FieldTypes, GetFilteredRecordOptions, IndexKind, SearchType, @@ -41,7 +41,8 @@ VectorSearchOptions, VectorSearchResult, VectorStore, - VectorStoreRecordCollection, + VectorStoreCollection, + VectorStoreCollectionDefinition, _get_collection_name_from_model, ) from semantic_kernel.exceptions import ( @@ -277,7 +278,7 @@ def _definition_to_azure_ai_search_index( @release_candidate class AzureAISearchCollection( - VectorStoreRecordCollection[TKey, TModel], + VectorStoreCollection[TKey, TModel], VectorSearch[TKey, TModel], Generic[TKey, TModel], ): diff --git a/python/semantic_kernel/connectors/memory/azure_cosmos_db.py b/python/semantic_kernel/connectors/memory/azure_cosmos_db.py index f90ff5989170..15168819e6be 100644 --- a/python/semantic_kernel/connectors/memory/azure_cosmos_db.py +++ b/python/semantic_kernel/connectors/memory/azure_cosmos_db.py @@ -23,10 +23,10 @@ MongoDBAtlasCollection, MongoDBAtlasStore, ) -from semantic_kernel.data._definitions import FieldTypes, VectorStoreCollectionDefinition from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DistanceFunction, + FieldTypes, GetFilteredRecordOptions, IndexKind, SearchType, @@ -35,7 +35,8 @@ VectorSearchOptions, VectorSearchResult, VectorStore, - VectorStoreRecordCollection, + VectorStoreCollection, + VectorStoreCollectionDefinition, _get_collection_name_from_model, ) from semantic_kernel.exceptions import ( @@ -655,7 +656,7 @@ async def _get_container_proxy(self, container_name: str, **kwargs) -> Container @release_candidate class CosmosNoSqlCollection( CosmosNoSqlBase, - VectorStoreRecordCollection[TNoSQLKey, TModel], + VectorStoreCollection[TNoSQLKey, TModel], VectorSearch[TNoSQLKey, TModel], Generic[TNoSQLKey, TModel], ): diff --git a/python/semantic_kernel/connectors/memory/chroma.py b/python/semantic_kernel/connectors/memory/chroma.py index 83e664f1e2cc..3c1d451edf6b 100644 --- a/python/semantic_kernel/connectors/memory/chroma.py +++ b/python/semantic_kernel/connectors/memory/chroma.py @@ -13,9 +13,8 @@ from chromadb.config import Settings from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data._definitions import VectorStoreCollectionDefinition from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DistanceFunction, GetFilteredRecordOptions, IndexKind, @@ -25,7 +24,8 @@ VectorSearchOptions, VectorSearchResult, VectorStore, - VectorStoreRecordCollection, + VectorStoreCollection, + VectorStoreCollectionDefinition, _get_collection_name_from_model, ) from semantic_kernel.exceptions.vector_store_exceptions import ( @@ -61,7 +61,7 @@ @release_candidate class ChromaCollection( - VectorStoreRecordCollection[TKey, TModel], + VectorStoreCollection[TKey, TModel], VectorSearch[TKey, TModel], Generic[TKey, TModel], ): diff --git a/python/semantic_kernel/connectors/memory/faiss.py b/python/semantic_kernel/connectors/memory/faiss.py index c665e3f3a5d3..ecc1a106c941 100644 --- a/python/semantic_kernel/connectors/memory/faiss.py +++ b/python/semantic_kernel/connectors/memory/faiss.py @@ -10,15 +10,16 @@ from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase from semantic_kernel.connectors.memory.in_memory import IN_MEMORY_SCORE_KEY, InMemoryCollection, InMemoryStore, TKey -from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DistanceFunction, IndexKind, SearchType, TModel, VectorSearchOptions, VectorSearchResult, + VectorStoreCollectionDefinition, + VectorStoreField, ) from semantic_kernel.exceptions import VectorStoreInitializationException, VectorStoreOperationException from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreModelException diff --git a/python/semantic_kernel/connectors/memory/in_memory.py b/python/semantic_kernel/connectors/memory/in_memory.py index d5ce2ca41458..909599c71eb5 100644 --- a/python/semantic_kernel/connectors/memory/in_memory.py +++ b/python/semantic_kernel/connectors/memory/in_memory.py @@ -11,9 +11,8 @@ from typing_extensions import override from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data._definitions import VectorStoreCollectionDefinition from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DISTANCE_FUNCTION_DIRECTION_HELPER, DistanceFunction, GetFilteredRecordOptions, @@ -23,7 +22,8 @@ VectorSearchOptions, VectorSearchResult, VectorStore, - VectorStoreRecordCollection, + VectorStoreCollection, + VectorStoreCollectionDefinition, ) from semantic_kernel.exceptions import VectorSearchExecutionException, VectorStoreModelValidationError from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreModelException, VectorStoreOperationException @@ -82,7 +82,7 @@ def __delattr__(self, name) -> None: class InMemoryCollection( - VectorStoreRecordCollection[TKey, TModel], + VectorStoreCollection[TKey, TModel], VectorSearch[TKey, TModel], Generic[TKey, TModel], ): diff --git a/python/semantic_kernel/connectors/memory/mongodb.py b/python/semantic_kernel/connectors/memory/mongodb.py index 1d2ed3a16e34..b4318b5f81ad 100644 --- a/python/semantic_kernel/connectors/memory/mongodb.py +++ b/python/semantic_kernel/connectors/memory/mongodb.py @@ -15,9 +15,8 @@ from pymongo.operations import SearchIndexModel from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DistanceFunction, GetFilteredRecordOptions, SearchType, @@ -26,7 +25,9 @@ VectorSearchOptions, VectorSearchResult, VectorStore, - VectorStoreRecordCollection, + VectorStoreCollection, + VectorStoreCollectionDefinition, + VectorStoreField, _get_collection_name_from_model, ) from semantic_kernel.exceptions import ( @@ -150,7 +151,7 @@ def _create_index_definitions( @release_candidate class MongoDBAtlasCollection( - VectorStoreRecordCollection[TKey, TModel], + VectorStoreCollection[TKey, TModel], VectorSearch[TKey, TModel], Generic[TKey, TModel], ): diff --git a/python/semantic_kernel/connectors/memory/pinecone.py b/python/semantic_kernel/connectors/memory/pinecone.py index edeea0922fb8..c813dd33e741 100644 --- a/python/semantic_kernel/connectors/memory/pinecone.py +++ b/python/semantic_kernel/connectors/memory/pinecone.py @@ -13,9 +13,8 @@ from pydantic import SecretStr, ValidationError from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DistanceFunction, GetFilteredRecordOptions, SearchType, @@ -24,7 +23,9 @@ VectorSearchOptions, VectorSearchResult, VectorStore, - VectorStoreRecordCollection, + VectorStoreCollection, + VectorStoreCollectionDefinition, + VectorStoreField, _get_collection_name_from_model, ) from semantic_kernel.exceptions.vector_store_exceptions import ( @@ -73,7 +74,7 @@ class PineconeSettings(KernelBaseSettings): @release_candidate class PineconeCollection( - VectorStoreRecordCollection[TKey, TModel], + VectorStoreCollection[TKey, TModel], VectorSearch[TKey, TModel], Generic[TKey, TModel], ): diff --git a/python/semantic_kernel/connectors/memory/postgres.py b/python/semantic_kernel/connectors/memory/postgres.py index 9156c38bf8da..f4ac2584b1dc 100644 --- a/python/semantic_kernel/connectors/memory/postgres.py +++ b/python/semantic_kernel/connectors/memory/postgres.py @@ -17,10 +17,10 @@ from pydantic_settings import SettingsConfigDict from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data._definitions import FieldTypes, VectorStoreCollectionDefinition, VectorStoreField from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DistanceFunction, + FieldTypes, GetFilteredRecordOptions, IndexKind, SearchType, @@ -29,7 +29,9 @@ VectorSearchOptions, VectorSearchResult, VectorStore, - VectorStoreRecordCollection, + VectorStoreCollection, + VectorStoreCollectionDefinition, + VectorStoreField, ) from semantic_kernel.exceptions import VectorStoreModelValidationError, VectorStoreOperationException from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorConnectionException @@ -302,7 +304,7 @@ async def create_connection_pool( @release_candidate class PostgresCollection( - VectorStoreRecordCollection[TKey, TModel], + VectorStoreCollection[TKey, TModel], VectorSearch[TKey, TModel], Generic[TKey, TModel], ): diff --git a/python/semantic_kernel/connectors/memory/qdrant.py b/python/semantic_kernel/connectors/memory/qdrant.py index cfe81f59d7fd..124401ebd5cb 100644 --- a/python/semantic_kernel/connectors/memory/qdrant.py +++ b/python/semantic_kernel/connectors/memory/qdrant.py @@ -28,9 +28,8 @@ from typing_extensions import override from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data._definitions import VectorStoreCollectionDefinition from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DistanceFunction, GetFilteredRecordOptions, IndexKind, @@ -40,7 +39,8 @@ VectorSearchOptions, VectorSearchResult, VectorStore, - VectorStoreRecordCollection, + VectorStoreCollection, + VectorStoreCollectionDefinition, ) from semantic_kernel.exceptions import ( VectorSearchExecutionException, @@ -121,7 +121,7 @@ def model_dump(self, **kwargs): @release_candidate class QdrantCollection( - VectorStoreRecordCollection[TKey, TModel], + VectorStoreCollection[TKey, TModel], VectorSearch[TKey, TModel], Generic[TKey, TModel], ): diff --git a/python/semantic_kernel/connectors/memory/redis.py b/python/semantic_kernel/connectors/memory/redis.py index 9c0de7c7fa35..ae9394275bfe 100644 --- a/python/semantic_kernel/connectors/memory/redis.py +++ b/python/semantic_kernel/connectors/memory/redis.py @@ -24,10 +24,10 @@ from redisvl.schema import StorageType from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data._definitions import FieldTypes, VectorStoreCollectionDefinition, VectorStoreField from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DistanceFunction, + FieldTypes, GetFilteredRecordOptions, IndexKind, SearchType, @@ -36,7 +36,9 @@ VectorSearchOptions, VectorSearchResult, VectorStore, - VectorStoreRecordCollection, + VectorStoreCollection, + VectorStoreCollectionDefinition, + VectorStoreField, ) from semantic_kernel.exceptions import ( VectorSearchExecutionException, @@ -180,7 +182,7 @@ class RedisSettings(KernelBaseSettings): @release_candidate class RedisCollection( - VectorStoreRecordCollection[TKey, TModel], + VectorStoreCollection[TKey, TModel], VectorSearch[TKey, TModel], Generic[TKey, TModel], ): diff --git a/python/semantic_kernel/connectors/memory/sql_server.py b/python/semantic_kernel/connectors/memory/sql_server.py index 5b1ccf796934..b579d228eae2 100644 --- a/python/semantic_kernel/connectors/memory/sql_server.py +++ b/python/semantic_kernel/connectors/memory/sql_server.py @@ -17,9 +17,8 @@ from pydantic import SecretStr, ValidationError, field_validator from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DISTANCE_FUNCTION_DIRECTION_HELPER, DistanceFunction, GetFilteredRecordOptions, @@ -29,7 +28,9 @@ VectorSearchOptions, VectorSearchResult, VectorStore, - VectorStoreRecordCollection, + VectorStoreCollection, + VectorStoreCollectionDefinition, + VectorStoreField, ) from semantic_kernel.exceptions import ( VectorSearchExecutionException, @@ -270,7 +271,7 @@ async def _get_mssql_connection(settings: SqlSettings) -> "Connection": @release_candidate class SqlServerCollection( - VectorStoreRecordCollection[TKey, TModel], + VectorStoreCollection[TKey, TModel], VectorSearch[TKey, TModel], Generic[TKey, TModel], ): diff --git a/python/semantic_kernel/connectors/memory/weaviate.py b/python/semantic_kernel/connectors/memory/weaviate.py index 431c7a318d19..6fb843b69bbb 100644 --- a/python/semantic_kernel/connectors/memory/weaviate.py +++ b/python/semantic_kernel/connectors/memory/weaviate.py @@ -20,9 +20,8 @@ from weaviate.exceptions import WeaviateClosedClientError, WeaviateConnectionError from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField from semantic_kernel.data._search import KernelSearchResults -from semantic_kernel.data._vectors import ( +from semantic_kernel.data.vectors import ( DistanceFunction, GetFilteredRecordOptions, IndexKind, @@ -32,7 +31,9 @@ VectorSearchOptions, VectorSearchResult, VectorStore, - VectorStoreRecordCollection, + VectorStoreCollection, + VectorStoreCollectionDefinition, + VectorStoreField, ) from semantic_kernel.exceptions import ( ServiceInvalidExecutionSettingsError, @@ -195,7 +196,7 @@ def is_using_client_embedding(cls, data: dict[str, Any]) -> bool: @release_candidate class WeaviateCollection( - VectorStoreRecordCollection[TKey, TModel], + VectorStoreCollection[TKey, TModel], VectorSearch[TKey, TModel], Generic[TKey, TModel], ): diff --git a/python/semantic_kernel/connectors/search/brave.py b/python/semantic_kernel/connectors/search/brave.py index bf0858a369c4..f38c5631e100 100644 --- a/python/semantic_kernel/connectors/search/brave.py +++ b/python/semantic_kernel/connectors/search/brave.py @@ -11,7 +11,8 @@ from pydantic import Field, SecretStr, ValidationError from semantic_kernel.connectors.search.utils import SearchLambdaVisitor -from semantic_kernel.data._search import KernelSearchResults, SearchOptions, TextSearch, TextSearchResult, TSearchResult +from semantic_kernel.data._search import KernelSearchResults, SearchOptions +from semantic_kernel.data.text_search import TextSearch, TextSearchResult, TSearchResult from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidRequestError from semantic_kernel.kernel_pydantic import KernelBaseModel, KernelBaseSettings from semantic_kernel.kernel_types import OptionalOneOrList diff --git a/python/semantic_kernel/connectors/search/google.py b/python/semantic_kernel/connectors/search/google.py index 736972798aa0..3c5d84ff8e9b 100644 --- a/python/semantic_kernel/connectors/search/google.py +++ b/python/semantic_kernel/connectors/search/google.py @@ -12,7 +12,8 @@ from pydantic import Field, SecretStr, ValidationError from semantic_kernel.connectors.search.utils import SearchLambdaVisitor -from semantic_kernel.data._search import KernelSearchResults, SearchOptions, TextSearch, TextSearchResult, TSearchResult +from semantic_kernel.data._search import KernelSearchResults, SearchOptions +from semantic_kernel.data.text_search import TextSearch, TextSearchResult, TSearchResult from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidRequestError from semantic_kernel.kernel_pydantic import KernelBaseModel, KernelBaseSettings from semantic_kernel.kernel_types import OptionalOneOrList diff --git a/python/semantic_kernel/data/__init__.py b/python/semantic_kernel/data/__init__.py index 5797e85c30ff..e69de29bb2d1 100644 --- a/python/semantic_kernel/data/__init__.py +++ b/python/semantic_kernel/data/__init__.py @@ -1,50 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - - -from semantic_kernel.data._definitions import ( - FieldTypes, - VectorStoreCollectionDefinition, - VectorStoreField, - vectorstoremodel, -) -from semantic_kernel.data._search import ( - DEFAULT_DESCRIPTION, - DEFAULT_FUNCTION_NAME, - DynamicFilterFunction, - KernelSearchResults, - TextSearch, - TextSearchResult, - create_options, - default_dynamic_filter_function, -) -from semantic_kernel.data._vectors import ( - DISTANCE_FUNCTION_DIRECTION_HELPER, - DistanceFunction, - IndexKind, - VectorSearch, - VectorSearchResult, - VectorStore, - VectorStoreRecordCollection, -) - -__all__ = [ - "DEFAULT_DESCRIPTION", - "DEFAULT_FUNCTION_NAME", - "DISTANCE_FUNCTION_DIRECTION_HELPER", - "DistanceFunction", - "DynamicFilterFunction", - "FieldTypes", - "IndexKind", - "KernelSearchResults", - "TextSearch", - "TextSearchResult", - "VectorSearch", - "VectorSearchResult", - "VectorStore", - "VectorStoreCollectionDefinition", - "VectorStoreField", - "VectorStoreRecordCollection", - "create_options", - "default_dynamic_filter_function", - "vectorstoremodel", -] diff --git a/python/semantic_kernel/data/_definitions.py b/python/semantic_kernel/data/_definitions.py deleted file mode 100644 index c8c03043f094..000000000000 --- a/python/semantic_kernel/data/_definitions.py +++ /dev/null @@ -1,542 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import logging -from collections.abc import Sequence -from dataclasses import dataclass -from enum import Enum -from inspect import Parameter, _empty, signature -from types import MappingProxyType, NoneType -from typing import Annotated, Any, Literal, Protocol, TypeVar, overload, runtime_checkable - -from pydantic import Field, ValidationError - -from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase -from semantic_kernel.data._vectors import DistanceFunction, IndexKind -from semantic_kernel.exceptions import VectorStoreModelException -from semantic_kernel.kernel_pydantic import KernelBaseModel -from semantic_kernel.utils.feature_stage_decorator import release_candidate - -logger = logging.getLogger(__name__) - - -# region: Fields - - -@release_candidate -class FieldTypes(str, Enum): - """Enumeration for field types in vector store models.""" - - KEY = "key" - VECTOR = "vector" - DATA = "data" - - def __str__(self) -> str: - """Return the string representation of the enum.""" - return self.value - - -@release_candidate -@dataclass -class VectorStoreField: - """Vector store fields.""" - - field_type: Literal[FieldTypes.DATA, FieldTypes.KEY, FieldTypes.VECTOR] = FieldTypes.DATA - name: str = "" - storage_name: str | None = None - type_: str | None = None - # data specific fields (all optional) - is_indexed: bool | None = None - is_full_text_indexed: bool | None = None - # vector specific fields (dimensions is mandatory) - dimensions: int | None = None - embedding_generator: EmbeddingGeneratorBase | None = None - # defaults for these fields are not set here, because they are not relevant for data and key types - index_kind: IndexKind | None = None - distance_function: DistanceFunction | None = None - - @overload - def __init__( - self, - field_type: Literal[FieldTypes.KEY, "key"] = FieldTypes.KEY, # type: ignore[assignment] - *, - name: str | None = None, - type: str | None = None, - storage_name: str | None = None, - ): - """Key field of the record. - - When the key will be auto-generated by the store, make sure it has a default, usually None. - - Args: - field_type: always "key". - name: The name of the field. - storage_name: The name of the field in the store, uses the field name by default. - type: The type of the field. - """ - ... - - @overload - def __init__( - self, - field_type: Literal[FieldTypes.DATA, "data"] = FieldTypes.DATA, # type: ignore[assignment] - *, - name: str | None = None, - type: str | None = None, - storage_name: str | None = None, - is_indexed: bool | None = None, - is_full_text_indexed: bool | None = None, - ): - """Data field in the record. - - Args: - field_type: always "data". - name: The name of the field. - storage_name: The name of the field in the store, uses the field name by default. - type: The type of the field. - is_indexed: Whether the field is indexed. - is_full_text_indexed: Whether the field is full text indexed. - """ - ... - - @overload - def __init__( - self, - field_type: Literal[FieldTypes.VECTOR, "vector"] = FieldTypes.VECTOR, # type: ignore[assignment] - *, - name: str | None = None, - type: str | None = None, - dimensions: Annotated[int, Field(gt=0)], - storage_name: str | None = None, - index_kind: IndexKind | None = None, - distance_function: DistanceFunction | None = None, - embedding_generator: EmbeddingGeneratorBase | None = None, - ): - """Vector field in the record. - - This field should contain the value you want to use for the vector. - When passing in the embedding generator, the embedding will be - generated locally before upserting. - If this is not set, the store should support generating the embedding for you. - If you want to retrieve the original content of the vector, - make sure to set this field twice, - once with the VectorStoreRecordDataField and once with the VectorStoreRecordVectorField. - - If you want to be able to get the vectors back, make sure the type allows this, especially for pydantic models. - For instance, if the input is a string, then the type annotation should be `str | list[float] | None`. - - If you want to cast the vector that is returned, you need to set the deserialize_function, - for instance: `deserialize_function=np.array`, (with `import numpy as np` at the top of your file). - If you want to set it up with more specific options, use a lambda, a custom function or a partial. - - Args: - field_type: always "vector". - name: The name of the field. - storage_name: The name of the field in the store, uses the field name by default. - type: Property type. - For vectors this should be the inner type of the vector. - By default the vector will be a list of numbers. - If you want to use a numpy array or some other optimized format, - set the cast_function with a function - that takes a list of floats and returns a numpy array. - dimensions: The number of dimensions of the vector, mandatory. - index_kind: The index kind to use, uses a default index kind when None. - distance_function: The distance function to use, uses a default distance function when None. - embedding_generator: The embedding generator to use. - If this is set, the embedding will be generated locally before upserting. - """ - ... - - def __init__( - self, - field_type=FieldTypes.DATA, - *, - name=None, - type=None, - storage_name=None, - is_indexed=None, - is_full_text_indexed=None, - dimensions=None, - index_kind=None, - distance_function=None, - embedding_generator=None, - ): - """Vector store field.""" - self.field_type = field_type if isinstance(field_type, FieldTypes) else FieldTypes(field_type) - # when a field is created, the name can be empty, - # when a field get's added to a definition, the name needs to be there. - if name: - self.name = name - self.storage_name = storage_name - self.type_ = type - self.is_indexed = is_indexed - self.is_full_text_indexed = is_full_text_indexed - if field_type == FieldTypes.VECTOR: - if dimensions is None: - raise ValidationError("Vector fields must specify 'dimensions'") - self.dimensions = dimensions - self.index_kind = index_kind or IndexKind.DEFAULT - self.distance_function = distance_function or DistanceFunction.DEFAULT - self.embedding_generator = embedding_generator - - -# region: Protocols - - -@runtime_checkable -class ToDictFunctionProtocol(Protocol): - """Protocol for to_dict function. - - Args: - record: The record to be serialized. - **kwargs: Additional keyword arguments. - - Returns: - A list of dictionaries. - """ - - def __call__(self, record: Any, **kwargs: Any) -> Sequence[dict[str, Any]]: ... # pragma: no cover - - -@runtime_checkable -class FromDictFunctionProtocol(Protocol): - """Protocol for from_dict function. - - Args: - records: A list of dictionaries. - **kwargs: Additional keyword arguments. - - Returns: - A record or list thereof. - """ - - def __call__(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Any: ... - - -@runtime_checkable -class SerializeFunctionProtocol(Protocol): - """Protocol for serialize function. - - Args: - record: The record to be serialized. - **kwargs: Additional keyword arguments. - - Returns: - The serialized record, ready to be consumed by the specific store. - - """ - - def __call__(self, record: Any, **kwargs: Any) -> Any: ... - - -@runtime_checkable -class DeserializeFunctionProtocol(Protocol): - """Protocol for deserialize function. - - Args: - records: The serialized record directly from the store. - **kwargs: Additional keyword arguments. - - Returns: - The deserialized record in the format expected by the application. - - """ - - def __call__(self, records: Any, **kwargs: Any) -> Any: ... - - -@runtime_checkable -class SerializeMethodProtocol(Protocol): - """Data model serialization protocol. - - This can optionally be implemented to allow single step serialization and deserialization - for using your data model with a specific datastore. - """ - - def serialize(self, **kwargs: Any) -> Any: - """Serialize the object to the format required by the data store.""" - ... # pragma: no cover - - -@runtime_checkable -class ToDictMethodProtocol(Protocol): - """Class used internally to check if a model has a to_dict method.""" - - def to_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: - """Serialize the object to the format required by the data store.""" - ... # pragma: no cover - - -# region: VectorStoreRecordDefinition - - -@release_candidate -class VectorStoreCollectionDefinition(KernelBaseModel): - """Collection definition for vector stores. - - Args: - fields: The fields of the record. - container_mode: Whether the record is in container mode. - to_dict: The to_dict function, should take a record and return a list of dicts. - from_dict: The from_dict function, should take a list of dicts and return a record. - deserialize: The deserialize function, should take a type specific to a datastore and return a record. - - """ - - fields: list[VectorStoreField] - key_name: str = Field(default="", init=False) - container_mode: bool = False - collection_name: str | None = None - to_dict: ToDictFunctionProtocol | None = None - from_dict: FromDictFunctionProtocol | None = None - serialize: SerializeFunctionProtocol | None = None - deserialize: DeserializeFunctionProtocol | None = None - - @property - def names(self) -> list[str]: - """Get the names of the fields.""" - return [field.name for field in self.fields] - - @property - def storage_names(self) -> list[str]: - """Get the names of the fields for storage.""" - return [field.storage_name or field.name for field in self.fields] - - @property - def key_field(self) -> VectorStoreField: - """Get the key field.""" - return next((field for field in self.fields if field.name == self.key_name), None) # type: ignore - - @property - def key_field_storage_name(self) -> str: - """Get the key field storage name.""" - return self.key_field.storage_name or self.key_field.name - - @property - def vector_fields(self) -> list[VectorStoreField]: - """Get the names of the vector fields.""" - return [field for field in self.fields if field.field_type == FieldTypes.VECTOR] - - @property - def data_fields(self) -> list[VectorStoreField]: - """Get the names of the data fields.""" - return [field for field in self.fields if field.field_type == FieldTypes.DATA] - - @property - def vector_field_names(self) -> list[str]: - """Get the names of the vector fields.""" - return [field.name for field in self.fields if field.field_type == FieldTypes.VECTOR] - - @property - def data_field_names(self) -> list[str]: - """Get the names of all the data fields.""" - return [field.name for field in self.fields if field.field_type == FieldTypes.DATA] - - def try_get_vector_field(self, field_name: str | None = None) -> VectorStoreField | None: - """Try to get the vector field. - - If the field_name is None, then the first vector field is returned. - If no vector fields are present None is returned. - - Args: - field_name: The field name. - - Returns: - VectorStoreRecordVectorField | None: The vector field or None. - """ - if field_name is None: - if len(self.vector_fields) == 0: - return None - return self.vector_fields[0] - for field in self.fields: - if field.name == field_name or field.storage_name == field_name: - if field.field_type == FieldTypes.VECTOR: - return field - raise VectorStoreModelException( - f"Field {field_name} is not a vector field, it is of type {type(field).__name__}." - ) - raise VectorStoreModelException(f"Field {field_name} not found.") - - def get_storage_names(self, include_vector_fields: bool = True, include_key_field: bool = True) -> list[str]: - """Get the names of the fields for the storage. - - Args: - include_vector_fields: Whether to include vector fields. - include_key_field: Whether to include the key field. - - Returns: - list[str]: The names of the fields. - """ - return [ - field.storage_name or field.name - for field in self.fields - if field.field_type == FieldTypes.DATA - or (field.field_type == FieldTypes.VECTOR and include_vector_fields) - or (field.field_type == FieldTypes.KEY and include_key_field) - ] - - def get_names(self, include_vector_fields: bool = True, include_key_field: bool = True) -> list[str]: - """Get the names of the fields. - - Args: - include_vector_fields: Whether to include vector fields. - include_key_field: Whether to include the key field. - - Returns: - list[str]: The names of the fields. - """ - return [ - field.name - for field in self.fields - if field.field_type == FieldTypes.DATA - or (field.field_type == FieldTypes.VECTOR and include_vector_fields) - or (field.field_type == FieldTypes.KEY and include_key_field) - ] - - def model_post_init(self, _: Any): - """Validate the fields. - - Raises: - VectorStoreModelException: If there is a field with an embedding property name - but no corresponding vector field. - VectorStoreModelException: If there is no key field. - """ - if len(self.fields) == 0: - raise VectorStoreModelException( - "There must be at least one field with a VectorStoreRecordField annotation." - ) - for field in self.fields: - if not field.name or field.name == "": - raise VectorStoreModelException("Field names must not be empty.") - if field.field_type == FieldTypes.KEY: - if self.key_name != "": - raise VectorStoreModelException("Memory record definition must have exactly one key field.") - self.key_name = field.name - if not self.key_name: - raise VectorStoreModelException("Memory record definition must have exactly one key field.") - - -# region: Signature parsing functions - - -def _parse_vector_store_record_field_instance(record_field: VectorStoreField, field: Parameter) -> VectorStoreField: - if not record_field.name or record_field.name != field.name: - record_field.name = field.name - if not record_field.type_ and hasattr(field.annotation, "__origin__"): - property_type = field.annotation.__origin__ - if record_field.field_type == FieldTypes.VECTOR: - if args := getattr(property_type, "__args__", None): - if NoneType in args and len(args) > 1: - for arg in args: - if arg is NoneType: - continue - - if ( - (inner_args := getattr(arg, "__args__", None)) - and len(inner_args) == 1 - and inner_args[0] is not NoneType - ): - property_type = inner_args[0] - break - property_type = arg - break - else: - property_type = args[0] - - else: - if (args := getattr(property_type, "__args__", None)) and NoneType in args and len(args) == 2: - property_type = args[0] - - record_field.type_ = str(property_type) if hasattr(property_type, "__args__") else property_type.__name__ - - return record_field - - -def _parse_parameter_to_field(field: Parameter) -> VectorStoreField | None: - # first check if there are any annotations - if field.annotation is not _empty and hasattr(field.annotation, "__metadata__"): - for field_annotation in field.annotation.__metadata__: - if isinstance(field_annotation, VectorStoreField): - return _parse_vector_store_record_field_instance(field_annotation, field) - # This means there are no annotations or that all annotations are of other types. - # we will check if there is a default, otherwise this will cause a runtime error. - # because it will not be stored, and retrieving this object will fail without a default for this field. - if field.default is _empty: - raise VectorStoreModelException( - "Fields that do not have a VectorStoreField annotation must have a default value." - ) - logger.debug(f'Field "{field.name}" does not have a VectorStoreField annotation, will not be part of the record.') - return None - - -def _parse_signature_to_definition( - parameters: MappingProxyType[str, Parameter], collection_name: str | None = None -) -> VectorStoreCollectionDefinition: - if len(parameters) == 0: - raise VectorStoreModelException( - "There must be at least one field in the datamodel. If you are using this with a @dataclass, " - "you might have inverted the order of the decorators, the vectorstoremodel decorator should be the top one." - ) - fields = [] - for param in parameters.values(): - field = _parse_parameter_to_field(param) - if field: - fields.append(field) - - return VectorStoreCollectionDefinition( - fields=fields, - collection_name=collection_name, - ) - - -# region: VectorStoreModel decorator - - -_T = TypeVar("_T") - - -@release_candidate -def vectorstoremodel( - cls: type[_T] | None = None, - collection_name: str | None = None, -) -> type[_T]: - """Returns the class as a vector store model. - - This decorator makes a class a vector store model. - There are three things being checked: - - The class must have at least one field with a annotation, - of type VectorStoreField. - - The class must have exactly one field with the field_type `key`. - - When creating a Vector Field, either supply the property type directly, - or make sure to set the property that you want the index to use first. - - - Args: - cls: The class to be decorated. - collection_name: The name of the collection to be used. - This is used to set the collection name in the VectorStoreCollectionDefinition. - - Raises: - VectorStoreModelException: If there are no fields with a VectorStoreField annotation. - VectorStoreModelException: If there are fields with no name. - VectorStoreModelException: If there is no key field. - """ - - def wrap(cls: type[_T]) -> type[_T]: - # get fields and annotations - cls_sig = signature(cls) - setattr(cls, "__kernel_vectorstoremodel__", True) - setattr( - cls, - "__kernel_vectorstoremodel_definition__", - _parse_signature_to_definition(cls_sig.parameters, collection_name), - ) - - return cls # type: ignore - - # See if we're being called as @vectorstoremodel or @vectorstoremodel(). - if cls is None: - # We're called with parens. - return wrap # type: ignore - - # We're called as @vectorstoremodel without parens. - return wrap(cls) diff --git a/python/semantic_kernel/data/_search.py b/python/semantic_kernel/data/_search.py index 58ad59444da9..94377a9ebb90 100644 --- a/python/semantic_kernel/data/_search.py +++ b/python/semantic_kernel/data/_search.py @@ -1,33 +1,59 @@ # Copyright (c) Microsoft. All rights reserved. -import json -import logging -from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Callable, Mapping, Sequence -from copy import deepcopy -from typing import Annotated, Any, Final, Generic, Literal, Protocol, TypeVar, overload - -from pydantic import BaseModel, ConfigDict, Field, ValidationError - -from semantic_kernel.exceptions import TextSearchException -from semantic_kernel.functions.kernel_function import KernelFunction -from semantic_kernel.functions.kernel_function_decorator import kernel_function -from semantic_kernel.functions.kernel_function_from_method import KernelFunctionFromMethod +# region: Options + + +from abc import ABC +from collections.abc import AsyncIterable, Callable, Mapping +from logging import Logger +from typing import Annotated, Any, Final, Generic, Protocol, TypeVar + +from pydantic import ConfigDict, Field + from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata from semantic_kernel.kernel_pydantic import KernelBaseModel from semantic_kernel.kernel_types import OptionalOneOrList from semantic_kernel.utils.feature_stage_decorator import release_candidate -logger = logging.getLogger(__name__) - +TSearchResult = TypeVar("TSearchResult") TSearchOptions = TypeVar("TSearchOptions", bound="SearchOptions") -DEFAULT_FUNCTION_NAME: Final[str] = "search" -DEFAULT_DESCRIPTION: Final[str] = ( - "Perform a search for content related to the specified query and return string results" + +DEFAULT_RETURN_PARAMETER_METADATA: KernelParameterMetadata = KernelParameterMetadata( + name="results", + description="The search results.", + type="list[str]", + type_object=list, + is_required=True, ) +# region: Text Search -# region: Options +DEFAULT_PARAMETER_METADATA: list[KernelParameterMetadata] = [ + KernelParameterMetadata( + name="query", + description="What to search for.", + type="str", + is_required=True, + type_object=str, + ), + KernelParameterMetadata( + name="top", + description="Number of results to return.", + type="int", + is_required=False, + default_value=2, + type_object=int, + ), + KernelParameterMetadata( + name="skip", + description="Number of results to skip.", + type="int", + is_required=False, + default_value=0, + type_object=int, + ), +] +DEFAULT_FUNCTION_NAME: Final[str] = "search" @release_candidate @@ -47,21 +73,6 @@ class SearchOptions(ABC, KernelBaseModel): ) -# region: Results - - -@release_candidate -class TextSearchResult(KernelBaseModel): - """The result of a text search.""" - - name: str | None = None - value: str | None = None - link: str | None = None - - -TSearchResult = TypeVar("TSearchResult") - - @release_candidate class KernelSearchResults(KernelBaseModel, Generic[TSearchResult]): """The result of a kernel search.""" @@ -89,7 +100,8 @@ def __call__( def create_options( options_class: type["TSearchOptions"], - options: "SearchOptions | None", + options: SearchOptions | None, + logger: Logger | None = None, **kwargs: Any, ) -> "TSearchOptions": """Create search options. @@ -102,6 +114,7 @@ def create_options( Args: options_class: The class of the options. options: The existing options to update. + logger: The logger to use for warnings. **kwargs: The keyword arguments to use to create the options. Returns: @@ -124,7 +137,8 @@ def create_options( except Exception: # This is very unlikely to happen, but if it does, we will just create new options. # one reason this could happen is if a different class is passed that has no model_dump method - logger.warning("Options are not valid. Creating new options from just kwargs.") + if logger: + logger.warning("Options are not valid. Creating new options from just kwargs.") kwargs.update(additional_kwargs) return options_class.model_validate(kwargs) @@ -172,335 +186,3 @@ def default_dynamic_filter_function( filter = [filter, new_filter] return filter - - -# region: Text Search - - -@release_candidate -class TextSearch: - """The base class for all text searchers.""" - - @property - def options_class(self) -> type["SearchOptions"]: - """The options class for the search.""" - return SearchOptions - - @staticmethod - def _default_parameter_metadata() -> list[KernelParameterMetadata]: - """Default parameter metadata for text search functions. - - This function should be overridden when necessary. - """ - return [ - KernelParameterMetadata( - name="query", - description="What to search for.", - type="str", - is_required=True, - type_object=str, - ), - KernelParameterMetadata( - name="top", - description="Number of results to return.", - type="int", - is_required=False, - default_value=2, - type_object=int, - ), - KernelParameterMetadata( - name="skip", - description="Number of results to skip.", - type="int", - is_required=False, - default_value=0, - type_object=int, - ), - ] - - @staticmethod - def _default_return_parameter_metadata() -> KernelParameterMetadata: - """Default return parameter metadata for text search functions. - - This function should be overridden by subclasses. - """ - return KernelParameterMetadata( - name="results", - description="The search results.", - type="list[str]", - type_object=list, - is_required=True, - ) - - # region: Public methods - - @overload - def create_search_function( - self, - function_name: str = DEFAULT_FUNCTION_NAME, - description: str = DEFAULT_DESCRIPTION, - *, - output_type: Literal["str"] = "str", - parameters: list[KernelParameterMetadata] | None = None, - return_parameter: KernelParameterMetadata | None = None, - filter: OptionalOneOrList[Callable | str] = None, - top: int = 5, - skip: int = 0, - include_total_count: bool = False, - filter_update_function: DynamicFilterFunction | None = None, - string_mapper: Callable[[TSearchResult], str] | None = None, - ) -> KernelFunction: - """Create a kernel function from a search function. - - Args: - output_type: The type of the output, default is "str". - function_name: The name of the function, to be used in the kernel, default is "search". - description: The description of the function, a default is provided. - parameters: The parameters for the function, a list of KernelParameterMetadata. - return_parameter: The return parameter for the function. - filter: The filter to use for the search. - top: The number of results to return. - skip: The number of results to skip. - include_total_count: Whether to include the total count of results. - filter_update_function: A function to update the search filters. - The function should return the updated filter. - The default function uses the parameters and the kwargs to update the options. - Adding equal to filters to the options for all parameters that are not "query". - As well as adding equal to filters for parameters that have a default value. - string_mapper: The function to map the search results. (the inner part of the KernelSearchResults type, - related to which search type you are using) to strings. - - Returns: - KernelFunction: The kernel function. - - """ - ... - - @overload - def create_search_function( - self, - function_name: str = DEFAULT_FUNCTION_NAME, - description: str = DEFAULT_DESCRIPTION, - *, - output_type: Literal["TextSearchResult"], - parameters: list[KernelParameterMetadata] | None = None, - return_parameter: KernelParameterMetadata | None = None, - filter: OptionalOneOrList[Callable | str] = None, - top: int = 5, - skip: int = 0, - include_total_count: bool = False, - filter_update_function: DynamicFilterFunction | None = None, - ) -> KernelFunction: - """Create a kernel function from a search function. - - Args: - output_type: The type of the output, in this case TextSearchResult. - function_name: The name of the function, to be used in the kernel, default is "search". - description: The description of the function, a default is provided. - parameters: The parameters for the function, a list of KernelParameterMetadata. - return_parameter: The return parameter for the function. - filter: The filter to use for the search. - top: The number of results to return. - skip: The number of results to skip. - include_total_count: Whether to include the total count of results. - filter_update_function: A function to update the search filters. - The function should return the updated filter. - The default function uses the parameters and the kwargs to update the options. - Adding equal to filters to the options for all parameters that are not "query". - As well as adding equal to filters for parameters that have a default value. - string_mapper: The function to map the TextSearchResult to strings. - for instance taking the value out of the results and just returning that, - otherwise a json-like string is returned. - - Returns: - KernelFunction: The kernel function. - - """ - ... - - @overload - def create_search_function( - self, - function_name: str = DEFAULT_FUNCTION_NAME, - description: str = DEFAULT_DESCRIPTION, - *, - output_type: Literal["Any"], - parameters: list[KernelParameterMetadata] | None = None, - return_parameter: KernelParameterMetadata | None = None, - filter: OptionalOneOrList[Callable | str] = None, - top: int = 5, - skip: int = 0, - include_total_count: bool = False, - filter_update_function: DynamicFilterFunction | None = None, - ) -> KernelFunction: - """Create a kernel function from a search function. - - Args: - function_name: The name of the function, to be used in the kernel, default is "search". - description: The description of the function, a default is provided. - output_type: The type of the output, in this case Any. - Any means that the results from the store are used directly. - The string_mapper can then be used to extract certain fields. - parameters: The parameters for the function, a list of KernelParameterMetadata. - return_parameter: The return parameter for the function. - filter: The filter to use for the search. - top: The number of results to return. - skip: The number of results to skip. - include_total_count: Whether to include the total count of results. - filter_update_function: A function to update the search filters. - The function should return the updated filter. - The default function uses the parameters and the kwargs to update the options. - Adding equal to filters to the options for all parameters that are not "query". - As well as adding equal to filters for parameters that have a default value. - string_mapper: The function to map the raw search results to strings. - When using this from a vector store, your results are of type - VectorSearchResult[TModel], - so the string_mapper can be used to extract the fields you want from the result. - The default is to use the model_dump_json method of the result, which will return a json-like string. - - Returns: - KernelFunction: The kernel function. - """ - ... - - def create_search_function( - self, - function_name=DEFAULT_FUNCTION_NAME, - description=DEFAULT_DESCRIPTION, - *, - output_type="str", - parameters=None, - return_parameter=None, - filter=None, - top=5, - skip=0, - include_total_count=False, - filter_update_function=None, - string_mapper=None, - ) -> KernelFunction: - """Create a kernel function from a search function.""" - options = SearchOptions( - filter=filter, - skip=skip, - top=top, - include_total_count=include_total_count, - ) - match output_type: - case "str": - return self._create_kernel_function( - output_type=str, - options=options, - parameters=parameters, - filter_update_function=filter_update_function, - return_parameter=return_parameter, - function_name=function_name, - description=description, - string_mapper=string_mapper, - ) - case "TextSearchResult": - return self._create_kernel_function( - output_type=TextSearchResult, - options=options, - parameters=parameters, - filter_update_function=filter_update_function, - return_parameter=return_parameter, - function_name=function_name, - description=description, - string_mapper=string_mapper, - ) - case "Any": - return self._create_kernel_function( - output_type="Any", - options=options, - parameters=parameters, - filter_update_function=filter_update_function, - return_parameter=return_parameter, - function_name=function_name, - description=description, - string_mapper=string_mapper, - ) - case _: - raise TextSearchException( - f"Unknown output type: {output_type}. Must be 'str', 'TextSearchResult', or 'Any'." - ) - - # endregion - # region: Private methods - - def _create_kernel_function( - self, - output_type: type[str] | type[TSearchResult] | Literal["Any"] = str, - options: SearchOptions | None = None, - parameters: list[KernelParameterMetadata] | None = None, - filter_update_function: DynamicFilterFunction | None = None, - return_parameter: KernelParameterMetadata | None = None, - function_name: str = DEFAULT_FUNCTION_NAME, - description: str = DEFAULT_DESCRIPTION, - string_mapper: Callable[[TSearchResult], str] | None = None, - ) -> KernelFunction: - """Create a kernel function from a search function.""" - update_func = filter_update_function or default_dynamic_filter_function - - @kernel_function(name=function_name, description=description) - async def search_wrapper(**kwargs: Any) -> Sequence[str]: - query = kwargs.pop("query", "") - try: - inner_options = create_options(SearchOptions, deepcopy(options), **kwargs) - except ValidationError: - # this usually only happens when the kwargs are invalid, so blank options in this case. - inner_options = SearchOptions() - inner_options.filter = update_func(filter=inner_options.filter, parameters=parameters, **kwargs) - try: - results = await self.search( - query=query, - output_type=output_type, - **inner_options.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True), - ) - except Exception as e: - msg = f"Exception in search function: {e}" - logger.error(msg) - raise TextSearchException(msg) from e - return await self._map_results(results, string_mapper) - - return KernelFunctionFromMethod( - method=search_wrapper, - parameters=self._default_parameter_metadata() if parameters is None else parameters, - return_parameter=return_parameter or self._default_return_parameter_metadata(), - ) - - async def _map_results( - self, - results: KernelSearchResults[TSearchResult], - string_mapper: Callable[[TSearchResult], str] | None = None, - ) -> list[str]: - """Map search results to strings.""" - if string_mapper: - return [string_mapper(result) async for result in results.results] - return [self._default_map_to_string(result) async for result in results.results] - - @staticmethod - def _default_map_to_string(result: BaseModel | object) -> str: - """Default mapping function for text search results.""" - if isinstance(result, BaseModel): - return result.model_dump_json() - return result if isinstance(result, str) else json.dumps(result) - - # region: Abstract methods - - @abstractmethod - async def search( - self, - query: str, - output_type: type[str] | type[TSearchResult] | Literal["Any"] = str, - **kwargs: Any, - ) -> "KernelSearchResults[TSearchResult]": - """Search for text, returning a KernelSearchResult with a list of strings. - - Args: - query: The query to search for. - output_type: The type of the output, default is str. - Can also be TextSearchResult or Any. - **kwargs: Additional keyword arguments to pass to the search function. - - """ - ... diff --git a/python/semantic_kernel/data/text_search.py b/python/semantic_kernel/data/text_search.py new file mode 100644 index 000000000000..29350ca2a82c --- /dev/null +++ b/python/semantic_kernel/data/text_search.py @@ -0,0 +1,348 @@ +# Copyright (c) Microsoft. All rights reserved. + +import json +import logging +from abc import abstractmethod +from collections.abc import Callable, Sequence +from copy import deepcopy +from typing import Any, Final, Literal, TypeVar, overload + +from pydantic import BaseModel, ValidationError + +from semantic_kernel.data._search import ( + DEFAULT_FUNCTION_NAME, + DEFAULT_PARAMETER_METADATA, + DEFAULT_RETURN_PARAMETER_METADATA, + DynamicFilterFunction, + KernelSearchResults, + SearchOptions, + create_options, + default_dynamic_filter_function, +) +from semantic_kernel.exceptions import TextSearchException +from semantic_kernel.functions.kernel_function import KernelFunction +from semantic_kernel.functions.kernel_function_decorator import kernel_function +from semantic_kernel.functions.kernel_function_from_method import KernelFunctionFromMethod +from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata +from semantic_kernel.kernel_pydantic import KernelBaseModel +from semantic_kernel.kernel_types import OptionalOneOrList +from semantic_kernel.utils.feature_stage_decorator import release_candidate + +logger = logging.getLogger(__name__) + +TSearchOptions = TypeVar("TSearchOptions", bound="SearchOptions") + +DEFAULT_DESCRIPTION: Final[str] = ( + "Perform a search for content related to the specified query and return string results" +) + +# region: Results + + +@release_candidate +class TextSearchResult(KernelBaseModel): + """The result of a text search.""" + + name: str | None = None + value: str | None = None + link: str | None = None + + +TSearchResult = TypeVar("TSearchResult") + + +@release_candidate +class TextSearch: + """The base class for all text searchers.""" + + @property + def options_class(self) -> type["SearchOptions"]: + """The options class for the search.""" + return SearchOptions + + # region: Public methods + + @overload + def create_search_function( + self, + function_name: str = DEFAULT_FUNCTION_NAME, + description: str = DEFAULT_DESCRIPTION, + *, + output_type: Literal["str"] = "str", + parameters: list[KernelParameterMetadata] | None = None, + return_parameter: KernelParameterMetadata | None = None, + filter: OptionalOneOrList[Callable | str] = None, + top: int = 5, + skip: int = 0, + include_total_count: bool = False, + filter_update_function: DynamicFilterFunction | None = None, + string_mapper: Callable[[TSearchResult], str] | None = None, + ) -> KernelFunction: + """Create a kernel function from a search function. + + Args: + output_type: The type of the output, default is "str". + function_name: The name of the function, to be used in the kernel, default is "search". + description: The description of the function, a default is provided. + parameters: The parameters for the function, a list of KernelParameterMetadata. + return_parameter: The return parameter for the function. + filter: The filter to use for the search. + top: The number of results to return. + skip: The number of results to skip. + include_total_count: Whether to include the total count of results. + filter_update_function: A function to update the search filters. + The function should return the updated filter. + The default function uses the parameters and the kwargs to update the options. + Adding equal to filters to the options for all parameters that are not "query". + As well as adding equal to filters for parameters that have a default value. + string_mapper: The function to map the search results. (the inner part of the KernelSearchResults type, + related to which search type you are using) to strings. + + Returns: + KernelFunction: The kernel function. + + """ + ... + + @overload + def create_search_function( + self, + function_name: str = DEFAULT_FUNCTION_NAME, + description: str = DEFAULT_DESCRIPTION, + *, + output_type: Literal["TextSearchResult"], + parameters: list[KernelParameterMetadata] | None = None, + return_parameter: KernelParameterMetadata | None = None, + filter: OptionalOneOrList[Callable | str] = None, + top: int = 5, + skip: int = 0, + include_total_count: bool = False, + filter_update_function: DynamicFilterFunction | None = None, + ) -> KernelFunction: + """Create a kernel function from a search function. + + Args: + output_type: The type of the output, in this case TextSearchResult. + function_name: The name of the function, to be used in the kernel, default is "search". + description: The description of the function, a default is provided. + parameters: The parameters for the function, a list of KernelParameterMetadata. + return_parameter: The return parameter for the function. + filter: The filter to use for the search. + top: The number of results to return. + skip: The number of results to skip. + include_total_count: Whether to include the total count of results. + filter_update_function: A function to update the search filters. + The function should return the updated filter. + The default function uses the parameters and the kwargs to update the options. + Adding equal to filters to the options for all parameters that are not "query". + As well as adding equal to filters for parameters that have a default value. + string_mapper: The function to map the TextSearchResult to strings. + for instance taking the value out of the results and just returning that, + otherwise a json-like string is returned. + + Returns: + KernelFunction: The kernel function. + + """ + ... + + @overload + def create_search_function( + self, + function_name: str = DEFAULT_FUNCTION_NAME, + description: str = DEFAULT_DESCRIPTION, + *, + output_type: Literal["Any"], + parameters: list[KernelParameterMetadata] | None = None, + return_parameter: KernelParameterMetadata | None = None, + filter: OptionalOneOrList[Callable | str] = None, + top: int = 5, + skip: int = 0, + include_total_count: bool = False, + filter_update_function: DynamicFilterFunction | None = None, + ) -> KernelFunction: + """Create a kernel function from a search function. + + Args: + function_name: The name of the function, to be used in the kernel, default is "search". + description: The description of the function, a default is provided. + output_type: The type of the output, in this case Any. + Any means that the results from the store are used directly. + The string_mapper can then be used to extract certain fields. + parameters: The parameters for the function, a list of KernelParameterMetadata. + return_parameter: The return parameter for the function. + filter: The filter to use for the search. + top: The number of results to return. + skip: The number of results to skip. + include_total_count: Whether to include the total count of results. + filter_update_function: A function to update the search filters. + The function should return the updated filter. + The default function uses the parameters and the kwargs to update the options. + Adding equal to filters to the options for all parameters that are not "query". + As well as adding equal to filters for parameters that have a default value. + string_mapper: The function to map the raw search results to strings. + When using this from a vector store, your results are of type + VectorSearchResult[TModel], + so the string_mapper can be used to extract the fields you want from the result. + The default is to use the model_dump_json method of the result, which will return a json-like string. + + Returns: + KernelFunction: The kernel function. + """ + ... + + def create_search_function( + self, + function_name=DEFAULT_FUNCTION_NAME, + description=DEFAULT_DESCRIPTION, + *, + output_type="str", + parameters=None, + return_parameter=None, + filter=None, + top=5, + skip=0, + include_total_count=False, + filter_update_function=None, + string_mapper=None, + ) -> KernelFunction: + """Create a kernel function from a search function.""" + options = SearchOptions( + filter=filter, + skip=skip, + top=top, + include_total_count=include_total_count, + ) + match output_type: + case "str": + return self._create_kernel_function( + output_type=str, + options=options, + parameters=parameters, + filter_update_function=filter_update_function, + return_parameter=return_parameter, + function_name=function_name, + description=description, + string_mapper=string_mapper, + ) + case "TextSearchResult": + return self._create_kernel_function( + output_type=TextSearchResult, + options=options, + parameters=parameters, + filter_update_function=filter_update_function, + return_parameter=return_parameter, + function_name=function_name, + description=description, + string_mapper=string_mapper, + ) + case "Any": + return self._create_kernel_function( + output_type="Any", + options=options, + parameters=parameters, + filter_update_function=filter_update_function, + return_parameter=return_parameter, + function_name=function_name, + description=description, + string_mapper=string_mapper, + ) + case _: + raise TextSearchException( + f"Unknown output type: {output_type}. Must be 'str', 'TextSearchResult', or 'Any'." + ) + + # endregion + # region: Private methods + + def _create_kernel_function( + self, + output_type: type[str] | type[TSearchResult] | Literal["Any"] = str, + options: SearchOptions | None = None, + parameters: list[KernelParameterMetadata] | None = None, + filter_update_function: DynamicFilterFunction | None = None, + return_parameter: KernelParameterMetadata | None = None, + function_name: str = DEFAULT_FUNCTION_NAME, + description: str = DEFAULT_DESCRIPTION, + string_mapper: Callable[[TSearchResult], str] | None = None, + ) -> KernelFunction: + """Create a kernel function from a search function.""" + update_func = filter_update_function or default_dynamic_filter_function + + @kernel_function(name=function_name, description=description) + async def search_wrapper(**kwargs: Any) -> Sequence[str]: + query = kwargs.pop("query", "") + try: + inner_options = create_options(SearchOptions, deepcopy(options), **kwargs) + except ValidationError: + # this usually only happens when the kwargs are invalid, so blank options in this case. + inner_options = SearchOptions() + inner_options.filter = update_func(filter=inner_options.filter, parameters=parameters, **kwargs) + try: + results = await self.search( + query=query, + output_type=output_type, + **inner_options.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True), + ) + except Exception as e: + msg = f"Exception in search function: {e}" + logger.error(msg) + raise TextSearchException(msg) from e + return await self._map_results(results, string_mapper) + + return KernelFunctionFromMethod( + method=search_wrapper, + parameters=DEFAULT_PARAMETER_METADATA if parameters is None else parameters, + return_parameter=return_parameter or DEFAULT_RETURN_PARAMETER_METADATA, + ) + + async def _map_results( + self, + results: KernelSearchResults[TSearchResult], + string_mapper: Callable[[TSearchResult], str] | None = None, + ) -> list[str]: + """Map search results to strings.""" + if string_mapper: + return [string_mapper(result) async for result in results.results] + return [self._default_map_to_string(result) async for result in results.results] + + @staticmethod + def _default_map_to_string(result: BaseModel | object) -> str: + """Default mapping function for text search results.""" + if isinstance(result, BaseModel): + return result.model_dump_json() + return result if isinstance(result, str) else json.dumps(result) + + # region: Abstract methods + + @abstractmethod + async def search( + self, + query: str, + output_type: type[str] | type[TSearchResult] | Literal["Any"] = str, + **kwargs: Any, + ) -> "KernelSearchResults[TSearchResult]": + """Search for text, returning a KernelSearchResult with a list of strings. + + Args: + query: The query to search for. + output_type: The type of the output, default is str. + Can also be TextSearchResult or Any. + **kwargs: Additional keyword arguments to pass to the search function. + + """ + ... + + +__all__ = [ + "DEFAULT_DESCRIPTION", + "DEFAULT_FUNCTION_NAME", + "DEFAULT_PARAMETER_METADATA", + "DEFAULT_RETURN_PARAMETER_METADATA", + "DynamicFilterFunction", + "KernelSearchResults", + "TextSearch", + "TextSearchResult", + "create_options", + "default_dynamic_filter_function", +] diff --git a/python/semantic_kernel/data/_vectors.py b/python/semantic_kernel/data/vectors.py similarity index 69% rename from python/semantic_kernel/data/_vectors.py rename to python/semantic_kernel/data/vectors.py index eae2be7325ec..b63e6871e9da 100644 --- a/python/semantic_kernel/data/_vectors.py +++ b/python/semantic_kernel/data/vectors.py @@ -8,27 +8,24 @@ from ast import AST, Lambda, NodeVisitor, expr, parse from collections.abc import AsyncIterable, Callable, Mapping, Sequence from copy import deepcopy +from dataclasses import dataclass from enum import Enum -from inspect import getsource -from typing import Annotated, Any, ClassVar, Final, Generic, Literal, TypeVar, overload +from inspect import Parameter, _empty, getsource, signature +from types import MappingProxyType, NoneType +from typing import Annotated, Any, ClassVar, Final, Generic, Literal, Protocol, TypeVar, overload, runtime_checkable from pydantic import BaseModel, Field, ValidationError, model_validator -from pydantic.dataclasses import dataclass +from pydantic.dataclasses import dataclass as pyd_dataclass from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings -from semantic_kernel.data._definitions import ( - FieldTypes, - SerializeMethodProtocol, - VectorStoreCollectionDefinition, - VectorStoreField, -) from semantic_kernel.data._search import ( DEFAULT_FUNCTION_NAME, + DEFAULT_PARAMETER_METADATA, + DEFAULT_RETURN_PARAMETER_METADATA, DynamicFilterFunction, KernelSearchResults, SearchOptions, - TextSearch, create_options, default_dynamic_filter_function, ) @@ -63,1402 +60,2024 @@ TModel = TypeVar("TModel", bound=object) TKey = TypeVar("TKey") _T = TypeVar("_T", bound="VectorStoreRecordHandler") -TSearchOptions = TypeVar("TSearchOptions", bound=SearchOptions) TFilters = TypeVar("TFilters") -# region: Helpers DEFAULT_DESCRIPTION: Final[str] = ( "Perform a vector search for data in a vector store, using the provided search options." ) -def _get_collection_name_from_model( - record_type: type[TModel], - definition: VectorStoreCollectionDefinition | None = None, -) -> str | None: - """Get the collection name from the data model type or definition.""" - if record_type and not definition: - definition = getattr(record_type, "__kernel_vectorstoremodel_definition__", None) - if definition and definition.collection_name: - return definition.collection_name - return None +# region: Fields and Collection Definitions -@dataclass -class GetFilteredRecordOptions: - """Options for filtering records. +@release_candidate +class FieldTypes(str, Enum): + """Enumeration for field types in vector store models.""" - Args: - top: The maximum number of records to return. - skip: The number of records to skip. - order_by: A dictionary with fields names and a bool, True means ascending, False means descending. - """ + KEY = "key" + VECTOR = "vector" + DATA = "data" - top: int = 10 - skip: int = 0 - order_by: Mapping[str, bool] | None = None + def __str__(self) -> str: + """Return the string representation of the enum.""" + return self.value -class LambdaVisitor(NodeVisitor, Generic[TFilters]): - """Visitor class to visit the AST nodes.""" +@runtime_checkable +class SerializeMethodProtocol(Protocol): + """Data model serialization protocol. - def __init__(self, lambda_parser: Callable[[expr], TFilters], output_filters: list[TFilters] | None = None) -> None: - """Initialize the visitor with a lambda parser and output filters.""" - self.lambda_parser = lambda_parser - self.output_filters = output_filters if output_filters is not None else [] + This can optionally be implemented to allow single step serialization and deserialization + for using your data model with a specific datastore. + """ - def visit_Lambda(self, node: Lambda) -> None: - """This method is called when a lambda expression is found.""" - self.output_filters.append(self.lambda_parser(node.body)) + def serialize(self, **kwargs: Any) -> Any: + """Serialize the object to the format required by the data store.""" + ... # pragma: no cover -@release_candidate -class SearchType(str, Enum): - """Enumeration for search types. +@runtime_checkable +class ToDictFunctionProtocol(Protocol): + """Protocol for to_dict function. - Contains: vector and keyword_hybrid. + Args: + record: The record to be serialized. + **kwargs: Additional keyword arguments. + + Returns: + A list of dictionaries. """ - VECTOR = "vector" - KEYWORD_HYBRID = "keyword_hybrid" + def __call__(self, record: Any, **kwargs: Any) -> Sequence[dict[str, Any]]: ... # pragma: no cover -@release_candidate -class VectorSearchOptions(SearchOptions): - """Options for vector search, builds on TextSearchOptions. +@runtime_checkable +class FromDictFunctionProtocol(Protocol): + """Protocol for from_dict function. - When multiple filters are used, they are combined with an AND operator. + Args: + records: A list of dictionaries. + **kwargs: Additional keyword arguments. + + Returns: + A record or list thereof. """ - vector_property_name: str | None = None - additional_property_name: str | None = None - top: Annotated[int, Field(gt=0)] = 3 - include_vectors: bool = False + def __call__(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Any: ... -@release_candidate -class VectorSearchResult(KernelBaseModel, Generic[TModel]): - """The result of a vector search.""" +@runtime_checkable +class SerializeFunctionProtocol(Protocol): + """Protocol for serialize function. - record: TModel - score: float | None = None + Args: + record: The record to be serialized. + **kwargs: Additional keyword arguments. + Returns: + The serialized record, ready to be consumed by the specific store. -# region: VectorStoreRecordHandler + """ + def __call__(self, record: Any, **kwargs: Any) -> Any: ... -@release_candidate -class VectorStoreRecordHandler(KernelBaseModel, Generic[TKey, TModel]): - """Vector Store Record Handler class. - This class is used to serialize and deserialize records to and from a vector store. - As well as validating the data model against the vector store. - It is subclassed by VectorStoreRecordCollection and VectorSearchBase. +@runtime_checkable +class DeserializeFunctionProtocol(Protocol): + """Protocol for deserialize function. + + Args: + records: The serialized record directly from the store. + **kwargs: Additional keyword arguments. + + Returns: + The deserialized record in the format expected by the application. + """ - record_type: type[TModel] - definition: VectorStoreCollectionDefinition - supported_key_types: ClassVar[set[str] | None] = None - supported_vector_types: ClassVar[set[str] | None] = None - embedding_generator: EmbeddingGeneratorBase | None = None + def __call__(self, records: Any, **kwargs: Any) -> Any: ... - @property - def _key_field_name(self) -> str: - return self.definition.key_name - @property - def _key_field_storage_name(self) -> str: - return self.definition.key_field.storage_name or self.definition.key_name +@runtime_checkable +class ToDictMethodProtocol(Protocol): + """Class used internally to check if a model has a to_dict method.""" - @property - def _container_mode(self) -> bool: - return self.definition.container_mode + def to_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + """Serialize the object to the format required by the data store.""" + ... # pragma: no cover - @model_validator(mode="before") - @classmethod - def _ensure_definition(cls: type[_T], data: Any) -> dict[str, Any]: - """Ensure there is a data model definition, if it isn't passed, try to get it from the data model type.""" - if isinstance(data, dict) and not data.get("definition"): - data["definition"] = getattr(data["record_type"], "__kernel_vectorstoremodel_definition__", None) - return data - def model_post_init(self, __context: object | None = None): - """Post init function that sets the key field and container mode values, and validates the datamodel.""" - self._validate_data_model() +class IndexKind(str, Enum): + """Index kinds for similarity search. - def _validate_data_model(self): - """Internal function that can be overloaded by child classes to validate datatypes, etc. + HNSW + Hierarchical Navigable Small World which performs an approximate nearest neighbor (ANN) search. + Lower accuracy than exhaustive k nearest neighbor, but faster and more efficient. - This should take the VectorStoreRecordDefinition from the item_type and validate it against the store. + Flat + Does a brute force search to find the nearest neighbors. + Calculates the distances between all pairs of data points, so has a linear time complexity, + that grows directly proportional to the number of points. + Also referred to as exhaustive k nearest neighbor in some databases. + High recall accuracy, but slower and more expensive than HNSW. + Better with smaller datasets. - Checks can include, allowed naming of parameters, allowed data types, allowed vector dimensions. + IVF Flat + Inverted File with Flat Compression. + Designed to enhance search efficiency by narrowing the search area + through the use of neighbor partitions or clusters. + Also referred to as approximate nearest neighbor (ANN) search. - Default checks are that the key field is in the allowed key types and the vector fields - are in the allowed vector types. + Disk ANN + Disk-based Approximate Nearest Neighbor algorithm designed for efficiently searching + for approximate nearest neighbors (ANN) in high-dimensional spaces. + The primary focus of DiskANN is to handle large-scale datasets that cannot fit entirely + into memory, leveraging disk storage to store the data while maintaining fast search times. - Raises: - VectorStoreModelValidationError: If the key field is not in the allowed key types. - VectorStoreModelValidationError: If the vector fields are not in the allowed vector types. + Quantized Flat + Index that compresses vectors using DiskANN-based quantization methods for better efficiency in the kNN search. - """ - if ( - self.supported_key_types - and self.definition.key_field.type_ - and self.definition.key_field.type_ not in self.supported_key_types - ): - raise VectorStoreModelValidationError( - f"Key field must be one of {self.supported_key_types}, got {self.definition.key_field.type_}" - ) - if not self.supported_vector_types: - return - for field in self.definition.vector_fields: - if field.type_ and field.type_ not in self.supported_vector_types: - raise VectorStoreModelValidationError( - f"Vector field {field.name} must be one of {self.supported_vector_types}, got {field.type_}" - ) + Dynamic + Dynamic index allows to automatically switch from FLAT to HNSW indexes. - @abstractmethod - def _serialize_dicts_to_store_models(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Sequence[Any]: - """Serialize a list of dicts of the data to the store model. + Default + Default index type. + Used when no index type is specified. + Will differ per vector store. - This method should be overridden by the child class to convert the dict to the store model. - """ - ... # pragma: no cover + """ - @abstractmethod - def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: Any) -> Sequence[dict[str, Any]]: - """Deserialize the store models to a list of dicts. + HNSW = "hnsw" + FLAT = "flat" + IVF_FLAT = "ivf_flat" + DISK_ANN = "disk_ann" + QUANTIZED_FLAT = "quantized_flat" + DYNAMIC = "dynamic" + DEFAULT = "default" - This method should be overridden by the child class to convert the store model to a list of dicts. - """ - ... # pragma: no cover - # region Serialization methods +class DistanceFunction(str, Enum): + """Distance functions for similarity search. - async def serialize(self, records: OneOrMany[TModel], **kwargs: Any) -> OneOrMany[Any]: - """Serialize the data model to the store model. + Cosine Similarity + the cosine (angular) similarity between two vectors + measures only the angle between the two vectors, without taking into account the length of the vectors + Cosine Similarity = 1 - Cosine Distance + -1 means vectors are opposite + 0 means vectors are orthogonal + 1 means vectors are identical + Cosine Distance + the cosine (angular) distance between two vectors + measures only the angle between the two vectors, without taking into account the length of the vectors + Cosine Distance = 1 - Cosine Similarity + 2 means vectors are opposite + 1 means vectors are orthogonal + 0 means vectors are identical + Dot Product + measures both the length and angle between two vectors + same as cosine similarity if the vectors are the same length, but more performant + Euclidean Distance + measures the Euclidean distance between two vectors + also known as l2-norm + Euclidean Squared Distance + measures the Euclidean squared distance between two vectors + also known as l2-squared + Manhattan + measures the Manhattan distance between two vectors + Hamming + number of differences between vectors at each dimensions + DEFAULT + default distance function + used when no distance function is specified + will differ per vector store. + """ - This method follows the following steps: - 1. Check if the data model has a serialize method. - Use that method to serialize and return the result. - 2. Serialize the records into a dict, using the data model specific method. - 3. Convert the dict to the store model, using the store specific method. + COSINE_SIMILARITY = "cosine_similarity" + COSINE_DISTANCE = "cosine_distance" + DOT_PROD = "dot_prod" + EUCLIDEAN_DISTANCE = "euclidean_distance" + EUCLIDEAN_SQUARED_DISTANCE = "euclidean_squared_distance" + MANHATTAN = "manhattan" + HAMMING = "hamming" + DEFAULT = "DEFAULT" - If overriding this method, make sure to first try to serialize the data model to the store model, - before doing the store specific version, - the user supplied version should have precedence. - Raises: - VectorStoreModelSerializationException: If an error occurs during serialization. +DISTANCE_FUNCTION_DIRECTION_HELPER: Final[dict[DistanceFunction, Callable[[int | float, int | float], bool]]] = { + DistanceFunction.COSINE_SIMILARITY: operator.gt, + DistanceFunction.COSINE_DISTANCE: operator.le, + DistanceFunction.DOT_PROD: operator.gt, + DistanceFunction.EUCLIDEAN_DISTANCE: operator.le, + DistanceFunction.EUCLIDEAN_SQUARED_DISTANCE: operator.le, + DistanceFunction.MANHATTAN: operator.le, + DistanceFunction.HAMMING: operator.le, +} + + +@release_candidate +@dataclass +class VectorStoreField: + """Vector store fields.""" + + field_type: Literal[FieldTypes.DATA, FieldTypes.KEY, FieldTypes.VECTOR] = FieldTypes.DATA + name: str = "" + storage_name: str | None = None + type_: str | None = None + # data specific fields (all optional) + is_indexed: bool | None = None + is_full_text_indexed: bool | None = None + # vector specific fields (dimensions is mandatory) + dimensions: int | None = None + embedding_generator: EmbeddingGeneratorBase | None = None + # defaults for these fields are not set here, because they are not relevant for data and key types + index_kind: IndexKind | None = None + distance_function: DistanceFunction | None = None + @overload + def __init__( + self, + field_type: Literal[FieldTypes.KEY, "key"] = FieldTypes.KEY, # type: ignore[assignment] + *, + name: str | None = None, + type: str | None = None, + storage_name: str | None = None, + ): + """Key field of the record. + + When the key will be auto-generated by the store, make sure it has a default, usually None. + + Args: + field_type: always "key". + name: The name of the field. + storage_name: The name of the field in the store, uses the field name by default. + type: The type of the field. """ - try: - if serialized := self._serialize_data_model_to_store_model(records): - return serialized - except VectorStoreModelSerializationException: - raise # pragma: no cover - except Exception as exc: - raise VectorStoreModelSerializationException(f"Error serializing records: {exc}") from exc + ... - try: - dict_records: list[dict[str, Any]] = [] - if not isinstance(records, list): - records = [records] # type: ignore - for rec in records: - dict_rec = self._serialize_data_model_to_dict(rec) - if isinstance(dict_rec, list): - dict_records.extend(dict_rec) - else: - dict_records.append(dict_rec) - except VectorStoreModelSerializationException: - raise # pragma: no cover - except Exception as exc: - raise VectorStoreModelSerializationException(f"Error serializing records: {exc}") from exc + @overload + def __init__( + self, + field_type: Literal[FieldTypes.DATA, "data"] = FieldTypes.DATA, # type: ignore[assignment] + *, + name: str | None = None, + type: str | None = None, + storage_name: str | None = None, + is_indexed: bool | None = None, + is_full_text_indexed: bool | None = None, + ): + """Data field in the record. - # add vectors - try: - dict_records = await self._add_vectors_to_records(dict_records) # type: ignore - except (VectorStoreModelException, VectorStoreOperationException): - raise - except Exception as exc: - raise VectorStoreOperationException( - "Exception occurred while trying to add the vectors to the records." - ) from exc + Args: + field_type: always "data". + name: The name of the field. + storage_name: The name of the field in the store, uses the field name by default. + type: The type of the field. + is_indexed: Whether the field is indexed. + is_full_text_indexed: Whether the field is full text indexed. + """ + ... - try: - return self._serialize_dicts_to_store_models(dict_records, **kwargs) # type: ignore - except VectorStoreModelSerializationException: - raise # pragma: no cover - except Exception as exc: - raise VectorStoreModelSerializationException(f"Error serializing records: {exc}") from exc + @overload + def __init__( + self, + field_type: Literal[FieldTypes.VECTOR, "vector"] = FieldTypes.VECTOR, # type: ignore[assignment] + *, + name: str | None = None, + type: str | None = None, + dimensions: Annotated[int, Field(gt=0)], + storage_name: str | None = None, + index_kind: IndexKind | None = None, + distance_function: DistanceFunction | None = None, + embedding_generator: EmbeddingGeneratorBase | None = None, + ): + """Vector field in the record. - def _serialize_data_model_to_store_model(self, record: OneOrMany[TModel], **kwargs: Any) -> OneOrMany[Any] | None: - """Serialize the data model to the store model. + This field should contain the value you want to use for the vector. + When passing in the embedding generator, the embedding will be + generated locally before upserting. + If this is not set, the store should support generating the embedding for you. + If you want to retrieve the original content of the vector, + make sure to set this field twice, + once with the VectorStoreRecordDataField and once with the VectorStoreRecordVectorField. - This works when the data model has supplied a serialize method, specific to a data source. - This is a method called 'serialize()' on the data model or part of the vector store record definition. + If you want to be able to get the vectors back, make sure the type allows this, especially for pydantic models. + For instance, if the input is a string, then the type annotation should be `str | list[float] | None`. - The developer is responsible for correctly serializing for the specific data source. + If you want to cast the vector that is returned, you need to set the deserialize_function, + for instance: `deserialize_function=np.array`, (with `import numpy as np` at the top of your file). + If you want to set it up with more specific options, use a lambda, a custom function or a partial. + + Args: + field_type: always "vector". + name: The name of the field. + storage_name: The name of the field in the store, uses the field name by default. + type: Property type. + For vectors this should be the inner type of the vector. + By default the vector will be a list of numbers. + If you want to use a numpy array or some other optimized format, + set the cast_function with a function + that takes a list of floats and returns a numpy array. + dimensions: The number of dimensions of the vector, mandatory. + index_kind: The index kind to use, uses a default index kind when None. + distance_function: The distance function to use, uses a default distance function when None. + embedding_generator: The embedding generator to use. + If this is set, the embedding will be generated locally before upserting. """ - if isinstance(record, Sequence): - result = [self._serialize_data_model_to_store_model(rec, **kwargs) for rec in record] - if not all(result): - return None - return result - if self.definition.serialize: - return self.definition.serialize(record, **kwargs) - if isinstance(record, SerializeMethodProtocol): - return record.serialize(**kwargs) - return None + ... - def _serialize_data_model_to_dict(self, record: TModel, **kwargs: Any) -> OneOrList[dict[str, Any]]: - """This function is used if no serialize method is found on the data model. + def __init__( + self, + field_type=FieldTypes.DATA, + *, + name=None, + type=None, + storage_name=None, + is_indexed=None, + is_full_text_indexed=None, + dimensions=None, + index_kind=None, + distance_function=None, + embedding_generator=None, + ): + """Vector store field.""" + self.field_type = field_type if isinstance(field_type, FieldTypes) else FieldTypes(field_type) + # when a field is created, the name can be empty, + # when a field get's added to a definition, the name needs to be there. + if name: + self.name = name + self.storage_name = storage_name + self.type_ = type + self.is_indexed = is_indexed + self.is_full_text_indexed = is_full_text_indexed + if field_type == FieldTypes.VECTOR: + if dimensions is None: + raise ValidationError("Vector fields must specify 'dimensions'") + self.dimensions = dimensions + self.index_kind = index_kind or IndexKind.DEFAULT + self.distance_function = distance_function or DistanceFunction.DEFAULT + self.embedding_generator = embedding_generator - This will generally serialize the data model to a dict, should not be overridden by child classes. - The output of this should be passed to the serialize_dict_to_store_model method. - """ - if self.definition.to_dict: - return self.definition.to_dict(record, **kwargs) # type: ignore - if isinstance(record, BaseModel): - return record.model_dump() +@release_candidate +class VectorStoreCollectionDefinition(KernelBaseModel): + """Collection definition for vector stores. - store_model = {} - for field in self.definition.fields: - store_model[field.storage_name or field.name] = ( - record.get(field.name, None) if isinstance(record, Mapping) else getattr(record, field.name) - ) - return store_model + Args: + fields: The fields of the record. + container_mode: Whether the record is in container mode. + to_dict: The to_dict function, should take a record and return a list of dicts. + from_dict: The from_dict function, should take a list of dicts and return a record. + deserialize: The deserialize function, should take a type specific to a datastore and return a record. - # region Deserialization methods + """ - def deserialize(self, records: OneOrMany[Any | dict[str, Any]], **kwargs: Any) -> OneOrMany[TModel] | None: - """Deserialize the store model to the data model. + fields: list[VectorStoreField] + key_name: str = Field(default="", init=False) + container_mode: bool = False + collection_name: str | None = None + to_dict: ToDictFunctionProtocol | None = None + from_dict: FromDictFunctionProtocol | None = None + serialize: SerializeFunctionProtocol | None = None + deserialize: DeserializeFunctionProtocol | None = None - This method follows the following steps: - 1. Check if the data model has a deserialize method. - Use that method to deserialize and return the result. - 2. Deserialize the store model to a dict, using the store specific method. - 3. Convert the dict to the data model, using the data model specific method. + @property + def names(self) -> list[str]: + """Get the names of the fields.""" + return [field.name for field in self.fields] - Raises: - VectorStoreModelDeserializationException: If an error occurs during deserialization. - """ - try: - if not records: - return None - if deserialized := self._deserialize_store_model_to_data_model(records, **kwargs): - return deserialized + @property + def storage_names(self) -> list[str]: + """Get the names of the fields for storage.""" + return [field.storage_name or field.name for field in self.fields] - if isinstance(records, Sequence): - dict_records = self._deserialize_store_models_to_dicts(records, **kwargs) - return ( - self._deserialize_dict_to_data_model(dict_records, **kwargs) - if self._container_mode - else [self._deserialize_dict_to_data_model(rec, **kwargs) for rec in dict_records] - ) + @property + def key_field(self) -> VectorStoreField: + """Get the key field.""" + return next((field for field in self.fields if field.name == self.key_name), None) # type: ignore - dict_record = self._deserialize_store_models_to_dicts([records], **kwargs)[0] - # regardless of mode, only 1 object is returned. - return self._deserialize_dict_to_data_model(dict_record, **kwargs) - except VectorStoreModelDeserializationException: - raise # pragma: no cover - except Exception as exc: - raise VectorStoreModelDeserializationException(f"Error deserializing records: {exc}") from exc + @property + def key_field_storage_name(self) -> str: + """Get the key field storage name.""" + return self.key_field.storage_name or self.key_field.name - def _deserialize_store_model_to_data_model(self, record: OneOrMany[Any], **kwargs: Any) -> OneOrMany[TModel] | None: - """Deserialize the store model to the data model. + @property + def vector_fields(self) -> list[VectorStoreField]: + """Get the names of the vector fields.""" + return [field for field in self.fields if field.field_type == FieldTypes.VECTOR] - This works when the data model has supplied a deserialize method, specific to a data source. - This uses a method called 'deserialize()' on the data model or part of the vector store record definition. + @property + def data_fields(self) -> list[VectorStoreField]: + """Get the names of the data fields.""" + return [field for field in self.fields if field.field_type == FieldTypes.DATA] - The developer is responsible for correctly deserializing for the specific data source. - """ - if self.definition.deserialize: - if isinstance(record, Sequence): - return self.definition.deserialize(record, **kwargs) - return self.definition.deserialize([record], **kwargs) - if func := getattr(self.record_type, "deserialize", None): - if isinstance(record, Sequence): - return [func(rec, **kwargs) for rec in record] - return func(record, **kwargs) - return None + @property + def vector_field_names(self) -> list[str]: + """Get the names of the vector fields.""" + return [field.name for field in self.fields if field.field_type == FieldTypes.VECTOR] - def _deserialize_dict_to_data_model(self, record: OneOrMany[dict[str, Any]], **kwargs: Any) -> TModel: - """This function is used if no deserialize method is found on the data model. + @property + def data_field_names(self) -> list[str]: + """Get the names of all the data fields.""" + return [field.name for field in self.fields if field.field_type == FieldTypes.DATA] - This method is the second step and will deserialize a dict to the data model, - should not be overridden by child classes. + def try_get_vector_field(self, field_name: str | None = None) -> VectorStoreField | None: + """Try to get the vector field. - The input of this should come from the _deserialized_store_model_to_dict function. + If the field_name is None, then the first vector field is returned. + If no vector fields are present None is returned. + + Args: + field_name: The field name. + + Returns: + VectorStoreRecordVectorField | None: The vector field or None. """ - if self.definition.from_dict: - if isinstance(record, Sequence): - return self.definition.from_dict(record, **kwargs) - ret = self.definition.from_dict([record], **kwargs) - return ret if self._container_mode else ret[0] - if isinstance(record, Sequence): - if len(record) > 1: - raise VectorStoreModelDeserializationException( - "Cannot deserialize multiple records to a single record unless you are using a container." + if field_name is None: + if len(self.vector_fields) == 0: + return None + return self.vector_fields[0] + for field in self.fields: + if field.name == field_name or field.storage_name == field_name: + if field.field_type == FieldTypes.VECTOR: + return field + raise VectorStoreModelException( + f"Field {field_name} is not a vector field, it is of type {type(field).__name__}." ) - record = record[0] - if func := getattr(self.record_type, "from_dict", None): - return func(record) - if issubclass(self.record_type, BaseModel): - for field in self.definition.fields: - if field.storage_name and field.storage_name in record: - record[field.name] = record.pop(field.storage_name) - return self.record_type.model_validate(record) # type: ignore - data_model_dict: dict[str, Any] = {} - for field in self.definition.fields: - value = record.get(field.storage_name or field.name, None) - if field.field_type == FieldTypes.VECTOR and not kwargs.get("include_vectors"): - continue - data_model_dict[field.name] = value - if self.record_type is dict: - return data_model_dict # type: ignore - return self.record_type(**data_model_dict) + raise VectorStoreModelException(f"Field {field_name} not found.") - async def _add_vectors_to_records( - self, - records: OneOrMany[dict[str, Any]], - **kwargs, - ) -> OneOrMany[dict[str, Any]]: - """Vectorize the vector record. + def get_storage_names(self, include_vector_fields: bool = True, include_key_field: bool = True) -> list[str]: + """Get the names of the fields for the storage. - This function can be passed to upsert or upsert batch of a VectorStoreRecordCollection. + Args: + include_vector_fields: Whether to include vector fields. + include_key_field: Whether to include the key field. - Loops through the fields of the data model definition, - looks at data fields, if they have a vector field, - looks up that vector field and checks if is a local embedding. + Returns: + list[str]: The names of the fields. + """ + return [ + field.storage_name or field.name + for field in self.fields + if field.field_type == FieldTypes.DATA + or (field.field_type == FieldTypes.VECTOR and include_vector_fields) + or (field.field_type == FieldTypes.KEY and include_key_field) + ] - If so adds that to a list of embeddings to make. + def get_names(self, include_vector_fields: bool = True, include_key_field: bool = True) -> list[str]: + """Get the names of the fields. - Finally calls Kernel add_embedding_to_object with the list of embeddings to make. + Args: + include_vector_fields: Whether to include vector fields. + include_key_field: Whether to include the key field. - Optional arguments are passed onto the Kernel add_embedding_to_object call. + Returns: + list[str]: The names of the fields. """ - # dict of embedding_field.name and tuple of record, settings, field_name - embeddings_to_make: list[tuple[str, int, EmbeddingGeneratorBase]] = [] + return [ + field.name + for field in self.fields + if field.field_type == FieldTypes.DATA + or (field.field_type == FieldTypes.VECTOR and include_vector_fields) + or (field.field_type == FieldTypes.KEY and include_key_field) + ] - for field in self.definition.vector_fields: - embedding_generator = field.embedding_generator or self.embedding_generator - if not embedding_generator: - continue - if field.dimensions is None: - raise VectorStoreModelException( - f"Field {field.name} has no dimensions, cannot create embedding for field." - ) - embeddings_to_make.append(( - field.storage_name or field.name, - field.dimensions, - embedding_generator, - )) + def model_post_init(self, _: Any): + """Validate the fields. - for field_name, dimensions, embedder in embeddings_to_make: - await self._add_embedding_to_object( - inputs=records, - field_name=field_name, - dimensions=dimensions, - embedding_generator=embedder, - container_mode=self.definition.container_mode, - **kwargs, + Raises: + VectorStoreModelException: If there is a field with an embedding property name + but no corresponding vector field. + VectorStoreModelException: If there is no key field. + """ + if len(self.fields) == 0: + raise VectorStoreModelException( + "There must be at least one field with a VectorStoreRecordField annotation." ) - return records - - async def _add_embedding_to_object( - self, - inputs: OneOrMany[Any], - field_name: str, - dimensions: int, - embedding_generator: EmbeddingGeneratorBase, - container_mode: bool = False, - **kwargs: Any, - ): - """Gather all fields to embed, batch the embedding generation and store.""" - contents: list[Any] = [] - dict_like = (getter := getattr(inputs, "get", False)) and callable(getter) - list_of_dicts: bool = False - if isinstance(inputs, list): - list_of_dicts = (getter := getattr(inputs[0], "get", False)) and callable(getter) - for record in inputs: - if list_of_dicts: - contents.append(record.get(field_name)) # type: ignore + for field in self.fields: + if not field.name or field.name == "": + raise VectorStoreModelException("Field names must not be empty.") + if field.field_type == FieldTypes.KEY: + if self.key_name != "": + raise VectorStoreModelException("Memory record definition must have exactly one key field.") + self.key_name = field.name + if not self.key_name: + raise VectorStoreModelException("Memory record definition must have exactly one key field.") + + +# region: Decorator + + +def _parse_vector_store_record_field_instance(record_field: VectorStoreField, field: Parameter) -> VectorStoreField: + if not record_field.name or record_field.name != field.name: + record_field.name = field.name + if not record_field.type_ and hasattr(field.annotation, "__origin__"): + property_type = field.annotation.__origin__ + if record_field.field_type == FieldTypes.VECTOR: + if args := getattr(property_type, "__args__", None): + if NoneType in args and len(args) > 1: + for arg in args: + if arg is NoneType: + continue + + if ( + (inner_args := getattr(arg, "__args__", None)) + and len(inner_args) == 1 + and inner_args[0] is not NoneType + ): + property_type = inner_args[0] + break + property_type = arg + break else: - contents.append(getattr(record, field_name)) + property_type = args[0] + else: - if dict_like: - contents.append(inputs.get(field_name)) # type: ignore - else: - contents.append(getattr(inputs, field_name)) + if (args := getattr(property_type, "__args__", None)) and NoneType in args and len(args) == 2: + property_type = args[0] + + record_field.type_ = str(property_type) if hasattr(property_type, "__args__") else property_type.__name__ + + return record_field + + +def _parse_parameter_to_field(field: Parameter) -> VectorStoreField | None: + # first check if there are any annotations + if field.annotation is not _empty and hasattr(field.annotation, "__metadata__"): + for field_annotation in field.annotation.__metadata__: + if isinstance(field_annotation, VectorStoreField): + return _parse_vector_store_record_field_instance(field_annotation, field) + # This means there are no annotations or that all annotations are of other types. + # we will check if there is a default, otherwise this will cause a runtime error. + # because it will not be stored, and retrieving this object will fail without a default for this field. + if field.default is _empty: + raise VectorStoreModelException( + "Fields that do not have a VectorStoreField annotation must have a default value." + ) + logger.debug(f'Field "{field.name}" does not have a VectorStoreField annotation, will not be part of the record.') + return None - vectors = await embedding_generator.generate_raw_embeddings( - texts=contents, settings=PromptExecutionSettings(dimensions=dimensions), **kwargs - ) # type: ignore - if vectors is None: - raise VectorStoreOperationException("No vectors were generated.") - if isinstance(inputs, list): - for record, vector in zip(inputs, vectors): - if list_of_dicts: - record[field_name] = vector # type: ignore - else: - setattr(record, field_name, vector) - return - if dict_like: - inputs[field_name] = vectors[0] # type: ignore - return - setattr(inputs, field_name, vectors[0]) +def _parse_signature_to_definition( + parameters: MappingProxyType[str, Parameter], collection_name: str | None = None +) -> VectorStoreCollectionDefinition: + if len(parameters) == 0: + raise VectorStoreModelException( + "There must be at least one field in the datamodel. If you are using this with a @dataclass, " + "you might have inverted the order of the decorators, the vectorstoremodel decorator should be the top one." + ) + fields = [] + for param in parameters.values(): + field = _parse_parameter_to_field(param) + if field: + fields.append(field) -# region: VectorStoreRecordCollection + return VectorStoreCollectionDefinition( + fields=fields, + collection_name=collection_name, + ) @release_candidate -class VectorStoreRecordCollection(VectorStoreRecordHandler[TKey, TModel], Generic[TKey, TModel]): - """Base class for a vector store record collection.""" +def vectorstoremodel( + cls: type[TModel] | None = None, + collection_name: str | None = None, +) -> type[TModel]: + """Returns the class as a vector store model. - collection_name: str = "" - managed_client: bool = True + This decorator makes a class a vector store model. + There are three things being checked: + - The class must have at least one field with a annotation, + of type VectorStoreField. + - The class must have exactly one field with the field_type `key`. + - When creating a Vector Field, either supply the property type directly, + or make sure to set the property that you want the index to use first. - @model_validator(mode="before") - @classmethod - def _ensure_collection_name(cls: type[_T], data: Any) -> dict[str, Any]: - """Ensure there is a collection name, if it isn't passed, try to get it from the data model type.""" - if ( - isinstance(data, dict) - and not data.get("collection_name") - and (collection_name := _get_collection_name_from_model(data["record_type"], data.get("definition"))) - ): - data["collection_name"] = collection_name - return data - async def __aenter__(self) -> Self: - """Enter the context manager.""" - return self + Args: + cls: The class to be decorated. + collection_name: The name of the collection to be used. + This is used to set the collection name in the VectorStoreCollectionDefinition. + + Raises: + VectorStoreModelException: If there are no fields with a VectorStoreField annotation. + VectorStoreModelException: If there are fields with no name. + VectorStoreModelException: If there is no key field. + """ - async def __aexit__(self, exc_type, exc_value, traceback) -> None: - """Exit the context manager. + def wrap(cls: type[TModel]) -> type[TModel]: + # get fields and annotations + cls_sig = signature(cls) + setattr(cls, "__kernel_vectorstoremodel__", True) + setattr( + cls, + "__kernel_vectorstoremodel_definition__", + _parse_signature_to_definition(cls_sig.parameters, collection_name), + ) - Should be overridden by subclasses, if necessary. + return cls # type: ignore - If the client is passed in the constructor, it should not be closed, - in that case the managed_client should be set to False. + # See if we're being called as @vectorstoremodel or @vectorstoremodel(). + if cls is None: + # We're called with parens. + return wrap # type: ignore - If the store supplied the managed client, it is responsible for closing it, - and it should not be closed here and so managed_client should be False. + # We're called as @vectorstoremodel without parens. + return wrap(cls) - Some services use two clients, one for the store and one for the collection, - in that case, the collection client should be closed here, - but the store client should only be closed when it is created in the collection. - A additional flag might be needed for that. - """ - pass - @abstractmethod - async def _inner_upsert( - self, - records: Sequence[Any], - **kwargs: Any, - ) -> Sequence[TKey]: - """Upsert the records, this should be overridden by the child class. +# region: VectorSearch Helpers - Args: - records: The records, the format is specific to the store. - **kwargs (Any): Additional arguments, to be passed to the store. - Returns: - The keys of the upserted records. +def _get_collection_name_from_model( + record_type: type[TModel], + definition: VectorStoreCollectionDefinition | None = None, +) -> str | None: + """Get the collection name from the data model type or definition.""" + if record_type and not definition: + definition = getattr(record_type, "__kernel_vectorstoremodel_definition__", None) + if definition and definition.collection_name: + return definition.collection_name + return None - Raises: - Exception: If an error occurs during the upsert. - There is no need to catch and parse exceptions in the inner functions, - they are handled by the public methods. - The only exception is raises exceptions yourself, such as a ValueError. - This is then caught and turned into the relevant exception by the public method. - This setup promotes a limited depth of the stack trace. - """ - ... # pragma: no cover +@pyd_dataclass +class GetFilteredRecordOptions: + """Options for filtering records. - @abstractmethod - async def _inner_get( - self, - keys: Sequence[TKey] | None = None, - options: GetFilteredRecordOptions | None = None, - **kwargs: Any, - ) -> OneOrMany[Any] | None: - """Get the records, this should be overridden by the child class. + Args: + top: The maximum number of records to return. + skip: The number of records to skip. + order_by: A dictionary with fields names and a bool, True means ascending, False means descending. + """ - Args: - keys: The keys to get. - options: the options to use for the get. - **kwargs: Additional arguments. + top: int = 10 + skip: int = 0 + order_by: Mapping[str, bool] | None = None - Returns: - The records from the store, not deserialized. - Raises: - Exception: If an error occurs during the upsert. - There is no need to catch and parse exceptions in the inner functions, - they are handled by the public methods. - The only exception is raises exceptions yourself, such as a ValueError. - This is then caught and turned into the relevant exception by the public method. - This setup promotes a limited depth of the stack trace. - """ - ... # pragma: no cover +class LambdaVisitor(NodeVisitor, Generic[TFilters]): + """Visitor class to visit the AST nodes.""" - @abstractmethod - async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: - """Delete the records, this should be overridden by the child class. + def __init__(self, lambda_parser: Callable[[expr], TFilters], output_filters: list[TFilters] | None = None) -> None: + """Initialize the visitor with a lambda parser and output filters.""" + self.lambda_parser = lambda_parser + self.output_filters = output_filters if output_filters is not None else [] - Args: - keys: The keys. - **kwargs: Additional arguments. + def visit_Lambda(self, node: Lambda) -> None: + """This method is called when a lambda expression is found.""" + self.output_filters.append(self.lambda_parser(node.body)) - Raises: - Exception: If an error occurs during the upsert. - There is no need to catch and parse exceptions in the inner functions, - they are handled by the public methods. - The only exception is raises exceptions yourself, such as a ValueError. - This is then caught and turned into the relevant exception by the public method. - This setup promotes a limited depth of the stack trace. - """ - ... # pragma: no cover - async def delete_create_collection(self, **kwargs: Any) -> None: - """Create the collection in the service, after first trying to delete it. +@release_candidate +class SearchType(str, Enum): + """Enumeration for search types. - First uses does_collection_exist to check if it exists, if it does deletes it. - Then, creates the collection. + Contains: vector and keyword_hybrid. + """ - """ - if await self.does_collection_exist(**kwargs): - await self.ensure_collection_deleted(**kwargs) - await self.create_collection(**kwargs) + VECTOR = "vector" + KEYWORD_HYBRID = "keyword_hybrid" - async def ensure_collection_exists(self, **kwargs: Any) -> bool: - """Create the collection in the service if it does not exists. - First uses does_collection_exist to check if it exists, if it does returns False. - Otherwise, creates the collection and returns True. +@release_candidate +class VectorSearchOptions(SearchOptions): + """Options for vector search, builds on TextSearchOptions. - Returns: - bool: True if the collection was created, False if it already exists. + When multiple filters are used, they are combined with an AND operator. + """ - """ - if await self.does_collection_exist(**kwargs): - return False - await self.create_collection(**kwargs) - return True + vector_property_name: str | None = None + additional_property_name: str | None = None + top: Annotated[int, Field(gt=0)] = 3 + include_vectors: bool = False - @abstractmethod - async def create_collection(self, **kwargs: Any) -> None: - """Create the collection in the service. - This should be overridden by the child class. +@release_candidate +class VectorSearchResult(KernelBaseModel, Generic[TModel]): + """The result of a vector search.""" - Raises: - Make sure the implementation of this function raises relevant exceptions with good descriptions. - This is different then the `_inner_x` methods, as this is a public method. + record: TModel + score: float | None = None - """ - ... # pragma: no cover - @abstractmethod - async def does_collection_exist(self, **kwargs: Any) -> bool: - """Check if the collection exists. +# region: VectorStoreRecordHandler - This should be overridden by the child class. - Raises: - Make sure the implementation of this function raises relevant exceptions with good descriptions. - This is different then the `_inner_x` methods, as this is a public method. - """ - ... # pragma: no cover +@release_candidate +class VectorStoreRecordHandler(KernelBaseModel, Generic[TKey, TModel]): + """Vector Store Record Handler class. - @abstractmethod - async def ensure_collection_deleted(self, **kwargs: Any) -> None: - """Delete the collection. + This class is used to serialize and deserialize records to and from a vector store. + As well as validating the data model against the vector store. + It is subclassed by VectorStoreRecordCollection and VectorSearchBase. + """ - This should be overridden by the child class. + record_type: type[TModel] + definition: VectorStoreCollectionDefinition + supported_key_types: ClassVar[set[str] | None] = None + supported_vector_types: ClassVar[set[str] | None] = None + embedding_generator: EmbeddingGeneratorBase | None = None - Raises: - Make sure the implementation of this function raises relevant exceptions with good descriptions. - This is different then the `_inner_x` methods, as this is a public method. - """ - ... # pragma: no cover + @property + def _key_field_name(self) -> str: + return self.definition.key_name - async def upsert( - self, - records: OneOrMany[TModel], - **kwargs, - ) -> OneOrMany[TKey]: - """Upsert one or more records. + @property + def _key_field_storage_name(self) -> str: + return self.definition.key_field.storage_name or self.definition.key_name - If the key of the record already exists, the existing record will be updated. - If the key does not exist, a new record will be created. + @property + def _container_mode(self) -> bool: + return self.definition.container_mode - Args: - records: The records to upsert, can be a single record, a list of records, or a single container. - If a single record is passed, a single key is returned, instead of a list of keys. - **kwargs: Additional arguments. + @model_validator(mode="before") + @classmethod + def _ensure_definition(cls: type[_T], data: Any) -> dict[str, Any]: + """Ensure there is a data model definition, if it isn't passed, try to get it from the data model type.""" + if isinstance(data, dict) and not data.get("definition"): + data["definition"] = getattr(data["record_type"], "__kernel_vectorstoremodel_definition__", None) + return data - Returns: - OneOrMany[TKey]: The keys of the upserted records. + def model_post_init(self, __context: object | None = None): + """Post init function that sets the key field and container mode values, and validates the datamodel.""" + self._validate_data_model() - Raises: - VectorStoreModelSerializationException: If an error occurs during serialization. - VectorStoreOperationException: If an error occurs during upserting. - """ - batch = True - if not isinstance(records, list) and not self._container_mode: - batch = False - if records is None: - raise VectorStoreOperationException("Either record or records must be provided.") + def _validate_data_model(self): + """Internal function that can be overloaded by child classes to validate datatypes, etc. - try: - data = await self.serialize(records) - # the serialize method will parse any exception into a VectorStoreModelSerializationException - except VectorStoreModelSerializationException: - raise + This should take the VectorStoreRecordDefinition from the item_type and validate it against the store. - try: - results = await self._inner_upsert(data if isinstance(data, list) else [data], **kwargs) # type: ignore - except Exception as exc: - raise VectorStoreOperationException(f"Error upserting record(s): {exc}") from exc - if batch or self._container_mode: - return results - return results[0] - - @overload - async def get( - self, - top: int = ..., - skip: int = ..., - order_by: OneOrMany[str] | dict[str, bool] | None = None, - include_vectors: bool = False, - **kwargs: Any, - ) -> Sequence[TModel] | None: - """Get records based on the ordering and selection criteria. - - Args: - include_vectors: Include the vectors in the response. Default is True. - Some vector stores do not support retrieving without vectors, even when set to false. - Some vector stores have specific parameters to control that behavior, when - that parameter is set, include_vectors is ignored. - top: The number of records to return. - Only used if keys are not provided. - skip: The number of records to skip. - Only used if keys are not provided. - order_by: The order by clause, - this can be a string, a list of strings or a dict, - when passing strings, they are assumed to be ascending. - Otherwise, use the value in the dict to set ascending (True) or descending (False). - example: {"field_name": True} or ["field_name", {"field_name2": False}]. - **kwargs: Additional arguments. + Checks can include, allowed naming of parameters, allowed data types, allowed vector dimensions. - Returns: - The records, either a list of TModel or the container type. + Default checks are that the key field is in the allowed key types and the vector fields + are in the allowed vector types. Raises: - VectorStoreOperationException: If an error occurs during the get. - VectorStoreModelDeserializationException: If an error occurs during deserialization. - """ - ... - - @overload - async def get( - self, - key: TKey = ..., - include_vectors: bool = False, - **kwargs: Any, - ) -> TModel | None: - """Get a record if it exists. - - Args: - key: The key to get. - include_vectors: Include the vectors in the response. Default is True. - Some vector stores do not support retrieving without vectors, even when set to false. - Some vector stores have specific parameters to control that behavior, when - that parameter is set, include_vectors is ignored. - **kwargs: Additional arguments. - - Returns: - The records, either a list of TModel or the container type. + VectorStoreModelValidationError: If the key field is not in the allowed key types. + VectorStoreModelValidationError: If the vector fields are not in the allowed vector types. - Raises: - VectorStoreOperationException: If an error occurs during the get. - VectorStoreModelDeserializationException: If an error occurs during deserialization. """ - ... + if ( + self.supported_key_types + and self.definition.key_field.type_ + and self.definition.key_field.type_ not in self.supported_key_types + ): + raise VectorStoreModelValidationError( + f"Key field must be one of {self.supported_key_types}, got {self.definition.key_field.type_}" + ) + if not self.supported_vector_types: + return + for field in self.definition.vector_fields: + if field.type_ and field.type_ not in self.supported_vector_types: + raise VectorStoreModelValidationError( + f"Vector field {field.name} must be one of {self.supported_vector_types}, got {field.type_}" + ) - @overload - async def get( - self, - keys: Sequence[TKey] = ..., - include_vectors: bool = False, - **kwargs: Any, - ) -> OneOrMany[TModel] | None: - """Get a batch of records whose keys exist in the collection, i.e. keys that do not exist are ignored. + @abstractmethod + def _serialize_dicts_to_store_models(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Sequence[Any]: + """Serialize a list of dicts of the data to the store model. - Args: - keys: The keys to get, if keys are provided, key is ignored. - include_vectors: Include the vectors in the response. Default is True. - Some vector stores do not support retrieving without vectors, even when set to false. - Some vector stores have specific parameters to control that behavior, when - that parameter is set, include_vectors is ignored. - **kwargs: Additional arguments. + This method should be overridden by the child class to convert the dict to the store model. + """ + ... # pragma: no cover - Returns: - The records, either a list of TModel or the container type. + @abstractmethod + def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: Any) -> Sequence[dict[str, Any]]: + """Deserialize the store models to a list of dicts. - Raises: - VectorStoreOperationException: If an error occurs during the get. - VectorStoreModelDeserializationException: If an error occurs during deserialization. + This method should be overridden by the child class to convert the store model to a list of dicts. """ - ... + ... # pragma: no cover - async def get( - self, - key=None, - keys=None, - include_vectors=False, - **kwargs, - ): - """Get a batch of records whose keys exist in the collection, i.e. keys that do not exist are ignored. + # region Serialization methods - Args: - key: The key to get. - keys: The keys to get, if keys are provided, key is ignored. - include_vectors: Include the vectors in the response. Default is True. - Some vector stores do not support retrieving without vectors, even when set to false. - Some vector stores have specific parameters to control that behavior, when - that parameter is set, include_vectors is ignored. - top: The number of records to return. - Only used if keys are not provided. - skip: The number of records to skip. - Only used if keys are not provided. - order_by: The order by clause, this is a list of dicts with the field name and ascending flag, - (default is True, which means ascending). - Only used if keys are not provided. - **kwargs: Additional arguments. + async def serialize(self, records: OneOrMany[TModel], **kwargs: Any) -> OneOrMany[Any]: + """Serialize the data model to the store model. - Returns: - The records, either a list of TModel or the container type. + This method follows the following steps: + 1. Check if the data model has a serialize method. + Use that method to serialize and return the result. + 2. Serialize the records into a dict, using the data model specific method. + 3. Convert the dict to the store model, using the store specific method. + + If overriding this method, make sure to first try to serialize the data model to the store model, + before doing the store specific version, + the user supplied version should have precedence. Raises: - VectorStoreOperationException: If an error occurs during the get. - VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorStoreModelSerializationException: If an error occurs during serialization. + """ - batch = True - options = None - if not keys and key: - if not isinstance(key, list): - keys = [key] - batch = False - else: - keys = key - if not keys: - if kwargs: - kw_order_by: OneOrList[str] | dict[str, bool] | None = kwargs.pop("order_by", None) # type: ignore - top = kwargs.pop("top", None) - skip = kwargs.pop("skip", None) - order_by: dict[str, bool] | None = None - if kw_order_by is not None: - order_by = {} - if isinstance(kw_order_by, str): - order_by[kw_order_by] = True - elif isinstance(kw_order_by, dict): - order_by = kw_order_by - elif isinstance(kw_order_by, list): - for item in kw_order_by: - if isinstance(item, str): - order_by[item] = True - else: - order_by.update(item) - else: - raise VectorStoreOperationException( - f"Invalid order_by type: {type(order_by)}, expected str, dict or list." - ) - try: - options = GetFilteredRecordOptions(top=top, skip=skip, order_by=order_by) - except Exception as exc: - raise VectorStoreOperationException(f"Error creating options: {exc}") from exc - else: - raise VectorStoreOperationException("Either key, keys or options must be provided.") try: - records = await self._inner_get(keys, include_vectors=include_vectors, options=options, **kwargs) + if serialized := self._serialize_data_model_to_store_model(records): + return serialized + except VectorStoreModelSerializationException: + raise # pragma: no cover except Exception as exc: - raise VectorStoreOperationException(f"Error getting record(s): {exc}") from exc + raise VectorStoreModelSerializationException(f"Error serializing records: {exc}") from exc - if not records: - return None + try: + dict_records: list[dict[str, Any]] = [] + if not isinstance(records, list): + records = [records] # type: ignore + for rec in records: + dict_rec = self._serialize_data_model_to_dict(rec) + if isinstance(dict_rec, list): + dict_records.extend(dict_rec) + else: + dict_records.append(dict_rec) + except VectorStoreModelSerializationException: + raise # pragma: no cover + except Exception as exc: + raise VectorStoreModelSerializationException(f"Error serializing records: {exc}") from exc + # add vectors try: - model_records = self.deserialize( - records if batch else records[0], include_vectors=include_vectors, **kwargs - ) - # the deserialize method will parse any exception into a VectorStoreModelDeserializationException - except VectorStoreModelDeserializationException: + dict_records = await self._add_vectors_to_records(dict_records) # type: ignore + except (VectorStoreModelException, VectorStoreOperationException): raise + except Exception as exc: + raise VectorStoreOperationException( + "Exception occurred while trying to add the vectors to the records." + ) from exc - # there are many code paths within the deserialize method, some supplied by the developer, - # and so depending on what is used, - # it might return a sequence, so we just return the first element, - # there should never be multiple elements (this is not a batch get), - # hence a raise if there are. - if batch: - return model_records - if not isinstance(model_records, Sequence): - return model_records - if len(model_records) == 1: - return model_records[0] - raise VectorStoreModelDeserializationException( - f"Error deserializing record, multiple records returned: {model_records}" - ) + try: + return self._serialize_dicts_to_store_models(dict_records, **kwargs) # type: ignore + except VectorStoreModelSerializationException: + raise # pragma: no cover + except Exception as exc: + raise VectorStoreModelSerializationException(f"Error serializing records: {exc}") from exc - async def delete(self, keys: OneOrMany[TKey], **kwargs): - """Delete one or more records by key. + def _serialize_data_model_to_store_model(self, record: OneOrMany[TModel], **kwargs: Any) -> OneOrMany[Any] | None: + """Serialize the data model to the store model. - An exception will be raised at the end if any record does not exist. + This works when the data model has supplied a serialize method, specific to a data source. + This is a method called 'serialize()' on the data model or part of the vector store record definition. - Args: - keys: The key or keys to be deleted. - **kwargs: Additional arguments. - Exceptions: - VectorStoreOperationException: If an error occurs during deletion or a record does not exist. + The developer is responsible for correctly serializing for the specific data source. """ - if not isinstance(keys, list): - keys = [keys] # type: ignore - try: - await self._inner_delete(keys, **kwargs) # type: ignore - except Exception as exc: - raise VectorStoreOperationException(f"Error deleting record(s): {exc}") from exc - - -# region: VectorStore - - -@release_candidate -class VectorStore(KernelBaseModel): - """Base class for vector stores.""" - - managed_client: bool = True - embedding_generator: EmbeddingGeneratorBase | None = None - - @abstractmethod - def get_collection( - self, - record_type: type[TModel], - *, - definition: VectorStoreCollectionDefinition | None = None, - collection_name: str | None = None, - embedding_generator: EmbeddingGeneratorBase | None = None, - **kwargs: Any, - ) -> "VectorStoreRecordCollection": - """Get a vector store record collection instance tied to this store. + if isinstance(record, Sequence): + result = [self._serialize_data_model_to_store_model(rec, **kwargs) for rec in record] + if not all(result): + return None + return result + if self.definition.serialize: + return self.definition.serialize(record, **kwargs) + if isinstance(record, SerializeMethodProtocol): + return record.serialize(**kwargs) + return None - Args: - record_type: The type of the records that will be used. - definition: The data model definition. - collection_name: The name of the collection. - embedding_generator: The embedding generator to use. - **kwargs: Additional arguments. + def _serialize_data_model_to_dict(self, record: TModel, **kwargs: Any) -> OneOrList[dict[str, Any]]: + """This function is used if no serialize method is found on the data model. - Returns: - A vector store record collection instance tied to this store. + This will generally serialize the data model to a dict, should not be overridden by child classes. + The output of this should be passed to the serialize_dict_to_store_model method. """ - ... # pragma: no cover + if self.definition.to_dict: + return self.definition.to_dict(record, **kwargs) # type: ignore + if isinstance(record, BaseModel): + return record.model_dump() - @abstractmethod - async def list_collection_names(self, **kwargs) -> Sequence[str]: - """Get the names of all collections.""" - ... # pragma: no cover + store_model = {} + for field in self.definition.fields: + store_model[field.storage_name or field.name] = ( + record.get(field.name, None) if isinstance(record, Mapping) else getattr(record, field.name) + ) + return store_model - async def does_collection_exist(self, collection_name: str) -> bool: - """Check if a collection exists. + # region Deserialization methods - This is a wrapper around the get_collection method of a collection, - to check if the collection exists. - """ - try: - data_model = VectorStoreCollectionDefinition(fields=[VectorStoreField("key", name="id")]) - collection = self.get_collection(record_type=dict, definition=data_model, collection_name=collection_name) - return await collection.does_collection_exist() - except VectorStoreOperationException: - return False + def deserialize(self, records: OneOrMany[Any | dict[str, Any]], **kwargs: Any) -> OneOrMany[TModel] | None: + """Deserialize the store model to the data model. - async def ensure_collection_deleted(self, collection_name: str) -> None: - """Delete a collection. + This method follows the following steps: + 1. Check if the data model has a deserialize method. + Use that method to deserialize and return the result. + 2. Deserialize the store model to a dict, using the store specific method. + 3. Convert the dict to the data model, using the data model specific method. - This is a wrapper around the get_collection method of a collection, - to delete the collection. + Raises: + VectorStoreModelDeserializationException: If an error occurs during deserialization. """ try: - data_model = VectorStoreCollectionDefinition(fields=[VectorStoreField("key", name="id")]) - collection = self.get_collection(record_type=dict, definition=data_model, collection_name=collection_name) - await collection.ensure_collection_deleted() - except VectorStoreOperationException: - pass - - async def __aenter__(self) -> Self: - """Enter the context manager.""" - return self - - async def __aexit__(self, exc_type, exc_value, traceback) -> None: - """Exit the context manager. + if not records: + return None + if deserialized := self._deserialize_store_model_to_data_model(records, **kwargs): + return deserialized - Should be overridden by subclasses, if necessary. + if isinstance(records, Sequence): + dict_records = self._deserialize_store_models_to_dicts(records, **kwargs) + return ( + self._deserialize_dict_to_data_model(dict_records, **kwargs) + if self._container_mode + else [self._deserialize_dict_to_data_model(rec, **kwargs) for rec in dict_records] + ) - If the client is passed in the constructor, it should not be closed, - in that case the managed_client should be set to False. - """ - pass # pragma: no cover + dict_record = self._deserialize_store_models_to_dicts([records], **kwargs)[0] + # regardless of mode, only 1 object is returned. + return self._deserialize_dict_to_data_model(dict_record, **kwargs) + except VectorStoreModelDeserializationException: + raise # pragma: no cover + except Exception as exc: + raise VectorStoreModelDeserializationException(f"Error deserializing records: {exc}") from exc + def _deserialize_store_model_to_data_model(self, record: OneOrMany[Any], **kwargs: Any) -> OneOrMany[TModel] | None: + """Deserialize the store model to the data model. -# region: Vector Search + This works when the data model has supplied a deserialize method, specific to a data source. + This uses a method called 'deserialize()' on the data model or part of the vector store record definition. + The developer is responsible for correctly deserializing for the specific data source. + """ + if self.definition.deserialize: + if isinstance(record, Sequence): + return self.definition.deserialize(record, **kwargs) + return self.definition.deserialize([record], **kwargs) + if func := getattr(self.record_type, "deserialize", None): + if isinstance(record, Sequence): + return [func(rec, **kwargs) for rec in record] + return func(record, **kwargs) + return None -@release_candidate -class VectorSearch(VectorStoreRecordHandler[TKey, TModel], Generic[TKey, TModel]): - """Base class for searching vectors.""" + def _deserialize_dict_to_data_model(self, record: OneOrMany[dict[str, Any]], **kwargs: Any) -> TModel: + """This function is used if no deserialize method is found on the data model. - supported_search_types: ClassVar[set[SearchType]] = Field(default_factory=set) + This method is the second step and will deserialize a dict to the data model, + should not be overridden by child classes. - @property - def options_class(self) -> type[SearchOptions]: - """The options class for the search.""" - return VectorSearchOptions + The input of this should come from the _deserialized_store_model_to_dict function. + """ + if self.definition.from_dict: + if isinstance(record, Sequence): + return self.definition.from_dict(record, **kwargs) + ret = self.definition.from_dict([record], **kwargs) + return ret if self._container_mode else ret[0] + if isinstance(record, Sequence): + if len(record) > 1: + raise VectorStoreModelDeserializationException( + "Cannot deserialize multiple records to a single record unless you are using a container." + ) + record = record[0] + if func := getattr(self.record_type, "from_dict", None): + return func(record) + if issubclass(self.record_type, BaseModel): + for field in self.definition.fields: + if field.storage_name and field.storage_name in record: + record[field.name] = record.pop(field.storage_name) + return self.record_type.model_validate(record) # type: ignore + data_model_dict: dict[str, Any] = {} + for field in self.definition.fields: + value = record.get(field.storage_name or field.name, None) + if field.field_type == FieldTypes.VECTOR and not kwargs.get("include_vectors"): + continue + data_model_dict[field.name] = value + if self.record_type is dict: + return data_model_dict # type: ignore + return self.record_type(**data_model_dict) - @abstractmethod - async def _inner_search( + async def _add_vectors_to_records( self, - search_type: SearchType, - options: VectorSearchOptions, - values: Any | None = None, - vector: Sequence[float | int] | None = None, - **kwargs: Any, - ) -> KernelSearchResults[VectorSearchResult[TModel]]: - """Inner search method. - - This is the main search method that should be implemented, and will be called by the public search methods. - Currently, at least one of the three search contents will be provided - (through the public interface mixin functions), in the future, this may be expanded to allow multiple of them. - - This method should return a KernelSearchResults object with the results of the search. - The inner "results" object of the KernelSearchResults should be a async iterator that yields the search results, - this allows things like paging to be implemented. - - There is a default helper method "_get_vector_search_results_from_results" to convert - the results to a async iterable VectorSearchResults, but this can be overridden if necessary. - - Options might be a object of type VectorSearchOptions, or a subclass of it. - - The implementation of this method must deal with the possibility that multiple search contents are provided, - and should handle them in a way that makes sense for that particular store. + records: OneOrMany[dict[str, Any]], + **kwargs, + ) -> OneOrMany[dict[str, Any]]: + """Vectorize the vector record. - The public methods will catch and reraise the three exceptions mentioned below, others are caught and turned - into a VectorSearchExecutionException. + This function can be passed to upsert or upsert batch of a VectorStoreRecordCollection. - Args: - search_type: The type of search to perform. - options: The search options, can be None. - values: The values to search for, optional. - vector: The vector to search for, optional. - **kwargs: Additional arguments that might be needed. + Loops through the fields of the data model definition, + looks at data fields, if they have a vector field, + looks up that vector field and checks if is a local embedding. - Returns: - The search results, wrapped in a KernelSearchResults object. + If so adds that to a list of embeddings to make. - Raises: - VectorSearchExecutionException: If an error occurs during the search. - VectorStoreModelDeserializationException: If an error occurs during deserialization. - VectorSearchOptionsException: If the search options are invalid. - VectorStoreOperationNotSupportedException: If the search type is not supported. + Finally calls Kernel add_embedding_to_object with the list of embeddings to make. + Optional arguments are passed onto the Kernel add_embedding_to_object call. """ - ... - - @abstractmethod - def _get_record_from_result(self, result: Any) -> Any: - """Get the record from the returned search result. + # dict of embedding_field.name and tuple of record, settings, field_name + embeddings_to_make: list[tuple[str, int, EmbeddingGeneratorBase]] = [] - Does any unpacking or processing of the result to get just the record. + for field in self.definition.vector_fields: + embedding_generator = field.embedding_generator or self.embedding_generator + if not embedding_generator: + continue + if field.dimensions is None: + raise VectorStoreModelException( + f"Field {field.name} has no dimensions, cannot create embedding for field." + ) + embeddings_to_make.append(( + field.storage_name or field.name, + field.dimensions, + embedding_generator, + )) - If the underlying SDK of the store returns a particular type that might include something - like a score or other metadata, this method should be overridden to extract just the record. + for field_name, dimensions, embedder in embeddings_to_make: + await self._add_embedding_to_object( + inputs=records, + field_name=field_name, + dimensions=dimensions, + embedding_generator=embedder, + container_mode=self.definition.container_mode, + **kwargs, + ) + return records - Likely returns a dict, but in some cases could return the record in the form of a SDK specific object. + async def _add_embedding_to_object( + self, + inputs: OneOrMany[Any], + field_name: str, + dimensions: int, + embedding_generator: EmbeddingGeneratorBase, + container_mode: bool = False, + **kwargs: Any, + ): + """Gather all fields to embed, batch the embedding generation and store.""" + contents: list[Any] = [] + dict_like = (getter := getattr(inputs, "get", False)) and callable(getter) + list_of_dicts: bool = False + if isinstance(inputs, list): + list_of_dicts = (getter := getattr(inputs[0], "get", False)) and callable(getter) + for record in inputs: + if list_of_dicts: + contents.append(record.get(field_name)) # type: ignore + else: + contents.append(getattr(record, field_name)) + else: + if dict_like: + contents.append(inputs.get(field_name)) # type: ignore + else: + contents.append(getattr(inputs, field_name)) - This method is used as part of the _get_vector_search_results_from_results method, - the output of it is passed to the deserializer. - """ - ... + vectors = await embedding_generator.generate_raw_embeddings( + texts=contents, settings=PromptExecutionSettings(dimensions=dimensions), **kwargs + ) # type: ignore + if vectors is None: + raise VectorStoreOperationException("No vectors were generated.") + if isinstance(inputs, list): + for record, vector in zip(inputs, vectors): + if list_of_dicts: + record[field_name] = vector # type: ignore + else: + setattr(record, field_name, vector) + return + if dict_like: + inputs[field_name] = vectors[0] # type: ignore + return + setattr(inputs, field_name, vectors[0]) - @abstractmethod - def _get_score_from_result(self, result: Any) -> float | None: - """Get the score from the result. - Does any unpacking or processing of the result to get just the score. +# region: VectorStoreRecordCollection - If the underlying SDK of the store returns a particular type with a score or other metadata, - this method extracts it. - """ - ... - async def _get_vector_search_results_from_results( - self, results: AsyncIterable[Any] | Sequence[Any], options: VectorSearchOptions | None = None - ) -> AsyncIterable[VectorSearchResult[TModel]]: - if isinstance(results, Sequence): - results = desync_list(results) - async for result in results: - if not result: - continue - try: - record = self.deserialize( - self._get_record_from_result(result), include_vectors=options.include_vectors if options else True - ) - except VectorStoreModelDeserializationException: - raise - except Exception as exc: - raise VectorStoreModelDeserializationException( - f"An error occurred while deserializing the record: {exc}" - ) from exc - score = self._get_score_from_result(result) - if record is not None: - # single records are always returned as single records by the deserializer - yield VectorSearchResult(record=record, score=score) # type: ignore +@release_candidate +class VectorStoreCollection(VectorStoreRecordHandler[TKey, TModel], Generic[TKey, TModel]): + """Base class for a vector store record collection.""" - @overload - async def search( + collection_name: str = "" + managed_client: bool = True + + @model_validator(mode="before") + @classmethod + def _ensure_collection_name(cls: type[_T], data: Any) -> dict[str, Any]: + """Ensure there is a collection name, if it isn't passed, try to get it from the data model type.""" + if ( + isinstance(data, dict) + and not data.get("collection_name") + and (collection_name := _get_collection_name_from_model(data["record_type"], data.get("definition"))) + ): + data["collection_name"] = collection_name + return data + + async def __aenter__(self) -> Self: + """Enter the context manager.""" + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + """Exit the context manager. + + Should be overridden by subclasses, if necessary. + + If the client is passed in the constructor, it should not be closed, + in that case the managed_client should be set to False. + + If the store supplied the managed client, it is responsible for closing it, + and it should not be closed here and so managed_client should be False. + + Some services use two clients, one for the store and one for the collection, + in that case, the collection client should be closed here, + but the store client should only be closed when it is created in the collection. + A additional flag might be needed for that. + """ + pass + + @abstractmethod + async def _inner_upsert( self, - values: Any, - *, - vector_field_name: str | None = None, - filter: OptionalOneOrList[Callable | str] = None, - top: int = 3, - skip: int = 0, - include_total_count: bool = False, - include_vectors: bool = False, + records: Sequence[Any], **kwargs: Any, - ) -> KernelSearchResults[VectorSearchResult[TModel]]: - """Search the vector store with Vector search for records that match the given value and filter. + ) -> Sequence[TKey]: + """Upsert the records, this should be overridden by the child class. Args: - values: The values to search for. These will be vectorized, - either by the store or using the provided generator. - vector_field_name: The name of the vector field to use for the search. - filter: The filter to apply to the search. - top: The number of results to return. - skip: The number of results to skip. - include_total_count: Whether to include the total count of results. - include_vectors: Whether to include the vectors in the results. - kwargs: If options are not set, this is used to create them. - they are passed on to the inner search method. + records: The records, the format is specific to the store. + **kwargs (Any): Additional arguments, to be passed to the store. + + Returns: + The keys of the upserted records. Raises: - VectorSearchExecutionException: If an error occurs during the search. - VectorStoreModelDeserializationException: If an error occurs during deserialization. - VectorSearchOptionsException: If the search options are invalid. - VectorStoreOperationNotSupportedException: If the search type is not supported. + Exception: If an error occurs during the upsert. + There is no need to catch and parse exceptions in the inner functions, + they are handled by the public methods. + The only exception is raises exceptions yourself, such as a ValueError. + This is then caught and turned into the relevant exception by the public method. + This setup promotes a limited depth of the stack trace. """ - ... + ... # pragma: no cover - @overload - async def search( + @abstractmethod + async def _inner_get( self, - *, - vector: Sequence[float | int], - vector_field_name: str | None = None, - filter: OptionalOneOrList[Callable | str] = None, - top: int = 3, - skip: int = 0, - include_total_count: bool = False, - include_vectors: bool = False, + keys: Sequence[TKey] | None = None, + options: GetFilteredRecordOptions | None = None, **kwargs: Any, - ) -> KernelSearchResults[VectorSearchResult[TModel]]: - """Search the vector store with Vector search for records that match the given vector and filter. + ) -> OneOrMany[Any] | None: + """Get the records, this should be overridden by the child class. Args: - vector: The vector to search for - vector_field_name: The name of the vector field to use for the search. - filter: The filter to apply to the search. - top: The number of results to return. - skip: The number of results to skip. - include_total_count: Whether to include the total count of results. - include_vectors: Whether to include the vectors in the results. - kwargs: If options are not set, this is used to create them. - they are passed on to the inner search method. + keys: The keys to get. + options: the options to use for the get. + **kwargs: Additional arguments. - Raises: - VectorSearchExecutionException: If an error occurs during the search. - VectorStoreModelDeserializationException: If an error occurs during deserialization. - VectorSearchOptionsException: If the search options are invalid. - VectorStoreOperationNotSupportedException: If the search type is not supported. + Returns: + The records from the store, not deserialized. + Raises: + Exception: If an error occurs during the upsert. + There is no need to catch and parse exceptions in the inner functions, + they are handled by the public methods. + The only exception is raises exceptions yourself, such as a ValueError. + This is then caught and turned into the relevant exception by the public method. + This setup promotes a limited depth of the stack trace. """ - ... + ... # pragma: no cover - async def search( - self, - values=None, - *, - vector=None, - vector_property_name=None, - filter=None, - top=3, - skip=0, - include_total_count=False, - include_vectors=False, - **kwargs, - ): - """Search the vector store for records that match the given value and filter. + @abstractmethod + async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: + """Delete the records, this should be overridden by the child class. Args: - values: The values to search for. - vector: The vector to search for, if not provided, the values will be used to generate a vector. - vector_property_name: The name of the vector property to use for the search. - filter: The filter to apply to the search. - top: The number of results to return. - skip: The number of results to skip. - include_total_count: Whether to include the total count of results. - include_vectors: Whether to include the vectors in the results. - kwargs: If options are not set, this is used to create them. - they are passed on to the inner search method. + keys: The keys. + **kwargs: Additional arguments. Raises: - VectorSearchExecutionException: If an error occurs during the search. - VectorStoreModelDeserializationException: If an error occurs during deserialization. - VectorSearchOptionsException: If the search options are invalid. - VectorStoreOperationNotSupportedException: If the search type is not supported. - - """ - if SearchType.VECTOR not in self.supported_search_types: - raise VectorStoreOperationNotSupportedException( - f"Vector search is not supported by this vector store: {self.__class__.__name__}" - ) - options = VectorSearchOptions( - filter=filter, - vector_property_name=vector_property_name, - top=top, - skip=skip, - include_total_count=include_total_count, - include_vectors=include_vectors, - ) - try: - return await self._inner_search( - search_type=SearchType.VECTOR, - values=values, - options=options, - vector=vector, - **kwargs, - ) - except ( - VectorStoreModelDeserializationException, - VectorSearchOptionsException, - VectorSearchExecutionException, - VectorStoreOperationNotSupportedException, - VectorStoreOperationException, - ): - raise # pragma: no cover - except Exception as exc: - raise VectorSearchExecutionException(f"An error occurred during the search: {exc}") from exc - - async def hybrid_search( - self, - values: Any, - *, - vector: list[float | int] | None = None, - vector_property_name: str | None = None, - additional_property_name: str | None = None, - filter: OptionalOneOrList[Callable | str] = None, - top: int = 3, - skip: int = 0, - include_total_count: bool = False, - include_vectors: bool = False, - **kwargs: Any, - ) -> KernelSearchResults[VectorSearchResult[TModel]]: - """Search the vector store for records that match the given values and filter. + Exception: If an error occurs during the upsert. + There is no need to catch and parse exceptions in the inner functions, + they are handled by the public methods. + The only exception is raises exceptions yourself, such as a ValueError. + This is then caught and turned into the relevant exception by the public method. + This setup promotes a limited depth of the stack trace. + """ + ... # pragma: no cover - Args: - values: The values to search for. - vector: The vector to search for, if not provided, the values will be used to generate a vector. - vector_property_name: The name of the vector field to use for the search. - additional_property_name: The name of the additional property field to use for the search. - filter: The filter to apply to the search. - top: The number of results to return. - skip: The number of results to skip. - include_total_count: Whether to include the total count of results. - include_vectors: Whether to include the vectors in the results. - kwargs: If options are not set, this is used to create them. - they are passed on to the inner search method. + async def delete_create_collection(self, **kwargs: Any) -> None: + """Create the collection in the service, after first trying to delete it. - Raises: - VectorSearchExecutionException: If an error occurs during the search. - VectorStoreModelDeserializationException: If an error occurs during deserialization. - VectorSearchOptionsException: If the search options are invalid. - VectorStoreOperationNotSupportedException: If the search type is not supported. + First uses does_collection_exist to check if it exists, if it does deletes it. + Then, creates the collection. """ - if SearchType.KEYWORD_HYBRID not in self.supported_search_types: - raise VectorStoreOperationNotSupportedException( - f"Keyword hybrid search is not supported by this vector store: {self.__class__.__name__}" - ) - options = VectorSearchOptions( - filter=filter, - vector_property_name=vector_property_name, - additional_property_name=additional_property_name, - top=top, - skip=skip, - include_total_count=include_total_count, - include_vectors=include_vectors, - ) - try: - return await self._inner_search( - search_type=SearchType.KEYWORD_HYBRID, - values=values, - vector=vector, - options=options, - **kwargs, - ) - except ( - VectorStoreModelDeserializationException, - VectorSearchOptionsException, - VectorSearchExecutionException, - VectorStoreOperationNotSupportedException, - VectorStoreOperationException, - ): - raise # pragma: no cover - except Exception as exc: - raise VectorSearchExecutionException(f"An error occurred during the search: {exc}") from exc + if await self.does_collection_exist(**kwargs): + await self.ensure_collection_deleted(**kwargs) + await self.create_collection(**kwargs) - async def _generate_vector_from_values( - self, - values: Any | None, - options: VectorSearchOptions, - ) -> Sequence[float | int] | None: - """Generate a vector from the given keywords.""" - if values is None: - return None - vector_field = self.definition.try_get_vector_field(options.vector_property_name) - if not vector_field: - raise VectorSearchOptionsException( - f"Vector field '{options.vector_property_name}' not found in data model definition." - ) - embedding_generator = ( - vector_field.embedding_generator if vector_field.embedding_generator else self.embedding_generator - ) - if not embedding_generator: - raise VectorSearchOptionsException( - f"Embedding generator not found for vector field '{options.vector_property_name}'." - ) + async def ensure_collection_exists(self, **kwargs: Any) -> bool: + """Create the collection in the service if it does not exists. - return ( - await embedding_generator.generate_embeddings( - # TODO (eavanvalkenburg): this only deals with string values, should support other types as well - # but that requires work on the embedding generators first. - texts=[values if isinstance(values, str) else json.dumps(values)], - settings=PromptExecutionSettings(dimensions=vector_field.dimensions), - ) - )[0].tolist() + First uses does_collection_exist to check if it exists, if it does returns False. + Otherwise, creates the collection and returns True. - def _build_filter(self, search_filter: OptionalOneOrMany[Callable | str] | None) -> OptionalOneOrMany[Any]: - """Create the filter based on the filters. + Returns: + bool: True if the collection was created, False if it already exists. - This function returns None, a single filter, or a list of filters. - If a single filter is passed, a single filter is returned. + """ + if await self.does_collection_exist(**kwargs): + return False + await self.create_collection(**kwargs) + return True - It takes the filters, which can be a Callable (lambda) or a string, and parses them into a filter object, - using the _lambda_parser method that is specific to each vector store. + @abstractmethod + async def create_collection(self, **kwargs: Any) -> None: + """Create the collection in the service. - If a list of filters, is passed, the parsed filters are also returned as a list, so the caller needs to - combine them in the appropriate way. + This should be overridden by the child class. + + Raises: + Make sure the implementation of this function raises relevant exceptions with good descriptions. + This is different then the `_inner_x` methods, as this is a public method. - Often called like this (when filters are strings): - ```python - if filter := self._build_filter(options.filter): - search_args["filter"] = filter if isinstance(filter, str) else " and ".join(filter) - ``` """ - if not search_filter: - return None + ... # pragma: no cover - filters = search_filter if isinstance(search_filter, list) else [search_filter] + @abstractmethod + async def does_collection_exist(self, **kwargs: Any) -> bool: + """Check if the collection exists. - created_filters: list[Any] = [] + This should be overridden by the child class. - visitor = LambdaVisitor(self._lambda_parser) - for filter_ in filters: - # parse lambda expression with AST - tree = parse(filter_ if isinstance(filter_, str) else getsource(filter_).strip()) - visitor.visit(tree) - created_filters = visitor.output_filters - if len(created_filters) == 0: - raise VectorStoreOperationException("No filter strings found.") - if len(created_filters) == 1: - return created_filters[0] - return created_filters + Raises: + Make sure the implementation of this function raises relevant exceptions with good descriptions. + This is different then the `_inner_x` methods, as this is a public method. + """ + ... # pragma: no cover @abstractmethod - def _lambda_parser(self, node: AST) -> Any: - """Parse the lambda expression and return the filter string. + async def ensure_collection_deleted(self, **kwargs: Any) -> None: + """Delete the collection. - This follows from the ast specs: https://docs.python.org/3/library/ast.html + This should be overridden by the child class. + + Raises: + Make sure the implementation of this function raises relevant exceptions with good descriptions. + This is different then the `_inner_x` methods, as this is a public method. """ - # This method should be implemented in the derived class - # to parse the lambda expression and return the filter string. - pass + ... # pragma: no cover - def create_search_function( + async def upsert( self, - function_name: str = DEFAULT_FUNCTION_NAME, - description: str = DEFAULT_DESCRIPTION, - *, - search_type: Literal["vector", "keyword_hybrid"] = "vector", - parameters: list[KernelParameterMetadata] | None = None, - return_parameter: KernelParameterMetadata | None = None, - filter: OptionalOneOrList[Callable | str] = None, - top: int = 5, - skip: int = 0, - vector_property_name: str | None = None, - additional_property_name: str | None = None, - include_vectors: bool = False, - include_total_count: bool = False, - filter_update_function: DynamicFilterFunction | None = None, - string_mapper: Callable[[VectorSearchResult[TModel]], str] | None = None, - ) -> KernelFunction: - """Create a kernel function from a search function. + records: OneOrMany[TModel], + **kwargs, + ) -> OneOrMany[TKey]: + """Upsert one or more records. + + If the key of the record already exists, the existing record will be updated. + If the key does not exist, a new record will be created. Args: - function_name: The name of the function, to be used in the kernel, default is "search". - description: The description of the function, a default is provided. - search_type: The type of search to perform, can be 'vector' or 'keyword_hybrid'. - parameters: The parameters for the function, - use an empty list for a function without parameters, - use None for the default set, which is "query", "top", and "skip". - return_parameter: The return parameter for the function. - filter: The filter to apply to the search. - top: The number of results to return. - skip: The number of results to skip. - vector_property_name: The name of the vector property to use for the search. - additional_property_name: The name of the additional property field to use for the search. - include_vectors: Whether to include the vectors in the results. - include_total_count: Whether to include the total count of results. - filter_update_function: A function to update the filters. - The function should return the updated filter. - The default function uses the parameters and the kwargs to update the filters, it - adds equal to filters to the options for all parameters that are not "query". - As well as adding equal to filters for parameters that have a default value. - string_mapper: The function to map the search results to strings. - """ - search_types = SearchType(search_type) - if search_types not in self.supported_search_types: - raise VectorStoreOperationNotSupportedException( - f"Search type '{search_types.value}' is not supported by this vector store: {self.__class__.__name__}" - ) + records: The records to upsert, can be a single record, a list of records, or a single container. + If a single record is passed, a single key is returned, instead of a list of keys. + **kwargs: Additional arguments. + + Returns: + OneOrMany[TKey]: The keys of the upserted records. + + Raises: + VectorStoreModelSerializationException: If an error occurs during serialization. + VectorStoreOperationException: If an error occurs during upserting. + """ + batch = True + if not isinstance(records, list) and not self._container_mode: + batch = False + if records is None: + raise VectorStoreOperationException("Either record or records must be provided.") + + try: + data = await self.serialize(records) + # the serialize method will parse any exception into a VectorStoreModelSerializationException + except VectorStoreModelSerializationException: + raise + + try: + results = await self._inner_upsert(data if isinstance(data, list) else [data], **kwargs) # type: ignore + except Exception as exc: + raise VectorStoreOperationException(f"Error upserting record(s): {exc}") from exc + if batch or self._container_mode: + return results + return results[0] + + @overload + async def get( + self, + top: int = ..., + skip: int = ..., + order_by: OneOrMany[str] | dict[str, bool] | None = None, + include_vectors: bool = False, + **kwargs: Any, + ) -> Sequence[TModel] | None: + """Get records based on the ordering and selection criteria. + + Args: + include_vectors: Include the vectors in the response. Default is True. + Some vector stores do not support retrieving without vectors, even when set to false. + Some vector stores have specific parameters to control that behavior, when + that parameter is set, include_vectors is ignored. + top: The number of records to return. + Only used if keys are not provided. + skip: The number of records to skip. + Only used if keys are not provided. + order_by: The order by clause, + this can be a string, a list of strings or a dict, + when passing strings, they are assumed to be ascending. + Otherwise, use the value in the dict to set ascending (True) or descending (False). + example: {"field_name": True} or ["field_name", {"field_name2": False}]. + **kwargs: Additional arguments. + + Returns: + The records, either a list of TModel or the container type. + + Raises: + VectorStoreOperationException: If an error occurs during the get. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + """ + ... + + @overload + async def get( + self, + key: TKey = ..., + include_vectors: bool = False, + **kwargs: Any, + ) -> TModel | None: + """Get a record if it exists. + + Args: + key: The key to get. + include_vectors: Include the vectors in the response. Default is True. + Some vector stores do not support retrieving without vectors, even when set to false. + Some vector stores have specific parameters to control that behavior, when + that parameter is set, include_vectors is ignored. + **kwargs: Additional arguments. + + Returns: + The records, either a list of TModel or the container type. + + Raises: + VectorStoreOperationException: If an error occurs during the get. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + """ + ... + + @overload + async def get( + self, + keys: Sequence[TKey] = ..., + include_vectors: bool = False, + **kwargs: Any, + ) -> OneOrMany[TModel] | None: + """Get a batch of records whose keys exist in the collection, i.e. keys that do not exist are ignored. + + Args: + keys: The keys to get, if keys are provided, key is ignored. + include_vectors: Include the vectors in the response. Default is True. + Some vector stores do not support retrieving without vectors, even when set to false. + Some vector stores have specific parameters to control that behavior, when + that parameter is set, include_vectors is ignored. + **kwargs: Additional arguments. + + Returns: + The records, either a list of TModel or the container type. + + Raises: + VectorStoreOperationException: If an error occurs during the get. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + """ + ... + + async def get( + self, + key=None, + keys=None, + include_vectors=False, + **kwargs, + ): + """Get a batch of records whose keys exist in the collection, i.e. keys that do not exist are ignored. + + Args: + key: The key to get. + keys: The keys to get, if keys are provided, key is ignored. + include_vectors: Include the vectors in the response. Default is True. + Some vector stores do not support retrieving without vectors, even when set to false. + Some vector stores have specific parameters to control that behavior, when + that parameter is set, include_vectors is ignored. + top: The number of records to return. + Only used if keys are not provided. + skip: The number of records to skip. + Only used if keys are not provided. + order_by: The order by clause, this is a list of dicts with the field name and ascending flag, + (default is True, which means ascending). + Only used if keys are not provided. + **kwargs: Additional arguments. + + Returns: + The records, either a list of TModel or the container type. + + Raises: + VectorStoreOperationException: If an error occurs during the get. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + """ + batch = True + options = None + if not keys and key: + if not isinstance(key, list): + keys = [key] + batch = False + else: + keys = key + if not keys: + if kwargs: + get_args = {} + kw_order_by: OneOrList[str] | dict[str, bool] | None = kwargs.pop("order_by", None) # type: ignore + if "top" in kwargs: + get_args["top"] = kwargs.pop("top", None) + if "skip" in kwargs: + get_args["skip"] = kwargs.pop("skip", None) + order_by: dict[str, bool] | None = None + if kw_order_by is not None: + order_by = {} + if isinstance(kw_order_by, str): + order_by[kw_order_by] = True + elif isinstance(kw_order_by, dict): + order_by = kw_order_by + elif isinstance(kw_order_by, list): + for item in kw_order_by: + if isinstance(item, str): + order_by[item] = True + else: + order_by.update(item) + else: + raise VectorStoreOperationException( + f"Invalid order_by type: {type(order_by)}, expected str, dict or list." + ) + get_args["order_by"] = order_by + try: + options = GetFilteredRecordOptions(**get_args) + except Exception as exc: + raise VectorStoreOperationException(f"Error creating options: {exc}") from exc + else: + raise VectorStoreOperationException("Either key, keys or options must be provided.") + try: + records = await self._inner_get(keys, include_vectors=include_vectors, options=options, **kwargs) + except Exception as exc: + raise VectorStoreOperationException(f"Error getting record(s): {exc}") from exc + + if not records: + return None + + try: + model_records = self.deserialize( + records if batch else records[0], include_vectors=include_vectors, **kwargs + ) + # the deserialize method will parse any exception into a VectorStoreModelDeserializationException + except VectorStoreModelDeserializationException: + raise + + # there are many code paths within the deserialize method, some supplied by the developer, + # and so depending on what is used, + # it might return a sequence, so we just return the first element, + # there should never be multiple elements (this is not a batch get), + # hence a raise if there are. + if batch: + return model_records + if not isinstance(model_records, Sequence): + return model_records + if len(model_records) == 1: + return model_records[0] + raise VectorStoreModelDeserializationException( + f"Error deserializing record, multiple records returned: {model_records}" + ) + + async def delete(self, keys: OneOrMany[TKey], **kwargs): + """Delete one or more records by key. + + An exception will be raised at the end if any record does not exist. + + Args: + keys: The key or keys to be deleted. + **kwargs: Additional arguments. + Exceptions: + VectorStoreOperationException: If an error occurs during deletion or a record does not exist. + """ + if not isinstance(keys, list): + keys = [keys] # type: ignore + try: + await self._inner_delete(keys, **kwargs) # type: ignore + except Exception as exc: + raise VectorStoreOperationException(f"Error deleting record(s): {exc}") from exc + + +# region: VectorStore + + +@release_candidate +class VectorStore(KernelBaseModel): + """Base class for vector stores.""" + + managed_client: bool = True + embedding_generator: EmbeddingGeneratorBase | None = None + + @abstractmethod + def get_collection( + self, + record_type: type[TModel], + *, + definition: VectorStoreCollectionDefinition | None = None, + collection_name: str | None = None, + embedding_generator: EmbeddingGeneratorBase | None = None, + **kwargs: Any, + ) -> "VectorStoreCollection": + """Get a vector store record collection instance tied to this store. + + Args: + record_type: The type of the records that will be used. + definition: The data model definition. + collection_name: The name of the collection. + embedding_generator: The embedding generator to use. + **kwargs: Additional arguments. + + Returns: + A vector store record collection instance tied to this store. + + """ + ... # pragma: no cover + + @abstractmethod + async def list_collection_names(self, **kwargs) -> Sequence[str]: + """Get the names of all collections.""" + ... # pragma: no cover + + async def does_collection_exist(self, collection_name: str) -> bool: + """Check if a collection exists. + + This is a wrapper around the get_collection method of a collection, + to check if the collection exists. + """ + try: + data_model = VectorStoreCollectionDefinition(fields=[VectorStoreField("key", name="id")]) + collection = self.get_collection(record_type=dict, definition=data_model, collection_name=collection_name) + return await collection.does_collection_exist() + except VectorStoreOperationException: + return False + + async def ensure_collection_deleted(self, collection_name: str) -> None: + """Delete a collection. + + This is a wrapper around the get_collection method of a collection, + to delete the collection. + """ + try: + data_model = VectorStoreCollectionDefinition(fields=[VectorStoreField("key", name="id")]) + collection = self.get_collection(record_type=dict, definition=data_model, collection_name=collection_name) + await collection.ensure_collection_deleted() + except VectorStoreOperationException: + pass + + async def __aenter__(self) -> Self: + """Enter the context manager.""" + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + """Exit the context manager. + + Should be overridden by subclasses, if necessary. + + If the client is passed in the constructor, it should not be closed, + in that case the managed_client should be set to False. + """ + pass # pragma: no cover + + +# region: Vector Search + + +@release_candidate +class VectorSearch(VectorStoreRecordHandler[TKey, TModel], Generic[TKey, TModel]): + """Base class for searching vectors.""" + + supported_search_types: ClassVar[set[SearchType]] = Field(default_factory=set) + + @property + def options_class(self) -> type[SearchOptions]: + """The options class for the search.""" + return VectorSearchOptions + + @abstractmethod + async def _inner_search( + self, + search_type: SearchType, + options: VectorSearchOptions, + values: Any | None = None, + vector: Sequence[float | int] | None = None, + **kwargs: Any, + ) -> KernelSearchResults[VectorSearchResult[TModel]]: + """Inner search method. + + This is the main search method that should be implemented, and will be called by the public search methods. + Currently, at least one of the three search contents will be provided + (through the public interface mixin functions), in the future, this may be expanded to allow multiple of them. + + This method should return a KernelSearchResults object with the results of the search. + The inner "results" object of the KernelSearchResults should be a async iterator that yields the search results, + this allows things like paging to be implemented. + + There is a default helper method "_get_vector_search_results_from_results" to convert + the results to a async iterable VectorSearchResults, but this can be overridden if necessary. + + Options might be a object of type VectorSearchOptions, or a subclass of it. + + The implementation of this method must deal with the possibility that multiple search contents are provided, + and should handle them in a way that makes sense for that particular store. + + The public methods will catch and reraise the three exceptions mentioned below, others are caught and turned + into a VectorSearchExecutionException. + + Args: + search_type: The type of search to perform. + options: The search options, can be None. + values: The values to search for, optional. + vector: The vector to search for, optional. + **kwargs: Additional arguments that might be needed. + + Returns: + The search results, wrapped in a KernelSearchResults object. + + Raises: + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreOperationNotSupportedException: If the search type is not supported. + + """ + ... + + @abstractmethod + def _get_record_from_result(self, result: Any) -> Any: + """Get the record from the returned search result. + + Does any unpacking or processing of the result to get just the record. + + If the underlying SDK of the store returns a particular type that might include something + like a score or other metadata, this method should be overridden to extract just the record. + + Likely returns a dict, but in some cases could return the record in the form of a SDK specific object. + + This method is used as part of the _get_vector_search_results_from_results method, + the output of it is passed to the deserializer. + """ + ... + + @abstractmethod + def _get_score_from_result(self, result: Any) -> float | None: + """Get the score from the result. + + Does any unpacking or processing of the result to get just the score. + + If the underlying SDK of the store returns a particular type with a score or other metadata, + this method extracts it. + """ + ... + + async def _get_vector_search_results_from_results( + self, results: AsyncIterable[Any] | Sequence[Any], options: VectorSearchOptions | None = None + ) -> AsyncIterable[VectorSearchResult[TModel]]: + if isinstance(results, Sequence): + results = desync_list(results) + async for result in results: + if not result: + continue + try: + record = self.deserialize( + self._get_record_from_result(result), include_vectors=options.include_vectors if options else True + ) + except VectorStoreModelDeserializationException: + raise + except Exception as exc: + raise VectorStoreModelDeserializationException( + f"An error occurred while deserializing the record: {exc}" + ) from exc + score = self._get_score_from_result(result) + if record is not None: + # single records are always returned as single records by the deserializer + yield VectorSearchResult(record=record, score=score) # type: ignore + + @overload + async def search( + self, + values: Any, + *, + vector_field_name: str | None = None, + filter: OptionalOneOrList[Callable | str] = None, + top: int = 3, + skip: int = 0, + include_total_count: bool = False, + include_vectors: bool = False, + **kwargs: Any, + ) -> KernelSearchResults[VectorSearchResult[TModel]]: + """Search the vector store with Vector search for records that match the given value and filter. + + Args: + values: The values to search for. These will be vectorized, + either by the store or using the provided generator. + vector_field_name: The name of the vector field to use for the search. + filter: The filter to apply to the search. + top: The number of results to return. + skip: The number of results to skip. + include_total_count: Whether to include the total count of results. + include_vectors: Whether to include the vectors in the results. + kwargs: If options are not set, this is used to create them. + they are passed on to the inner search method. + + Raises: + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreOperationNotSupportedException: If the search type is not supported. + + """ + ... + + @overload + async def search( + self, + *, + vector: Sequence[float | int], + vector_field_name: str | None = None, + filter: OptionalOneOrList[Callable | str] = None, + top: int = 3, + skip: int = 0, + include_total_count: bool = False, + include_vectors: bool = False, + **kwargs: Any, + ) -> KernelSearchResults[VectorSearchResult[TModel]]: + """Search the vector store with Vector search for records that match the given vector and filter. + + Args: + vector: The vector to search for + vector_field_name: The name of the vector field to use for the search. + filter: The filter to apply to the search. + top: The number of results to return. + skip: The number of results to skip. + include_total_count: Whether to include the total count of results. + include_vectors: Whether to include the vectors in the results. + kwargs: If options are not set, this is used to create them. + they are passed on to the inner search method. + + Raises: + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreOperationNotSupportedException: If the search type is not supported. + + """ + ... + + async def search( + self, + values=None, + *, + vector=None, + vector_property_name=None, + filter=None, + top=3, + skip=0, + include_total_count=False, + include_vectors=False, + **kwargs, + ): + """Search the vector store for records that match the given value and filter. + + Args: + values: The values to search for. + vector: The vector to search for, if not provided, the values will be used to generate a vector. + vector_property_name: The name of the vector property to use for the search. + filter: The filter to apply to the search. + top: The number of results to return. + skip: The number of results to skip. + include_total_count: Whether to include the total count of results. + include_vectors: Whether to include the vectors in the results. + kwargs: If options are not set, this is used to create them. + they are passed on to the inner search method. + + Raises: + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreOperationNotSupportedException: If the search type is not supported. + + """ + if SearchType.VECTOR not in self.supported_search_types: + raise VectorStoreOperationNotSupportedException( + f"Vector search is not supported by this vector store: {self.__class__.__name__}" + ) + options = VectorSearchOptions( + filter=filter, + vector_property_name=vector_property_name, + top=top, + skip=skip, + include_total_count=include_total_count, + include_vectors=include_vectors, + ) + try: + return await self._inner_search( + search_type=SearchType.VECTOR, + values=values, + options=options, + vector=vector, + **kwargs, + ) + except ( + VectorStoreModelDeserializationException, + VectorSearchOptionsException, + VectorSearchExecutionException, + VectorStoreOperationNotSupportedException, + VectorStoreOperationException, + ): + raise # pragma: no cover + except Exception as exc: + raise VectorSearchExecutionException(f"An error occurred during the search: {exc}") from exc + + async def hybrid_search( + self, + values: Any, + *, + vector: list[float | int] | None = None, + vector_property_name: str | None = None, + additional_property_name: str | None = None, + filter: OptionalOneOrList[Callable | str] = None, + top: int = 3, + skip: int = 0, + include_total_count: bool = False, + include_vectors: bool = False, + **kwargs: Any, + ) -> KernelSearchResults[VectorSearchResult[TModel]]: + """Search the vector store for records that match the given values and filter. + + Args: + values: The values to search for. + vector: The vector to search for, if not provided, the values will be used to generate a vector. + vector_property_name: The name of the vector field to use for the search. + additional_property_name: The name of the additional property field to use for the search. + filter: The filter to apply to the search. + top: The number of results to return. + skip: The number of results to skip. + include_total_count: Whether to include the total count of results. + include_vectors: Whether to include the vectors in the results. + kwargs: If options are not set, this is used to create them. + they are passed on to the inner search method. + + Raises: + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreOperationNotSupportedException: If the search type is not supported. + + """ + if SearchType.KEYWORD_HYBRID not in self.supported_search_types: + raise VectorStoreOperationNotSupportedException( + f"Keyword hybrid search is not supported by this vector store: {self.__class__.__name__}" + ) + options = VectorSearchOptions( + filter=filter, + vector_property_name=vector_property_name, + additional_property_name=additional_property_name, + top=top, + skip=skip, + include_total_count=include_total_count, + include_vectors=include_vectors, + ) + try: + return await self._inner_search( + search_type=SearchType.KEYWORD_HYBRID, + values=values, + vector=vector, + options=options, + **kwargs, + ) + except ( + VectorStoreModelDeserializationException, + VectorSearchOptionsException, + VectorSearchExecutionException, + VectorStoreOperationNotSupportedException, + VectorStoreOperationException, + ): + raise # pragma: no cover + except Exception as exc: + raise VectorSearchExecutionException(f"An error occurred during the search: {exc}") from exc + + async def _generate_vector_from_values( + self, + values: Any | None, + options: VectorSearchOptions, + ) -> Sequence[float | int] | None: + """Generate a vector from the given keywords.""" + if values is None: + return None + vector_field = self.definition.try_get_vector_field(options.vector_property_name) + if not vector_field: + raise VectorSearchOptionsException( + f"Vector field '{options.vector_property_name}' not found in data model definition." + ) + embedding_generator = ( + vector_field.embedding_generator if vector_field.embedding_generator else self.embedding_generator + ) + if not embedding_generator: + raise VectorSearchOptionsException( + f"Embedding generator not found for vector field '{options.vector_property_name}'." + ) + + return ( + await embedding_generator.generate_embeddings( + # TODO (eavanvalkenburg): this only deals with string values, should support other types as well + # but that requires work on the embedding generators first. + texts=[values if isinstance(values, str) else json.dumps(values)], + settings=PromptExecutionSettings(dimensions=vector_field.dimensions), + ) + )[0].tolist() + + def _build_filter(self, search_filter: OptionalOneOrMany[Callable | str] | None) -> OptionalOneOrMany[Any]: + """Create the filter based on the filters. + + This function returns None, a single filter, or a list of filters. + If a single filter is passed, a single filter is returned. + + It takes the filters, which can be a Callable (lambda) or a string, and parses them into a filter object, + using the _lambda_parser method that is specific to each vector store. + + If a list of filters, is passed, the parsed filters are also returned as a list, so the caller needs to + combine them in the appropriate way. + + Often called like this (when filters are strings): + ```python + if filter := self._build_filter(options.filter): + search_args["filter"] = filter if isinstance(filter, str) else " and ".join(filter) + ``` + """ + if not search_filter: + return None + + filters = search_filter if isinstance(search_filter, list) else [search_filter] + + created_filters: list[Any] = [] + + visitor = LambdaVisitor(self._lambda_parser) + for filter_ in filters: + # parse lambda expression with AST + tree = parse(filter_ if isinstance(filter_, str) else getsource(filter_).strip()) + visitor.visit(tree) + created_filters = visitor.output_filters + if len(created_filters) == 0: + raise VectorStoreOperationException("No filter strings found.") + if len(created_filters) == 1: + return created_filters[0] + return created_filters + + @abstractmethod + def _lambda_parser(self, node: AST) -> Any: + """Parse the lambda expression and return the filter string. + + This follows from the ast specs: https://docs.python.org/3/library/ast.html + """ + # This method should be implemented in the derived class + # to parse the lambda expression and return the filter string. + pass + + def create_search_function( + self, + function_name: str = DEFAULT_FUNCTION_NAME, + description: str = DEFAULT_DESCRIPTION, + *, + search_type: Literal["vector", "keyword_hybrid"] = "vector", + parameters: list[KernelParameterMetadata] | None = None, + return_parameter: KernelParameterMetadata | None = None, + filter: OptionalOneOrList[Callable | str] = None, + top: int = 5, + skip: int = 0, + vector_property_name: str | None = None, + additional_property_name: str | None = None, + include_vectors: bool = False, + include_total_count: bool = False, + filter_update_function: DynamicFilterFunction | None = None, + string_mapper: Callable[[VectorSearchResult[TModel]], str] | None = None, + ) -> KernelFunction: + """Create a kernel function from a search function. + + Args: + function_name: The name of the function, to be used in the kernel, default is "search". + description: The description of the function, a default is provided. + search_type: The type of search to perform, can be 'vector' or 'keyword_hybrid'. + parameters: The parameters for the function, + use an empty list for a function without parameters, + use None for the default set, which is "query", "top", and "skip". + return_parameter: The return parameter for the function. + filter: The filter to apply to the search. + top: The number of results to return. + skip: The number of results to skip. + vector_property_name: The name of the vector property to use for the search. + additional_property_name: The name of the additional property field to use for the search. + include_vectors: Whether to include the vectors in the results. + include_total_count: Whether to include the total count of results. + filter_update_function: A function to update the filters. + The function should return the updated filter. + The default function uses the parameters and the kwargs to update the filters, it + adds equal to filters to the options for all parameters that are not "query". + As well as adding equal to filters for parameters that have a default value. + string_mapper: The function to map the search results to strings. + """ + search_types = SearchType(search_type) + if search_types not in self.supported_search_types: + raise VectorStoreOperationNotSupportedException( + f"Search type '{search_types.value}' is not supported by this vector store: {self.__class__.__name__}" + ) options = VectorSearchOptions( filter=filter, skip=skip, @@ -1479,166 +2098,306 @@ def create_search_function( string_mapper=string_mapper, ) - def _create_kernel_function( + def _create_kernel_function( + self, + search_type: SearchType, + options: SearchOptions | None = None, + parameters: list[KernelParameterMetadata] | None = None, + filter_update_function: DynamicFilterFunction | None = None, + return_parameter: KernelParameterMetadata | None = None, + function_name: str = DEFAULT_FUNCTION_NAME, + description: str = DEFAULT_DESCRIPTION, + string_mapper: Callable[[VectorSearchResult[TModel]], str] | None = None, + ) -> KernelFunction: + """Create a kernel function from a search function.""" + update_func = filter_update_function or default_dynamic_filter_function + + @kernel_function(name=function_name, description=description) + async def search_wrapper(**kwargs: Any) -> Sequence[str]: + query = kwargs.pop("query", "") + try: + inner_options = create_options(self.options_class, deepcopy(options), **kwargs) + except ValidationError: + # this usually only happens when the kwargs are invalid, so blank options in this case. + inner_options = self.options_class() + inner_options.filter = update_func(filter=inner_options.filter, parameters=parameters, **kwargs) + match search_type: + case SearchType.VECTOR: + try: + results = await self.search( + values=query, + **inner_options.model_dump(exclude_defaults=True, exclude_none=True), + ) + except Exception as e: + msg = f"Exception in search function: {e}" + logger.error(msg) + raise TextSearchException(msg) from e + case SearchType.KEYWORD_HYBRID: + try: + results = await self.hybrid_search( + values=query, + **inner_options.model_dump(exclude_defaults=True, exclude_none=True), + ) + except Exception as e: + msg = f"Exception in hybrid search function: {e}" + logger.error(msg) + raise TextSearchException(msg) from e + case _: + raise VectorStoreOperationNotSupportedException( + f"Search type '{search_type}' is not supported by this vector store: {self.__class__.__name__}" + ) + if string_mapper: + return [string_mapper(result) async for result in results.results] + return [result.model_dump_json(exclude_none=True) async for result in results.results] + + return KernelFunctionFromMethod( + method=search_wrapper, + parameters=DEFAULT_PARAMETER_METADATA if parameters is None else parameters, + return_parameter=return_parameter or DEFAULT_RETURN_PARAMETER_METADATA, + ) + + +@runtime_checkable +class VectorStoreCollectionProtocol(Protocol): # noqa: D101 + collection_name: str + record_type: object + definition: VectorStoreCollectionDefinition + supported_key_types: ClassVar[set[str]] + supported_vector_types: ClassVar[set[str]] + embedding_generator: EmbeddingGeneratorBase | None = None + + async def ensure_collection_exists(self, **kwargs: Any) -> bool: + """Create the collection in the service if it does not exists. + + First uses does_collection_exist to check if it exists, if it does returns False. + Otherwise, creates the collection and returns True. + + Args: + **kwargs: Additional arguments. + + Returns: + bool: True if the collection was created, False if it already exists. + """ + ... + + async def create_collection(self, **kwargs: Any) -> None: + """Create the collection in the service. + + Args: + **kwargs: Additional arguments. + + Raises: + Make sure the implementation of this function raises relevant exceptions with good descriptions. + """ + ... + + async def does_collection_exist(self, **kwargs: Any) -> bool: + """Check if the collection exists. + + Args: + **kwargs: Additional arguments. + + Returns: + bool: True if the collection exists, False otherwise. + + Raises: + Make sure the implementation of this function raises relevant exceptions with good descriptions. + """ + ... + + async def ensure_collection_deleted(self, **kwargs: Any) -> None: + """Delete the collection. + + Args: + **kwargs: Additional arguments. + """ + ... + + async def get( + self, + key: Any = None, + keys: Sequence[Any] | None = None, + include_vectors: bool = False, + top: int | None = None, + skip: int | None = None, + order_by: OneOrMany[str] | dict[str, bool] | None = None, + **kwargs: Any, + ) -> OptionalOneOrList[Any]: + """Get a batch of records whose keys exist in the collection, i.e. keys that do not exist are ignored. + + Args: + key: The key to get. + keys: The keys to get, if keys are provided, key is ignored. + include_vectors: Include the vectors in the response. Default is False. + Some vector stores do not support retrieving without vectors, even when set to false. + Some vector stores have specific parameters to control that behavior, when + that parameter is set, include_vectors is ignored. + top: The number of records to return. + Only used if keys are not provided. + skip: The number of records to skip. + Only used if keys are not provided. + order_by: The order by clause, + this can be a string, a list of strings or a dict, + when passing strings, they are assumed to be ascending. + Otherwise, use the value in the dict to set ascending (True) or descending (False). + example: {"field_name": True} or ["field_name", {"field_name2": False}]. + **kwargs: Additional arguments. + + Returns: + The records, either a list of TModel or the container type. + + Raises: + VectorStoreOperationException: If an error occurs during the get. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + """ + ... + + async def upsert( self, - search_type: SearchType, - options: SearchOptions | None = None, - parameters: list[KernelParameterMetadata] | None = None, - filter_update_function: DynamicFilterFunction | None = None, - return_parameter: KernelParameterMetadata | None = None, - function_name: str = DEFAULT_FUNCTION_NAME, - description: str = DEFAULT_DESCRIPTION, - string_mapper: Callable[[VectorSearchResult[TModel]], str] | None = None, - ) -> KernelFunction: - """Create a kernel function from a search function.""" - update_func = filter_update_function or default_dynamic_filter_function + records: OneOrMany[Any], + **kwargs: Any, + ) -> OneOrMany[Any]: + """Upsert one or more records. - @kernel_function(name=function_name, description=description) - async def search_wrapper(**kwargs: Any) -> Sequence[str]: - query = kwargs.pop("query", "") - try: - inner_options = create_options(self.options_class, deepcopy(options), **kwargs) - except ValidationError: - # this usually only happens when the kwargs are invalid, so blank options in this case. - inner_options = self.options_class() - inner_options.filter = update_func(filter=inner_options.filter, parameters=parameters, **kwargs) - match search_type: - case SearchType.VECTOR: - try: - results = await self.search( - values=query, - **inner_options.model_dump(exclude_defaults=True, exclude_none=True), - ) - except Exception as e: - msg = f"Exception in search function: {e}" - logger.error(msg) - raise TextSearchException(msg) from e - case SearchType.KEYWORD_HYBRID: - try: - results = await self.hybrid_search( - values=query, - **inner_options.model_dump(exclude_defaults=True, exclude_none=True), - ) - except Exception as e: - msg = f"Exception in hybrid search function: {e}" - logger.error(msg) - raise TextSearchException(msg) from e - case _: - raise VectorStoreOperationNotSupportedException( - f"Search type '{search_type}' is not supported by this vector store: {self.__class__.__name__}" - ) - if string_mapper: - return [string_mapper(result) async for result in results.results] - return [result.model_dump_json(exclude_none=True) async for result in results.results] + If the key of the record already exists, the existing record will be updated. + If the key does not exist, a new record will be created. - return KernelFunctionFromMethod( - method=search_wrapper, - parameters=TextSearch._default_parameter_metadata() if parameters is None else parameters, - return_parameter=return_parameter or TextSearch._default_return_parameter_metadata(), - ) + Args: + records: The records to upsert, can be a single record, a list of records, or a single container. + If a single record is passed, a single key is returned, instead of a list of keys. + **kwargs: Additional arguments. + Returns: + OneOrMany[Any]: The keys of the upserted records. -class IndexKind(str, Enum): - """Index kinds for similarity search. + Raises: + VectorStoreModelSerializationException: If an error occurs during serialization. + VectorStoreOperationException: If an error occurs during upserting. + """ + ... - HNSW - Hierarchical Navigable Small World which performs an approximate nearest neighbor (ANN) search. - Lower accuracy than exhaustive k nearest neighbor, but faster and more efficient. + async def delete(self, keys: OneOrMany[Any], **kwargs: Any) -> None: + """Delete one or more records by key. - Flat - Does a brute force search to find the nearest neighbors. - Calculates the distances between all pairs of data points, so has a linear time complexity, - that grows directly proportional to the number of points. - Also referred to as exhaustive k nearest neighbor in some databases. - High recall accuracy, but slower and more expensive than HNSW. - Better with smaller datasets. + An exception will be raised at the end if any record does not exist. - IVF Flat - Inverted File with Flat Compression. - Designed to enhance search efficiency by narrowing the search area - through the use of neighbor partitions or clusters. - Also referred to as approximate nearest neighbor (ANN) search. + Args: + keys: The key or keys to be deleted. + **kwargs: Additional arguments. - Disk ANN - Disk-based Approximate Nearest Neighbor algorithm designed for efficiently searching - for approximate nearest neighbors (ANN) in high-dimensional spaces. - The primary focus of DiskANN is to handle large-scale datasets that cannot fit entirely - into memory, leveraging disk storage to store the data while maintaining fast search times. + Raises: + VectorStoreOperationException: If an error occurs during deletion or a record does not exist. + """ + ... - Quantized Flat - Index that compresses vectors using DiskANN-based quantization methods for better efficiency in the kNN search. - Dynamic - Dynamic index allows to automatically switch from FLAT to HNSW indexes. +@runtime_checkable +class VectorSearchProtocol(VectorStoreCollectionProtocol, Protocol): + """Protocol to check that a collection supports vector search.""" - Default - Default index type. - Used when no index type is specified. - Will differ per vector store. + supported_search_types: ClassVar[set[SearchType]] - """ + async def search( + self, + values: Any = None, + *, + vector: Sequence[float | int] | None = None, + vector_property_name: str | None = None, + filter: OptionalOneOrList[Callable | str] = None, + top: int = 3, + skip: int = 0, + include_total_count: bool = False, + include_vectors: bool = False, + **kwargs: Any, + ) -> KernelSearchResults[VectorSearchResult]: + """Search the vector store for records that match the given value and filter. - HNSW = "hnsw" - FLAT = "flat" - IVF_FLAT = "ivf_flat" - DISK_ANN = "disk_ann" - QUANTIZED_FLAT = "quantized_flat" - DYNAMIC = "dynamic" - DEFAULT = "default" + Args: + values: The values to search for. These will be vectorized, + either by the store or using the provided generator. + vector: The vector to search for, if not provided, the values will be used to generate a vector. + vector_property_name: The name of the vector property to use for the search. + filter: The filter to apply to the search. + top: The number of results to return. + skip: The number of results to skip. + include_total_count: Whether to include the total count of results. + include_vectors: Whether to include the vectors in the results. + kwargs: If options are not set, this is used to create them. + they are passed on to the inner search method. + Returns: + The search results. -class DistanceFunction(str, Enum): - """Distance functions for similarity search. + Raises: + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreOperationNotSupportedException: If the search type is not supported. + """ + ... - Cosine Similarity - the cosine (angular) similarity between two vectors - measures only the angle between the two vectors, without taking into account the length of the vectors - Cosine Similarity = 1 - Cosine Distance - -1 means vectors are opposite - 0 means vectors are orthogonal - 1 means vectors are identical - Cosine Distance - the cosine (angular) distance between two vectors - measures only the angle between the two vectors, without taking into account the length of the vectors - Cosine Distance = 1 - Cosine Similarity - 2 means vectors are opposite - 1 means vectors are orthogonal - 0 means vectors are identical - Dot Product - measures both the length and angle between two vectors - same as cosine similarity if the vectors are the same length, but more performant - Euclidean Distance - measures the Euclidean distance between two vectors - also known as l2-norm - Euclidean Squared Distance - measures the Euclidean squared distance between two vectors - also known as l2-squared - Manhattan - measures the Manhattan distance between two vectors - Hamming - number of differences between vectors at each dimensions - DEFAULT - default distance function - used when no distance function is specified - will differ per vector store. - """ + async def hybrid_search( + self, + values: Any, + *, + vector: list[float | int] | None = None, + vector_property_name: str | None = None, + additional_property_name: str | None = None, + filter: OptionalOneOrList[Callable | str] = None, + top: int = 3, + skip: int = 0, + include_total_count: bool = False, + include_vectors: bool = False, + **kwargs: Any, + ) -> KernelSearchResults[VectorSearchResult]: + """Search the vector store for records that match the given values and filter using hybrid search. - COSINE_SIMILARITY = "cosine_similarity" - COSINE_DISTANCE = "cosine_distance" - DOT_PROD = "dot_prod" - EUCLIDEAN_DISTANCE = "euclidean_distance" - EUCLIDEAN_SQUARED_DISTANCE = "euclidean_squared_distance" - MANHATTAN = "manhattan" - HAMMING = "hamming" - DEFAULT = "DEFAULT" + Args: + values: The values to search for. + vector: The vector to search for, if not provided, the values will be used to generate a vector. + vector_property_name: The name of the vector field to use for the search. + additional_property_name: The name of the additional property field to use for the search. + filter: The filter to apply to the search. + top: The number of results to return. + skip: The number of results to skip. + include_total_count: Whether to include the total count of results. + include_vectors: Whether to include the vectors in the results. + kwargs: If options are not set, this is used to create them. + they are passed on to the inner search method. + Returns: + The search results. -DISTANCE_FUNCTION_DIRECTION_HELPER: Final[dict[DistanceFunction, Callable[[int | float, int | float], bool]]] = { - DistanceFunction.COSINE_SIMILARITY: operator.gt, - DistanceFunction.COSINE_DISTANCE: operator.le, - DistanceFunction.DOT_PROD: operator.gt, - DistanceFunction.EUCLIDEAN_DISTANCE: operator.le, - DistanceFunction.EUCLIDEAN_SQUARED_DISTANCE: operator.le, - DistanceFunction.MANHATTAN: operator.le, - DistanceFunction.HAMMING: operator.le, -} + Raises: + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreOperationNotSupportedException: If the search type is not supported. + """ + ... + + +__all__ = [ + "DEFAULT_DESCRIPTION", + "DEFAULT_FUNCTION_NAME", + "DEFAULT_PARAMETER_METADATA", + "DEFAULT_RETURN_PARAMETER_METADATA", + "DISTANCE_FUNCTION_DIRECTION_HELPER", + "DistanceFunction", + "DynamicFilterFunction", + "FieldTypes", + "IndexKind", + "KernelSearchResults", + "SearchType", + "VectorSearch", + "VectorSearchProtocol", + "VectorSearchResult", + "VectorStore", + "VectorStoreCollection", + "VectorStoreCollectionDefinition", + "VectorStoreCollectionProtocol", + "VectorStoreField", + "create_options", + "default_dynamic_filter_function", + "vectorstoremodel", +] diff --git a/python/tests/conftest.py b/python/tests/conftest.py index d7981f5f01b9..dfe935102687 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -12,7 +12,7 @@ from pytest import fixture from semantic_kernel.agents import Agent, DeclarativeSpecMixin, register_agent_type -from semantic_kernel.data._definitions import VectorStoreCollectionDefinition, VectorStoreField, vectorstoremodel +from semantic_kernel.data.vectors import VectorStoreCollectionDefinition, VectorStoreField, vectorstoremodel if TYPE_CHECKING: from semantic_kernel import Kernel @@ -380,7 +380,7 @@ def record_type(index_kind: str, distance_function: str, vector_property_type: s class DataModelClass(BaseModel): content: Annotated[str, VectorStoreField("data")] vector: Annotated[ - str | list[float] | None, + list[float] | str | None, VectorStoreField( "vector", type=vector_property_type, diff --git a/python/tests/integration/memory/azure_cosmos_db/conftest.py b/python/tests/integration/memory/azure_cosmos_db/conftest.py index c8d70bddac14..4162604a3076 100644 --- a/python/tests/integration/memory/azure_cosmos_db/conftest.py +++ b/python/tests/integration/memory/azure_cosmos_db/conftest.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pytest import fixture -from semantic_kernel.data._definitions import VectorStoreField, vectorstoremodel +from semantic_kernel.data.vectors import VectorStoreField, vectorstoremodel @fixture diff --git a/python/tests/integration/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql.py b/python/tests/integration/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql.py index dae2bf9a9569..8b136a06bcf2 100644 --- a/python/tests/integration/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql.py +++ b/python/tests/integration/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql.py @@ -9,8 +9,8 @@ from azure.cosmos.aio import CosmosClient from azure.cosmos.partition_key import PartitionKey -from semantic_kernel.connectors.memory.azure_cosmos_db import AzureCosmosDBNoSQLCompositeKey, CosmosNoSqlStore -from semantic_kernel.data._vectors import VectorStore +from semantic_kernel.connectors.memory.azure_cosmos_db import CosmosNoSqlCompositeKey, CosmosNoSqlStore +from semantic_kernel.data.vectors import VectorStore from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorException from tests.integration.memory.vector_store_test_base import VectorStoreTestBase @@ -90,9 +90,7 @@ async def test_custom_partition_key( partition_key=PartitionKey(path="/product_type"), ) - composite_key = AzureCosmosDBNoSQLCompositeKey( - key=data_record["id"], partition_key=data_record["product_type"] - ) + composite_key = CosmosNoSqlCompositeKey(key=data_record["id"], partition_key=data_record["product_type"]) # Upsert await collection.create_collection() diff --git a/python/tests/integration/memory/postgres/test_postgres_int.py b/python/tests/integration/memory/postgres/test_postgres_int.py index 4784b88fd3f3..566f8b3ad937 100644 --- a/python/tests/integration/memory/postgres/test_postgres_int.py +++ b/python/tests/integration/memory/postgres/test_postgres_int.py @@ -11,9 +11,14 @@ from pydantic import BaseModel from semantic_kernel.connectors.memory.postgres import PostgresCollection, PostgresSettings, PostgresStore -from semantic_kernel.data import VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data._definitions import vectorstoremodel -from semantic_kernel.data._vectors import DistanceFunction, IndexKind, VectorSearchOptions +from semantic_kernel.data.vectors import ( + DistanceFunction, + IndexKind, + VectorSearchOptions, + VectorStoreCollectionDefinition, + VectorStoreField, + vectorstoremodel, +) from semantic_kernel.exceptions.memory_connector_exceptions import ( MemoryConnectorConnectionException, MemoryConnectorInitializationError, diff --git a/python/tests/integration/memory/test_vector_store.py b/python/tests/integration/memory/test_vector_store.py index 84a6dc69d065..e4312f5ecbb9 100644 --- a/python/tests/integration/memory/test_vector_store.py +++ b/python/tests/integration/memory/test_vector_store.py @@ -9,7 +9,7 @@ import pytest from semantic_kernel.connectors.memory.redis import RedisCollectionTypes -from semantic_kernel.data import VectorStore +from semantic_kernel.data.vectors import VectorStore from semantic_kernel.exceptions import MemoryConnectorConnectionException from tests.integration.memory.data_records import RAW_RECORD_ARRAY, RAW_RECORD_LIST from tests.integration.memory.vector_store_test_base import VectorStoreTestBase diff --git a/python/tests/integration/memory/vector_store_test_base.py b/python/tests/integration/memory/vector_store_test_base.py index 3b07392f9344..c383bc65d2c6 100644 --- a/python/tests/integration/memory/vector_store_test_base.py +++ b/python/tests/integration/memory/vector_store_test_base.py @@ -4,7 +4,7 @@ import pytest -from semantic_kernel.data import VectorStore +from semantic_kernel.data.vectors import VectorStore def get_redis_store(): diff --git a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_mongodb_collection.py b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_mongodb_collection.py index 04d8b412d979..403c95d54fdc 100644 --- a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_mongodb_collection.py +++ b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_mongodb_collection.py @@ -6,7 +6,7 @@ from pymongo import AsyncMongoClient from semantic_kernel.connectors.memory.azure_cosmos_db import CosmosMongoCollection -from semantic_kernel.data import VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data.vectors import VectorStoreCollectionDefinition, VectorStoreField from semantic_kernel.exceptions import VectorStoreInitializationException diff --git a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py index 6242e212eca0..aff9f1477cb9 100644 --- a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py +++ b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py @@ -13,8 +13,11 @@ _create_default_indexing_policy_nosql, _create_default_vector_embedding_policy, ) -from semantic_kernel.exceptions import VectorStoreInitializationException -from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreModelException, VectorStoreOperationException +from semantic_kernel.exceptions import ( + VectorStoreInitializationException, + VectorStoreModelException, + VectorStoreOperationException, +) def test_azure_cosmos_db_no_sql_collection_init( @@ -300,7 +303,6 @@ async def test_azure_cosmos_db_no_sql_collection_create_collection_allow_custom_ [ ("hnsw", "cosine_similarity", "float"), # unsupported index kind ("flat", "hamming", "float"), # unsupported distance function - ("flat", "cosine_similarity", "double"), # unsupported property type ], ) async def test_azure_cosmos_db_no_sql_collection_create_collection_unsupported_vector_field_property( diff --git a/python/tests/unit/connectors/memory/test_azure_ai_search.py b/python/tests/unit/connectors/memory/test_azure_ai_search.py index 31cb21af436d..a8dbe2bb7e9f 100644 --- a/python/tests/unit/connectors/memory/test_azure_ai_search.py +++ b/python/tests/unit/connectors/memory/test_azure_ai_search.py @@ -201,11 +201,13 @@ async def test_get(collection, mock_get): @mark.parametrize( "order_by, ordering", [ - param({"field": "id"}, ["id"], id="single id"), - param({"field": "id", "ascending": True}, ["id"], id="ascending id"), - param({"field": "id", "ascending": False}, ["id desc"], id="descending id"), - param([{"field": "id", "ascending": True}], ["id"], id="ascending id list"), - param([{"field": "id"}, {"field": "content"}], ["id", "content"], id="multiple"), + param("id", ["id"], id="single id"), + param({"id": True}, ["id"], id="ascending id"), + param({"id": False}, ["id desc"], id="descending id"), + param(["id"], ["id"], id="ascending id list"), + param(["id", "content"], ["id", "content"], id="multiple"), + param([{"id": True}, {"content": False}], ["id", "content desc"], id="multiple desc"), + param(["id", {"content": False}], ["id", "content desc"], id="multiple mix"), ], ) async def test_get_without_key(collection, mock_get, mock_search, order_by, ordering): diff --git a/python/tests/unit/connectors/memory/test_faiss.py b/python/tests/unit/connectors/memory/test_faiss.py index 5b2f91c6673b..af1147279dd4 100644 --- a/python/tests/unit/connectors/memory/test_faiss.py +++ b/python/tests/unit/connectors/memory/test_faiss.py @@ -4,8 +4,7 @@ from pytest import fixture, mark, raises from semantic_kernel.connectors.memory.faiss import FaissCollection, FaissStore -from semantic_kernel.data import VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data._vectors import DistanceFunction +from semantic_kernel.data.vectors import DistanceFunction, VectorStoreCollectionDefinition, VectorStoreField from semantic_kernel.exceptions import VectorStoreInitializationException diff --git a/python/tests/unit/connectors/memory/test_in_memory.py b/python/tests/unit/connectors/memory/test_in_memory.py index 0817bbb24800..4706a08bb3f5 100644 --- a/python/tests/unit/connectors/memory/test_in_memory.py +++ b/python/tests/unit/connectors/memory/test_in_memory.py @@ -3,7 +3,7 @@ from pytest import fixture, mark, raises from semantic_kernel.connectors.memory.in_memory import InMemoryCollection, InMemoryStore -from semantic_kernel.data._vectors import DistanceFunction +from semantic_kernel.data.vectors import DistanceFunction from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreOperationException diff --git a/python/tests/unit/connectors/memory/test_postgres_store.py b/python/tests/unit/connectors/memory/test_postgres_store.py index f9d95c4524bd..889465d67768 100644 --- a/python/tests/unit/connectors/memory/test_postgres_store.py +++ b/python/tests/unit/connectors/memory/test_postgres_store.py @@ -17,8 +17,7 @@ PostgresSettings, PostgresStore, ) -from semantic_kernel.data._definitions import VectorStoreField, vectorstoremodel -from semantic_kernel.data._vectors import DistanceFunction, IndexKind +from semantic_kernel.data.vectors import DistanceFunction, IndexKind, VectorStoreField, vectorstoremodel @fixture(scope="function") diff --git a/python/tests/unit/connectors/memory/test_qdrant.py b/python/tests/unit/connectors/memory/test_qdrant.py index 8eac43b522d4..d0b71ea3a27f 100644 --- a/python/tests/unit/connectors/memory/test_qdrant.py +++ b/python/tests/unit/connectors/memory/test_qdrant.py @@ -7,8 +7,7 @@ from qdrant_client.models import Datatype, Distance, FieldCondition, MatchValue, VectorParams from semantic_kernel.connectors.memory.qdrant import QdrantCollection, QdrantStore -from semantic_kernel.data._definitions import VectorStoreField -from semantic_kernel.data._vectors import DistanceFunction +from semantic_kernel.data.vectors import DistanceFunction, VectorStoreField from semantic_kernel.exceptions import ( VectorSearchExecutionException, VectorStoreInitializationException, diff --git a/python/tests/unit/connectors/memory/test_sql_server.py b/python/tests/unit/connectors/memory/test_sql_server.py index 18508b4dacb4..674b957b29fb 100644 --- a/python/tests/unit/connectors/memory/test_sql_server.py +++ b/python/tests/unit/connectors/memory/test_sql_server.py @@ -21,8 +21,7 @@ _build_select_query, _build_select_table_names_query, ) -from semantic_kernel.data._definitions import VectorStoreField -from semantic_kernel.data._vectors import DistanceFunction, IndexKind, VectorSearchOptions +from semantic_kernel.data.vectors import DistanceFunction, IndexKind, VectorSearchOptions, VectorStoreField from semantic_kernel.exceptions.vector_store_exceptions import ( VectorStoreInitializationException, VectorStoreOperationException, diff --git a/python/tests/unit/connectors/search/test_brave_search.py b/python/tests/unit/connectors/search/test_brave_search.py index 542616e18ad1..59a6b4db3da3 100644 --- a/python/tests/unit/connectors/search/test_brave_search.py +++ b/python/tests/unit/connectors/search/test_brave_search.py @@ -6,7 +6,8 @@ import pytest from semantic_kernel.connectors.search.brave import BraveSearch, BraveSearchResponse, BraveWebPage, BraveWebPages -from semantic_kernel.data._search import KernelSearchResults, SearchOptions, TextSearchResult +from semantic_kernel.data._search import KernelSearchResults, SearchOptions +from semantic_kernel.data.text_search import TextSearchResult from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidRequestError diff --git a/python/tests/unit/connectors/search/test_google_search.py b/python/tests/unit/connectors/search/test_google_search.py index 2a92bd70e1ab..d61d61bd76e9 100644 --- a/python/tests/unit/connectors/search/test_google_search.py +++ b/python/tests/unit/connectors/search/test_google_search.py @@ -11,7 +11,7 @@ GoogleSearchResponse, GoogleSearchResult, ) -from semantic_kernel.data._search import TextSearchResult +from semantic_kernel.data.text_search import TextSearchResult from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidRequestError diff --git a/python/tests/unit/data/conftest.py b/python/tests/unit/data/conftest.py index 654fb93838b4..5d9dd478ddd1 100644 --- a/python/tests/unit/data/conftest.py +++ b/python/tests/unit/data/conftest.py @@ -10,23 +10,26 @@ from pydantic import BaseModel, Field from pytest import fixture -from semantic_kernel.data import ( +from semantic_kernel.data.vectors import ( KernelSearchResults, + SearchType, + VectorSearch, + VectorSearchResult, + VectorStoreCollection, VectorStoreCollectionDefinition, - VectorStoreRecordCollection, + VectorStoreField, vectorstoremodel, ) -from semantic_kernel.data._definitions import VectorStoreField -from semantic_kernel.data._vectors import VectorSearch, VectorSearchResult from semantic_kernel.kernel_types import OptionalOneOrMany @fixture def DictVectorStoreRecordCollection() -> type[VectorSearch]: class DictVectorStoreRecordCollection( - VectorStoreRecordCollection[str, Any], + VectorStoreCollection[str, Any], VectorSearch[str, Any], ): + supported_search_types = {SearchType.VECTOR} inner_storage: dict[str, Any] = Field(default_factory=dict) async def _inner_delete(self, keys: Sequence[str], **kwargs: Any) -> None: diff --git a/python/tests/unit/data/test_filter.py b/python/tests/unit/data/test_filter.py deleted file mode 100644 index fd93c8d92d1e..000000000000 --- a/python/tests/unit/data/test_filter.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from semantic_kernel.data._vectors import VectorSearchOptions - - -def test_lambda_filter(): - options = VectorSearchOptions(filter=lambda x: x.tag == "value") - assert options.filter is not None - - -def test_lambda_filter_str(): - options = VectorSearchOptions(filter='lambda x: x.tag == "value"') - assert options.filter is not None diff --git a/python/tests/unit/data/test_text_search.py b/python/tests/unit/data/test_text_search.py index ad7d6477f7fe..fb8cf0acc738 100644 --- a/python/tests/unit/data/test_text_search.py +++ b/python/tests/unit/data/test_text_search.py @@ -8,16 +8,16 @@ from pydantic import BaseModel from semantic_kernel import Kernel -from semantic_kernel.data import TextSearch -from semantic_kernel.data._search import ( +from semantic_kernel.data.text_search import ( DEFAULT_DESCRIPTION, DEFAULT_FUNCTION_NAME, KernelSearchResults, SearchOptions, + TextSearch, TextSearchResult, create_options, ) -from semantic_kernel.data._vectors import VectorSearchOptions +from semantic_kernel.data.vectors import VectorSearchOptions from semantic_kernel.exceptions import TextSearchException from semantic_kernel.functions import KernelArguments, KernelParameterMetadata from semantic_kernel.utils.list_handler import desync_list diff --git a/python/tests/unit/data/test_vector_search_base.py b/python/tests/unit/data/test_vector_search_base.py index e43b2a1d17f3..a4c3d5de817d 100644 --- a/python/tests/unit/data/test_vector_search_base.py +++ b/python/tests/unit/data/test_vector_search_base.py @@ -3,13 +3,14 @@ import pytest -from semantic_kernel.data._vectors import VectorSearch, VectorSearchOptions +from semantic_kernel.data.vectors import VectorSearch, VectorSearchOptions, VectorSearchProtocol async def test_search(vector_store_record_collection: VectorSearch): + assert isinstance(vector_store_record_collection, VectorSearchProtocol) record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} await vector_store_record_collection.upsert(record) - results = await vector_store_record_collection._inner_search(options=VectorSearchOptions(), keywords="test_content") + results = await vector_store_record_collection.search(vector=[1.0, 2.0, 3.0]) records = [rec async for rec in results.results] assert records[0].record == record diff --git a/python/tests/unit/data/test_vector_store_model_decorator.py b/python/tests/unit/data/test_vector_store_model_decorator.py index eafb71517307..23ce3ceab5f2 100644 --- a/python/tests/unit/data/test_vector_store_model_decorator.py +++ b/python/tests/unit/data/test_vector_store_model_decorator.py @@ -9,8 +9,7 @@ from pydantic.dataclasses import dataclass as pydantic_dataclass from pytest import raises -from semantic_kernel.data import VectorStoreCollectionDefinition, VectorStoreField -from semantic_kernel.data._definitions import vectorstoremodel +from semantic_kernel.data.vectors import VectorStoreCollectionDefinition, VectorStoreField, vectorstoremodel from semantic_kernel.exceptions import VectorStoreModelException diff --git a/python/tests/unit/data/test_vector_store_record_collection.py b/python/tests/unit/data/test_vector_store_record_collection.py index fb85118491ea..ff6f1dcffe0a 100644 --- a/python/tests/unit/data/test_vector_store_record_collection.py +++ b/python/tests/unit/data/test_vector_store_record_collection.py @@ -6,7 +6,7 @@ from pandas import DataFrame from pytest import mark, raises -from semantic_kernel.data._definitions import SerializeMethodProtocol, ToDictMethodProtocol +from semantic_kernel.data.vectors import SerializeMethodProtocol, ToDictMethodProtocol from semantic_kernel.exceptions import ( VectorStoreModelDeserializationException, VectorStoreModelSerializationException, @@ -276,7 +276,7 @@ async def test_get_fail_multiple(DictVectorStoreRecordCollection, definition): await vector_store_record_collection.upsert(record) assert len(vector_store_record_collection.inner_storage) == 1 with ( - patch("semantic_kernel.data.vectors.VectorStoreRecordCollection.deserialize") as deserialize_mock, + patch("semantic_kernel.data.vectors.VectorStoreCollection.deserialize") as deserialize_mock, raises( VectorStoreModelDeserializationException, match="Error deserializing record, multiple records returned:" ), diff --git a/python/tests/unit/data/test_vector_store_record_definition.py b/python/tests/unit/data/test_vector_store_record_definition.py index 7f88bcfffa4c..4b137f53fa0f 100644 --- a/python/tests/unit/data/test_vector_store_record_definition.py +++ b/python/tests/unit/data/test_vector_store_record_definition.py @@ -2,7 +2,7 @@ from pytest import raises -from semantic_kernel.data import VectorStoreCollectionDefinition, VectorStoreField +from semantic_kernel.data.vectors import VectorStoreCollectionDefinition, VectorStoreField from semantic_kernel.exceptions import VectorStoreModelException