diff --git a/docs/decisions/0050-updated-vector-store-design.md b/docs/decisions/0050-updated-vector-store-design.md new file mode 100644 index 000000000000..c008068b1e95 --- /dev/null +++ b/docs/decisions/0050-updated-vector-store-design.md @@ -0,0 +1,995 @@ +--- +# These are optional elements. Feel free to remove any of them. +status: proposed +contact: westey-m +date: 2024-06-05 +deciders: sergeymenshykh, markwallace, rbarreto, dmytrostruk, westey-m, matthewbolanos, eavanvalkenburg +consulted: stephentoub, dluc, ajcvickers, roji +informed: +--- + +# Updated Memory Connector Design + +## Context and Problem Statement + +Semantic Kernel has a collection of connectors to popular Vector databases e.g. Azure AI Search, Chroma, Milvus, ... +Each Memory connector implements a memory abstraction defined by Semantic Kernel and allows developers to easily integrate Vector databases into their applications. +The current abstractions are experimental and the purpose of this ADR is to progress the design of the abstractions so that they can graduate to non experimental status. + +### Problems with current design + +1. The `IMemoryStore` interface has four responsibilities with different cardinalities. Some are schema aware and others schema agnostic. +2. The `IMemoryStore` interface only supports a fixed schema for data storage, retrieval and search, which limits its usability by customers with existing data sets. +2. The `IMemoryStore` implementations are opinionated around key encoding / decoding and collection name sanitization, which limits its usability by customers with existing data sets. + +Responsibilities: + +|Functional Area|Cardinality|Significance to Semantic Kernel| +|-|-|-| +|Collection/Index create|An implementation per store type and model|Valuable when building a store and adding data| +|Collection/Index list names, exists and delete|An implementation per store type|Valuable when building a store and adding data| +|Data Storage and Retrieval|An implementation per store type|Valuable when building a store and adding data| +|Vector Search|An implementation per store type, model and search type|Valuable for many scenarios including RAG, finding contradictory facts based on user input, finding similar memories to merge, etc.| + + +### Memory Store Today +```cs +interface IMemoryStore +{ + // Collection / Index Management + Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default); + IAsyncEnumerable GetCollectionsAsync(CancellationToken cancellationToken = default); + Task DoesCollectionExistAsync(string collectionName, CancellationToken cancellationToken = default); + Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default); + + // Data Storage and Retrieval + Task UpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken = default); + IAsyncEnumerable UpsertBatchAsync(string collectionName, IEnumerable records, CancellationToken cancellationToken = default); + Task GetAsync(string collectionName, string key, bool withEmbedding = false, CancellationToken cancellationToken = default); + IAsyncEnumerable GetBatchAsync(string collectionName, IEnumerable keys, bool withVectors = false, CancellationToken cancellationToken = default); + Task RemoveAsync(string collectionName, string key, CancellationToken cancellationToken = default); + Task RemoveBatchAsync(string collectionName, IEnumerable keys, CancellationToken cancellationToken = default); + + // Vector Search + IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync( + string collectionName, + ReadOnlyMemory embedding, + int limit, + double minRelevanceScore = 0.0, + bool withVectors = false, + CancellationToken cancellationToken = default); + + Task<(MemoryRecord, double)?> GetNearestMatchAsync( + string collectionName, + ReadOnlyMemory embedding, + double minRelevanceScore = 0.0, + bool withEmbedding = false, + CancellationToken cancellationToken = default); +} +``` + +### Actions + +1. The `IMemoryStore` should be split into different interfaces, so that schema aware and schema agnostic operations are separated. +2. The **Data Storage and Retrieval** and **Vector Search** areas should allow typed access to data and support any schema that is currently available in the customer's data store. +3. The collection / index create functionality should allow developers to use a common definition that is part of the abstraction to create collections. +4. The collection / index list/exists/delete functionality should allow management of any collection regardless of schema. +5. Remove opinionated behaviors from connectors. The opinionated behavior limits the ability of these connectors to be used with pre-existing vector databases. As far as possible these behaviors should be moved into decorators or be injectable. Examples of opinionated behaviors: + 1. The AzureAISearch connector encodes keys before storing and decodes them after retrieval since keys in Azure AI Search supports a limited set of characters. + 2. The AzureAISearch connector sanitizes collection names before using them, since Azure AI Search supports a limited set of characters. + 3. The Redis connector prepends the collection name on to the front of keys before storing records and also registers the collection name as a prefix for records to be indexed by the index. + +### Non-functional requirements for new connectors +1. Ensure all connectors are throwing the same exceptions consistently with data about the request made provided in a consistent manner. +2. Add consistent telemetry for all connectors. +3. As far as possible integration tests should be runnable on build server. + +### New Designs + +The separation between collection/index management and record management. + +```mermaid +--- +title: SK Collection/Index and record management +--- +classDiagram + note for IVectorRecordStore "Can manage records for any scenario" + note for IVectorCollectionCreate "Can create collections and\nindexes" + note for IVectorCollectionNonSchema "Can retrieve/delete any collections and\nindexes" + + namespace SKAbstractions{ + class IVectorCollectionCreate{ + <> + +CreateCollection + } + + class IVectorCollectionNonSchema{ + <> + +GetCollectionNames + +CollectionExists + +DeleteCollection + } + + class IVectorRecordStore~TModel~{ + <> + +Upsert(TModel record) string + +UpserBatch(TModel record) string + +Get(string key) TModel + +GetBatch(string[] keys) TModel[] + +Delete(string key) + +DeleteBatch(string[] keys) + } + } + + namespace AzureAIMemory{ + class AzureAISearchVectorCollectionCreate{ + } + + class AzureAISearchVectorCollectionNonSchema{ + } + + class AzureAISearchVectorRecordStore{ + } + } + + namespace RedisMemory{ + class RedisVectorCollectionCreate{ + } + + class RedisVectorCollectionNonSchema{ + } + + class RedisVectorRecordStore{ + } + } + + IVectorCollectionCreate <|-- AzureAISearchVectorCollectionCreate + IVectorCollectionNonSchema <|-- AzureAISearchVectorCollectionNonSchema + IVectorRecordStore <|-- AzureAISearchVectorRecordStore + + IVectorCollectionCreate <|-- RedisVectorCollectionCreate + IVectorCollectionNonSchema <|-- RedisVectorCollectionNonSchema + IVectorRecordStore <|-- RedisVectorRecordStore +``` + +How to use your own schema with core sk functionality. + +```mermaid +--- +title: Chat History Break Glass +--- +classDiagram + note for IVectorRecordStore "Can manage records\nfor any scenario" + note for IVectorCollectionCreate "Can create collections\nan dindexes" + note for IVectorCollectionNonSchema "Can retrieve/delete any\ncollections and indexes" + note for CustomerHistoryVectorCollectionCreate "Creates history collections and indices\nusing Customer requirements" + note for CustomerHistoryVectorRecordStore "Decorator class for IVectorRecordStore that maps\nbetween the customer model to our model" + + namespace SKAbstractions{ + class IVectorCollectionCreate{ + <> + +CreateCollection + } + + class IVectorCollectionNonSchema{ + <> + +GetCollectionNames + +CollectionExists + +DeleteCollection + } + + class IVectorRecordStore~TModel~{ + <> + +Upsert(TModel record) string + +Get(string key) TModel + +Delete(string key) string + } + + class ISemanticTextMemory{ + <> + +SaveInformationAsync() + +SaveReferenceAsync() + +GetAsync() + +DeleteAsync() + +SearchAsync() + +GetCollectionsAsync() + } + } + + namespace CustomerProject{ + class CustomerHistoryModel{ + +string text + +float[] vector + +Dictionary~string, string~ properties + } + + class CustomerHistoryVectorCollectionCreate{ + +CreateCollection + } + + class CustomerHistoryVectorRecordStore{ + -IVectorRecordStore~CustomerHistoryModel~ _store + +Upsert(ChatHistoryModel record) string + +Get(string key) ChatHistoryModel + +Delete(string key) string + } + } + + namespace SKCore{ + class SemanticTextMemory{ + -IVectorRecordStore~ChatHistoryModel~ _VectorRecordStore + -IMemoryCollectionService _collectionsService + -ITextEmbeddingGenerationService _embeddingGenerationService + } + + class ChatHistoryPlugin{ + -ISemanticTextMemory memory + } + + class ChatHistoryModel{ + +string message + +float[] embedding + +Dictionary~string, string~ metadata + } + } + + IVectorCollectionCreate <|-- CustomerHistoryVectorCollectionCreate + + IVectorRecordStore <|-- CustomerHistoryVectorRecordStore + IVectorRecordStore <.. CustomerHistoryVectorRecordStore + CustomerHistoryModel <.. CustomerHistoryVectorRecordStore + ChatHistoryModel <.. CustomerHistoryVectorRecordStore + + ChatHistoryModel <.. SemanticTextMemory + IVectorRecordStore <.. SemanticTextMemory + IVectorCollectionCreate <.. SemanticTextMemory + + ISemanticTextMemory <.. ChatHistoryPlugin +``` + +### Vector Store Cross Store support - General Features + +A comparison of the different ways in which stores implement storage capabilities to help drive decisions: + +|Feature|Azure AI Search|Weaviate|Redis|Chroma|FAISS|Pinecone|LLamaIndex|PostgreSql|Qdrant|Milvus| +|-|-|-|-|-|-|-|-|-|-|-| +|Get Item Support|Y|Y|Y|Y||Y||Y|Y|Y| +|Batch Operation Support|Y|Y|Y|Y||Y||||Y| +|Per Item Results for Batch Operations|Y|Y|Y|N||N||||| +|Keys of upserted records|Y|Y|N3|N3||N3||||Y| +|Keys of removed records|Y||N3|N||N||||N3| +|Retrieval field selection for gets|Y||Y4|P2||N||Y|Y|Y| +|Include/Exclude Embeddings for gets|P1|Y|Y4,1|Y||N||P1|Y|N| +|Failure reasons when batch partially fails|Y|Y|Y|N||N||||| +|Is Key separate from data|N|Y|Y|Y||Y||N|Y|N| +|Can Generate Ids|N|Y|N|N||Y||Y|N|Y| +|Can Generate Embedding|Not Available Via API yet|Y|N|Client Side Abstraction|||||N|| + +Footnotes: +- P = Partial Support +- 1 Only if you have the schema, to select the appropriate fields. +- 2 Supports broad categories of fields only. +- 3 Id is required in request, so can be returned if needed. +- 4 No strong typed support when specifying field list. + +### Vector Store Cross Store support - Fields, types and indexing + +|Feature|Azure AI Search|Weaviate|Redis|Chroma|FAISS|Pinecone|LLamaIndex|PostgreSql|Qdrant|Milvus| +|-|-|-|-|-|-|-|-|-|-|-| +|Field Differentiation|Fields|Key, Props, Vectors|Key, Fields|Key, Document, Metadata, Vector||Key, Metadata, SparseValues, Vector||Fields|Key, Props(Payload), Vectors|Fields| +|Multiple Vector per record support|Y|Y|Y|N||[N](https://docs.pinecone.io/guides/data/upsert-data#upsert-records-with-metadata)||Y|Y|Y| +|Index to Collection|1 to 1|1 to 1|1 to many|1 to 1|-|1 to 1|-|1 to 1|1 to 1|1 to 1| +|Id Type|String|UUID|string with collection name prefix|string||string|UUID|64Bit Int / UUID / ULID|64Bit Unsigned Int / UUID|Int64 / varchar| +|Supported Vector Types|[Collection(Edm.Byte) / Collection(Edm.Single) / Collection(Edm.Half) / Collection(Edm.Int16) / Collection(Edm.SByte)](https://learn.microsoft.com/en-us/rest/api/searchservice/supported-data-types)|float32|FLOAT32 and FLOAT64|||[Rust f32](https://docs.pinecone.io/troubleshooting/embedding-values-changed-when-upserted)||[single-precision (4 byte float) / half-precision (2 byte float) / binary (1bit) / sparse vectors (4 bytes)](https://github.com/pgvector/pgvector?tab=readme-ov-file#pgvector)|UInt8 / Float32|Binary / Float32 / Float16 / BFloat16 / SparseFloat| +|Supported Distance Functions|[Cosine / dot prod / euclidean dist (l2 norm)](https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#similarity-metrics-used-to-measure-nearness)|[Cosine dist / dot prod / Squared L2 dist / hamming (num of diffs) / manhattan dist](https://weaviate.io/developers/weaviate/config-refs/distances#available-distance-metrics)|[Euclidean dist (L2) / Inner prod (IP) / Cosine dist](https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/vectors/)|[Squared L2 / Inner prod / Cosine similarity](https://docs.trychroma.com/guides#changing-the-distance-function)||[cosine sim / euclidean dist / dot prod](https://docs.pinecone.io/reference/api/control-plane/create_index)||[L2 dist / inner prod / cosine dist / L1 dist / Hamming dist / Jaccard dist (NB: Specified at query time, not index creation time)](https://github.com/pgvector/pgvector?tab=readme-ov-file#pgvector)|[Dot prod / Cosine sim / Euclidean dist (L2) / Manhattan dist](https://qdrant.tech/documentation/concepts/search/)|[Cosine sim / Euclidean dist / Inner Prod](https://milvus.io/docs/index-vector-fields.md)| +|Supported index types|[Exhaustive KNN (FLAT) / HNSW](https://learn.microsoft.com/en-us/azure/search/vector-search-ranking#algorithms-used-in-vector-search)|[HNSW / Flat / Dynamic](https://weaviate.io/developers/weaviate/config-refs/schema/vector-index)|[HNSW / FLAT](https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/vectors/#create-a-vector-field)|[HNSW not configurable](https://cookbook.chromadb.dev/core/concepts/#vector-index-hnsw-index)||[PGA](https://www.pinecone.io/blog/hnsw-not-enough/)||[HNSW / IVFFlat](https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing)|[HNSW for dense](https://qdrant.tech/documentation/concepts/indexing/#vector-index)|

[In Memory: FLAT / IVF_FLAT / IVF_SQ8 / IVF_PQ / HNSW / SCANN](https://milvus.io/docs/index.md)

[On Disk: DiskANN](https://milvus.io/docs/disk_index.md)

[GPU: GPU_CAGRA / GPU_IVF_FLAT / GPU_IVF_PQ / GPU_BRUTE_FORCE](https://milvus.io/docs/gpu_index.md)

| + +Footnotes: +- HNSW = Hierarchical Navigable Small World (HNSW performs an [approximate nearest neighbor (ANN)](https://learn.microsoft.com/en-us/azure/search/vector-search-overview#approximate-nearest-neighbors) search) +- KNN = k-nearest neighbors (performs a brute-force search that scans the entire vector space) +- IVFFlat = Inverted File with Flat Compression (This index type uses approximate nearest neighbor search (ANNS) to provide fast searches) +- Weaviate Dynamic = Starts as flat and switches to HNSW if the number of objects exceed a limit +- PGA = [Pinecone Graph Algorithm](https://www.pinecone.io/blog/hnsw-not-enough/) + +### Vector Store Cross Store support - Search and filtering + +|Feature|Azure AI Search|Weaviate|Redis|Chroma|FAISS|Pinecone|LLamaIndex|PostgreSql|Qdrant|Milvus| +|-|-|-|-|-|-|-|-|-|-|-| +|Index allows text search|Y|Y|Y|Y (On Metadata by default)||[Only in combination with Vector](https://docs.pinecone.io/guides/data/understanding-hybrid-search)||Y (with TSVECTOR field)|Y|Y| +|Text search query format|[Simple or Full Lucene](https://learn.microsoft.com/en-us/azure/search/search-query-create?tabs=portal-text-query#choose-a-query-type-simple--full)|[wildcard](https://weaviate.io/developers/weaviate/search/filters#filter-text-on-partial-matches)|wildcard & fuzzy|[contains & not contains](https://docs.trychroma.com/guides#filtering-by-document-contents)||Text only||[wildcard & binary operators](https://www.postgresql.org/docs/16/textsearch-controls.html#TEXTSEARCH-PARSING-QUERIES)|[Text only](https://qdrant.tech/documentation/concepts/filtering/#full-text-match)|[wildcard](https://milvus.io/docs/single-vector-search.md#Filtered-search)| +|Multi Field Vector Search Support|Y|[N](https://weaviate.io/developers/weaviate/search/similarity)||N (no multi vector support)||N||[Unclear due to order by syntax](https://github.com/pgvector/pgvector?tab=readme-ov-file#querying)|[N](https://qdrant.tech/documentation/concepts/search/)|[Y](https://milvus.io/api-reference/restful/v2.4.x/v2/Vector%20(v2)/Hybrid%20Search.md)| +|Targeted Multi Field Text Search Support|Y|[Y](https://weaviate.io/developers/weaviate/search/hybrid#set-weights-on-property-values)|[Y](https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/query_syntax/#field-modifiers)|N (only on document)||N||Y|Y|Y| +|Vector per Vector Field for Search|Y|N/A||N/A|||N/A||N/A|N/A|[Y](https://milvus.io/docs/multi-vector-search.md#Step-1-Create-Multiple-AnnSearchRequest-Instances)| +|Separate text search query from vectors|Y|[Y](https://weaviate.io/developers/weaviate/search/hybrid#specify-a-search-vector)|Y|Y||Y||Y|Y|[Y](https://milvus.io/api-reference/restful/v2.4.x/v2/Vector%20(v2)/Hybrid%20Search.md)| +|Allows filtering|Y|Y|Y (on TAG)|Y (On Metadata by default)||[Y](https://docs.pinecone.io/guides/indexes/configure-pod-based-indexes#selective-metadata-indexing)||Y|Y|Y| +|Allows filter grouping|Y (Odata)|[Y](https://weaviate.io/developers/weaviate/search/filters#nested-filters)||[Y](https://docs.trychroma.com/guides#using-logical-operators)||Y||Y|[Y](https://qdrant.tech/documentation/concepts/filtering/#clauses-combination)|[Y](https://milvus.io/docs/get-and-scalar-query.md#Use-Basic-Operators)| +|Allows scalar index field setup|Y|Y|Y|N||Y||Y|Y|Y| +|Requires scalar index field setup to filter|Y|Y|Y|N||N (on by default for all)||N|N|N (can filter without index)| + +### Support for different mappers + +Mapping between data models and the storage models can also require custom logic depending on the type of data model and storage model involved. + +I'm therefore proposing that we allow mappers to be injectable for each `VectorStoreCollection` instance. The interfaces for these would vary depending +on the storage models used by each vector store and any unique capabilities that each vector store may have, e.g. qdrant can operate in `single` or +`multiple named vector` modes, which means the mapper needs to know whether to set a single vector or fill a vector map. + +In addition to this, we should build first party mappers for each of the vector stores, which will cater for built in, generic models or use metadata to perform the mapping. + +### Support for different storage schemas + +The different stores vary in many ways around how data is organized. +- Some just store a record with fields on it, where fields can be a key or a data field or a vector and their type is determined at collection creation time. +- Others separate fields by type when interacting with the api, e.g. you have to specify a key explicitly, put metadata into a metadata dictionary and put vectors into a vector array. + +I'm proposing that we allow two ways in which to provide the information required to map data between the consumer data model and storage data model. +First is a set of configuration objects that capture the types of each field. Second would be a set of attributes that can be used to decorate the model itself +and can be converted to the configuration objects, allowing a single execution path. +Additional configuration properties can easily be added for each type of field as required, e.g. IsFilterable or IsFullTextSearchable, allowing us to also create an index from the provided configuration. + +I'm also proposing that even though similar attributes already exist in other systems, e.g. System.ComponentModel.DataAnnotations.KeyAttribute, we create our own. +We will likely require additional properties on all these attributes that are not currently supported on the existing attributes, e.g. whether a field is or +should be filterable. Requiring users to switch to new attributes later will be disruptive. + +Here is what the attributes would look like, plus a sample use case. + +```cs +sealed class VectorStoreRecordKeyAttribute : Attribute +{ +} +sealed class VectorStoreRecordDataAttribute : Attribute +{ + public bool HasEmbedding { get; set; } + public string EmbeddingPropertyName { get; set; } +} +sealed class VectorStoreRecordVectorAttribute : Attribute +{ +} + +public record HotelInfo( + [property: VectorStoreRecordKey, JsonPropertyName("hotel-id")] string HotelId, + [property: VectorStoreRecordData, JsonPropertyName("hotel-name")] string HotelName, + [property: VectorStoreRecordData(HasEmbedding = true, EmbeddingPropertyName = "DescriptionEmbeddings"), JsonPropertyName("description")] string Description, + [property: VectorStoreRecordVector, JsonPropertyName("description-embeddings")] ReadOnlyMemory? DescriptionEmbeddings); +``` + +Here is what the configuration objects would look like. + +```cs +abstract class VectorStoreRecordProperty(string propertyName); + +sealed class VectorStoreRecordKeyProperty(string propertyName): Field(propertyName) +{ +} +sealed class VectorStoreRecordDataProperty(string propertyName): Field(propertyName) +{ + bool HasEmbedding; + string EmbeddingPropertyName; +} +sealed class VectorStoreRecordVectorProperty(string propertyName): Field(propertyName) +{ +} + +sealed class VectorStoreRecordDefinition +{ + IReadOnlyList Properties; +} +``` + +### Notable method signature changes from existing interface + +All methods currently existing on IMemoryStore will be ported to new interfaces, but in places I am proposing that we make changes to improve +consistency and scalability. + +1. `RemoveAsync` and `RemoveBatchAsync` renamed to `DeleteAsync` and `DeleteBatchAsync`, since record are actually deleted, and this also matches the verb used for collections. +2. `GetCollectionsAsync` renamed to `GetCollectionNamesAsync`, since we are only retrieving names and no other information about collections. +3. `DoesCollectionExistAsync` renamed to `CollectionExistsAsync` since this is shorter and is more commonly used in other apis. + +### Comparison with other AI frameworks + +|Criteria|Current SK Implementation|Proposed SK Implementation|Spring AI|LlamaIndex|Langchain| +|-|-|-|-|-|-| +|Support for Custom Schemas|N|Y|N|N|N| +|Naming of store|MemoryStore|VectorStore, VectorStoreCollection|VectorStore|VectorStore|VectorStore| +|MultiVector support|N|Y|N|N|N| +|Support Multiple Collections via SDK params|Y|Y|N (via app config)|Y|Y| + +## Decision Drivers + +From GitHub Issue: +- API surface must be easy to use and intuitive +- Alignment with other patterns in the SK +- - Design must allow Memory Plugins to be easily instantiated with any connector +- Design must support all Kernel content types +- Design must allow for database specific configuration +- All NFR's to be production ready are implemented (see Roadmap for more detail) +- Basic CRUD operations must be supported so that connectors can be used in a polymorphic manner +- Official Database Clients must be used where available +- Dynamic database schema must be supported +- Dependency injection must be supported +- Azure-ML YAML format must be supported +- Breaking glass scenarios must be supported + +## Considered Questions + +1. Combined collection and record management vs separated. +2. Collection name and key value normalization in decorator or main class. +3. Collection name as method param or constructor param. +4. How to normalize ids across different vector stores where different types are supported. +5. Store Interface/Class Naming + +### Question 1: Combined collection and record management vs separated. + +#### Option 1 - Combined collection and record management + +```cs +interface IVectorRecordStore +{ + Task CreateCollectionAsync(CollectionCreateConfig collectionConfig, CancellationToken cancellationToken = default); + IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default); + Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default); + Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default); + + Task UpsertAsync(TRecord data, CancellationToken cancellationToken = default); + IAsyncEnumerable UpsertBatchAsync(IEnumerable dataSet, CancellationToken cancellationToken = default); + Task GetAsync(string key, bool withEmbedding = false, CancellationToken cancellationToken = default); + IAsyncEnumerable GetBatchAsync(IEnumerable keys, bool withVectors = false, CancellationToken cancellationToken = default); + Task DeleteAsync(string key, CancellationToken cancellationToken = default); + Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default); +} + +class AzureAISearchVectorRecordStore( + Azure.Search.Documents.Indexes.SearchIndexClient client, + Schema schema): IVectorRecordStore; + +class WeaviateVectorRecordStore( + WeaviateClient client, + Schema schema): IVectorRecordStore; + +class RedisVectorRecordStore( + StackExchange.Redis.IDatabase database, + Schema schema): IVectorRecordStore; +``` + +#### Option 2 - Separated collection and record management with opinionated create implementations + +```cs + +interface IVectorCollectionStore +{ + virtual Task CreateChatHistoryCollectionAsync(string name, CancellationToken cancellationToken = default); + virtual Task CreateSemanticCacheCollectionAsync(string name, CancellationToken cancellationToken = default); + + IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default); + Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default); + Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default); +} + +class AzureAISearchVectorCollectionStore: IVectorCollectionStore; +class RedisVectorCollectionStore: IVectorCollectionStore; +class WeaviateVectorCollectionStore: IVectorCollectionStore; + +// Customers can inherit from our implementations and replace just the creation scenarios to match their schemas. +class CustomerCollectionStore: AzureAISearchVectorCollectionStore, IVectorCollectionStore; + +// We can also create implementations that create indices based on an MLIndex specification. +class MLIndexAzureAISearchVectorCollectionStore(MLIndex mlIndexSpec): AzureAISearchVectorCollectionStore, IVectorCollectionStore; + +interface IVectorRecordStore +{ + Task GetAsync(string key, GetRecordOptions? options = default, CancellationToken cancellationToken = default); + Task DeleteAsync(string key, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default); + Task UpsertAsync(TRecord record, UpsertRecordOptions? options = default, CancellationToken cancellationToken = default); +} + +class AzureAISearchVectorRecordStore(): IVectorRecordStore; +``` + +#### Option 3 - Separated collection and record management with collection create separate from other operations. + +Vector store same as option 2 so not repeated for brevity. + +```cs + +interface IVectorCollectionCreate +{ + virtual Task CreateCollectionAsync(string name, CancellationToken cancellationToken = default); +} + +// Implement a generic version of create that takes a configuration that should work for 80% of cases. +class AzureAISearchConfiguredVectorCollectionCreate(CollectionCreateConfig collectionConfig): IVectorCollectionCreate; + +// Allow custom implementations of create for break glass scenarios for outside the 80% case. +class AzureAISearchChatHistoryVectorCollectionCreate: IVectorCollectionCreate; +class AzureAISearchSemanticCacheVectorCollectionCreate: IVectorCollectionCreate; + +// Customers can create their own creation scenarios to match their schemas, but can continue to use our get, does exist and delete class. +class CustomerChatHistoryVectorCollectionCreate: IVectorCollectionCreate; + +interface IVectorCollectionNonSchema +{ + IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default); + Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default); + Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default); +} + +class AzureAISearchVectorCollectionNonSchema: IVectorCollectionNonSchema; +class RedisVectorCollectionNonSchema: IVectorCollectionNonSchema; +class WeaviateVectorCollectionNonSchema: IVectorCollectionNonSchema; + +``` + +#### Option 4 - Separated collection and record management with collection create separate from other operations, with collection management aggregation class on top. + +Variation on option 3. + +```cs + +interface IVectorCollectionCreate +{ + virtual Task CreateCollectionAsync(string name, CancellationToken cancellationToken = default); +} + +interface IVectorCollectionNonSchema +{ + IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default); + Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default); + Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default); +} + +// DB Specific NonSchema implementations +class AzureAISearchVectorCollectionNonSchema: IVectorCollectionNonSchema; +class RedisVectorCollectionNonSchema: IVectorCollectionNonSchema; + +// Combined Create + NonSchema Interface +interface IVectorCollectionStore: IVectorCollectionCreate, IVectorCollectionNonSchema {} + +// Base abstract class that forwards non-create operations to provided implementation. +abstract class VectorCollectionStore(IVectorCollectionNonSchema collectionNonSchema): IVectorCollectionStore +{ + public abstract Task CreateCollectionAsync(string name, CancellationToken cancellationToken = default); + public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) { return collectionNonSchema.ListCollectionNamesAsync(cancellationToken); } + public Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) { return collectionNonSchema.CollectionExistsAsync(name, cancellationToken); } + public Task DeleteCollectionAsync(string name, CancellationToken cancellationToken = default) { return collectionNonSchema.DeleteCollectionAsync(name, cancellationToken); } +} + +// Collections store implementations, that inherit from base class, and just adds the different creation implementations. +class AzureAISearchChatHistoryVectorCollectionStore(AzureAISearchVectorCollectionNonSchema nonSchema): VectorCollectionStore(nonSchema); +class AzureAISearchSemanticCacheVectorCollectionStore(AzureAISearchVectorCollectionNonSchema nonSchema): VectorCollectionStore(nonSchema); +class AzureAISearchMLIndexVectorCollectionStore(AzureAISearchVectorCollectionNonSchema nonSchema): VectorCollectionStore(nonSchema); + +// Customer collections store implementation, that uses the base Azure AI Search implementation for get, doesExist and delete, but adds its own creation. +class ContosoProductsVectorCollectionStore(AzureAISearchVectorCollectionNonSchema nonSchema): VectorCollectionStore(nonSchema); + +``` + +#### Option 5 - Separated collection and record management with collection create separate from other operations, with overall aggregation class on top. + +Same as option 3 / 4, plus: + +```cs + +interface IVectorStore : IVectorCollectionStore, IVectorRecordStore +{ +} + +// Create a static factory that produces one of these, so only the interface is public, not the class. +internal class VectorStore(IVectorCollectionCreate create, IVectorCollectionNonSchema nonSchema, IVectorRecordStore records): IVectorStore +{ +} + +``` + +#### Option 6 - Collection store acts as factory for record store. + +`IVectorStore` acts as a factory for `IVectorStoreCollection`, and any schema agnostic multi-collection operations are kept on `IVectorStore`. + + +```cs +public interface IVectorStore +{ + IVectorStoreCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null); + IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default)); +} + +public interface IVectorStoreCollection +{ + public string Name { get; } + + // Collection Operations + Task CreateCollectionAsync(); + Task CreateCollectionIfNotExistsAsync(); + Task CollectionExistsAsync(); + Task DeleteCollectionAsync(); + + // Data manipulation + Task GetAsync(TKey key, GetRecordOptions? options = default, CancellationToken cancellationToken = default); + IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, CancellationToken cancellationToken = default); + Task DeleteAsync(TKey key, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default); + Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default); + Task UpsertAsync(TRecord record, UpsertRecordOptions? options = default, CancellationToken cancellationToken = default); + IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = default, CancellationToken cancellationToken = default); +} +``` + + +#### Decision Outcome + +Option 1 is problematic on its own, since we have to allow consumers to create custom implementations of collection create for break glass scenarios. With +a single interface like this, it will require them to implement many methods that they do not want to change. Options 4 & 5, gives us more flexibility while +still preserving the ease of use of an aggregated interface as described in Option 1. + +Option 2 doesn't give us the flexbility we need for break glass scenarios, since it only allows certain types of collections to be created. It also means +that each time a new collection type is required it introduces a breaking change, so it is not a viable option. + +Since collection create and configuration and the possible options vary considerable across different database types, we will need to support an easy +to use break glass scenario for collection creation. While we would be able to develop a basic configurable create option, for complex create scenarios +users will need to implement their own. We will also need to support multiple create implementations out of the box, e.g. a configuration based option using +our own configuration, create implementations that re-create the current model for backward compatibility, create implementations that use other configuration +as input, e.g. Azure-ML YAML. Therefore separating create, which may have many implementations, from exists, list and delete, which requires only a single implementation per database type is useful. +Option 3 provides us this separation, but Option 4 + 5 builds on top of this, and allows us to combine different implementations together for simpler +consumption. + +Chosen option: 6 + +- Easy to use, and similar to many SDk implementations. +- Can pass a single object around for both collection and record access. + +### Question 2: Collection name and key value normalization in store, decorator or via injection. + +#### Option 1 - Normalization in main record store + +- Pros: Simple +- Cons: The normalization needs to vary separately from the record store, so this will not work + +```cs + public class AzureAISearchVectorStoreCollection : IVectorStoreCollection + { + ... + + // On input. + var normalizedCollectionName = this.NormalizeCollectionName(collectionName); + var encodedId = AzureAISearchMemoryRecord.EncodeId(key); + + ... + + // On output. + DecodeId(this.Id) + + ... + } +``` + +#### Option 2 - Normalization in decorator + +- Pros: Allows normalization to vary separately from the record store. +- Pros: No code executed when no normalization required. +- Pros: Easy to package matching encoders/decoders together. +- Pros: Easier to obsolete encoding/normalization as a concept. +- Cons: Not a major con, but need to implement the full VectorStoreCollection interface, instead of e.g. just providing the two translation functions, if we go with option 3. +- Cons: Hard to have a generic implementation that can work with any model, without either changing the data in the provided object on upsert or doing cloning in an expensive way. + +```cs + new KeyNormalizingAISearchVectorStoreCollection( + "keyField", + new AzureAISearchVectorStoreCollection(...)); +``` + +#### Option 3 - Normalization via optional function parameters to record store constructor + +- Pros: Allows normalization to vary separately from the record store. +- Pros: No need to implement the full VectorStoreCollection interface. +- Pros: Can modify values on serialization without changing the incoming record, if supported by DB SDK. +- Cons: Harder to package matching encoders/decoders together. + +```cs +public class AzureAISearchVectorStoreCollection(StoreOptions options); + +public class StoreOptions +{ + public Func? EncodeKey { get; init; } + public Func? DecodeKey { get; init; } + public Func? SanitizeCollectionName { get; init; } +} +``` + +#### Option 4 - Normalization via custom mapper + +If developer wants to change any values they can do so by creating a custom mapper. + +- Cons: Developer needs to implement a mapper if they want to do normalization. +- Cons: Developer cannot change collection name as part of the mapping. +- Pros: No new extension points required to support normalization. +- Pros: Developer can change any field in the record. + +#### Decision Outcome + +Chosen option 3, since it is similar to how we are doing mapper injection and would also work well in python. + +Option 1 won't work because if e.g. the data was written using another tool, it may be unlikely that it was encoded using the same mechanism as supported here +and therefore this functionality may not be appropriate. The developer should have the ability to not use this functionality or +provide their own encoding / decoding behavior. + +### Question 3: Collection name as method param or via constructor or either + +#### Option 1 - Collection name as method param + +```cs +public class MyVectorStoreCollection() +{ + public async Task GetAsync(string collectionName, string key, GetRecordOptions? options = default, CancellationToken cancellationToken = default); +} +``` + +#### Option 2 - Collection name via constructor + +```cs +public class MyVectorStoreCollection(string defaultCollectionName) +{ + public async Task GetAsync(string key, GetRecordOptions? options = default, CancellationToken cancellationToken = default); +} +``` + +#### Option 3 - Collection name via either + +```cs +public class MyVectorStoreCollection(string defaultCollectionName) +{ + public async Task GetAsync(string key, GetRecordOptions? options = default, CancellationToken cancellationToken = default); +} + +public class GetRecordOptions +{ + public string CollectionName { get; init; }; +} +``` + +#### Decision Outcome + +Chosen option 2. None of the other options work with the decision outcome of Question 1, since that design requires the `VectorStoreCollection` to be tied to a single collection instance. + +### Question 4: How to normalize ids across different vector stores where different types are supported. + +#### Option 1 - Take a string and convert to a type that was specified on the constructor + +```cs +public async Task GetAsync(string key, GetRecordOptions? options = default, CancellationToken cancellationToken = default) +{ + var convertedKey = this.keyType switch + { + KeyType.Int => int.parse(key), + KeyType.GUID => Guid.parse(key) + } + + ... +} +``` + +- No additional overloads are required over time so no breaking changes. +- Most data types can easily be represented in string form and converted to/from it. + +#### Option 2 - Take an object and cast to a type that was specified on the constructor. + +```cs +public async Task GetAsync(object key, GetRecordOptions? options = default, CancellationToken cancellationToken = default) +{ + var convertedKey = this.keyType switch + { + KeyType.Int => key as int, + KeyType.GUID => key as Guid + } + + if (convertedKey is null) + { + throw new InvalidOperationException($"The provided key must be of type {this.keyType}") + } + + ... +} + +``` + +- No additional overloads are required over time so no breaking changes. +- Any data types can be represented as object. + +#### Option 3 - Multiple overloads where we convert where possible, throw when not possible. + +```cs +public async Task GetAsync(string key, GetRecordOptions? options = default, CancellationToken cancellationToken = default) +{ + var convertedKey = this.keyType switch + { + KeyType.Int => int.Parse(key), + KeyType.String => key, + KeyType.GUID => Guid.Parse(key) + } +} +public async Task GetAsync(int key, GetRecordOptions? options = default, CancellationToken cancellationToken = default) +{ + var convertedKey = this.keyType switch + { + KeyType.Int => key, + KeyType.String => key.ToString(), + KeyType.GUID => throw new InvalidOperationException($"The provided key must be convertible to a GUID.") + } +} +public async Task GetAsync(GUID key, GetRecordOptions? options = default, CancellationToken cancellationToken = default) +{ + var convertedKey = this.keyType switch + { + KeyType.Int => throw new InvalidOperationException($"The provided key must be convertible to an int.") + KeyType.String => key.ToString(), + KeyType.GUID => key + } +} +``` + +- Additional overloads are required over time if new key types are found on new connectors, causing breaking changes. +- You can still call a method that causes a runtime error, when the type isn't supported. + +#### Option 4 - Add key type as generic to interface + +```cs +interface IVectorRecordStore +{ + Task GetAsync(TKey key, GetRecordOptions? options = default, CancellationToken cancellationToken = default); +} + +class AzureAISearchVectorRecordStore: IVectorRecordStore +{ + public AzureAISearchVectorRecordStore() + { + // Check if TKey matches the type of the field marked as a key on TRecord and throw if they don't match. + // Also check if keytype is one of the allowed types for Azure AI Search and throw if it isn't. + } +} + +``` + +- No runtime issues after construction. +- More cumbersome interface. + +#### Decision Outcome + +Chosen option 4, since it is forwards compatible with any complex key types we may need to support but still allows +each implementation to hardcode allowed key types if the vector db only supports certain key types. + +### Question 5: Store Interface/Class Naming. + +#### Option 1 - VectorDB + +```cs +interface IVectorDBRecordService {} +interface IVectorDBCollectionUpdateService {} +interface IVectorDBCollectionCreateService {} +``` + +#### Option 2 - Memory + +```cs +interface IMemoryRecordService {} +interface IMemoryCollectionUpdateService {} +interface IMemoryCollectionCreateService {} +``` + +### Option 3 - VectorStore + +```cs +interface IVectorRecordStore {} +interface IVectorCollectionNonSchema {} +interface IVectorCollectionCreate {} +interface IVectorCollectionStore {}: IVectorCollectionCreate, IVectorCollectionNonSchema +interface IVectorStore {}: IVectorCollectionStore, IVectorRecordStore +``` + +### Option 4 - VectorStore + VectorStoreCollection + +```cs +interface IVectorStore +{ + IVectorStoreCollection GetCollection() +} +interface IVectorStoreCollection +{ + Get() + Delete() + Upsert() +} +``` + +#### Decision Outcome + +Chosen option 4. The word memory is broad enough to encompass any data, so using it seems arbitrary. All competitors are using the term vector store, so using something similar is good for recognition. +Option 4 also matches our design as chosen in question 1. + +## Usage Examples + +### DI Framework: .net 8 Keyed Services + +```cs +class CacheEntryModel(string prompt, string result, ReadOnlyMemory promptEmbedding); + +class SemanticTextMemory(IVectorStore configuredVectorStore, VectorStoreRecordDefinition? vectorStoreRecordDefinition): ISemanticTextMemory +{ + public async Task SaveInformation(string collectionName, TDataType record) + { + var collection = vectorStore.GetCollection(collectionName, vectorStoreRecordDefinition); + if (!await collection.CollectionExists()) + { + await collection.CreateCollection(); + } + await collection.UpsertAsync(record); + } +} + +class CacheSetFunctionFilter(ISemanticTextMemory memory); // Saves results to cache. +class CacheGetPromptFilter(ISemanticTextMemory memory); // Check cache for entries. + +var builder = Kernel.CreateBuilder(); + +builder + // Existing registration: + .AddAzureOpenAITextEmbeddingGeneration(textEmbeddingDeploymentName, azureAIEndpoint, apiKey, serviceId: "AzureOpenAI:text-embedding-ada-002") + + // Register an IVectorStore implementation under the given key. + .AddAzureAISearch("Cache", azureAISearchEndpoint, apiKey, new Options() { withEmbeddingGeneration = true }); + +// Add Semantic Cache Memory for the cache entry model. +builder.Services.AddTransient(sp => { + return new SemanticTextMemory( + sp.GetKeyedService("Cache"), + cacheRecordDefinition); +}); + +// Add filter to retrieve items from cache and one to add items to cache. +// Since these filters depend on ISemanticTextMemory and that is already registered, it should get matched automatically. +builder.Services.AddTransient(); +builder.Services.AddTransient(); +``` + +## Roadmap + +### Record Management + +1. Release VectorStoreCollection public interface and implementations for Azure AI Search, Qdrant and Redis. +2. Add support for registering record stores with SK container to allow automatic dependency injection. +3. Add VectorStoreCollection implementations for remaining stores. + +### Collection Management + +4. Release Collection Management public interface and implementations for Azure AI Search, Qdrant and Redis. +5. Add support for registering collection management with SK container to allow automatic dependency injection. +6. Add Collection Management implementations for remaining stores. + +### Collection Creation + +7. Release Collection Creation public interface. +8. Create cross db collection creation config that supports common functionality, and per daatabase implementation that supports this configuration. +9. Add support for registering collection creation with SK container to allow automatic dependency injection. + +### First Party Memory Features and well known model support + +10. Add model and mappers for legacy SK MemoryStore interface, so that consumers using this has an upgrade path to the new memory storage stack. +11. Add model and mappers for popular loader systems, like Kernel Memory or LlamaIndex. +11. Explore adding first party implementations for common scenarios, e.g. semantic caching. Specfics TBD. + +### Cross Cutting Requirements + +Need the following for all features: + +- Unit tests +- Integration tests +- Logging / Telemetry +- Common Exception Handling +- Samples, including: + - Usage scenario for collection and record management using custom model and configured collection creation. + - A simple consumption example like semantic caching, specfics TBD. + - Adding your own collection creation implementation. + - Adding your own custom model mapper. +- Documentation, including: + - How to create models and annotate/describe them to use with the storage system. + - How to define configuration for creating collections using common create implementation. + - How to use record and collection management apis. + - How to implement your own collection create implementation for break glass scenario. + - How to implement your own mapper. + - How to upgrade from the current storage system to the new one. diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 645c8a249d2a..b59aa7714c51 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -30,6 +30,7 @@ + @@ -68,6 +69,7 @@ + @@ -92,6 +94,7 @@ + diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index dd43184b6612..89cc2c897d61 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -14,6 +14,7 @@ + diff --git a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs new file mode 100644 index 000000000000..ea498f20c5ab --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet; +using Docker.DotNet.Models; + +namespace Memory.VectorStoreFixtures; + +/// +/// Helper class that creates and deletes containers for the vector store examples. +/// +internal static class VectorStoreInfra +{ + /// + /// Setup the qdrant container by pulling the image and running it. + /// + /// The docker client to create the container with. + /// The id of the container. + public static async Task SetupQdrantContainerAsync(DockerClient client) + { + await client.Images.CreateImageAsync( + new ImagesCreateParameters + { + FromImage = "qdrant/qdrant", + Tag = "latest", + }, + null, + new Progress()); + + var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() + { + Image = "qdrant/qdrant", + HostConfig = new HostConfig() + { + PortBindings = new Dictionary> + { + {"6333", new List {new() {HostPort = "6333" } }}, + {"6334", new List {new() {HostPort = "6334" } }} + }, + PublishAllPorts = true + }, + ExposedPorts = new Dictionary + { + { "6333", default }, + { "6334", default } + }, + }); + + await client.Containers.StartContainerAsync( + container.ID, + new ContainerStartParameters()); + + return container.ID; + } + + /// + /// Setup the redis container by pulling the image and running it. + /// + /// The docker client to create the container with. + /// The id of the container. + public static async Task SetupRedisContainerAsync(DockerClient client) + { + await client.Images.CreateImageAsync( + new ImagesCreateParameters + { + FromImage = "redis/redis-stack", + Tag = "latest", + }, + null, + new Progress()); + + var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() + { + Image = "redis/redis-stack", + HostConfig = new HostConfig() + { + PortBindings = new Dictionary> + { + {"6379", new List {new() {HostPort = "6379"}}}, + {"8001", new List {new() {HostPort = "8001"}}} + }, + PublishAllPorts = true + }, + ExposedPorts = new Dictionary + { + { "6379", default }, + { "8001", default } + }, + }); + + await client.Containers.StartContainerAsync( + container.ID, + new ContainerStartParameters()); + + return container.ID; + } + + /// + /// Stop and delete the container with the specified id. + /// + /// The docker client to delete the container in. + /// The id of the container to delete. + /// An async task. + public static async Task DeleteContainerAsync(DockerClient client, string containerId) + { + await client.Containers.StopContainerAsync(containerId, new ContainerStopParameters()); + await client.Containers.RemoveContainerAsync(containerId, new ContainerRemoveParameters()); + } +} diff --git a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreQdrantContainerFixture.cs b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreQdrantContainerFixture.cs new file mode 100644 index 000000000000..820b5d3bf172 --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreQdrantContainerFixture.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet; +using Qdrant.Client; + +namespace Memory.VectorStoreFixtures; + +/// +/// Fixture to use for creating a Qdrant container before tests and delete it after tests. +/// +public class VectorStoreQdrantContainerFixture : IAsyncLifetime +{ + private DockerClient? _dockerClient; + private string? _qdrantContainerId; + + public async Task InitializeAsync() + { + } + + public async Task ManualInitializeAsync() + { + if (this._qdrantContainerId == null) + { + // Connect to docker and start the docker container. + using var dockerClientConfiguration = new DockerClientConfiguration(); + this._dockerClient = dockerClientConfiguration.CreateClient(); + this._qdrantContainerId = await VectorStoreInfra.SetupQdrantContainerAsync(this._dockerClient); + + // Delay until the Qdrant server is ready. + var qdrantClient = new QdrantClient("localhost"); + var succeeded = false; + var attemptCount = 0; + while (!succeeded && attemptCount++ < 10) + { + try + { + await qdrantClient.ListCollectionsAsync(); + succeeded = true; + } + catch (Exception) + { + await Task.Delay(1000); + } + } + } + } + + public async Task DisposeAsync() + { + if (this._dockerClient != null && this._qdrantContainerId != null) + { + // Delete docker container. + await VectorStoreInfra.DeleteContainerAsync(this._dockerClient, this._qdrantContainerId); + } + } +} diff --git a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreRedisContainerFixture.cs b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreRedisContainerFixture.cs new file mode 100644 index 000000000000..eb35b7ff555f --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreRedisContainerFixture.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet; + +namespace Memory.VectorStoreFixtures; + +/// +/// Fixture to use for creating a Redis container before tests and delete it after tests. +/// +public class VectorStoreRedisContainerFixture : IAsyncLifetime +{ + private DockerClient? _dockerClient; + private string? _redisContainerId; + + public async Task InitializeAsync() + { + } + + public async Task ManualInitializeAsync() + { + if (this._redisContainerId == null) + { + // Connect to docker and start the docker container. + using var dockerClientConfiguration = new DockerClientConfiguration(); + this._dockerClient = dockerClientConfiguration.CreateClient(); + this._redisContainerId = await VectorStoreInfra.SetupRedisContainerAsync(this._dockerClient); + } + } + + public async Task DisposeAsync() + { + if (this._dockerClient != null && this._redisContainerId != null) + { + // Delete docker container. + await VectorStoreInfra.DeleteContainerAsync(this._dockerClient, this._redisContainerId); + } + } +} diff --git a/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_CustomMapper.cs b/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_CustomMapper.cs new file mode 100644 index 000000000000..db8e259f4e7a --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_CustomMapper.cs @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Memory.VectorStoreFixtures; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using Microsoft.SemanticKernel.Connectors.Redis; +using Microsoft.SemanticKernel.Data; +using Microsoft.SemanticKernel.Embeddings; +using StackExchange.Redis; + +namespace Memory; + +/// +/// An example showing how to ingest data into a vector store using with a custom mapper. +/// In this example, the storage model differs significantly from the data model, so a custom mapper is used to map between the two. +/// A is used to define the schema of the storage model, and this means that the connector +/// will not try and infer the schema from the data model. +/// In storage the data is stored as a JSON object that looks similar to this: +/// +/// { +/// "Term": "API", +/// "Definition": "Application Programming Interface. A set of rules and specifications that allow software components to communicate and exchange data.", +/// "DefinitionEmbedding": [ ... ] +/// } +/// +/// However, the data model is a class with a property for key and two dictionaries for the data (Term and Definition) and vector (DefinitionEmbedding). +/// +/// The example shows the following steps: +/// 1. Create an embedding generator. +/// 2. Create a Redis Vector Store using a custom factory for creating collections. +/// When constructing a collection, the factory injects a custom mapper that maps between the data model and the storage model if required. +/// 3. Ingest some data into the vector store. +/// 4. Read the data back from the vector store. +/// +/// You need a local instance of Docker running, since the associated fixture will try and start a Redis container in the local docker instance to run against. +/// +public class VectorStore_DataIngestion_CustomMapper(ITestOutputHelper output, VectorStoreRedisContainerFixture redisFixture) : BaseTest(output), IClassFixture +{ + /// + /// A record definition for the glossary entries that defines the storage schema of the record. + /// + private static readonly VectorStoreRecordDefinition s_glossaryDefinition = new() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("Term", typeof(string)), + new VectorStoreRecordDataProperty("Definition", typeof(string)), + new VectorStoreRecordVectorProperty("DefinitionEmbedding", typeof(ReadOnlyMemory)) { Dimensions = 1536, DistanceFunction = DistanceFunction.DotProductSimilarity } + } + }; + + [Fact] + public async Task ExampleAsync() + { + // Create an embedding generation service. + var textEmbeddingGenerationService = new AzureOpenAITextEmbeddingGenerationService( + TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + TestConfiguration.AzureOpenAIEmbeddings.ApiKey); + + // Initiate the docker container and construct the vector store using the custom factory for creating collections. + await redisFixture.ManualInitializeAsync(); + ConnectionMultiplexer redis = ConnectionMultiplexer.Connect("localhost:6379"); + var vectorStore = new RedisVectorStore(redis.GetDatabase(), new() { VectorStoreCollectionFactory = new Factory() }); + + // Get and create collection if it doesn't exist, using the record definition containing the storage model. + var collection = vectorStore.GetCollection("skglossary", s_glossaryDefinition); + await collection.CreateCollectionIfNotExistsAsync(); + + // Create glossary entries and generate embeddings for them. + var glossaryEntries = CreateGlossaryEntries().ToList(); + var tasks = glossaryEntries.Select(entry => Task.Run(async () => + { + entry.Vectors["DefinitionEmbedding"] = await textEmbeddingGenerationService.GenerateEmbeddingAsync((string)entry.Data["Definition"]); + })); + await Task.WhenAll(tasks); + + // Upsert the glossary entries into the collection and return their keys. + var upsertedKeysTasks = glossaryEntries.Select(x => collection.UpsertAsync(x)); + var upsertedKeys = await Task.WhenAll(upsertedKeysTasks); + + // Retrieve one of the upserted records from the collection. + var upsertedRecord = await collection.GetAsync(upsertedKeys.First(), new() { IncludeVectors = true }); + + // Write upserted keys and one of the upserted records to the console. + Console.WriteLine($"Upserted keys: {string.Join(", ", upsertedKeys)}"); + Console.WriteLine($"Upserted record: {JsonSerializer.Serialize(upsertedRecord)}"); + } + + /// + /// A custom mapper that maps between the data model and the storage model. + /// + private sealed class Mapper : IVectorStoreRecordMapper + { + public (string Key, JsonNode Node) MapFromDataToStorageModel(GenericDataModel dataModel) + { + var jsonObject = new JsonObject(); + + jsonObject.Add("Term", dataModel.Data["Term"].ToString()); + jsonObject.Add("Definition", dataModel.Data["Definition"].ToString()); + + var vector = (ReadOnlyMemory)dataModel.Vectors["DefinitionEmbedding"]; + var jsonArray = new JsonArray(vector.ToArray().Select(x => JsonValue.Create(x)).ToArray()); + jsonObject.Add("DefinitionEmbedding", jsonArray); + + return (dataModel.Key, jsonObject); + } + + public GenericDataModel MapFromStorageToDataModel((string Key, JsonNode Node) storageModel, StorageToDataModelMapperOptions options) + { + var dataModel = new GenericDataModel + { + Key = storageModel.Key, + Data = new Dictionary + { + { "Term", (string)storageModel.Node["Term"]! }, + { "Definition", (string)storageModel.Node["Definition"]! } + }, + Vectors = new Dictionary + { + { "DefinitionEmbedding", new ReadOnlyMemory(storageModel.Node["DefinitionEmbedding"]!.AsArray().Select(x => (float)x!).ToArray()) } + } + }; + + return dataModel; + } + } + + /// + /// A factory for creating collections in the vector store + /// + private sealed class Factory : IRedisVectorStoreRecordCollectionFactory + { + public IVectorStoreRecordCollection CreateVectorStoreRecordCollection(IDatabase database, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + where TKey : notnull + where TRecord : class + { + // If the record definition is the glossary definition and the record type is the generic data model, inject the custom mapper into the collection options. + if (vectorStoreRecordDefinition == s_glossaryDefinition && typeof(TRecord) == typeof(GenericDataModel)) + { + var customCollection = new RedisJsonVectorStoreRecordCollection(database, name, new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition, JsonNodeCustomMapper = new Mapper() }) as IVectorStoreRecordCollection; + return customCollection!; + } + + // Otherwise, just create a standard collection with the default mapper. + var collection = new RedisJsonVectorStoreRecordCollection(database, name, new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; + return collection!; + } + } + + /// + /// Sample generic data model class that can store any data. + /// + private sealed class GenericDataModel + { + public string Key { get; set; } + + public Dictionary Data { get; set; } + + public Dictionary Vectors { get; set; } + } + + /// + /// Create some sample glossary entries using the generic data model. + /// + /// A list of sample glossary entries. + private static IEnumerable CreateGlossaryEntries() + { + yield return new GenericDataModel + { + Key = "1", + Data = new() + { + { "Term", "API" }, + { "Definition", "Application Programming Interface. A set of rules and specifications that allow software components to communicate and exchange data." } + }, + Vectors = new() + }; + + yield return new GenericDataModel + { + Key = "2", + Data = new() + { + { "Term", "Connectors" }, + { "Definition", "Connectors allow you to integrate with various services provide AI capabilities, including LLM, AudioToText, TextToAudio, Embedding generation, etc." } + }, + Vectors = new() + }; + + yield return new GenericDataModel + { + Key = "3", + Data = new() + { + { "Term", "RAG" }, + { "Definition", "Retrieval Augmented Generation - a term that refers to the process of retrieving additional data to provide as context to an LLM to use when generating a response (completion) to a user’s question (prompt)." } + }, + Vectors = new() + }; + } +} diff --git a/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_MultiStore.cs b/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_MultiStore.cs new file mode 100644 index 000000000000..18f0e5b476ca --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_MultiStore.cs @@ -0,0 +1,256 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Memory.VectorStoreFixtures; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using Microsoft.SemanticKernel.Connectors.Qdrant; +using Microsoft.SemanticKernel.Connectors.Redis; +using Microsoft.SemanticKernel.Data; +using Microsoft.SemanticKernel.Embeddings; +using Qdrant.Client; +using StackExchange.Redis; + +namespace Memory; + +/// +/// An example showing how to ingest data into a vector store using , or . +/// Since Redis and Volatile supports string keys and Qdrant supports ulong or Guid keys, this example also shows how you can have common code +/// that works with both types of keys by using a generic key generator function. +/// +/// The example shows the following steps: +/// 1. Register a vector store and embedding generator with the DI container. +/// 2. Register a class (DataIngestor) with the DI container that uses the vector store and embedding generator to ingest data. +/// 3. Ingest some data into the vector store. +/// 4. Read the data back from the vector store. +/// +/// For some databases in this sample (Redis & Qdrant), you need a local instance of Docker running, since the associated fixtures will try and start containers in the local docker instance to run against. +/// +[Collection("Sequential")] +public class VectorStore_DataIngestion_MultiStore(ITestOutputHelper output, VectorStoreRedisContainerFixture redisFixture, VectorStoreQdrantContainerFixture qdrantFixture) : BaseTest(output), IClassFixture, IClassFixture +{ + /// + /// Example with dependency injection. + /// + /// The type of database to run the example for. + [Theory] + [InlineData("Redis")] + [InlineData("Qdrant")] + [InlineData("Volatile")] + public async Task ExampleWithDIAsync(string databaseType) + { + // Use the kernel for DI purposes. + var kernelBuilder = Kernel + .CreateBuilder(); + + // Register an embedding generation service with the DI container. + kernelBuilder.AddAzureOpenAITextEmbeddingGeneration( + deploymentName: TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + endpoint: TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + apiKey: TestConfiguration.AzureOpenAIEmbeddings.ApiKey); + + // Register the chosen vector store with the DI container and initialize docker containers via the fixtures where needed. + if (databaseType == "Redis") + { + await redisFixture.ManualInitializeAsync(); + kernelBuilder.AddRedisVectorStore("localhost:6379"); + } + else if (databaseType == "Qdrant") + { + await qdrantFixture.ManualInitializeAsync(); + kernelBuilder.AddQdrantVectorStore("localhost"); + } + else if (databaseType == "Volatile") + { + kernelBuilder.AddVolatileVectorStore(); + } + + // Register the DataIngestor with the DI container. + kernelBuilder.Services.AddTransient(); + + // Build the kernel. + var kernel = kernelBuilder.Build(); + + // Build a DataIngestor object using the DI container. + var dataIngestor = kernel.GetRequiredService(); + + // Invoke the data ingestor using an appropriate key generator function for each database type. + // Redis and Volatile supports string keys, while Qdrant supports ulong or Guid keys, so we use a different key generator for each key type. + if (databaseType == "Redis" || databaseType == "Volatile") + { + await this.UpsertDataAndReadFromVectorStoreAsync(dataIngestor, () => Guid.NewGuid().ToString()); + } + else if (databaseType == "Qdrant") + { + await this.UpsertDataAndReadFromVectorStoreAsync(dataIngestor, () => Guid.NewGuid()); + } + } + + /// + /// Example without dependency injection. + /// + /// The type of database to run the example for. + [Theory] + [InlineData("Redis")] + [InlineData("Qdrant")] + [InlineData("Volatile")] + public async Task ExampleWithoutDIAsync(string databaseType) + { + // Create an embedding generation service. + var textEmbeddingGenerationService = new AzureOpenAITextEmbeddingGenerationService( + TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + TestConfiguration.AzureOpenAIEmbeddings.ApiKey); + + // Construct the chosen vector store and initialize docker containers via the fixtures where needed. + IVectorStore vectorStore; + if (databaseType == "Redis") + { + await redisFixture.ManualInitializeAsync(); + var database = ConnectionMultiplexer.Connect("localhost:6379").GetDatabase(); + vectorStore = new RedisVectorStore(database); + } + else if (databaseType == "Qdrant") + { + await qdrantFixture.ManualInitializeAsync(); + var qdrantClient = new QdrantClient("localhost"); + vectorStore = new QdrantVectorStore(qdrantClient); + } + else if (databaseType == "Volatile") + { + vectorStore = new VolatileVectorStore(); + } + else + { + throw new ArgumentException("Invalid database type."); + } + + // Create the DataIngestor. + var dataIngestor = new DataIngestor(vectorStore, textEmbeddingGenerationService); + + // Invoke the data ingestor using an appropriate key generator function for each database type. + // Redis and Volatile supports string keys, while Qdrant supports ulong or Guid keys, so we use a different key generator for each key type. + if (databaseType == "Redis" || databaseType == "Volatile") + { + await this.UpsertDataAndReadFromVectorStoreAsync(dataIngestor, () => Guid.NewGuid().ToString()); + } + else if (databaseType == "Qdrant") + { + await this.UpsertDataAndReadFromVectorStoreAsync(dataIngestor, () => Guid.NewGuid()); + } + } + + private async Task UpsertDataAndReadFromVectorStoreAsync(DataIngestor dataIngestor, Func uniqueKeyGenerator) + where TKey : notnull + { + // Ingest some data into the vector store. + var upsertedKeys = await dataIngestor.ImportDataAsync(uniqueKeyGenerator); + + // Get one of the upserted records. + var upsertedRecord = await dataIngestor.GetGlossaryAsync(upsertedKeys.First()); + + // Write upserted keys and one of the upserted records to the console. + Console.WriteLine($"Upserted keys: {string.Join(", ", upsertedKeys)}"); + Console.WriteLine($"Upserted record: {JsonSerializer.Serialize(upsertedRecord)}"); + } + + /// + /// Sample class that does ingestion of sample data into a vector store and allows retrieval of data from the vector store. + /// + /// The vector store to ingest data into. + /// Used to generate embeddings for the data being ingested. + private sealed class DataIngestor(IVectorStore vectorStore, ITextEmbeddingGenerationService textEmbeddingGenerationService) + { + /// + /// Create some glossary entries and upsert them into the vector store. + /// + /// The keys of the upserted glossary entries. + /// The type of the keys in the vector store. + public async Task> ImportDataAsync(Func uniqueKeyGenerator) + where TKey : notnull + { + // Get and create collection if it doesn't exist. + var collection = vectorStore.GetCollection>("skglossary"); + await collection.CreateCollectionIfNotExistsAsync(); + + // Create glossary entries and generate embeddings for them. + var glossaryEntries = CreateGlossaryEntries(uniqueKeyGenerator).ToList(); + var tasks = glossaryEntries.Select(entry => Task.Run(async () => + { + entry.DefinitionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(entry.Definition); + })); + await Task.WhenAll(tasks); + + // Upsert the glossary entries into the collection and return their keys. + var upsertedKeys = glossaryEntries.Select(x => collection.UpsertAsync(x)); + return await Task.WhenAll(upsertedKeys); + } + + /// + /// Get a glossary entry from the vector store. + /// + /// The key of the glossary entry to retrieve. + /// The glossary entry. + /// The type of the keys in the vector store. + public Task?> GetGlossaryAsync(TKey key) + where TKey : notnull + { + var collection = vectorStore.GetCollection>("skglossary"); + return collection.GetAsync(key, new() { IncludeVectors = true }); + } + } + + /// + /// Create some sample glossary entries. + /// + /// The type of the model key. + /// A function that can be used to generate unique keys for the model in the type that the model requires. + /// A list of sample glossary entries. + private static IEnumerable> CreateGlossaryEntries(Func uniqueKeyGenerator) + { + yield return new Glossary + { + Key = uniqueKeyGenerator(), + Term = "API", + Definition = "Application Programming Interface. A set of rules and specifications that allow software components to communicate and exchange data." + }; + + yield return new Glossary + { + Key = uniqueKeyGenerator(), + Term = "Connectors", + Definition = "Connectors allow you to integrate with various services provide AI capabilities, including LLM, AudioToText, TextToAudio, Embedding generation, etc." + }; + + yield return new Glossary + { + Key = uniqueKeyGenerator(), + Term = "RAG", + Definition = "Retrieval Augmented Generation - a term that refers to the process of retrieving additional data to provide as context to an LLM to use when generating a response (completion) to a user’s question (prompt)." + }; + } + + /// + /// Sample model class that represents a glossary entry. + /// + /// + /// Note that each property is decorated with an attribute that specifies how the property should be treated by the vector store. + /// This allows us to create a collection in the vector store and upsert and retrieve instances of this class without any further configuration. + /// + /// The type of the model key. + private sealed class Glossary + { + [VectorStoreRecordKey] + public TKey Key { get; set; } + + [VectorStoreRecordData] + public string Term { get; set; } + + [VectorStoreRecordData] + public string Definition { get; set; } + + [VectorStoreRecordVector(1536)] + public ReadOnlyMemory DefinitionEmbedding { get; set; } + } +} diff --git a/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_Simple.cs b/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_Simple.cs new file mode 100644 index 000000000000..341e5c2bbda2 --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_Simple.cs @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Memory.VectorStoreFixtures; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using Microsoft.SemanticKernel.Connectors.Qdrant; +using Microsoft.SemanticKernel.Data; +using Microsoft.SemanticKernel.Embeddings; +using Qdrant.Client; + +namespace Memory; + +/// +/// A simple example showing how to ingest data into a vector store using . +/// +/// The example shows the following steps: +/// 1. Create an embedding generator. +/// 2. Create a Qdrant Vector Store. +/// 3. Ingest some data into the vector store. +/// 4. Read the data back from the vector store. +/// +/// You need a local instance of Docker running, since the associated fixture will try and start a Qdrant container in the local docker instance to run against. +/// +[Collection("Sequential")] +public class VectorStore_DataIngestion_Simple(ITestOutputHelper output, VectorStoreQdrantContainerFixture qdrantFixture) : BaseTest(output), IClassFixture +{ + [Fact] + public async Task ExampleAsync() + { + // Create an embedding generation service. + var textEmbeddingGenerationService = new AzureOpenAITextEmbeddingGenerationService( + TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + TestConfiguration.AzureOpenAIEmbeddings.ApiKey); + + // Initiate the docker container and construct the vector store. + await qdrantFixture.ManualInitializeAsync(); + var vectorStore = new QdrantVectorStore(new QdrantClient("localhost")); + + // Get and create collection if it doesn't exist. + var collection = vectorStore.GetCollection("skglossary"); + await collection.CreateCollectionIfNotExistsAsync(); + + // Create glossary entries and generate embeddings for them. + var glossaryEntries = CreateGlossaryEntries().ToList(); + var tasks = glossaryEntries.Select(entry => Task.Run(async () => + { + entry.DefinitionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(entry.Definition); + })); + await Task.WhenAll(tasks); + + // Upsert the glossary entries into the collection and return their keys. + var upsertedKeysTasks = glossaryEntries.Select(x => collection.UpsertAsync(x)); + var upsertedKeys = await Task.WhenAll(upsertedKeysTasks); + + // Retrieve one of the upserted records from the collection. + var upsertedRecord = await collection.GetAsync(upsertedKeys.First(), new() { IncludeVectors = true }); + + // Write upserted keys and one of the upserted records to the console. + Console.WriteLine($"Upserted keys: {string.Join(", ", upsertedKeys)}"); + Console.WriteLine($"Upserted record: {JsonSerializer.Serialize(upsertedRecord)}"); + } + + /// + /// Sample model class that represents a glossary entry. + /// + /// + /// Note that each property is decorated with an attribute that specifies how the property should be treated by the vector store. + /// This allows us to create a collection in the vector store and upsert and retrieve instances of this class without any further configuration. + /// + private sealed class Glossary + { + [VectorStoreRecordKey] + public ulong Key { get; set; } + + [VectorStoreRecordData] + public string Term { get; set; } + + [VectorStoreRecordData] + public string Definition { get; set; } + + [VectorStoreRecordVector(1536)] + public ReadOnlyMemory DefinitionEmbedding { get; set; } + } + + /// + /// Create some sample glossary entries. + /// + /// A list of sample glossary entries. + private static IEnumerable CreateGlossaryEntries() + { + yield return new Glossary + { + Key = 1, + Term = "API", + Definition = "Application Programming Interface. A set of rules and specifications that allow software components to communicate and exchange data." + }; + + yield return new Glossary + { + Key = 2, + Term = "Connectors", + Definition = "Connectors allow you to integrate with various services provide AI capabilities, including LLM, AudioToText, TextToAudio, Embedding generation, etc." + }; + + yield return new Glossary + { + Key = 3, + Term = "RAG", + Definition = "Retrieval Augmented Generation - a term that refers to the process of retrieving additional data to provide as context to an LLM to use when generating a response (completion) to a user’s question (prompt)." + }; + } +} diff --git a/dotnet/samples/Concepts/README.md b/dotnet/samples/Concepts/README.md index 77427c605193..26eef28982a7 100644 --- a/dotnet/samples/Concepts/README.md +++ b/dotnet/samples/Concepts/README.md @@ -104,6 +104,9 @@ Down below you can find the code snippets that demonstrate the usage of many Sem - [TextMemoryPlugin_GeminiEmbeddingGeneration](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/TextMemoryPlugin_GeminiEmbeddingGeneration.cs) - [TextMemoryPlugin_MultipleMemoryStore](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) - [TextMemoryPlugin_RecallJsonSerializationWithOptions](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/TextMemoryPlugin_RecallJsonSerializationWithOptions.cs) +- [VectorStore_DataIngestion_Simple: A simple example of how to do data ingestion into a vector store when getting started.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_Simple.cs) +- [VectorStore_DataIngestion_MultiStore: An example of data ingestion that uses the same code to ingest into multiple vector stores types.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_MultiStore.cs) +- [VectorStore_DataIngestion_CustomMapper: An example that shows how to use a custom mapper for when your data model and storage model doesn't match.](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Memory/VectorStore_DataIngestion_CustomMapper.cs) ## Optimization - Examples of different cost and performance optimization techniques diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchKernelBuilderExtensionsTests.cs new file mode 100644 index 000000000000..740c3898ce03 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchKernelBuilderExtensionsTests.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Azure; +using Azure.Core; +using Azure.Search.Documents.Indexes; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.AzureAISearch; +using Microsoft.SemanticKernel.Data; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; + +/// +/// Tests for the class. +/// +public class AzureAISearchKernelBuilderExtensionsTests +{ + private readonly IKernelBuilder _kernelBuilder; + + public AzureAISearchKernelBuilderExtensionsTests() + { + this._kernelBuilder = Kernel.CreateBuilder(); + } + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange. + this._kernelBuilder.Services.AddSingleton(Mock.Of()); + + // Act. + this._kernelBuilder.AddAzureAISearchVectorStore(); + + // Assert. + this.AssertVectorStoreCreated(); + } + + [Fact] + public void AddVectorStoreWithUriAndCredsRegistersClass() + { + // Act. + this._kernelBuilder.AddAzureAISearchVectorStore(new Uri("https://localhost"), new AzureKeyCredential("fakeKey")); + + // Assert. + this.AssertVectorStoreCreated(); + } + + [Fact] + public void AddVectorStoreWithUriAndTokenCredsRegistersClass() + { + // Act. + this._kernelBuilder.AddAzureAISearchVectorStore(new Uri("https://localhost"), Mock.Of()); + + // Assert. + this.AssertVectorStoreCreated(); + } + + private void AssertVectorStoreCreated() + { + var kernel = this._kernelBuilder.Build(); + var vectorStore = kernel.Services.GetRequiredService(); + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..f2446ea7a809 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchServiceCollectionExtensionsTests.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Azure; +using Azure.Core; +using Azure.Search.Documents.Indexes; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Connectors.AzureAISearch; +using Microsoft.SemanticKernel.Data; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; + +/// +/// Tests for the class. +/// +public class AzureAISearchServiceCollectionExtensionsTests +{ + private readonly IServiceCollection _serviceCollection; + + public AzureAISearchServiceCollectionExtensionsTests() + { + this._serviceCollection = new ServiceCollection(); + } + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange. + this._serviceCollection.AddSingleton(Mock.Of()); + + // Act. + this._serviceCollection.AddAzureAISearchVectorStore(); + + // Assert. + this.AssertVectorStoreCreated(); + } + + [Fact] + public void AddVectorStoreWithUriAndCredsRegistersClass() + { + // Act. + this._serviceCollection.AddAzureAISearchVectorStore(new Uri("https://localhost"), new AzureKeyCredential("fakeKey")); + + // Assert. + this.AssertVectorStoreCreated(); + } + + [Fact] + public void AddVectorStoreWithUriAndTokenCredsRegistersClass() + { + // Act. + this._serviceCollection.AddAzureAISearchVectorStore(new Uri("https://localhost"), Mock.Of()); + + // Assert. + this.AssertVectorStoreCreated(); + } + + private void AssertVectorStoreCreated() + { + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + var vectorStore = serviceProvider.GetRequiredService(); + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionCreateMappingTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionCreateMappingTests.cs new file mode 100644 index 000000000000..075880775324 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreCollectionCreateMappingTests.cs @@ -0,0 +1,210 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Azure.Search.Documents.Indexes.Models; +using Microsoft.SemanticKernel.Connectors.AzureAISearch; +using Microsoft.SemanticKernel.Data; +using Xunit; + +namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; + +/// +/// Contains tests for the class. +/// +public class AzureAISearchVectorStoreCollectionCreateMappingTests +{ + [Fact] + public void MapKeyFieldCreatesSearchableField() + { + // Arrange + var keyProperty = new VectorStoreRecordKeyProperty("testkey", typeof(string)); + var storagePropertyName = "test_key"; + + // Act + var result = AzureAISearchVectorStoreCollectionCreateMapping.MapKeyField(keyProperty, storagePropertyName); + + // Assert + Assert.NotNull(result); + Assert.Equal(storagePropertyName, result.Name); + Assert.True(result.IsKey); + Assert.True(result.IsFilterable); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFilterableStringDataFieldCreatesSimpleField(bool isFilterable) + { + // Arrange + var dataProperty = new VectorStoreRecordDataProperty("testdata", typeof(string)) { IsFilterable = isFilterable }; + var storagePropertyName = "test_data"; + + // Act + var result = AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty, storagePropertyName); + + // Assert + Assert.NotNull(result); + Assert.IsType(result); + Assert.Equal(storagePropertyName, result.Name); + Assert.False(result.IsKey); + Assert.Equal(isFilterable, result.IsFilterable); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFullTextSearchableStringDataFieldCreatesSearchableField(bool isFilterable) + { + // Arrange + var dataProperty = new VectorStoreRecordDataProperty("testdata", typeof(string)) { IsFilterable = isFilterable, IsFullTextSearchable = true }; + var storagePropertyName = "test_data"; + + // Act + var result = AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty, storagePropertyName); + + // Assert + Assert.NotNull(result); + Assert.IsType(result); + Assert.Equal(storagePropertyName, result.Name); + Assert.False(result.IsKey); + Assert.Equal(isFilterable, result.IsFilterable); + } + + [Fact] + public void MapFullTextSearchableStringDataFieldThrowsForInvalidType() + { + // Arrange + var dataProperty = new VectorStoreRecordDataProperty("testdata", typeof(int)) { IsFullTextSearchable = true }; + var storagePropertyName = "test_data"; + + // Act & Assert + Assert.Throws(() => AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty, storagePropertyName)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapDataFieldCreatesSimpleField(bool isFilterable) + { + // Arrange + var dataProperty = new VectorStoreRecordDataProperty("testdata", typeof(int)) { IsFilterable = isFilterable }; + var storagePropertyName = "test_data"; + + // Act + var result = AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty, storagePropertyName); + + // Assert + Assert.NotNull(result); + Assert.IsType(result); + Assert.Equal(storagePropertyName, result.Name); + Assert.Equal(SearchFieldDataType.Int32, result.Type); + Assert.False(result.IsKey); + Assert.Equal(isFilterable, result.IsFilterable); + } + + [Fact] + public void MapVectorFieldCreatesVectorSearchField() + { + // Arrange + var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 10, IndexKind = IndexKind.Flat, DistanceFunction = DistanceFunction.DotProductSimilarity }; + var storagePropertyName = "test_vector"; + + // Act + var (vectorSearchField, algorithmConfiguration, vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty, storagePropertyName); + + // Assert + Assert.NotNull(vectorSearchField); + Assert.NotNull(algorithmConfiguration); + Assert.NotNull(vectorSearchProfile); + Assert.Equal(storagePropertyName, vectorSearchField.Name); + Assert.Equal(vectorProperty.Dimensions, vectorSearchField.VectorSearchDimensions); + + Assert.Equal("test_vectorAlgoConfig", algorithmConfiguration.Name); + Assert.IsType(algorithmConfiguration); + var flatConfig = algorithmConfiguration as ExhaustiveKnnAlgorithmConfiguration; + Assert.Equal(VectorSearchAlgorithmMetric.DotProduct, flatConfig!.Parameters.Metric); + + Assert.Equal("test_vectorProfile", vectorSearchProfile.Name); + Assert.Equal("test_vectorAlgoConfig", vectorSearchProfile.AlgorithmConfigurationName); + } + + [Theory] + [InlineData(IndexKind.Hnsw, typeof(HnswAlgorithmConfiguration))] + [InlineData(IndexKind.Flat, typeof(ExhaustiveKnnAlgorithmConfiguration))] + public void MapVectorFieldCreatesExpectedAlgoConfigTypes(string indexKind, Type algoConfigType) + { + // Arrange + var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 10, IndexKind = indexKind, DistanceFunction = DistanceFunction.DotProductSimilarity }; + var storagePropertyName = "test_vector"; + + // Act + var (vectorSearchField, algorithmConfiguration, vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty, storagePropertyName); + + // Assert + Assert.Equal("test_vectorAlgoConfig", algorithmConfiguration.Name); + Assert.Equal(algoConfigType, algorithmConfiguration.GetType()); + } + + [Fact] + public void MapVectorFieldDefaultsToHsnwAndCosine() + { + // Arrange + var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 10 }; + var storagePropertyName = "test_vector"; + + // Act + var (vectorSearchField, algorithmConfiguration, vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty, storagePropertyName); + + // Assert + Assert.IsType(algorithmConfiguration); + var hnswConfig = algorithmConfiguration as HnswAlgorithmConfiguration; + Assert.Equal(VectorSearchAlgorithmMetric.Cosine, hnswConfig!.Parameters.Metric); + } + + [Fact] + public void MapVectorFieldThrowsForUnsupportedDistanceFunction() + { + // Arrange + var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 10, DistanceFunction = DistanceFunction.ManhattanDistance }; + var storagePropertyName = "test_vector"; + + // Act & Assert + Assert.Throws(() => AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty, storagePropertyName)); + } + + [Fact] + public void MapVectorFieldThrowsForMissingDimensionsCount() + { + // Arrange + var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)); + var storagePropertyName = "test_vector"; + + // Act & Assert + Assert.Throws(() => AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField(vectorProperty, storagePropertyName)); + } + + [Theory] + [MemberData(nameof(DataTypeMappingOptions))] + public void GetSDKFieldDataTypeMapsTypesCorrectly(Type propertyType, SearchFieldDataType searchFieldDataType) + { + // Act & Assert + Assert.Equal(searchFieldDataType, AzureAISearchVectorStoreCollectionCreateMapping.GetSDKFieldDataType(propertyType)); + } + + public static IEnumerable DataTypeMappingOptions() + { + yield return new object[] { typeof(string), SearchFieldDataType.String }; + yield return new object[] { typeof(bool), SearchFieldDataType.Boolean }; + yield return new object[] { typeof(int), SearchFieldDataType.Int32 }; + yield return new object[] { typeof(long), SearchFieldDataType.Int64 }; + yield return new object[] { typeof(float), SearchFieldDataType.Double }; + yield return new object[] { typeof(double), SearchFieldDataType.Double }; + yield return new object[] { typeof(DateTime), SearchFieldDataType.DateTimeOffset }; + yield return new object[] { typeof(DateTimeOffset), SearchFieldDataType.DateTimeOffset }; + + yield return new object[] { typeof(string[]), SearchFieldDataType.Collection(SearchFieldDataType.String) }; + yield return new object[] { typeof(IEnumerable), SearchFieldDataType.Collection(SearchFieldDataType.String) }; + yield return new object[] { typeof(List), SearchFieldDataType.Collection(SearchFieldDataType.String) }; + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..c303613248f0 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs @@ -0,0 +1,618 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using Azure; +using Azure.Search.Documents; +using Azure.Search.Documents.Indexes; +using Azure.Search.Documents.Indexes.Models; +using Azure.Search.Documents.Models; +using Microsoft.SemanticKernel.Connectors.AzureAISearch; +using Microsoft.SemanticKernel.Data; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; + +/// +/// Contains tests for the class. +/// +public class AzureAISearchVectorStoreRecordCollectionTests +{ + private const string TestCollectionName = "testcollection"; + private const string TestRecordKey1 = "testid1"; + private const string TestRecordKey2 = "testid2"; + + private readonly Mock _searchIndexClientMock; + private readonly Mock _searchClientMock; + + private readonly CancellationToken _testCancellationToken = new(false); + + public AzureAISearchVectorStoreRecordCollectionTests() + { + this._searchClientMock = new Mock(MockBehavior.Strict); + this._searchIndexClientMock = new Mock(MockBehavior.Strict); + this._searchIndexClientMock.Setup(x => x.GetSearchClient(TestCollectionName)).Returns(this._searchClientMock.Object); + } + + [Theory] + [InlineData(TestCollectionName, true)] + [InlineData("nonexistentcollection", false)] + public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) + { + this._searchIndexClientMock.Setup(x => x.GetSearchClient(collectionName)).Returns(this._searchClientMock.Object); + + // Arrange. + if (expectedExists) + { + this._searchIndexClientMock + .Setup(x => x.GetIndexAsync(collectionName, this._testCancellationToken)) + .Returns(Task.FromResult?>(null)); + } + else + { + this._searchIndexClientMock + .Setup(x => x.GetIndexAsync(collectionName, this._testCancellationToken)) + .ThrowsAsync(new RequestFailedException(404, "Index not found")); + } + + var sut = new AzureAISearchVectorStoreRecordCollection(this._searchIndexClientMock.Object, collectionName); + + // Act. + var actual = await sut.CollectionExistsAsync(this._testCancellationToken); + + // Assert. + Assert.Equal(expectedExists, actual); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task CreateCollectionCallsSDKAsync(bool useDefinition, bool useCustomJsonSerializerOptions) + { + // Arrange. + this._searchIndexClientMock + .Setup(x => x.CreateIndexAsync(It.IsAny(), this._testCancellationToken)) + .ReturnsAsync(Response.FromValue(new SearchIndex(TestCollectionName), Mock.Of())); + + var sut = this.CreateRecordCollection(useDefinition, useCustomJsonSerializerOptions); + + // Act. + await sut.CreateCollectionAsync(); + + // Assert. + var expectedFieldNames = useCustomJsonSerializerOptions ? new[] { "key", "storage_data1", "data2", "storage_vector1", "vector2" } : new[] { "Key", "storage_data1", "Data2", "storage_vector1", "Vector2" }; + this._searchIndexClientMock + .Verify( + x => x.CreateIndexAsync( + It.Is(si => si.Fields.Count == 5 && si.Fields.Select(f => f.Name).SequenceEqual(expectedFieldNames) && si.Name == TestCollectionName && si.VectorSearch.Profiles.Count == 2 && si.VectorSearch.Algorithms.Count == 2), + this._testCancellationToken), + Times.Once); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task CreateCollectionIfNotExistsSDKAsync(bool useDefinition, bool expectedExists) + { + // Arrange. + if (expectedExists) + { + this._searchIndexClientMock + .Setup(x => x.GetIndexAsync(TestCollectionName, this._testCancellationToken)) + .Returns(Task.FromResult?>(null)); + } + else + { + this._searchIndexClientMock + .Setup(x => x.GetIndexAsync(TestCollectionName, this._testCancellationToken)) + .ThrowsAsync(new RequestFailedException(404, "Index not found")); + } + + this._searchIndexClientMock + .Setup(x => x.CreateIndexAsync(It.IsAny(), this._testCancellationToken)) + .ReturnsAsync(Response.FromValue(new SearchIndex(TestCollectionName), Mock.Of())); + + var sut = this.CreateRecordCollection(useDefinition); + + // Act. + await sut.CreateCollectionIfNotExistsAsync(); + + // Assert. + if (expectedExists) + { + this._searchIndexClientMock + .Verify( + x => x.CreateIndexAsync( + It.IsAny(), + this._testCancellationToken), + Times.Never); + } + else + { + this._searchIndexClientMock + .Verify( + x => x.CreateIndexAsync( + It.Is(si => si.Fields.Count == 5 && si.Name == TestCollectionName && si.VectorSearch.Profiles.Count == 2 && si.VectorSearch.Algorithms.Count == 2), + this._testCancellationToken), + Times.Once); + } + } + + [Fact] + public async Task CanDeleteCollectionAsync() + { + // Arrange. + this._searchIndexClientMock + .Setup(x => x.DeleteIndexAsync(TestCollectionName, this._testCancellationToken)) + .Returns(Task.FromResult(null)); + + var sut = this.CreateRecordCollection(false); + + // Act. + await sut.DeleteCollectionAsync(this._testCancellationToken); + + // Assert. + this._searchIndexClientMock.Verify(x => x.DeleteIndexAsync(TestCollectionName, this._testCancellationToken), Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanGetRecordWithVectorsAsync(bool useDefinition) + { + // Arrange. + this._searchClientMock.Setup( + x => x.GetDocumentAsync( + TestRecordKey1, + It.Is(x => !x.SelectedFields.Any()), + this._testCancellationToken)) + .ReturnsAsync(Response.FromValue(CreateModel(TestRecordKey1, true), Mock.Of())); + + var sut = this.CreateRecordCollection(useDefinition); + + // Act. + var actual = await sut.GetAsync( + TestRecordKey1, + new() { IncludeVectors = true }, + this._testCancellationToken); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(TestRecordKey1, actual.Key); + Assert.Equal("data 1", actual.Data1); + Assert.Equal("data 2", actual.Data2); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector1!.Value.ToArray()); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector2!.Value.ToArray()); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task CanGetRecordWithoutVectorsAsync(bool useDefinition, bool useCustomJsonSerializerOptions) + { + // Arrange. + var storageObject = JsonSerializer.SerializeToNode(CreateModel(TestRecordKey1, false))!.AsObject(); + + var expectedSelectFields = useCustomJsonSerializerOptions ? new[] { "storage_data1", "data2", "key" } : new[] { "storage_data1", "Data2", "Key" }; + this._searchClientMock.Setup( + x => x.GetDocumentAsync( + TestRecordKey1, + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(Response.FromValue(CreateModel(TestRecordKey1, true), Mock.Of())); + + var sut = this.CreateRecordCollection(useDefinition, useCustomJsonSerializerOptions); + + // Act. + var actual = await sut.GetAsync( + TestRecordKey1, + new() { IncludeVectors = false }, + this._testCancellationToken); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(TestRecordKey1, actual.Key); + Assert.Equal("data 1", actual.Data1); + Assert.Equal("data 2", actual.Data2); + + this._searchClientMock.Verify( + x => x.GetDocumentAsync( + TestRecordKey1, + It.Is(x => x.SelectedFields.SequenceEqual(expectedSelectFields)), + this._testCancellationToken), + Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition) + { + // Arrange. + this._searchClientMock.Setup( + x => x.GetDocumentAsync( + It.IsAny(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync((string id, GetDocumentOptions options, CancellationToken cancellationToken) => + { + return Response.FromValue(CreateModel(id, true), Mock.Of()); + }); + + var sut = this.CreateRecordCollection(useDefinition); + + // Act. + var actual = await sut.GetBatchAsync( + [TestRecordKey1, TestRecordKey2], + new() { IncludeVectors = true }, + this._testCancellationToken).ToListAsync(); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(2, actual.Count); + Assert.Equal(TestRecordKey1, actual[0].Key); + Assert.Equal(TestRecordKey2, actual[1].Key); + } + + [Fact] + public async Task CanGetRecordWithCustomMapperAsync() + { + // Arrange. + var storageObject = JsonSerializer.SerializeToNode(CreateModel(TestRecordKey1, true))!.AsObject(); + + // Arrange GetDocumentAsync mock returning JsonObject. + this._searchClientMock.Setup( + x => x.GetDocumentAsync( + TestRecordKey1, + It.Is(x => !x.SelectedFields.Any()), + this._testCancellationToken)) + .ReturnsAsync(Response.FromValue(storageObject, Mock.Of())); + + // Arrange mapper mock from JsonObject to data model. + var mapperMock = new Mock>(MockBehavior.Strict); + mapperMock.Setup( + x => x.MapFromStorageToDataModel( + storageObject, + It.Is(x => x.IncludeVectors))) + .Returns(CreateModel(TestRecordKey1, true)); + + // Arrange target with custom mapper. + var sut = new AzureAISearchVectorStoreRecordCollection( + this._searchIndexClientMock.Object, + TestCollectionName, + new() + { + JsonObjectCustomMapper = mapperMock.Object + }); + + // Act. + var actual = await sut.GetAsync(TestRecordKey1, new() { IncludeVectors = true }, this._testCancellationToken); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(TestRecordKey1, actual.Key); + Assert.Equal("data 1", actual.Data1); + Assert.Equal("data 2", actual.Data2); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector1!.Value.ToArray()); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector2!.Value.ToArray()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanDeleteRecordAsync(bool useDefinition) + { + // Arrange. +#pragma warning disable Moq1002 // Moq: No matching constructor + var indexDocumentsResultMock = new Mock(MockBehavior.Strict, new List()); +#pragma warning restore Moq1002 // Moq: No matching constructor + + this._searchClientMock.Setup( + x => x.DeleteDocumentsAsync( + It.IsAny(), + It.IsAny>(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(Response.FromValue(indexDocumentsResultMock.Object, Mock.Of())); + + var sut = this.CreateRecordCollection(useDefinition); + + // Act. + await sut.DeleteAsync( + TestRecordKey1, + cancellationToken: this._testCancellationToken); + + // Assert. + this._searchClientMock.Verify( + x => x.DeleteDocumentsAsync( + "Key", + It.Is>(x => x.Count() == 1 && x.Contains(TestRecordKey1)), + It.IsAny(), + this._testCancellationToken), + Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanDeleteManyRecordsWithVectorsAsync(bool useDefinition) + { + // Arrange. +#pragma warning disable Moq1002 // Moq: No matching constructor + var indexDocumentsResultMock = new Mock(MockBehavior.Strict, new List()); +#pragma warning restore Moq1002 // Moq: No matching constructor + + this._searchClientMock.Setup( + x => x.DeleteDocumentsAsync( + It.IsAny(), + It.IsAny>(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(Response.FromValue(indexDocumentsResultMock.Object, Mock.Of())); + + var sut = this.CreateRecordCollection(useDefinition); + + // Act. + await sut.DeleteBatchAsync( + [TestRecordKey1, TestRecordKey2], + cancellationToken: this._testCancellationToken); + + // Assert. + this._searchClientMock.Verify( + x => x.DeleteDocumentsAsync( + "Key", + It.Is>(x => x.Count() == 2 && x.Contains(TestRecordKey1) && x.Contains(TestRecordKey2)), + It.IsAny(), + this._testCancellationToken), + Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanUpsertRecordAsync(bool useDefinition) + { + // Arrange upload result object. +#pragma warning disable Moq1002 // Moq: No matching constructor + var indexingResult = new Mock(MockBehavior.Strict, TestRecordKey1, true, 200); + var indexingResults = new List(); + indexingResults.Add(indexingResult.Object); + var indexDocumentsResultMock = new Mock(MockBehavior.Strict, indexingResults); +#pragma warning restore Moq1002 // Moq: No matching constructor + + // Arrange upload. + this._searchClientMock.Setup( + x => x.UploadDocumentsAsync( + It.IsAny>(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(Response.FromValue(indexDocumentsResultMock.Object, Mock.Of())); + + // Arrange sut. + var sut = this.CreateRecordCollection(useDefinition); + + var model = CreateModel(TestRecordKey1, true); + + // Act. + var actual = await sut.UpsertAsync( + model, + cancellationToken: this._testCancellationToken); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(TestRecordKey1, actual); + this._searchClientMock.Verify( + x => x.UploadDocumentsAsync( + It.Is>(x => x.Count() == 1 && x.First().Key == TestRecordKey1), + It.Is(x => x.ThrowOnAnyError == true), + this._testCancellationToken), + Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanUpsertManyRecordsAsync(bool useDefinition) + { + // Arrange upload result object. +#pragma warning disable Moq1002 // Moq: No matching constructor + var indexingResult1 = new Mock(MockBehavior.Strict, TestRecordKey1, true, 200); + var indexingResult2 = new Mock(MockBehavior.Strict, TestRecordKey2, true, 200); + + var indexingResults = new List(); + indexingResults.Add(indexingResult1.Object); + indexingResults.Add(indexingResult2.Object); + var indexDocumentsResultMock = new Mock(MockBehavior.Strict, indexingResults); +#pragma warning restore Moq1002 // Moq: No matching constructor + + // Arrange upload. + this._searchClientMock.Setup( + x => x.UploadDocumentsAsync( + It.IsAny>(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(Response.FromValue(indexDocumentsResultMock.Object, Mock.Of())); + + // Arrange sut. + var sut = this.CreateRecordCollection(useDefinition); + + var model1 = CreateModel(TestRecordKey1, true); + var model2 = CreateModel(TestRecordKey2, true); + + // Act. + var actual = await sut.UpsertBatchAsync( + [model1, model2], + cancellationToken: this._testCancellationToken).ToListAsync(); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(2, actual.Count); + Assert.Equal(TestRecordKey1, actual[0]); + Assert.Equal(TestRecordKey2, actual[1]); + + this._searchClientMock.Verify( + x => x.UploadDocumentsAsync( + It.Is>(x => x.Count() == 2 && x.First().Key == TestRecordKey1 && x.ElementAt(1).Key == TestRecordKey2), + It.Is(x => x.ThrowOnAnyError == true), + this._testCancellationToken), + Times.Once); + } + + [Fact] + public async Task CanUpsertRecordWithCustomMapperAsync() + { + // Arrange. +#pragma warning disable Moq1002 // Moq: No matching constructor + var indexingResult = new Mock(MockBehavior.Strict, TestRecordKey1, true, 200); + var indexingResults = new List(); + indexingResults.Add(indexingResult.Object); + var indexDocumentsResultMock = new Mock(MockBehavior.Strict, indexingResults); +#pragma warning restore Moq1002 // Moq: No matching constructor + + var model = CreateModel(TestRecordKey1, true); + var storageObject = JsonSerializer.SerializeToNode(model)!.AsObject(); + + // Arrange UploadDocumentsAsync mock returning upsert result. + this._searchClientMock.Setup( + x => x.UploadDocumentsAsync( + It.IsAny>(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync((IEnumerable documents, IndexDocumentsOptions options, CancellationToken cancellationToken) => + { + // Need to force a materialization of the documents enumerable here, otherwise the mapper (and therefore its mock) doesn't get invoked. + var materializedDocuments = documents.ToList(); + return Response.FromValue(indexDocumentsResultMock.Object, Mock.Of()); + }); + + // Arrange mapper mock from data model to JsonObject. + var mapperMock = new Mock>(MockBehavior.Strict); + mapperMock + .Setup(x => x.MapFromDataToStorageModel(It.IsAny())) + .Returns(storageObject); + + // Arrange target with custom mapper. + var sut = new AzureAISearchVectorStoreRecordCollection( + this._searchIndexClientMock.Object, + TestCollectionName, + new() + { + JsonObjectCustomMapper = mapperMock.Object + }); + + // Act. + await sut.UpsertAsync( + model, + null, + this._testCancellationToken); + + // Assert. + mapperMock + .Verify( + x => x.MapFromDataToStorageModel(It.Is(x => x.Key == TestRecordKey1)), + Times.Once); + } + + /// + /// Tests that the collection can be created even if the definition and the type do not match. + /// In this case, the expectation is that a custom mapper will be provided to map between the + /// schema as defined by the definition and the different data model. + /// + [Fact] + public void CanCreateCollectionWithMismatchedDefinitionAndType() + { + // Arrange. + var definition = new VectorStoreRecordDefinition() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Id", typeof(string)), + new VectorStoreRecordDataProperty("Text", typeof(string)), + new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { Dimensions = 4 }, + } + }; + + // Act. + var sut = new AzureAISearchVectorStoreRecordCollection( + this._searchIndexClientMock.Object, + TestCollectionName, + new() { VectorStoreRecordDefinition = definition, JsonObjectCustomMapper = Mock.Of>() }); + } + + private AzureAISearchVectorStoreRecordCollection CreateRecordCollection(bool useDefinition, bool useCustomJsonSerializerOptions = false) + { + return new AzureAISearchVectorStoreRecordCollection( + this._searchIndexClientMock.Object, + TestCollectionName, + new() + { + VectorStoreRecordDefinition = useDefinition ? this._multiPropsDefinition : null, + JsonSerializerOptions = useCustomJsonSerializerOptions ? this._customJsonSerializerOptions : null + }); + } + + private static MultiPropsModel CreateModel(string key, bool withVectors) + { + return new MultiPropsModel + { + Key = key, + Data1 = "data 1", + Data2 = "data 2", + Vector1 = withVectors ? new float[] { 1, 2, 3, 4 } : null, + Vector2 = withVectors ? new float[] { 1, 2, 3, 4 } : null, + NotAnnotated = null, + }; + } + + private readonly JsonSerializerOptions _customJsonSerializerOptions = new() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }; + + private readonly VectorStoreRecordDefinition _multiPropsDefinition = new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("Data1", typeof(string)), + new VectorStoreRecordDataProperty("Data2", typeof(string)), + new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory)) { Dimensions = 4 }, + new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory)) { Dimensions = 4 } + ] + }; + + public sealed class MultiPropsModel + { + [VectorStoreRecordKey] + public string Key { get; set; } = string.Empty; + + [JsonPropertyName("storage_data1")] + [VectorStoreRecordData] + public string Data1 { get; set; } = string.Empty; + + [VectorStoreRecordData] + public string Data2 { get; set; } = string.Empty; + + [JsonPropertyName("storage_vector1")] + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? Vector1 { get; set; } + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? Vector2 { get; set; } + + public string? NotAnnotated { get; set; } + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreTests.cs new file mode 100644 index 000000000000..889b486da2ad --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreTests.cs @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure; +using Azure.Search.Documents; +using Azure.Search.Documents.Indexes; +using Microsoft.SemanticKernel.Connectors.AzureAISearch; +using Microsoft.SemanticKernel.Data; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; + +/// +/// Contains tests for the class. +/// +public class AzureAISearchVectorStoreTests +{ + private const string TestCollectionName = "testcollection"; + + private readonly Mock _searchIndexClientMock; + private readonly Mock _searchClientMock; + + private readonly CancellationToken _testCancellationToken = new(false); + + public AzureAISearchVectorStoreTests() + { + this._searchClientMock = new Mock(MockBehavior.Strict); + this._searchIndexClientMock = new Mock(MockBehavior.Strict); + this._searchIndexClientMock.Setup(x => x.GetSearchClient(TestCollectionName)).Returns(this._searchClientMock.Object); + } + + [Fact] + public void GetCollectionReturnsCollection() + { + // Arrange. + var sut = new AzureAISearchVectorStore(this._searchIndexClientMock.Object); + + // Act. + var actual = sut.GetCollection(TestCollectionName); + + // Assert. + Assert.NotNull(actual); + Assert.IsType>(actual); + } + + [Fact] + public void GetCollectionCallsFactoryIfProvided() + { + // Arrange. + var factoryMock = new Mock(MockBehavior.Strict); + var collectionMock = new Mock>(MockBehavior.Strict); + factoryMock + .Setup(x => x.CreateVectorStoreRecordCollection(this._searchIndexClientMock.Object, TestCollectionName, null)) + .Returns(collectionMock.Object); + var sut = new AzureAISearchVectorStore(this._searchIndexClientMock.Object, new() { VectorStoreCollectionFactory = factoryMock.Object }); + + // Act. + var actual = sut.GetCollection(TestCollectionName); + + // Assert. + Assert.Equal(collectionMock.Object, actual); + } + + [Fact] + public void GetCollectionThrowsForInvalidKeyType() + { + // Arrange. + var sut = new AzureAISearchVectorStore(this._searchIndexClientMock.Object); + + // Act & Assert. + Assert.Throws(() => sut.GetCollection(TestCollectionName)); + } + + [Fact] + public async Task ListCollectionNamesCallsSDKAsync() + { + // Arrange async enumerator mock. + var iterationCounter = 0; + var asyncEnumeratorMock = new Mock>(MockBehavior.Strict); + asyncEnumeratorMock.Setup(x => x.MoveNextAsync()).Returns(() => ValueTask.FromResult(iterationCounter++ <= 4)); + asyncEnumeratorMock.Setup(x => x.Current).Returns(() => $"testcollection{iterationCounter}"); + + // Arrange pageable mock. + var pageableMock = new Mock>(MockBehavior.Strict); + pageableMock.Setup(x => x.GetAsyncEnumerator(this._testCancellationToken)).Returns(asyncEnumeratorMock.Object); + + // Arrange search index client mock and sut. + this._searchIndexClientMock + .Setup(x => x.GetIndexNamesAsync(this._testCancellationToken)) + .Returns(pageableMock.Object); + var sut = new AzureAISearchVectorStore(this._searchIndexClientMock.Object); + + // Act. + var actual = sut.ListCollectionNamesAsync(this._testCancellationToken); + + // Assert. + Assert.NotNull(actual); + var actualList = await actual.ToListAsync(); + Assert.Equal(5, actualList.Count); + Assert.All(actualList, (value, index) => Assert.Equal($"testcollection{index + 1}", value)); + } + + public sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public string Key { get; set; } = string.Empty; + + [VectorStoreRecordData] + public string Data { get; set; } = string.Empty; + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? Vector { get; set; } + + public string? NotAnnotated { get; set; } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchKernelBuilderExtensions.cs new file mode 100644 index 000000000000..16d48e60a66d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchKernelBuilderExtensions.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Azure; +using Azure.Core; +using Azure.Search.Documents.Indexes; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +/// +/// Extension methods to register Azure AI Search instances on the . +/// +public static class AzureAISearchKernelBuilderExtensions +{ + /// + /// Register an Azure AI Search with the specified service ID and where is retrieved from the dependency injection container. + /// + /// The builder to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddAzureAISearchVectorStore(this IKernelBuilder builder, AzureAISearchVectorStoreOptions? options = default, string? serviceId = default) + { + builder.Services.AddAzureAISearchVectorStore(options, serviceId); + return builder; + } + + /// + /// Register an Azure AI Search with the provided and and the specified service ID. + /// + /// The builder to register the on. + /// The service endpoint for Azure AI Search. + /// The credential to authenticate to Azure AI Search with. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddAzureAISearchVectorStore(this IKernelBuilder builder, Uri endpoint, TokenCredential tokenCredential, AzureAISearchVectorStoreOptions? options = default, string? serviceId = default) + { + builder.Services.AddAzureAISearchVectorStore(endpoint, tokenCredential, options, serviceId); + return builder; + } + + /// + /// Register an Azure AI Search with the provided and and the specified service ID. + /// + /// The builder to register the on. + /// The service endpoint for Azure AI Search. + /// The credential to authenticate to Azure AI Search with. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddAzureAISearchVectorStore(this IKernelBuilder builder, Uri endpoint, AzureKeyCredential credential, AzureAISearchVectorStoreOptions? options = default, string? serviceId = default) + { + builder.Services.AddAzureAISearchVectorStore(endpoint, credential, options, serviceId); + return builder; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchServiceCollectionExtensions.cs new file mode 100644 index 000000000000..3c55d6ade628 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchServiceCollectionExtensions.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Azure; +using Azure.Core; +using Azure.Search.Documents.Indexes; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +/// +/// Extension methods to register Azure AI Search instances on an . +/// +public static class AzureAISearchServiceCollectionExtensions +{ + /// + /// Register an Azure AI Search with the specified service ID and where is retrieved from the dependency injection container. + /// + /// The to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddAzureAISearchVectorStore(this IServiceCollection services, AzureAISearchVectorStoreOptions? options = default, string? serviceId = default) + { + // If we are not constructing the SearchIndexClient, add the IVectorStore as transient, since we + // cannot make assumptions about how SearchIndexClient is being managed. + services.AddKeyedTransient( + serviceId, + (sp, obj) => + { + var searchIndexClient = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService(); + + return new AzureAISearchVectorStore( + searchIndexClient, + selectedOptions); + }); + + return services; + } + + /// + /// Register an Azure AI Search with the provided and and the specified service ID. + /// + /// The to register the on. + /// The service endpoint for Azure AI Search. + /// The credential to authenticate to Azure AI Search with. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddAzureAISearchVectorStore(this IServiceCollection services, Uri endpoint, TokenCredential tokenCredential, AzureAISearchVectorStoreOptions? options = default, string? serviceId = default) + { + Verify.NotNull(endpoint); + Verify.NotNull(tokenCredential); + + services.AddKeyedSingleton( + serviceId, + (sp, obj) => + { + var searchIndexClient = new SearchIndexClient(endpoint, tokenCredential); + var selectedOptions = options ?? sp.GetService(); + + return new AzureAISearchVectorStore( + searchIndexClient, + selectedOptions); + }); + + return services; + } + + /// + /// Register an Azure AI Search with the provided and and the specified service ID. + /// + /// The to register the on. + /// The service endpoint for Azure AI Search. + /// The credential to authenticate to Azure AI Search with. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddAzureAISearchVectorStore(this IServiceCollection services, Uri endpoint, AzureKeyCredential credential, AzureAISearchVectorStoreOptions? options = default, string? serviceId = default) + { + Verify.NotNull(endpoint); + Verify.NotNull(credential); + + services.AddKeyedSingleton( + serviceId, + (sp, obj) => + { + var searchIndexClient = new SearchIndexClient(endpoint, credential); + var selectedOptions = options ?? sp.GetService(); + + return new AzureAISearchVectorStore( + searchIndexClient, + selectedOptions); + }); + + return services; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStore.cs new file mode 100644 index 000000000000..2ca2bf9577f5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStore.cs @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Azure; +using Azure.Search.Documents.Indexes; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +/// +/// Class for accessing the list of collections in a Azure AI Search vector store. +/// +/// +/// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. +/// +public sealed class AzureAISearchVectorStore : IVectorStore +{ + /// The name of this database for telemetry purposes. + private const string DatabaseName = "AzureAISearch"; + + /// Azure AI Search client that can be used to manage the list of indices in an Azure AI Search Service. + private readonly SearchIndexClient _searchIndexClient; + + /// Optional configuration options for this class. + private readonly AzureAISearchVectorStoreOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// Azure AI Search client that can be used to manage the list of indices in an Azure AI Search Service. + /// Optional configuration options for this class. + public AzureAISearchVectorStore(SearchIndexClient searchIndexClient, AzureAISearchVectorStoreOptions? options = default) + { + Verify.NotNull(searchIndexClient); + + this._searchIndexClient = searchIndexClient; + this._options = options ?? new AzureAISearchVectorStoreOptions(); + } + + /// + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + where TRecord : class + { + if (typeof(TKey) != typeof(string)) + { + throw new NotSupportedException("Only string keys are supported."); + } + + if (this._options.VectorStoreCollectionFactory is not null) + { + return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(this._searchIndexClient, name, vectorStoreRecordDefinition); + } + + var directlyCreatedStore = new AzureAISearchVectorStoreRecordCollection(this._searchIndexClient, name, new AzureAISearchVectorStoreRecordCollectionOptions() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; + return directlyCreatedStore!; + } + + /// + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var indexNamesEnumerable = this._searchIndexClient.GetIndexNamesAsync(cancellationToken).ConfigureAwait(false); + var indexNamesEnumerator = indexNamesEnumerable.GetAsyncEnumerator(); + + var nextResult = await GetNextIndexNameAsync(indexNamesEnumerator).ConfigureAwait(false); + while (nextResult.more) + { + yield return nextResult.name; + nextResult = await GetNextIndexNameAsync(indexNamesEnumerator).ConfigureAwait(false); + } + } + + /// + /// Helper method to get the next index name from the enumerator with a try catch around the move next call to convert + /// any to , since try catch is not supported + /// around a yield return. + /// + /// The enumerator to get the next result from. + /// A value indicating whether there are more results and the current string if true. + private static async Task<(string name, bool more)> GetNextIndexNameAsync(ConfiguredCancelableAsyncEnumerable.Enumerator enumerator) + { + const string OperationName = "GetIndexNames"; + + try + { + var more = await enumerator.MoveNextAsync(); + return (enumerator.Current, more); + } + catch (AggregateException ex) when (ex.InnerException is RequestFailedException innerEx) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + OperationName = OperationName + }; + } + catch (RequestFailedException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + OperationName = OperationName + }; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionCreateMapping.cs new file mode 100644 index 000000000000..2ee086d69d53 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionCreateMapping.cs @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Azure.Search.Documents.Indexes.Models; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +/// +/// Contains mapping helpers to use when creating a Azure AI Search vector collection. +/// +internal static class AzureAISearchVectorStoreCollectionCreateMapping +{ + /// + /// Map from a to an Azure AI Search . + /// + /// The key property definition. + /// The name of the property in storage. + /// The for the provided property definition. + public static SearchableField MapKeyField(VectorStoreRecordKeyProperty keyProperty, string storagePropertyName) + { + return new SearchableField(storagePropertyName) { IsKey = true, IsFilterable = true }; + } + + /// + /// Map from a to an Azure AI Search . + /// + /// The data property definition. + /// The name of the property in storage. + /// The for the provided property definition. + /// Throws when the definition is missing required information. + public static SimpleField MapDataField(VectorStoreRecordDataProperty dataProperty, string storagePropertyName) + { + if (dataProperty.IsFullTextSearchable) + { + if (dataProperty.PropertyType != typeof(string)) + { + throw new InvalidOperationException($"Property {nameof(dataProperty.IsFullTextSearchable)} on {nameof(VectorStoreRecordDataProperty)} '{dataProperty.DataModelPropertyName}' is set to true, but the property type is not a string. The Azure AI Search VectorStore supports {nameof(dataProperty.IsFullTextSearchable)} on string properties only."); + } + + return new SearchableField(storagePropertyName) { IsFilterable = dataProperty.IsFilterable }; + } + + return new SimpleField(storagePropertyName, AzureAISearchVectorStoreCollectionCreateMapping.GetSDKFieldDataType(dataProperty.PropertyType)) { IsFilterable = dataProperty.IsFilterable }; + } + + /// + /// Map form a to an Azure AI Search and generate the required index configuration. + /// + /// The vector property definition. + /// The name of the property in storage. + /// The and required index configuration. + /// Throws when the definition is missing required information, or unsupported options are configured. + public static (VectorSearchField vectorSearchField, VectorSearchAlgorithmConfiguration algorithmConfiguration, VectorSearchProfile vectorSearchProfile) MapVectorField(VectorStoreRecordVectorProperty vectorProperty, string storagePropertyName) + { + if (vectorProperty.Dimensions is not > 0) + { + throw new InvalidOperationException($"Property {nameof(vectorProperty.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' must be set to a positive integer to create a collection."); + } + + // Build a name for the profile and algorithm configuration based on the property name + // since we'll just create a separate one for each vector property. + var vectorSearchProfileName = $"{storagePropertyName}Profile"; + var algorithmConfigName = $"{storagePropertyName}AlgoConfig"; + + // Read the vector index settings from the property definition and create the right index configuration. + var indexKind = AzureAISearchVectorStoreCollectionCreateMapping.GetSKIndexKind(vectorProperty); + var algorithmMetric = AzureAISearchVectorStoreCollectionCreateMapping.GetSDKDistanceAlgorithm(vectorProperty); + + VectorSearchAlgorithmConfiguration algorithmConfiguration = indexKind switch + { + IndexKind.Hnsw => new HnswAlgorithmConfiguration(algorithmConfigName) { Parameters = new HnswParameters { Metric = algorithmMetric } }, + IndexKind.Flat => new ExhaustiveKnnAlgorithmConfiguration(algorithmConfigName) { Parameters = new ExhaustiveKnnParameters { Metric = algorithmMetric } }, + _ => throw new InvalidOperationException($"Index kind '{indexKind}' on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Azure AI Search VectorStore.") + }; + var vectorSearchProfile = new VectorSearchProfile(vectorSearchProfileName, algorithmConfigName); + + return (new VectorSearchField(storagePropertyName, vectorProperty.Dimensions.Value, vectorSearchProfileName), algorithmConfiguration, vectorSearchProfile); + } + + /// + /// Get the configured from the given . + /// If none is configured the default is . + /// + /// The vector property definition. + /// The configured or default . + public static string GetSKIndexKind(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty.IndexKind is null) + { + return IndexKind.Hnsw; + } + + return vectorProperty.IndexKind; + } + + /// + /// Get the configured from the given . + /// If none is configured, the default is . + /// + /// The vector property definition. + /// The chosen . + /// Thrown if a distance function is chosen that isn't supported by Azure AI Search. + public static VectorSearchAlgorithmMetric GetSDKDistanceAlgorithm(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty.DistanceFunction is null) + { + return VectorSearchAlgorithmMetric.Cosine; + } + + return vectorProperty.DistanceFunction switch + { + DistanceFunction.CosineSimilarity => VectorSearchAlgorithmMetric.Cosine, + DistanceFunction.DotProductSimilarity => VectorSearchAlgorithmMetric.DotProduct, + DistanceFunction.EuclideanDistance => VectorSearchAlgorithmMetric.Euclidean, + _ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Azure AI Search VectorStore.") + }; + } + + /// + /// Maps the given property type to the corresponding . + /// + /// The property type to map. + /// The that corresponds to the given property type." + /// Thrown if the given type is not supported. + public static SearchFieldDataType GetSDKFieldDataType(Type propertyType) + { + return propertyType switch + { + Type stringType when stringType == typeof(string) => SearchFieldDataType.String, + Type boolType when boolType == typeof(bool) || boolType == typeof(bool?) => SearchFieldDataType.Boolean, + Type intType when intType == typeof(int) || intType == typeof(int?) => SearchFieldDataType.Int32, + Type longType when longType == typeof(long) || longType == typeof(long?) => SearchFieldDataType.Int64, + Type floatType when floatType == typeof(float) || floatType == typeof(float?) => SearchFieldDataType.Double, + Type doubleType when doubleType == typeof(double) || doubleType == typeof(double?) => SearchFieldDataType.Double, + Type dateTimeType when dateTimeType == typeof(DateTime) || dateTimeType == typeof(DateTime?) => SearchFieldDataType.DateTimeOffset, + Type dateTimeOffsetType when dateTimeOffsetType == typeof(DateTimeOffset) || dateTimeOffsetType == typeof(DateTimeOffset?) => SearchFieldDataType.DateTimeOffset, + Type collectionType when typeof(IEnumerable).IsAssignableFrom(collectionType) => SearchFieldDataType.Collection(GetSDKFieldDataType(GetEnumerableType(propertyType))), + _ => throw new InvalidOperationException($"Data type '{propertyType}' for {nameof(VectorStoreRecordDataProperty)} is not supported by the Azure AI Search VectorStore.") + }; + } + + /// + /// Gets the type of object stored in the given enumerable type. + /// + /// The enumerable to get the stored type for. + /// The type of object stored in the given enumerable type. + /// Thrown when the given type is not enumerable. + public static Type GetEnumerableType(Type type) + { + if (type is IEnumerable) + { + return typeof(object); + } + else if (type.IsArray) + { + return type.GetElementType()!; + } + + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IEnumerable<>)) + { + return type.GetGenericArguments()[0]; + } + + if (type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)) is Type enumerableInterface) + { + return enumerableInterface.GetGenericArguments()[0]; + } + + throw new InvalidOperationException($"Data type '{type}' for {nameof(VectorStoreRecordDataProperty)} is not supported by the Azure AI Search VectorStore."); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreOptions.cs new file mode 100644 index 000000000000..e8d54c8b7740 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreOptions.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +/// +/// Options when creating a . +/// +public sealed class AzureAISearchVectorStoreOptions +{ + /// + /// An optional factory to use for constructing instances, if custom options are required. + /// + public IAzureAISearchVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..21018b39c223 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs @@ -0,0 +1,450 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Threading; +using System.Threading.Tasks; +using Azure; +using Azure.Search.Documents; +using Azure.Search.Documents.Indexes; +using Azure.Search.Documents.Indexes.Models; +using Azure.Search.Documents.Models; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +/// +/// Service for storing and retrieving vector records, that uses Azure AI Search as the underlying storage. +/// +/// The data model to use for adding, updating and retrieving data from storage. +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public sealed class AzureAISearchVectorStoreRecordCollection : IVectorStoreRecordCollection +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix + where TRecord : class +{ + /// The name of this database for telemetry purposes. + private const string DatabaseName = "AzureAISearch"; + + /// A set of types that a key on the provided model may have. + private static readonly HashSet s_supportedKeyTypes = + [ + typeof(string) + ]; + + /// A set of types that data properties on the provided model may have. + private static readonly HashSet s_supportedDataTypes = + [ + typeof(string), + typeof(int), + typeof(long), + typeof(double), + typeof(float), + typeof(bool), + typeof(DateTimeOffset), + typeof(int?), + typeof(long?), + typeof(double?), + typeof(float?), + typeof(bool?), + typeof(DateTimeOffset?), + ]; + + /// A set of types that vectors on the provided model may have. + /// + /// Azure AI Search is adding support for more types than just float32, but these are not available for use via the + /// SDK yet. We will update this list as the SDK is updated. + /// + /// + private static readonly HashSet s_supportedVectorTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?) + ]; + + /// Azure AI Search client that can be used to manage the list of indices in an Azure AI Search Service. + private readonly SearchIndexClient _searchIndexClient; + + /// Azure AI Search client that can be used to manage data in an Azure AI Search Service index. + private readonly SearchClient _searchClient; + + /// The name of the collection that this will access. + private readonly string _collectionName; + + /// Optional configuration options for this class. + private readonly AzureAISearchVectorStoreRecordCollectionOptions _options; + + /// A definition of the current storage model. + private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition; + + /// The storage name of the key field for the collections that this class is used with. + private readonly string _keyStoragePropertyName; + + /// The storage names of all non vector fields on the current model. + private readonly List _nonVectorStoragePropertyNames = new(); + + /// A dictionary that maps from a property name to the storage name that should be used when serializing it to json for data and vector properties. + private readonly Dictionary _storagePropertyNames = new(); + + /// + /// Initializes a new instance of the class. + /// + /// Azure AI Search client that can be used to manage the list of indices in an Azure AI Search Service. + /// The name of the collection that this will access. + /// Optional configuration options for this class. + /// Thrown when is null. + /// Thrown when options are misconfigured. + public AzureAISearchVectorStoreRecordCollection(SearchIndexClient searchIndexClient, string collectionName, AzureAISearchVectorStoreRecordCollectionOptions? options = default) + { + // Verify. + Verify.NotNull(searchIndexClient); + Verify.NotNullOrWhiteSpace(collectionName); + + // Assign. + this._searchIndexClient = searchIndexClient; + this._collectionName = collectionName; + this._options = options ?? new AzureAISearchVectorStoreRecordCollectionOptions(); + this._searchClient = this._searchIndexClient.GetSearchClient(collectionName); + this._vectorStoreRecordDefinition = this._options.VectorStoreRecordDefinition ?? VectorStoreRecordPropertyReader.CreateVectorStoreRecordDefinitionFromType(typeof(TRecord), true); + var jsonSerializerOptions = this._options.JsonSerializerOptions ?? JsonSerializerOptions.Default; + + // Validate property types. + var properties = VectorStoreRecordPropertyReader.SplitDefinitionAndVerify(typeof(TRecord).Name, this._vectorStoreRecordDefinition, supportsMultipleVectors: true, requiresAtLeastOneVector: false); + VectorStoreRecordPropertyReader.VerifyPropertyTypes([properties.KeyProperty], s_supportedKeyTypes, "Key"); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(properties.DataProperties, s_supportedDataTypes, "Data", supportEnumerable: true); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(properties.VectorProperties, s_supportedVectorTypes, "Vector"); + + // Get storage names and store for later use. + this._storagePropertyNames = VectorStoreRecordPropertyReader.BuildPropertyNameToJsonPropertyNameMap(properties, typeof(TRecord), jsonSerializerOptions); + this._keyStoragePropertyName = this._storagePropertyNames[properties.KeyProperty.DataModelPropertyName]; + this._nonVectorStoragePropertyNames = properties.DataProperties + .Cast() + .Concat([properties.KeyProperty]) + .Select(x => this._storagePropertyNames[x.DataModelPropertyName]) + .ToList(); + } + + /// + public string CollectionName => this._collectionName; + + /// + public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + try + { + await this._searchIndexClient.GetIndexAsync(this._collectionName, cancellationToken).ConfigureAwait(false); + return true; + } + catch (RequestFailedException ex) when (ex.Status == 404) + { + return false; + } + catch (RequestFailedException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this._collectionName, + OperationName = "GetIndex" + }; + } + } + + /// + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + var vectorSearchConfig = new VectorSearch(); + var searchFields = new List(); + + // Loop through all properties and create the search fields. + foreach (var property in this._vectorStoreRecordDefinition.Properties) + { + // Key property. + if (property is VectorStoreRecordKeyProperty keyProperty) + { + searchFields.Add(AzureAISearchVectorStoreCollectionCreateMapping.MapKeyField(keyProperty, this._keyStoragePropertyName)); + } + + // Data property. + if (property is VectorStoreRecordDataProperty dataProperty) + { + searchFields.Add(AzureAISearchVectorStoreCollectionCreateMapping.MapDataField(dataProperty, this._storagePropertyNames[dataProperty.DataModelPropertyName])); + } + + // Vector property. + if (property is VectorStoreRecordVectorProperty vectorProperty) + { + (VectorSearchField vectorSearchField, VectorSearchAlgorithmConfiguration algorithmConfiguration, VectorSearchProfile vectorSearchProfile) = AzureAISearchVectorStoreCollectionCreateMapping.MapVectorField( + vectorProperty, + this._storagePropertyNames[vectorProperty.DataModelPropertyName]); + + // Add the search field, plus its profile and algorithm configuration to the search config. + searchFields.Add(vectorSearchField); + vectorSearchConfig.Algorithms.Add(algorithmConfiguration); + vectorSearchConfig.Profiles.Add(vectorSearchProfile); + } + } + + // Create the index. + var searchIndex = new SearchIndex(this._collectionName, searchFields); + searchIndex.VectorSearch = vectorSearchConfig; + + return this.RunOperationAsync( + "CreateIndex", + () => this._searchIndexClient.CreateIndexAsync(searchIndex, cancellationToken)); + } + + /// + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) + { + await this.CreateCollectionAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + return this.RunOperationAsync( + "DeleteIndex", + () => this._searchIndexClient.DeleteIndexAsync(this._collectionName, cancellationToken)); + } + + /// + public Task GetAsync(string key, GetRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNullOrWhiteSpace(key); + + // Create Options. + var innerOptions = this.ConvertGetDocumentOptions(options); + var includeVectors = options?.IncludeVectors ?? false; + + // Get record. + return this.GetDocumentAndMapToDataModelAsync(key, includeVectors, innerOptions, cancellationToken); + } + + /// + public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + // Create Options + var innerOptions = this.ConvertGetDocumentOptions(options); + var includeVectors = options?.IncludeVectors ?? false; + + // Get records in parallel. + var tasks = keys.Select(key => this.GetDocumentAndMapToDataModelAsync(key, includeVectors, innerOptions, cancellationToken)); + var results = await Task.WhenAll(tasks).ConfigureAwait(false); + foreach (var result in results) + { + if (result is not null) + { + yield return result; + } + } + } + + /// + public Task DeleteAsync(string key, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNullOrWhiteSpace(key); + + // Remove record. + return this.RunOperationAsync( + "DeleteDocuments", + () => this._searchClient.DeleteDocumentsAsync(this._keyStoragePropertyName, [key], new IndexDocumentsOptions(), cancellationToken)); + } + + /// + public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + // Remove records. + return this.RunOperationAsync( + "DeleteDocuments", + () => this._searchClient.DeleteDocumentsAsync(this._keyStoragePropertyName, keys, new IndexDocumentsOptions(), cancellationToken)); + } + + /// + public async Task UpsertAsync(TRecord record, UpsertRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNull(record); + + // Create options. + var innerOptions = new IndexDocumentsOptions { ThrowOnAnyError = true }; + + // Upsert record. + var results = await this.MapToStorageModelAndUploadDocumentAsync([record], innerOptions, cancellationToken).ConfigureAwait(false); + return results.Value.Results[0].Key; + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(records); + + // Create Options + var innerOptions = new IndexDocumentsOptions { ThrowOnAnyError = true }; + + // Upsert records + var results = await this.MapToStorageModelAndUploadDocumentAsync(records, innerOptions, cancellationToken).ConfigureAwait(false); + + // Get results + var resultKeys = results.Value.Results.Select(x => x.Key).ToList(); + foreach (var resultKey in resultKeys) { yield return resultKey; } + } + + /// + /// Get the document with the given key and map it to the data model using the configured mapper type. + /// + /// The key of the record to get. + /// A value indicating whether to include vectors in the result or not. + /// The Azure AI Search sdk options for getting a document. + /// The to monitor for cancellation requests. The default is . + /// The retrieved document, mapped to the consumer data model. + private async Task GetDocumentAndMapToDataModelAsync( + string key, + bool includeVectors, + GetDocumentOptions innerOptions, + CancellationToken cancellationToken) + { + const string OperationName = "GetDocument"; + + // Use the user provided mapper. + if (this._options.JsonObjectCustomMapper is not null) + { + var jsonObject = await this.RunOperationAsync( + OperationName, + () => GetDocumentWithNotFoundHandlingAsync(this._searchClient, key, innerOptions, cancellationToken)).ConfigureAwait(false); + + if (jsonObject is null) + { + return null; + } + + return VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + OperationName, + () => this._options.JsonObjectCustomMapper!.MapFromStorageToDataModel(jsonObject, new() { IncludeVectors = includeVectors })); + } + + // Use the built in Azure AI Search mapper. + return await this.RunOperationAsync( + OperationName, + () => GetDocumentWithNotFoundHandlingAsync(this._searchClient, key, innerOptions, cancellationToken)).ConfigureAwait(false); + } + + /// + /// Map the data model to the storage model and upload the document. + /// + /// The records to upload. + /// The Azure AI Search sdk options for uploading a document. + /// The to monitor for cancellation requests. The default is . + /// The document upload result. + private Task> MapToStorageModelAndUploadDocumentAsync( + IEnumerable records, + IndexDocumentsOptions innerOptions, + CancellationToken cancellationToken) + { + const string OperationName = "UploadDocuments"; + + // Use the user provided mapper. + if (this._options.JsonObjectCustomMapper is not null) + { + var jsonObjects = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + OperationName, + () => records.Select(this._options.JsonObjectCustomMapper!.MapFromDataToStorageModel)); + + return this.RunOperationAsync( + OperationName, + () => this._searchClient.UploadDocumentsAsync(jsonObjects, innerOptions, cancellationToken)); + } + + // Use the built in Azure AI Search mapper. + return this.RunOperationAsync( + OperationName, + () => this._searchClient.UploadDocumentsAsync(records, innerOptions, cancellationToken)); + } + + /// + /// Convert the public options model to the Azure AI Search options model. + /// + /// The public options model. + /// The Azure AI Search options model. + private GetDocumentOptions ConvertGetDocumentOptions(GetRecordOptions? options) + { + var innerOptions = new GetDocumentOptions(); + if (options?.IncludeVectors is false) + { + innerOptions.SelectedFields.AddRange(this._nonVectorStoragePropertyNames); + } + + return innerOptions; + } + + /// + /// Get a document with the given key, and return null if it is not found. + /// + /// The type to deserialize the document to. + /// The search client to use when fetching the document. + /// The key of the record to get. + /// The Azure AI Search sdk options for getting a document. + /// The to monitor for cancellation requests. The default is . + /// The retrieved document, mapped to the consumer data model, or null if not found. + private static async Task GetDocumentWithNotFoundHandlingAsync( + SearchClient searchClient, + string key, + GetDocumentOptions innerOptions, + CancellationToken cancellationToken) + { + try + { + return await searchClient.GetDocumentAsync(key, innerOptions, cancellationToken).ConfigureAwait(false); + } + catch (RequestFailedException ex) when (ex.Status == 404) + { + return default; + } + } + + /// + /// Run the given operation and wrap any with ."/> + /// + /// The response type of the operation. + /// The type of database operation being run. + /// The operation to run. + /// The result of the operation. + private async Task RunOperationAsync(string operationName, Func> operation) + { + try + { + return await operation.Invoke().ConfigureAwait(false); + } + catch (AggregateException ex) when (ex.InnerException is RequestFailedException innerEx) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this._collectionName, + OperationName = operationName + }; + } + catch (RequestFailedException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this._collectionName, + OperationName = operationName + }; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..462dcd5d6e66 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Azure.Search.Documents.Indexes; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +/// +/// Options when creating a . +/// +public sealed class AzureAISearchVectorStoreRecordCollectionOptions + where TRecord : class +{ + /// + /// Gets or sets an optional custom mapper to use when converting between the data model and the Azure AI Search record. + /// + /// + /// If not set, the default mapper that is provided by the Azure AI Search client SDK will be used. + /// + public IVectorStoreRecordMapper? JsonObjectCustomMapper { get; init; } = null; + + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; + + /// + /// Gets or sets the JSON serializer options to use when converting between the data model and the Azure AI Search record. + /// Note that when using the default mapper, you will need to provide the same set of both here and when constructing the . + /// + public JsonSerializerOptions? JsonSerializerOptions { get; init; } = null; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/IAzureAISearchVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/IAzureAISearchVectorStoreRecordCollectionFactory.cs new file mode 100644 index 000000000000..3e7dc2d82bc9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/IAzureAISearchVectorStoreRecordCollectionFactory.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Azure.Search.Documents.Indexes; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +/// +/// Interface for constructing Azure AI Search instances when using to retrieve these. +/// +public interface IAzureAISearchVectorStoreRecordCollectionFactory +{ + /// + /// Constructs a new instance of the . + /// + /// The data type of the record key. + /// The data model to use for adding, updating and retrieving data from storage. + /// Azure AI Search client that can be used to manage the list of indices in an Azure AI Search Service. + /// The name of the collection to connect to. + /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. + /// The new instance of . + IVectorStoreRecordCollection CreateVectorStoreRecordCollection(SearchIndexClient searchIndexClient, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + where TKey : notnull + where TRecord : class; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Connectors.Memory.Pinecone.csproj b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Connectors.Memory.Pinecone.csproj index 462a89b0bd8b..69b47fe172f0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/Connectors.Memory.Pinecone.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/Connectors.Memory.Pinecone.csproj @@ -19,6 +19,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeVectorStoreRecordCollectionFactory.cs new file mode 100644 index 000000000000..965639e93c8e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/IPineconeVectorStoreRecordCollectionFactory.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Data; +using Sdk = Pinecone; + +namespace Microsoft.SemanticKernel.Connectors.Pinecone; + +/// +/// Interface for constructing Pinecone instances when using to retrieve these. +/// +public interface IPineconeVectorStoreRecordCollectionFactory +{ + /// + /// Constructs a new instance of the . + /// + /// The data type of the record key. + /// The data model to use for adding, updating and retrieving data from storage. + /// Pinecone client that can be used to manage the collections and points in a Pinecone store. + /// The name of the collection to connect to. + /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. + /// The new instance of . + IVectorStoreRecordCollection CreateVectorStoreRecordCollection(Sdk.PineconeClient pineconeClient, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + where TKey : notnull + where TRecord : class; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeKernelBuilderExtensions.cs new file mode 100644 index 000000000000..f4c6e643ecc5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeKernelBuilderExtensions.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Data; +using Sdk = Pinecone; + +namespace Microsoft.SemanticKernel.Connectors.Pinecone; + +/// +/// Extension methods to register Pinecone instances on the . +/// +public static class PineconeKernelBuilderExtensions +{ + /// + /// Register a Pinecone with the specified service ID and where is retrieved from the dependency injection container. + /// + /// The builder to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddPineconeVectorStore(this IKernelBuilder builder, PineconeVectorStoreOptions? options = default, string? serviceId = default) + { + builder.Services.AddPineconeVectorStore(options, serviceId); + return builder; + } + + /// + /// Register a Pinecone with the specified service ID and where is constructed using the provided apikey. + /// + /// The builder to register the on. + /// The api key for Pinecone. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddPineconeVectorStore(this IKernelBuilder builder, string apiKey, PineconeVectorStoreOptions? options = default, string? serviceId = default) + { + builder.Services.AddPineconeVectorStore(apiKey, options, serviceId); + return builder; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeServiceCollectionExtensions.cs new file mode 100644 index 000000000000..8473d4fbd79e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeServiceCollectionExtensions.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Data; +using Sdk = Pinecone; + +namespace Microsoft.SemanticKernel.Connectors.Pinecone; + +/// +/// Extension methods to register Pinecone instances on an . +/// +public static class PineconeServiceCollectionExtensions +{ + /// + /// Register a Pinecone with the specified service ID and where is retrieved from the dependency injection container. + /// + /// The to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPineconeVectorStore(this IServiceCollection services, PineconeVectorStoreOptions? options = default, string? serviceId = default) + { + // If we are not constructing the PineconeClient, add the IVectorStore as transient, since we + // cannot make assumptions about how PineconeClient is being managed. + services.AddKeyedTransient( + serviceId, + (sp, obj) => + { + var pineconeClient = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService(); + + return new PineconeVectorStore( + pineconeClient, + selectedOptions); + }); + + return services; + } + + /// + /// Register a Pinecone with the specified service ID and where is constructed using the provided apikey. + /// + /// The to register the on. + /// The api key for Pinecone. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPineconeVectorStore(this IServiceCollection services, string apiKey, PineconeVectorStoreOptions? options = default, string? serviceId = default) + { + services.AddKeyedSingleton( + serviceId, + (sp, obj) => + { + var pineconeClient = new Sdk.PineconeClient(apiKey); + var selectedOptions = options ?? sp.GetService(); + + return new PineconeVectorStore( + pineconeClient, + selectedOptions); + }); + + return services; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStore.cs new file mode 100644 index 000000000000..ec5b6114c801 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStore.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using Grpc.Core; +using Microsoft.SemanticKernel.Data; +using Pinecone; +using Sdk = Pinecone; + +namespace Microsoft.SemanticKernel.Connectors.Pinecone; + +/// +/// Class for accessing the list of collections in a Pinecone vector store. +/// +/// +/// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. +/// +public sealed class PineconeVectorStore : IVectorStore +{ + private const string DatabaseName = "Pinecone"; + private const string ListCollectionsName = "ListCollections"; + + private readonly Sdk.PineconeClient _pineconeClient; + private readonly PineconeVectorStoreOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// Pinecone client that can be used to manage the collections and points in a Pinecone store. + /// Optional configuration options for this class. + public PineconeVectorStore(Sdk.PineconeClient pineconeClient, PineconeVectorStoreOptions? options = default) + { + Verify.NotNull(pineconeClient); + + this._pineconeClient = pineconeClient; + this._options = options ?? new PineconeVectorStoreOptions(); + } + + /// + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + where TRecord : class + { + if (typeof(TKey) != typeof(string)) + { + throw new NotSupportedException("Only string keys are supported."); + } + + if (this._options.VectorStoreCollectionFactory is not null) + { + return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(this._pineconeClient, name, vectorStoreRecordDefinition); + } + + return (new PineconeVectorStoreRecordCollection( + this._pineconeClient, + name, + new PineconeVectorStoreRecordCollectionOptions() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection)!; + } + + /// + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + IndexDetails[] collections; + + try + { + collections = await this._pineconeClient.ListIndexes(cancellationToken).ConfigureAwait(false); + } + catch (RpcException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + OperationName = ListCollectionsName + }; + } + + foreach (var collection in collections) + { + yield return collection.Name; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionCreateMapping.cs new file mode 100644 index 000000000000..0a50cf2ac399 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionCreateMapping.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.SemanticKernel.Data; +using Pinecone; + +namespace Microsoft.SemanticKernel.Connectors.Pinecone; + +/// +/// Contains mapping helpers to use when creating a Pinecone vector collection. +/// +internal static class PineconeVectorStoreCollectionCreateMapping +{ + /// + /// Maps information stored in to a structure used by Pinecone SDK to create a serverless index. + /// + /// The property to map. + /// The structure containing settings used to create a serverless index. + /// Thrown if the property is missing information or has unsupported options specified. + public static (uint Dimension, Metric Metric) MapServerlessIndex(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty!.Dimensions is not > 0) + { + throw new InvalidOperationException($"Property {nameof(vectorProperty.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' must be set to a positive integer to create a collection."); + } + + return (Dimension: (uint)vectorProperty.Dimensions, Metric: GetSDKMetricAlgorithm(vectorProperty)); + } + + /// + /// Get the configured from the given . + /// If none is configured, the default is . + /// + /// The vector property definition. + /// The chosen . + /// Thrown if a distance function is chosen that isn't supported by Pinecone. + public static Metric GetSDKMetricAlgorithm(VectorStoreRecordVectorProperty vectorProperty) + => vectorProperty.DistanceFunction switch + { + DistanceFunction.CosineSimilarity => Metric.Cosine, + DistanceFunction.DotProductSimilarity => Metric.DotProduct, + DistanceFunction.EuclideanDistance => Metric.Euclidean, + null => Metric.Cosine, + _ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Pinecone VectorStore.") + }; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreOptions.cs new file mode 100644 index 000000000000..7a6fc9767f62 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreOptions.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.Pinecone; + +/// +/// Options when creating a . +/// +public sealed class PineconeVectorStoreOptions +{ + /// + /// An optional factory to use for constructing instances, if custom options are required. + /// + public IPineconeVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..323681f629be --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; +using Microsoft.SemanticKernel.Data; +using Pinecone.Grpc; +using Sdk = Pinecone; + +namespace Microsoft.SemanticKernel.Connectors.Pinecone; + +/// +/// Service for storing and retrieving vector records, that uses Pinecone as the underlying storage. +/// +/// The data model to use for adding, updating and retrieving data from storage. +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public sealed class PineconeVectorStoreRecordCollection : IVectorStoreRecordCollection +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix + where TRecord : class +{ + private const string DatabaseName = "Pinecone"; + private const string CreateCollectionName = "CreateCollection"; + private const string CollectionExistsName = "CollectionExists"; + private const string DeleteCollectionName = "DeleteCollection"; + + private const string UpsertOperationName = "Upsert"; + private const string DeleteOperationName = "Delete"; + private const string GetOperationName = "Get"; + + private readonly Sdk.PineconeClient _pineconeClient; + private readonly PineconeVectorStoreRecordCollectionOptions _options; + private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition; + private readonly IVectorStoreRecordMapper _mapper; + + private Sdk.Index? _index; + + /// + public string CollectionName { get; } + + /// + /// Initializes a new instance of the class. + /// + /// Pinecone client that can be used to manage the collections and vectors in a Pinecone store. + /// Optional configuration options for this class. + /// Thrown if the is null. + /// The name of the collection that this will access. + /// Thrown for any misconfigured options. + public PineconeVectorStoreRecordCollection(Sdk.PineconeClient pineconeClient, string collectionName, PineconeVectorStoreRecordCollectionOptions? options = null) + { + Verify.NotNull(pineconeClient); + + this._pineconeClient = pineconeClient; + this.CollectionName = collectionName; + this._options = options ?? new PineconeVectorStoreRecordCollectionOptions(); + this._vectorStoreRecordDefinition = this._options.VectorStoreRecordDefinition ?? VectorStoreRecordPropertyReader.CreateVectorStoreRecordDefinitionFromType(typeof(TRecord), true); + + if (this._options.VectorCustomMapper is null) + { + this._mapper = new PineconeVectorStoreRecordMapper(this._vectorStoreRecordDefinition); + } + else + { + this._mapper = this._options.VectorCustomMapper; + } + } + + /// + public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + var result = await this.RunOperationAsync( + CollectionExistsName, + async () => + { + var collections = await this._pineconeClient.ListIndexes(cancellationToken).ConfigureAwait(false); + + return collections.Any(x => x.Name == this.CollectionName); + }).ConfigureAwait(false); + + return result; + } + + /// + public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + // we already run through record property validation, so a single VectorStoreRecordVectorProperty is guaranteed. + var vectorProperty = this._vectorStoreRecordDefinition.Properties.OfType().First(); + var (dimension, metric) = PineconeVectorStoreCollectionCreateMapping.MapServerlessIndex(vectorProperty); + + await this.RunOperationAsync( + CreateCollectionName, + () => this._pineconeClient.CreateServerlessIndex( + this.CollectionName, + dimension, + metric, + this._options.ServerlessIndexCloud, + this._options.ServerlessIndexRegion, + cancellationToken)).ConfigureAwait(false); + } + + /// + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) + { + await this.CreateCollectionAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + => this.RunOperationAsync( + DeleteCollectionName, + () => this._pineconeClient.DeleteIndex(this.CollectionName, cancellationToken)); + + /// + public async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(key); + + var records = await this.GetBatchAsync([key], options, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); + + return records.FirstOrDefault(); + } + + /// + public async IAsyncEnumerable GetBatchAsync( + IEnumerable keys, + GetRecordOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + var indexNamespace = this.GetIndexNamespace(); + var mapperOptions = new StorageToDataModelMapperOptions { IncludeVectors = options?.IncludeVectors ?? false }; + + var index = await this.GetIndexAsync(this.CollectionName, cancellationToken).ConfigureAwait(false); + + var results = await this.RunOperationAsync( + GetOperationName, + () => index.Fetch(keys, indexNamespace, cancellationToken)).ConfigureAwait(false); + + var records = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + GetOperationName, + () => results.Values.Select(x => this._mapper.MapFromStorageToDataModel(x, mapperOptions)).ToList()); + + foreach (var record in records) + { + yield return record; + } + } + + /// + public Task DeleteAsync(string key, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNullOrWhiteSpace(key); + + return this.DeleteBatchAsync([key], options, cancellationToken); + } + + /// + public async Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + var indexNamespace = this.GetIndexNamespace(); + + var index = await this.GetIndexAsync(this.CollectionName, cancellationToken).ConfigureAwait(false); + + await this.RunOperationAsync( + DeleteOperationName, + () => index.Delete(keys, indexNamespace, cancellationToken)).ConfigureAwait(false); + } + + /// + public async Task UpsertAsync(TRecord record, UpsertRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNull(record); + + var indexNamespace = this.GetIndexNamespace(); + + var index = await this.GetIndexAsync(this.CollectionName, cancellationToken).ConfigureAwait(false); + + var vector = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + UpsertOperationName, + () => this._mapper.MapFromDataToStorageModel(record)); + + await this.RunOperationAsync( + UpsertOperationName, + () => index.Upsert([vector], indexNamespace, cancellationToken)).ConfigureAwait(false); + + return vector.Id; + } + + /// + public async IAsyncEnumerable UpsertBatchAsync( + IEnumerable records, + UpsertRecordOptions? options = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(records); + + var indexNamespace = this.GetIndexNamespace(); + + var index = await this.GetIndexAsync(this.CollectionName, cancellationToken).ConfigureAwait(false); + + var vectors = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + UpsertOperationName, + () => records.Select(this._mapper.MapFromDataToStorageModel).ToList()); + + await this.RunOperationAsync( + UpsertOperationName, + () => index.Upsert(vectors, indexNamespace, cancellationToken)).ConfigureAwait(false); + + foreach (var vector in vectors) + { + yield return vector.Id; + } + } + + private async Task RunOperationAsync(string operationName, Func> operation) + { + try + { + return await operation.Invoke().ConfigureAwait(false); + } + catch (RpcException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + } + + private async Task RunOperationAsync(string operationName, Func operation) + { + try + { + await operation.Invoke().ConfigureAwait(false); + } + catch (RpcException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + } + + private async Task> GetIndexAsync(string indexName, CancellationToken cancellationToken) + { + this._index ??= await this._pineconeClient.GetIndex(indexName, cancellationToken).ConfigureAwait(false); + + return this._index; + } + + private string? GetIndexNamespace() + => this._options.IndexNamespace; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..f328524ec758 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Data; +using Pinecone; + +namespace Microsoft.SemanticKernel.Connectors.Pinecone; + +/// +/// Options when creating a . +/// +public sealed class PineconeVectorStoreRecordCollectionOptions + where TRecord : class +{ + /// + /// Gets or sets an optional custom mapper to use when converting between the data model and the Pinecone vector. + /// + public IVectorStoreRecordMapper? VectorCustomMapper { get; init; } = null; + + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; + + /// + /// Gets or sets the value for a namespace within the Pinecone index that will be used for operations involving records (Get, Upsert, Delete)."/> + /// + public string? IndexNamespace { get; init; } = null; + + /// + /// Gets or sets the value for public cloud where the serverless index is hosted. + /// + /// + /// This value is only used when creating a new Pinecone index. Default value is 'aws'. + /// + public string ServerlessIndexCloud { get; init; } = "aws"; + + /// + /// Gets or sets the value for region where the serverless index is created. + /// + /// + /// This option is only used when creating a new Pinecone index. Default value is 'us-east-1'. + /// + public string ServerlessIndexRegion { get; init; } = "us-east-1"; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordMapper.cs new file mode 100644 index 000000000000..da1d95ad6de9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordMapper.cs @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.SemanticKernel.Data; +using Pinecone; + +namespace Microsoft.SemanticKernel.Connectors.Pinecone; + +/// +/// Mapper between a Pinecone record and the consumer data model that uses json as an intermediary to allow supporting a wide range of models. +/// +/// The consumer data model to map to or from. +internal sealed class PineconeVectorStoreRecordMapper : IVectorStoreRecordMapper + where TRecord : class +{ + /// A set of types that a key on the provided model may have. + private static readonly HashSet s_supportedKeyTypes = [typeof(string)]; + + /// A set of types that data properties on the provided model may have. + private static readonly HashSet s_supportedDataTypes = + [ + typeof(bool), + typeof(bool?), + typeof(string), + typeof(int), + typeof(int?), + typeof(long), + typeof(long?), + typeof(float), + typeof(float?), + typeof(double), + typeof(double?), + typeof(decimal), + typeof(decimal?), + ]; + + /// A set of types that enumerable data properties on the provided model may use as their element types. + private static readonly HashSet s_supportedEnumerableDataElementTypes = + [ + typeof(string) + ]; + + /// A set of types that vectors on the provided model may have. + private static readonly HashSet s_supportedVectorTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?), + ]; + + private readonly PropertyInfo _keyPropertyInfo; + + private readonly List _dataPropertiesInfo; + + private readonly PropertyInfo _vectorPropertyInfo; + + private readonly Dictionary _storagePropertyNames = []; + + private readonly Dictionary _jsonPropertyNames = []; + + /// + /// Initializes a new instance of the class. + /// + /// The record definition that defines the schema of the record type. + public PineconeVectorStoreRecordMapper( + VectorStoreRecordDefinition vectorStoreRecordDefinition) + { + // Validate property types. + var propertiesInfo = VectorStoreRecordPropertyReader.FindProperties(typeof(TRecord), vectorStoreRecordDefinition, supportsMultipleVectors: false); + VectorStoreRecordPropertyReader.VerifyPropertyTypes([propertiesInfo.KeyProperty], s_supportedKeyTypes, "Key"); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(propertiesInfo.DataProperties, s_supportedDataTypes, s_supportedEnumerableDataElementTypes, "Data"); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(propertiesInfo.VectorProperties, s_supportedVectorTypes, "Vector"); + + // Assign. + this._keyPropertyInfo = propertiesInfo.KeyProperty; + this._dataPropertiesInfo = propertiesInfo.DataProperties; + this._vectorPropertyInfo = propertiesInfo.VectorProperties[0]; + + // Get storage names and store for later use. + var properties = VectorStoreRecordPropertyReader.SplitDefinitionAndVerify(typeof(TRecord).Name, vectorStoreRecordDefinition, supportsMultipleVectors: false, requiresAtLeastOneVector: true); + this._jsonPropertyNames = VectorStoreRecordPropertyReader.BuildPropertyNameToJsonPropertyNameMap(properties, typeof(TRecord), JsonSerializerOptions.Default); + this._storagePropertyNames = VectorStoreRecordPropertyReader.BuildPropertyNameToStorageNameMap(properties); + } + + /// + public Vector MapFromDataToStorageModel(TRecord dataModel) + { + var keyObject = this._keyPropertyInfo.GetValue(dataModel); + if (keyObject is null) + { + throw new VectorStoreRecordMappingException($"Key property {this._keyPropertyInfo.Name} on provided record of type {typeof(TRecord).FullName} may not be null."); + } + + var metadata = new MetadataMap(); + foreach (var dataPropertyInfo in this._dataPropertiesInfo) + { + var propertyName = this._storagePropertyNames[dataPropertyInfo.Name]; + var propertyValue = dataPropertyInfo.GetValue(dataModel); + if (propertyValue != null) + { + metadata[propertyName] = ConvertToMetadataValue(propertyValue); + } + } + + var valuesObject = this._vectorPropertyInfo.GetValue(dataModel); + if (valuesObject is not ReadOnlyMemory values) + { + throw new VectorStoreRecordMappingException($"Vector property {this._vectorPropertyInfo.Name} on provided record of type {typeof(TRecord).FullName} may not be null."); + } + + // TODO: what about sparse values? + var result = new Vector + { + Id = (string)keyObject, + Values = values.ToArray(), + Metadata = metadata, + SparseValues = null + }; + + return result; + } + + /// + public TRecord MapFromStorageToDataModel(Vector storageModel, StorageToDataModelMapperOptions options) + { + var keyJsonName = this._jsonPropertyNames[this._keyPropertyInfo.Name]; + var outputJsonObject = new JsonObject + { + { keyJsonName, JsonValue.Create(storageModel.Id) }, + }; + + if (options?.IncludeVectors is true) + { + var propertyName = this._storagePropertyNames[this._vectorPropertyInfo.Name]; + var jsonName = this._jsonPropertyNames[this._vectorPropertyInfo.Name]; + outputJsonObject.Add(jsonName, new JsonArray(storageModel.Values.Select(x => JsonValue.Create(x)).ToArray())); + } + + if (storageModel.Metadata != null) + { + foreach (var dataProperty in this._dataPropertiesInfo) + { + var propertyName = this._storagePropertyNames[dataProperty.Name]; + var jsonName = this._jsonPropertyNames[dataProperty.Name]; + + if (storageModel.Metadata.TryGetValue(propertyName, out var value)) + { + outputJsonObject.Add(jsonName, ConvertFromMetadataValueToJsonNode(value)); + } + } + } + + return outputJsonObject.Deserialize()!; + } + + private static JsonNode? ConvertFromMetadataValueToJsonNode(MetadataValue metadataValue) + => metadataValue.Inner switch + { + null => null, + bool boolValue => JsonValue.Create(boolValue), + string stringValue => JsonValue.Create(stringValue), + int intValue => JsonValue.Create(intValue), + long longValue => JsonValue.Create(longValue), + float floatValue => JsonValue.Create(floatValue), + double doubleValue => JsonValue.Create(doubleValue), + decimal decimalValue => JsonValue.Create(decimalValue), + MetadataValue[] array => new JsonArray(array.Select(ConvertFromMetadataValueToJsonNode).ToArray()), + List list => new JsonArray(list.Select(ConvertFromMetadataValueToJsonNode).ToArray()), + _ => throw new VectorStoreRecordMappingException($"Unsupported metadata type: '{metadataValue.Inner?.GetType().FullName}'."), + }; + + // TODO: take advantage of MetadataValue.TryCreate once we upgrade the version of Pinecone.NET + private static MetadataValue ConvertToMetadataValue(object? sourceValue) + => sourceValue switch + { + bool boolValue => boolValue, + string stringValue => stringValue, + int intValue => intValue, + long longValue => longValue, + float floatValue => floatValue, + double doubleValue => doubleValue, + decimal decimalValue => decimalValue, + string[] stringArray => stringArray, + List stringList => stringList, + IEnumerable stringEnumerable => stringEnumerable.ToArray(), + _ => throw new VectorStoreRecordMappingException($"Unsupported source value type '{sourceValue?.GetType().FullName}'.") + }; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Connectors.Memory.Qdrant.csproj b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Connectors.Memory.Qdrant.csproj index d9037605f6e5..322a58d22400 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/Connectors.Memory.Qdrant.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/Connectors.Memory.Qdrant.csproj @@ -20,10 +20,12 @@ + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorStoreRecordCollectionFactory.cs new file mode 100644 index 000000000000..2f93e14dfb82 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/IQdrantVectorStoreRecordCollectionFactory.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Data; +using Qdrant.Client; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +/// +/// Interface for constructing Qdrant instances when using to retrieve these. +/// +public interface IQdrantVectorStoreRecordCollectionFactory +{ + /// + /// Constructs a new instance of the . + /// + /// The data type of the record key. + /// The data model to use for adding, updating and retrieving data from storage. + /// Qdrant client that can be used to manage the collections and points in a Qdrant store. + /// The name of the collection to connect to. + /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. + /// The new instance of . + IVectorStoreRecordCollection CreateVectorStoreRecordCollection(QdrantClient qdrantClient, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + where TKey : notnull + where TRecord : class; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/MockableQdrantClient.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/MockableQdrantClient.cs new file mode 100644 index 000000000000..020455558b7d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/MockableQdrantClient.cs @@ -0,0 +1,258 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Qdrant.Client; +using Qdrant.Client.Grpc; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +/// +/// Decorator class for that exposes the required methods as virtual allowing for mocking in unit tests. +/// +internal class MockableQdrantClient +{ + /// Qdrant client that can be used to manage the collections and points in a Qdrant store. + private readonly QdrantClient _qdrantClient; + + /// + /// Initializes a new instance of the class. + /// + /// Qdrant client that can be used to manage the collections and points in a Qdrant store. + public MockableQdrantClient(QdrantClient qdrantClient) + { + Verify.NotNull(qdrantClient); + this._qdrantClient = qdrantClient; + } + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + + /// + /// Constructor for mocking purposes only. + /// + internal MockableQdrantClient() + { + } + +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + + /// + /// Gets the internal that this mockable instance wraps. + /// + public QdrantClient QdrantClient => this._qdrantClient; + + /// + /// Check if a collection exists. + /// + /// The name of the collection. + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task CollectionExistsAsync( + string collectionName, + CancellationToken cancellationToken = default) + => this._qdrantClient.CollectionExistsAsync(collectionName, cancellationToken); + + /// + /// Creates a new collection with the given parameters. + /// + /// The name of the collection to be created. + /// + /// Configuration of the vector storage. Vector params contains size and distance for the vector storage. + /// This overload creates a single anonymous vector storage. + /// + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task CreateCollectionAsync( + string collectionName, + VectorParams vectorsConfig, + CancellationToken cancellationToken = default) + => this._qdrantClient.CreateCollectionAsync( + collectionName, + vectorsConfig, + cancellationToken: cancellationToken); + + /// + /// Creates a new collection with the given parameters. + /// + /// The name of the collection to be created. + /// + /// Configuration of the vector storage. Vector params contains size and distance for the vector storage. + /// This overload creates a vector storage for each key in the provided map. + /// + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task CreateCollectionAsync( + string collectionName, + VectorParamsMap? vectorsConfig = null, + CancellationToken cancellationToken = default) + => this._qdrantClient.CreateCollectionAsync( + collectionName, + vectorsConfig, + cancellationToken: cancellationToken); + + /// + /// Creates a payload field index in a collection. + /// + /// The name of the collection. + /// Field name to index. + /// The schema type of the field. + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task CreatePayloadIndexAsync( + string collectionName, + string fieldName, + PayloadSchemaType schemaType = PayloadSchemaType.Keyword, + CancellationToken cancellationToken = default) + => this._qdrantClient.CreatePayloadIndexAsync(collectionName, fieldName, schemaType, cancellationToken: cancellationToken); + + /// + /// Drop a collection and all its associated data. + /// + /// The name of the collection. + /// Wait timeout for operation commit in seconds, if not specified - default value will be supplied + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task DeleteCollectionAsync( + string collectionName, + TimeSpan? timeout = null, + CancellationToken cancellationToken = default) + => this._qdrantClient.DeleteCollectionAsync(collectionName, timeout, cancellationToken); + + /// + /// Gets the names of all existing collections. + /// + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task> ListCollectionsAsync(CancellationToken cancellationToken = default) + => this._qdrantClient.ListCollectionsAsync(cancellationToken); + + /// + /// Delete a point. + /// + /// The name of the collection. + /// The ID to delete. + /// Whether to wait until the changes have been applied. Defaults to true. + /// Write ordering guarantees. Defaults to Weak. + /// Option for custom sharding to specify used shard keys. + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task DeleteAsync( + string collectionName, + ulong id, + bool wait = true, + WriteOrderingType? ordering = null, + ShardKeySelector? shardKeySelector = null, + CancellationToken cancellationToken = default) + => this._qdrantClient.DeleteAsync(collectionName, id, wait, ordering, shardKeySelector, cancellationToken: cancellationToken); + + /// + /// Delete a point. + /// + /// The name of the collection. + /// The ID to delete. + /// Whether to wait until the changes have been applied. Defaults to true. + /// Write ordering guarantees. Defaults to Weak. + /// Option for custom sharding to specify used shard keys. + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task DeleteAsync( + string collectionName, + Guid id, + bool wait = true, + WriteOrderingType? ordering = null, + ShardKeySelector? shardKeySelector = null, + CancellationToken cancellationToken = default) + => this._qdrantClient.DeleteAsync(collectionName, id, wait, ordering, shardKeySelector, cancellationToken: cancellationToken); + + /// + /// Delete a point. + /// + /// The name of the collection. + /// The IDs to delete. + /// Whether to wait until the changes have been applied. Defaults to true. + /// Write ordering guarantees. Defaults to Weak. + /// Option for custom sharding to specify used shard keys. + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task DeleteAsync( + string collectionName, + IReadOnlyList ids, + bool wait = true, + WriteOrderingType? ordering = null, + ShardKeySelector? shardKeySelector = null, + CancellationToken cancellationToken = default) + => this._qdrantClient.DeleteAsync(collectionName, ids, wait, ordering, shardKeySelector, cancellationToken: cancellationToken); + + /// + /// Delete a point. + /// + /// The name of the collection. + /// The IDs to delete. + /// Whether to wait until the changes have been applied. Defaults to true. + /// Write ordering guarantees. Defaults to Weak. + /// Option for custom sharding to specify used shard keys. + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task DeleteAsync( + string collectionName, + IReadOnlyList ids, + bool wait = true, + WriteOrderingType? ordering = null, + ShardKeySelector? shardKeySelector = null, + CancellationToken cancellationToken = default) + => this._qdrantClient.DeleteAsync(collectionName, ids, wait, ordering, shardKeySelector, cancellationToken: cancellationToken); + + /// + /// Perform insert and updates on points. If a point with a given ID already exists, it will be overwritten. + /// + /// The name of the collection. + /// The points to be upserted. + /// Whether to wait until the changes have been applied. Defaults to true. + /// Write ordering guarantees. + /// Option for custom sharding to specify used shard keys. + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task UpsertAsync( + string collectionName, + IReadOnlyList points, + bool wait = true, + WriteOrderingType? ordering = null, + ShardKeySelector? shardKeySelector = null, + CancellationToken cancellationToken = default) + => this._qdrantClient.UpsertAsync(collectionName, points, wait, ordering, shardKeySelector, cancellationToken); + + /// + /// Retrieve points. + /// + /// The name of the collection. + /// List of points to retrieve. + /// Whether to include the payload or not. + /// Whether to include the vectors or not. + /// Options for specifying read consistency guarantees. + /// Option for custom sharding to specify used shard keys. + /// + /// The token to monitor for cancellation requests. The default value is . + /// + public virtual Task> RetrieveAsync( + string collectionName, + IReadOnlyList ids, + bool withPayload = true, + bool withVectors = false, + ReadConsistency? readConsistency = null, + ShardKeySelector? shardKeySelector = null, + CancellationToken cancellationToken = default) + => this._qdrantClient.RetrieveAsync(collectionName, ids, withPayload, withVectors, readConsistency, shardKeySelector, cancellationToken); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantKernelBuilderExtensions.cs new file mode 100644 index 000000000000..213aef587653 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantKernelBuilderExtensions.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Data; +using Qdrant.Client; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +/// +/// Extension methods to register Qdrant instances on the . +/// +public static class QdrantKernelBuilderExtensions +{ + /// + /// Register a Qdrant with the specified service ID and where is retrieved from the dependency injection container. + /// + /// The builder to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddQdrantVectorStore(this IKernelBuilder builder, QdrantVectorStoreOptions? options = default, string? serviceId = default) + { + builder.Services.AddQdrantVectorStore(options, serviceId); + return builder; + } + /// + /// Register a Qdrant with the specified service ID and where is constructed using the provided parameters. + /// + /// The builder to register the on. + /// The Qdrant service host name. + /// The Qdrant service port. + /// A value indicating whether to use HTTPS for communicating with Qdrant. + /// The Qdrant service API key. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddQdrantVectorStore(this IKernelBuilder builder, string host, int port = 6334, bool https = false, string? apiKey = default, QdrantVectorStoreOptions? options = default, string? serviceId = default) + { + builder.Services.AddQdrantVectorStore(host, port, https, apiKey, options, serviceId); + return builder; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantServiceCollectionExtensions.cs new file mode 100644 index 000000000000..b534b2ea7578 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantServiceCollectionExtensions.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Data; +using Qdrant.Client; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +/// +/// Extension methods to register Qdrant instances on an . +/// +public static class QdrantServiceCollectionExtensions +{ + /// + /// Register a Qdrant with the specified service ID and where is retrieved from the dependency injection container. + /// + /// The to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddQdrantVectorStore(this IServiceCollection services, QdrantVectorStoreOptions? options = default, string? serviceId = default) + { + // If we are not constructing the QdrantClient, add the IVectorStore as transient, since we + // cannot make assumptions about how QdrantClient is being managed. + services.AddKeyedTransient( + serviceId, + (sp, obj) => + { + var qdrantClient = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService(); + + return new QdrantVectorStore( + qdrantClient, + selectedOptions); + }); + + return services; + } + /// + /// Register a Qdrant with the specified service ID and where is constructed using the provided parameters. + /// + /// The to register the on. + /// The Qdrant service host name. + /// The Qdrant service port. + /// A value indicating whether to use HTTPS for communicating with Qdrant. + /// The Qdrant service API key. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddQdrantVectorStore(this IServiceCollection services, string host, int port = 6334, bool https = false, string? apiKey = default, QdrantVectorStoreOptions? options = default, string? serviceId = default) + { + services.AddKeyedSingleton( + serviceId, + (sp, obj) => + { + var qdrantClient = new QdrantClient(host, port, https, apiKey); + var selectedOptions = options ?? sp.GetService(); + + return new QdrantVectorStore( + qdrantClient, + selectedOptions); + }); + + return services; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStore.cs new file mode 100644 index 000000000000..ef9c9f1593f0 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStore.cs @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using Grpc.Core; +using Microsoft.SemanticKernel.Data; +using Qdrant.Client; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +/// +/// Class for accessing the list of collections in a Qdrant vector store. +/// +/// +/// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. +/// +public sealed class QdrantVectorStore : IVectorStore +{ + /// The name of this database for telemetry purposes. + private const string DatabaseName = "Qdrant"; + + /// Qdrant client that can be used to manage the collections and points in a Qdrant store. + private readonly MockableQdrantClient _qdrantClient; + + /// Optional configuration options for this class. + private readonly QdrantVectorStoreOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// Qdrant client that can be used to manage the collections and points in a Qdrant store. + /// Optional configuration options for this class. + public QdrantVectorStore(QdrantClient qdrantClient, QdrantVectorStoreOptions? options = default) + : this(new MockableQdrantClient(qdrantClient), options) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// Qdrant client that can be used to manage the collections and points in a Qdrant store. + /// Optional configuration options for this class. + internal QdrantVectorStore(MockableQdrantClient qdrantClient, QdrantVectorStoreOptions? options = default) + { + Verify.NotNull(qdrantClient); + + this._qdrantClient = qdrantClient; + this._options = options ?? new QdrantVectorStoreOptions(); + } + + /// + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + where TRecord : class + { + if (typeof(TKey) != typeof(ulong) && typeof(TKey) != typeof(Guid)) + { + throw new NotSupportedException("Only ulong and Guid keys are supported."); + } + + if (this._options.VectorStoreCollectionFactory is not null) + { + return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(this._qdrantClient.QdrantClient, name, vectorStoreRecordDefinition); + } + + var directlyCreatedStore = new QdrantVectorStoreRecordCollection(this._qdrantClient, name, new QdrantVectorStoreRecordCollectionOptions() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }); + var castCreatedStore = directlyCreatedStore as IVectorStoreRecordCollection; + return castCreatedStore!; + } + + /// + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + IReadOnlyList collections; + + try + { + collections = await this._qdrantClient.ListCollectionsAsync(cancellationToken).ConfigureAwait(false); + } + catch (RpcException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + OperationName = "ListCollections" + }; + } + + foreach (var collection in collections) + { + yield return collection; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionCreateMapping.cs new file mode 100644 index 000000000000..e637ae2e06ab --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionCreateMapping.cs @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.SemanticKernel.Data; +using Qdrant.Client.Grpc; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +/// +/// Contains mapping helpers to use when creating a qdrant vector collection. +/// +internal static class QdrantVectorStoreCollectionCreateMapping +{ + /// A dictionary of types and their matching qdrant index schema type. + public static readonly Dictionary s_schemaTypeMap = new() + { + { typeof(short), PayloadSchemaType.Integer }, + { typeof(sbyte), PayloadSchemaType.Integer }, + { typeof(byte), PayloadSchemaType.Integer }, + { typeof(ushort), PayloadSchemaType.Integer }, + { typeof(int), PayloadSchemaType.Integer }, + { typeof(uint), PayloadSchemaType.Integer }, + { typeof(long), PayloadSchemaType.Integer }, + { typeof(ulong), PayloadSchemaType.Integer }, + { typeof(float), PayloadSchemaType.Float }, + { typeof(double), PayloadSchemaType.Float }, + { typeof(decimal), PayloadSchemaType.Float }, + + { typeof(short?), PayloadSchemaType.Integer }, + { typeof(sbyte?), PayloadSchemaType.Integer }, + { typeof(byte?), PayloadSchemaType.Integer }, + { typeof(ushort?), PayloadSchemaType.Integer }, + { typeof(int?), PayloadSchemaType.Integer }, + { typeof(uint?), PayloadSchemaType.Integer }, + { typeof(long?), PayloadSchemaType.Integer }, + { typeof(ulong?), PayloadSchemaType.Integer }, + { typeof(float?), PayloadSchemaType.Float }, + { typeof(double?), PayloadSchemaType.Float }, + { typeof(decimal?), PayloadSchemaType.Float }, + + { typeof(string), PayloadSchemaType.Keyword }, + { typeof(DateTime), PayloadSchemaType.Datetime }, + { typeof(bool), PayloadSchemaType.Bool }, + + { typeof(DateTime?), PayloadSchemaType.Datetime }, + { typeof(bool?), PayloadSchemaType.Bool }, + }; + + /// + /// Maps a single to a qdrant . + /// + /// The property to map. + /// The mapped . + /// Thrown if the property is missing information or has unsupported options specified. + public static VectorParams MapSingleVector(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty!.Dimensions is not > 0) + { + throw new InvalidOperationException($"Property {nameof(vectorProperty.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' must be set to a positive integer to create a collection."); + } + + if (vectorProperty!.IndexKind is not null && vectorProperty!.IndexKind != IndexKind.Hnsw) + { + throw new InvalidOperationException($"Index kind '{vectorProperty!.IndexKind}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Qdrant VectorStore."); + } + + return new VectorParams { Size = (ulong)vectorProperty.Dimensions, Distance = QdrantVectorStoreCollectionCreateMapping.GetSDKDistanceAlgorithm(vectorProperty) }; + } + + /// + /// Maps a collection of to a qdrant . + /// + /// The properties to map. + /// The mapping of property names to storage names. + /// THe mapped . + /// Thrown if the property is missing information or has unsupported options specified. + public static VectorParamsMap MapNamedVectors(IEnumerable vectorProperties, Dictionary storagePropertyNames) + { + var vectorParamsMap = new VectorParamsMap(); + + foreach (var vectorProperty in vectorProperties) + { + var storageName = storagePropertyNames[vectorProperty.DataModelPropertyName]; + + // Add each vector property to the vectors map. + vectorParamsMap.Map.Add( + storageName, + MapSingleVector(vectorProperty)); + } + + return vectorParamsMap; + } + + /// + /// Get the configured from the given . + /// If none is configured, the default is . + /// + /// The vector property definition. + /// The chosen . + /// Thrown if a distance function is chosen that isn't supported by qdrant. + public static Distance GetSDKDistanceAlgorithm(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty.DistanceFunction is null) + { + return Distance.Cosine; + } + + return vectorProperty.DistanceFunction switch + { + DistanceFunction.CosineSimilarity => Distance.Cosine, + DistanceFunction.DotProductSimilarity => Distance.Dot, + DistanceFunction.EuclideanDistance => Distance.Euclid, + DistanceFunction.ManhattanDistance => Distance.Manhattan, + _ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Qdrant VectorStore.") + }; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreOptions.cs new file mode 100644 index 000000000000..c3ead1bdee2d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreOptions.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +/// +/// Options when creating a . +/// +public sealed class QdrantVectorStoreOptions +{ + /// + /// Gets or sets a value indicating whether the vectors in the store are named and multiple vectors are supported, or whether there is just a single unnamed vector per qdrant point. + /// Defaults to single vector per point. + /// + public bool HasNamedVectors { get; set; } = false; + + /// + /// An optional factory to use for constructing instances, if custom options are required. + /// + public IQdrantVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..a49c530b2cdb --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs @@ -0,0 +1,481 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; +using Microsoft.SemanticKernel.Data; +using Qdrant.Client; +using Qdrant.Client.Grpc; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +/// +/// Service for storing and retrieving vector records, that uses Qdrant as the underlying storage. +/// +/// The data model to use for adding, updating and retrieving data from storage. +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public sealed class QdrantVectorStoreRecordCollection : IVectorStoreRecordCollection, IVectorStoreRecordCollection +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix + where TRecord : class +{ + /// A set of types that a key on the provided model may have. + private static readonly HashSet s_supportedKeyTypes = + [ + typeof(ulong), + typeof(Guid) + ]; + + /// The name of this database for telemetry purposes. + private const string DatabaseName = "Qdrant"; + + /// The name of the upsert operation for telemetry purposes. + private const string UpsertName = "Upsert"; + + /// The name of the Delete operation for telemetry purposes. + private const string DeleteName = "Delete"; + + /// Qdrant client that can be used to manage the collections and points in a Qdrant store. + private readonly MockableQdrantClient _qdrantClient; + + /// The name of the collection that this will access. + private readonly string _collectionName; + + /// Optional configuration options for this class. + private readonly QdrantVectorStoreRecordCollectionOptions _options; + + /// A definition of the current storage model. + private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition; + + /// A mapper to use for converting between qdrant point and consumer models. + private readonly IVectorStoreRecordMapper _mapper; + + /// A dictionary that maps from a property name to the configured name that should be used when storing it. + private readonly Dictionary _storagePropertyNames = new(); + + /// + /// Initializes a new instance of the class. + /// + /// Qdrant client that can be used to manage the collections and points in a Qdrant store. + /// The name of the collection that this will access. + /// Optional configuration options for this class. + /// Thrown if the is null. + /// Thrown for any misconfigured options. + public QdrantVectorStoreRecordCollection(QdrantClient qdrantClient, string collectionName, QdrantVectorStoreRecordCollectionOptions? options = null) + : this(new MockableQdrantClient(qdrantClient), collectionName, options) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// Qdrant client that can be used to manage the collections and points in a Qdrant store. + /// The name of the collection that this will access. + /// Optional configuration options for this class. + /// Thrown if the is null. + /// Thrown for any misconfigured options. + internal QdrantVectorStoreRecordCollection(MockableQdrantClient qdrantClient, string collectionName, QdrantVectorStoreRecordCollectionOptions? options = null) + { + // Verify. + Verify.NotNull(qdrantClient); + Verify.NotNullOrWhiteSpace(collectionName); + + // Assign. + this._qdrantClient = qdrantClient; + this._collectionName = collectionName; + this._options = options ?? new QdrantVectorStoreRecordCollectionOptions(); + this._vectorStoreRecordDefinition = this._options.VectorStoreRecordDefinition ?? VectorStoreRecordPropertyReader.CreateVectorStoreRecordDefinitionFromType(typeof(TRecord), true); + + // Validate property types. + var properties = VectorStoreRecordPropertyReader.SplitDefinitionAndVerify(typeof(TRecord).Name, this._vectorStoreRecordDefinition, supportsMultipleVectors: this._options.HasNamedVectors, requiresAtLeastOneVector: !this._options.HasNamedVectors); + VectorStoreRecordPropertyReader.VerifyPropertyTypes([properties.KeyProperty], s_supportedKeyTypes, "Key"); + + // Build a map of property names to storage names. + this._storagePropertyNames = VectorStoreRecordPropertyReader.BuildPropertyNameToStorageNameMap(properties); + + // Assign Mapper. + if (this._options.PointStructCustomMapper is not null) + { + // Custom Mapper. + this._mapper = this._options.PointStructCustomMapper; + } + else + { + // Default Mapper. + this._mapper = new QdrantVectorStoreRecordMapper( + this._vectorStoreRecordDefinition, + this._options.HasNamedVectors, + this._storagePropertyNames); + } + } + + /// + public string CollectionName => this._collectionName; + + /// + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + return this.RunOperationAsync( + "CollectionExists", + () => this._qdrantClient.CollectionExistsAsync(this._collectionName, cancellationToken)); + } + + /// + public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + if (!this._options.HasNamedVectors) + { + // If we are not using named vectors, we can only have one vector property. We can assume we have exactly one, since this is already verified in the constructor. + var singleVectorProperty = this._vectorStoreRecordDefinition.Properties.OfType().First(); + + // Map the single vector property to the qdrant config. + var vectorParams = QdrantVectorStoreCollectionCreateMapping.MapSingleVector(singleVectorProperty!); + + // Create the collection with the single unnamed vector. + await this.RunOperationAsync( + "CreateCollection", + () => this._qdrantClient.CreateCollectionAsync( + this._collectionName, + vectorParams, + cancellationToken: cancellationToken)).ConfigureAwait(false); + } + else + { + // Since we are using named vectors, iterate over all vector properties. + var vectorProperties = this._vectorStoreRecordDefinition.Properties.OfType(); + + // Map the named vectors to the qdrant config. + var vectorParamsMap = QdrantVectorStoreCollectionCreateMapping.MapNamedVectors(vectorProperties, this._storagePropertyNames); + + // Create the collection with named vectors. + await this.RunOperationAsync( + "CreateCollection", + () => this._qdrantClient.CreateCollectionAsync( + this._collectionName, + vectorParamsMap, + cancellationToken: cancellationToken)).ConfigureAwait(false); + } + + // Add indexes for each of the data properties that require filtering. + var dataProperties = this._vectorStoreRecordDefinition.Properties.OfType().Where(x => x.IsFilterable); + foreach (var dataProperty in dataProperties) + { + var storageFieldName = this._storagePropertyNames[dataProperty.DataModelPropertyName]; + var schemaType = QdrantVectorStoreCollectionCreateMapping.s_schemaTypeMap[dataProperty.PropertyType!]; + + await this.RunOperationAsync( + "CreatePayloadIndex", + () => this._qdrantClient.CreatePayloadIndexAsync( + this._collectionName, + storageFieldName, + schemaType, + cancellationToken: cancellationToken)).ConfigureAwait(false); + } + + // Add indexes for each of the data properties that require full text search. + dataProperties = this._vectorStoreRecordDefinition.Properties.OfType().Where(x => x.IsFullTextSearchable); + foreach (var dataProperty in dataProperties) + { + if (dataProperty.PropertyType != typeof(string)) + { + throw new InvalidOperationException($"Property {nameof(dataProperty.IsFullTextSearchable)} on {nameof(VectorStoreRecordDataProperty)} '{dataProperty.DataModelPropertyName}' is set to true, but the property type is not a string. The Qdrant VectorStore supports {nameof(dataProperty.IsFullTextSearchable)} on string properties only."); + } + + var storageFieldName = this._storagePropertyNames[dataProperty.DataModelPropertyName]; + + await this.RunOperationAsync( + "CreatePayloadIndex", + () => this._qdrantClient.CreatePayloadIndexAsync( + this._collectionName, + storageFieldName, + PayloadSchemaType.Text, + cancellationToken: cancellationToken)).ConfigureAwait(false); + } + } + + /// + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) + { + await this.CreateCollectionAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + return this.RunOperationAsync( + "DeleteCollection", + () => this._qdrantClient.DeleteCollectionAsync(this._collectionName, null, cancellationToken)); + } + + /// + public async Task GetAsync(ulong key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(key); + + var retrievedPoints = await this.GetBatchAsync([key], options, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); + return retrievedPoints.FirstOrDefault(); + } + + /// + public async Task GetAsync(Guid key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(key); + + var retrievedPoints = await this.GetBatchAsync([key], options, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); + return retrievedPoints.FirstOrDefault(); + } + + /// + public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, CancellationToken cancellationToken = default) + { + return this.GetBatchByPointIdAsync(keys, key => new PointId { Num = key }, options, cancellationToken); + } + + /// + public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, CancellationToken cancellationToken = default) + { + return this.GetBatchByPointIdAsync(keys, key => new PointId { Uuid = key.ToString("D") }, options, cancellationToken); + } + + /// + public Task DeleteAsync(ulong key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(key); + + return this.RunOperationAsync( + DeleteName, + () => this._qdrantClient.DeleteAsync( + this._collectionName, + key, + wait: true, + cancellationToken: cancellationToken)); + } + + /// + public Task DeleteAsync(Guid key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(key); + + return this.RunOperationAsync( + DeleteName, + () => this._qdrantClient.DeleteAsync( + this._collectionName, + key, + wait: true, + cancellationToken: cancellationToken)); + } + + /// + public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + return this.RunOperationAsync( + DeleteName, + () => this._qdrantClient.DeleteAsync( + this._collectionName, + keys.ToList(), + wait: true, + cancellationToken: cancellationToken)); + } + + /// + public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + return this.RunOperationAsync( + DeleteName, + () => this._qdrantClient.DeleteAsync( + this._collectionName, + keys.ToList(), + wait: true, + cancellationToken: cancellationToken)); + } + + /// + public async Task UpsertAsync(TRecord record, UpsertRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNull(record); + + // Create point from record. + var pointStruct = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + UpsertName, + () => this._mapper.MapFromDataToStorageModel(record)); + + // Upsert. + await this.RunOperationAsync( + UpsertName, + () => this._qdrantClient.UpsertAsync(this._collectionName, [pointStruct], true, cancellationToken: cancellationToken)).ConfigureAwait(false); + return pointStruct.Id.Num; + } + + /// + async Task IVectorStoreRecordCollection.UpsertAsync(TRecord record, UpsertRecordOptions? options, CancellationToken cancellationToken) + { + Verify.NotNull(record); + + // Create point from record. + var pointStruct = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + UpsertName, + () => this._mapper.MapFromDataToStorageModel(record)); + + // Upsert. + await this.RunOperationAsync( + UpsertName, + () => this._qdrantClient.UpsertAsync(this._collectionName, [pointStruct], true, cancellationToken: cancellationToken)).ConfigureAwait(false); + return Guid.Parse(pointStruct.Id.Uuid); + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(records); + + // Create points from records. + var pointStructs = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + UpsertName, + () => records.Select(this._mapper.MapFromDataToStorageModel).ToList()); + + // Upsert. + await this.RunOperationAsync( + UpsertName, + () => this._qdrantClient.UpsertAsync(this._collectionName, pointStructs, true, cancellationToken: cancellationToken)).ConfigureAwait(false); + + foreach (var pointStruct in pointStructs) + { + yield return pointStruct.Id.Num; + } + } + + /// + async IAsyncEnumerable IVectorStoreRecordCollection.UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) + { + Verify.NotNull(records); + + // Create points from records. + var pointStructs = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + UpsertName, + () => records.Select(this._mapper.MapFromDataToStorageModel).ToList()); + + // Upsert. + await this.RunOperationAsync( + UpsertName, + () => this._qdrantClient.UpsertAsync(this._collectionName, pointStructs, true, cancellationToken: cancellationToken)).ConfigureAwait(false); + + foreach (var pointStruct in pointStructs) + { + yield return Guid.Parse(pointStruct.Id.Uuid); + } + } + + /// + /// Get the requested records from the Qdrant store using the provided keys. + /// + /// The keys of the points to retrieve. + /// Function to convert the provided keys to point ids. + /// The retrieval options. + /// The to monitor for cancellation requests. The default is . + /// The retrieved points. + private async IAsyncEnumerable GetBatchByPointIdAsync( + IEnumerable keys, + Func keyConverter, + GetRecordOptions? options, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + const string OperationName = "Retrieve"; + Verify.NotNull(keys); + + // Create options. + var pointsIds = keys.Select(key => keyConverter(key)).ToArray(); + var includeVectors = options?.IncludeVectors ?? false; + + // Retrieve data points. + var retrievedPoints = await this.RunOperationAsync( + OperationName, + () => this._qdrantClient.RetrieveAsync(this._collectionName, pointsIds, true, includeVectors, cancellationToken: cancellationToken)).ConfigureAwait(false); + + // Convert the retrieved points to the target data model. + foreach (var retrievedPoint in retrievedPoints) + { + var pointStruct = new PointStruct + { + Id = retrievedPoint.Id, + Vectors = retrievedPoint.Vectors, + Payload = { } + }; + + foreach (KeyValuePair payloadEntry in retrievedPoint.Payload) + { + pointStruct.Payload.Add(payloadEntry.Key, payloadEntry.Value); + } + + yield return VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel(pointStruct, new() { IncludeVectors = includeVectors })); + } + } + + /// + /// Run the given operation and wrap any with ."/> + /// + /// The type of database operation being run. + /// The operation to run. + /// The result of the operation. + private async Task RunOperationAsync(string operationName, Func operation) + { + try + { + await operation.Invoke().ConfigureAwait(false); + } + catch (RpcException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this._collectionName, + OperationName = operationName + }; + } + } + + /// + /// Run the given operation and wrap any with ."/> + /// + /// The response type of the operation. + /// The type of database operation being run. + /// The operation to run. + /// The result of the operation. + private async Task RunOperationAsync(string operationName, Func> operation) + { + try + { + return await operation.Invoke().ConfigureAwait(false); + } + catch (RpcException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this._collectionName, + OperationName = operationName + }; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..e6c51c97f6a6 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Data; +using Qdrant.Client.Grpc; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +/// +/// Options when creating a . +/// +public sealed class QdrantVectorStoreRecordCollectionOptions + where TRecord : class +{ + /// + /// Gets or sets a value indicating whether the vectors in the store are named and multiple vectors are supported, or whether there is just a single unnamed vector per qdrant point. + /// Defaults to single vector per point. + /// + public bool HasNamedVectors { get; set; } = false; + + /// + /// Gets or sets an optional custom mapper to use when converting between the data model and the qdrant point. + /// + /// + /// If not set, a default mapper that uses json as an intermediary to allow automatic mapping to a wide variety of types will be used. + /// + public IVectorStoreRecordMapper? PointStructCustomMapper { get; init; } = null; + + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordMapper.cs new file mode 100644 index 000000000000..2c4238982391 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordMapper.cs @@ -0,0 +1,298 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.SemanticKernel.Data; +using Qdrant.Client.Grpc; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +/// +/// Mapper between a Qdrant record and the consumer data model that uses json as an intermediary to allow supporting a wide range of models. +/// +/// The consumer data model to map to or from. +internal sealed class QdrantVectorStoreRecordMapper : IVectorStoreRecordMapper + where TRecord : class +{ + /// A set of types that data properties on the provided model may have. + private static readonly HashSet s_supportedDataTypes = + [ + typeof(string), + typeof(int), + typeof(long), + typeof(double), + typeof(float), + typeof(bool), + typeof(int?), + typeof(long?), + typeof(double?), + typeof(float?), + typeof(bool?) + ]; + + /// A set of types that vectors on the provided model may have. + /// + /// While qdrant supports float32 and uint64, the api only supports float64, therefore + /// any float32 vectors will be converted to float64 before being sent to qdrant. + /// + private static readonly HashSet s_supportedVectorTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?), + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?) + ]; + + /// A property info object that points at the key property for the current model, allowing easy reading and writing of this property. + private readonly PropertyInfo _keyPropertyInfo; + + /// A list of property info objects that point at the data properties in the current model, and allows easy reading and writing of these properties. + private readonly List _dataPropertiesInfo; + + /// A list of property info objects that point at the vector properties in the current model, and allows easy reading and writing of these properties. + private readonly List _vectorPropertiesInfo; + + /// A dictionary that maps from a property name to the configured name that should be used when storing it. + private readonly Dictionary _storagePropertyNames; + + /// A dictionary that maps from a property name to the configured name that should be used when serializing it to json. + private readonly Dictionary _jsonPropertyNames = new(); + + /// A value indicating whether the vectors in the store are named, or whether there is just a single unnamed vector per qdrant point. + private readonly bool _hasNamedVectors; + + /// + /// Initializes a new instance of the class. + /// + /// The record definition that defines the schema of the record type. + /// A value indicating whether the vectors in the store are named, or whether there is just a single unnamed vector per qdrant point. + /// A dictionary that maps from a property name to the configured name that should be used when storing it. + public QdrantVectorStoreRecordMapper( + VectorStoreRecordDefinition vectorStoreRecordDefinition, + bool hasNamedVectors, + Dictionary storagePropertyNames) + { + Verify.NotNull(vectorStoreRecordDefinition); + Verify.NotNull(storagePropertyNames); + + // Validate property types. + var propertiesInfo = VectorStoreRecordPropertyReader.FindProperties(typeof(TRecord), vectorStoreRecordDefinition, supportsMultipleVectors: hasNamedVectors); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(propertiesInfo.DataProperties, s_supportedDataTypes, "Data", supportEnumerable: true); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(propertiesInfo.VectorProperties, s_supportedVectorTypes, "Vector"); + + // Assign. + this._hasNamedVectors = hasNamedVectors; + this._keyPropertyInfo = propertiesInfo.KeyProperty; + this._dataPropertiesInfo = propertiesInfo.DataProperties; + this._vectorPropertiesInfo = propertiesInfo.VectorProperties; + this._storagePropertyNames = storagePropertyNames; + + // Get json storage names and store for later use. + this._jsonPropertyNames = VectorStoreRecordPropertyReader.BuildPropertyNameToJsonPropertyNameMap(propertiesInfo, typeof(TRecord), JsonSerializerOptions.Default); + } + + /// + public PointStruct MapFromDataToStorageModel(TRecord dataModel) + { + PointId pointId; + if (this._keyPropertyInfo.PropertyType == typeof(ulong)) + { + var key = this._keyPropertyInfo.GetValue(dataModel) as ulong? ?? throw new VectorStoreRecordMappingException($"Missing key property {this._keyPropertyInfo.Name} on provided record of type {typeof(TRecord).FullName}."); + pointId = new PointId { Num = key }; + } + else if (this._keyPropertyInfo.PropertyType == typeof(Guid)) + { + var key = this._keyPropertyInfo.GetValue(dataModel) as Guid? ?? throw new VectorStoreRecordMappingException($"Missing key property {this._keyPropertyInfo.Name} on provided record of type {typeof(TRecord).FullName}."); + pointId = new PointId { Uuid = key.ToString("D") }; + } + else + { + throw new VectorStoreRecordMappingException($"Unsupported key type {this._keyPropertyInfo.PropertyType.FullName} for key property {this._keyPropertyInfo.Name} on provided record of type {typeof(TRecord).FullName}."); + } + + // Create point. + var pointStruct = new PointStruct + { + Id = pointId, + Vectors = new Vectors(), + Payload = { }, + }; + + // Add point payload. + foreach (var dataPropertyInfo in this._dataPropertiesInfo) + { + var propertyName = this._storagePropertyNames[dataPropertyInfo.Name]; + var propertyValue = dataPropertyInfo.GetValue(dataModel); + pointStruct.Payload.Add(propertyName, ConvertToGrpcFieldValue(propertyValue)); + } + + // Add vectors. + if (this._hasNamedVectors) + { + var namedVectors = new NamedVectors(); + foreach (var vectorPropertyInfo in this._vectorPropertiesInfo) + { + var propertyName = this._storagePropertyNames[vectorPropertyInfo.Name]; + var propertyValue = vectorPropertyInfo.GetValue(dataModel); + if (propertyValue is not null) + { + var castPropertyValue = (ReadOnlyMemory)propertyValue; + namedVectors.Vectors.Add(propertyName, castPropertyValue.ToArray()); + } + } + + pointStruct.Vectors.Vectors_ = namedVectors; + } + else + { + // We already verified in the constructor via FindProperties that there is exactly one vector property when not using named vectors. + var vectorPropertyInfo = this._vectorPropertiesInfo.First(); + if (vectorPropertyInfo.GetValue(dataModel) is ReadOnlyMemory floatROM) + { + pointStruct.Vectors.Vector = floatROM.ToArray(); + } + else + { + throw new VectorStoreRecordMappingException($"Vector property {vectorPropertyInfo.Name} on provided record of type {typeof(TRecord).FullName} may not be null when not using named vectors."); + } + } + + return pointStruct; + } + + /// + public TRecord MapFromStorageToDataModel(PointStruct storageModel, StorageToDataModelMapperOptions options) + { + // Get the key property name and value. + var keyJsonName = this._jsonPropertyNames[this._keyPropertyInfo.Name]; + var keyPropertyValue = storageModel.Id.HasNum ? storageModel.Id.Num as object : storageModel.Id.Uuid as object; + + // Create a json object to represent the point. + var outputJsonObject = new JsonObject + { + { keyJsonName, JsonValue.Create(keyPropertyValue) }, + }; + + // Add each vector property if embeddings are included in the point. + if (options?.IncludeVectors is true) + { + foreach (var vectorProperty in this._vectorPropertiesInfo) + { + var propertyName = this._storagePropertyNames[vectorProperty.Name]; + var jsonName = this._jsonPropertyNames[vectorProperty.Name]; + + if (this._hasNamedVectors) + { + if (storageModel.Vectors.Vectors_.Vectors.TryGetValue(propertyName, out var vector)) + { + outputJsonObject.Add(jsonName, new JsonArray(vector.Data.Select(x => JsonValue.Create(x)).ToArray())); + } + } + else + { + outputJsonObject.Add(jsonName, new JsonArray(storageModel.Vectors.Vector.Data.Select(x => JsonValue.Create(x)).ToArray())); + } + } + } + + // Add each data property. + foreach (var dataProperty in this._dataPropertiesInfo) + { + var propertyName = this._storagePropertyNames[dataProperty.Name]; + var jsonName = this._jsonPropertyNames[dataProperty.Name]; + + if (storageModel.Payload.TryGetValue(propertyName, out var value)) + { + outputJsonObject.Add(jsonName, ConvertFromGrpcFieldValueToJsonNode(value)); + } + } + + // Convert from json object to the target data model. + return JsonSerializer.Deserialize(outputJsonObject)!; + } + + /// + /// Convert the given to the correct native type based on its properties. + /// + /// The value to convert to a native type. + /// The converted native value. + /// Thrown when an unsupported type is encountered. + private static JsonNode? ConvertFromGrpcFieldValueToJsonNode(Value payloadValue) + { + return payloadValue.KindCase switch + { + Value.KindOneofCase.NullValue => null, + Value.KindOneofCase.IntegerValue => JsonValue.Create(payloadValue.IntegerValue), + Value.KindOneofCase.StringValue => JsonValue.Create(payloadValue.StringValue), + Value.KindOneofCase.DoubleValue => JsonValue.Create(payloadValue.DoubleValue), + Value.KindOneofCase.BoolValue => JsonValue.Create(payloadValue.BoolValue), + Value.KindOneofCase.ListValue => new JsonArray(payloadValue.ListValue.Values.Select(x => ConvertFromGrpcFieldValueToJsonNode(x)).ToArray()), + Value.KindOneofCase.StructValue => new JsonObject(payloadValue.StructValue.Fields.ToDictionary(x => x.Key, x => ConvertFromGrpcFieldValueToJsonNode(x.Value))), + _ => throw new VectorStoreRecordMappingException($"Unsupported grpc value kind {payloadValue.KindCase}."), + }; + } + + /// + /// Convert the given to a object that can be stored in Qdrant. + /// + /// The object to convert. + /// The converted Qdrant value. + /// Thrown when an unsupported type is encountered. + private static Value ConvertToGrpcFieldValue(object? sourceValue) + { + var value = new Value(); + if (sourceValue is null) + { + value.NullValue = NullValue.NullValue; + } + else if (sourceValue is int intValue) + { + value.IntegerValue = intValue; + } + else if (sourceValue is long longValue) + { + value.IntegerValue = longValue; + } + else if (sourceValue is string stringValue) + { + value.StringValue = stringValue; + } + else if (sourceValue is float floatValue) + { + value.DoubleValue = floatValue; + } + else if (sourceValue is double doubleValue) + { + value.DoubleValue = doubleValue; + } + else if (sourceValue is bool boolValue) + { + value.BoolValue = boolValue; + } + else if (sourceValue is IEnumerable || + sourceValue is IEnumerable || + sourceValue is IEnumerable || + sourceValue is IEnumerable || + sourceValue is IEnumerable || + sourceValue is IEnumerable) + { + var listValue = sourceValue as IEnumerable; + value.ListValue = new ListValue(); + foreach (var item in listValue!) + { + value.ListValue.Values.Add(ConvertToGrpcFieldValue(item)); + } + } + else + { + throw new VectorStoreRecordMappingException($"Unsupported source value type {sourceValue?.GetType().FullName}."); + } + + return value; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/IRedisVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/IRedisVectorStoreRecordCollectionFactory.cs new file mode 100644 index 000000000000..f4eae7661b7a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/IRedisVectorStoreRecordCollectionFactory.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Data; +using StackExchange.Redis; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Interface for constructing Redis instances when using to retrieve these. +/// +public interface IRedisVectorStoreRecordCollectionFactory +{ + /// + /// Constructs a new instance of the . + /// + /// The data type of the record key. + /// The data model to use for adding, updating and retrieving data from storage. + /// The Redis database to read/write records from. + /// The name of the collection to connect to. + /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. + /// The new instance of . + IVectorStoreRecordCollection CreateVectorStoreRecordCollection(IDatabase database, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + where TKey : notnull + where TRecord : class; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..e68edb98870e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs @@ -0,0 +1,374 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Data; +using NRedisStack.RedisStackCommands; +using NRedisStack.Search; +using NRedisStack.Search.Literals.Enums; +using StackExchange.Redis; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Service for storing and retrieving vector records, that uses Redis HashSets as the underlying storage. +/// +/// The data model to use for adding, updating and retrieving data from storage. +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public sealed class RedisHashSetVectorStoreRecordCollection : IVectorStoreRecordCollection +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix + where TRecord : class +{ + /// The name of this database for telemetry purposes. + private const string DatabaseName = "Redis"; + + /// A set of types that a key on the provided model may have. + private static readonly HashSet s_supportedKeyTypes = + [ + typeof(string) + ]; + + /// A set of types that data properties on the provided model may have. + private static readonly HashSet s_supportedDataTypes = + [ + typeof(string), + typeof(int), + typeof(uint), + typeof(long), + typeof(ulong), + typeof(double), + typeof(float), + typeof(bool), + typeof(int?), + typeof(uint?), + typeof(long?), + typeof(ulong?), + typeof(double?), + typeof(float?), + typeof(bool?) + ]; + + /// A set of types that vectors on the provided model may have. + private static readonly HashSet s_supportedVectorTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?), + typeof(ReadOnlyMemory?) + ]; + + /// The Redis database to read/write records from. + private readonly IDatabase _database; + + /// The name of the collection that this will access. + private readonly string _collectionName; + + /// Optional configuration options for this class. + private readonly RedisHashSetVectorStoreRecordCollectionOptions _options; + + /// A definition of the current storage model. + private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition; + + /// An array of the names of all the data properties that are part of the Redis payload, i.e. all properties except the key and vector properties. + private readonly RedisValue[] _dataStoragePropertyNames; + + /// A dictionary that maps from a property name to the storage name that should be used when serializing it to json for data and vector properties. + private readonly Dictionary _storagePropertyNames = new(); + + /// The mapper to use when mapping between the consumer data model and the Redis record. + private readonly IVectorStoreRecordMapper _mapper; + + /// + /// Initializes a new instance of the class. + /// + /// The Redis database to read/write records from. + /// The name of the collection that this will access. + /// Optional configuration options for this class. + /// Throw when parameters are invalid. + public RedisHashSetVectorStoreRecordCollection(IDatabase database, string collectionName, RedisHashSetVectorStoreRecordCollectionOptions? options = null) + { + // Verify. + Verify.NotNull(database); + Verify.NotNullOrWhiteSpace(collectionName); + + // Assign. + this._database = database; + this._collectionName = collectionName; + this._options = options ?? new RedisHashSetVectorStoreRecordCollectionOptions(); + this._vectorStoreRecordDefinition = this._options.VectorStoreRecordDefinition ?? VectorStoreRecordPropertyReader.CreateVectorStoreRecordDefinitionFromType(typeof(TRecord), true); + + // Validate property types. + var properties = VectorStoreRecordPropertyReader.SplitDefinitionAndVerify(typeof(TRecord).Name, this._vectorStoreRecordDefinition, supportsMultipleVectors: true, requiresAtLeastOneVector: false); + VectorStoreRecordPropertyReader.VerifyPropertyTypes([properties.KeyProperty], s_supportedKeyTypes, "Key"); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(properties.DataProperties, s_supportedDataTypes, "Data"); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(properties.VectorProperties, s_supportedVectorTypes, "Vector"); + + // Lookup storage property names. + this._storagePropertyNames = VectorStoreRecordPropertyReader.BuildPropertyNameToStorageNameMap(properties); + this._dataStoragePropertyNames = properties + .DataProperties + .Select(x => this._storagePropertyNames[x.DataModelPropertyName]) + .Select(RedisValue.Unbox) + .ToArray(); + + // Assign Mapper. + if (this._options.HashEntriesCustomMapper is not null) + { + this._mapper = this._options.HashEntriesCustomMapper; + } + else + { + this._mapper = new RedisHashSetVectorStoreRecordMapper(this._vectorStoreRecordDefinition, this._storagePropertyNames); + } + } + + /// + public string CollectionName => this._collectionName; + + /// + public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + try + { + await this._database.FT().InfoAsync(this._collectionName).ConfigureAwait(false); + return true; + } + catch (RedisServerException ex) when (ex.Message.Contains("Unknown index name")) + { + return false; + } + catch (RedisConnectionException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this._collectionName, + OperationName = "FT.INFO" + }; + } + } + + /// + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + // Map the record definition to a schema. + var schema = RedisVectorStoreCollectionCreateMapping.MapToSchema(this._vectorStoreRecordDefinition.Properties, this._storagePropertyNames); + + // Create the index creation params. + // Add the collection name and colon as the index prefix, which means that any record where the key is prefixed with this text will be indexed by this index + var createParams = new FTCreateParams() + .AddPrefix($"{this._collectionName}:") + .On(IndexDataType.HASH); + + // Create the index. + return this.RunOperationAsync("FT.CREATE", () => this._database.FT().CreateAsync(this._collectionName, createParams, schema)); + } + + /// + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) + { + await this.CreateCollectionAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + return this.RunOperationAsync("FT.DROPINDEX", () => this._database.FT().DropIndexAsync(this._collectionName)); + } + + /// + public async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNullOrWhiteSpace(key); + + // Create Options + var maybePrefixedKey = this.PrefixKeyIfNeeded(key); + var includeVectors = options?.IncludeVectors ?? false; + var operationName = includeVectors ? "HGETALL" : "HMGET"; + + // Get the Redis value. + HashEntry[] retrievedHashEntries; + if (includeVectors) + { + retrievedHashEntries = await this.RunOperationAsync( + operationName, + () => this._database.HashGetAllAsync(maybePrefixedKey)).ConfigureAwait(false); + } + else + { + var fieldKeys = this._dataStoragePropertyNames; + var retrievedValues = await this.RunOperationAsync( + operationName, + () => this._database.HashGetAsync(maybePrefixedKey, fieldKeys)).ConfigureAwait(false); + retrievedHashEntries = fieldKeys.Zip(retrievedValues, (field, value) => new HashEntry(field, value)).Where(x => x.Value.HasValue).ToArray(); + } + + // Return null if we found nothing. + if (retrievedHashEntries == null || retrievedHashEntries.Length == 0) + { + return null; + } + + // Convert to the caller's data model. + return VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + operationName, + () => + { + return this._mapper.MapFromStorageToDataModel((key, retrievedHashEntries), new() { IncludeVectors = includeVectors }); + }); + } + + /// + public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + // Get records in parallel. + var tasks = keys.Select(x => this.GetAsync(x, options, cancellationToken)); + var results = await Task.WhenAll(tasks).ConfigureAwait(false); + foreach (var result in results) + { + if (result is not null) + { + yield return result; + } + } + } + + /// + public Task DeleteAsync(string key, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNullOrWhiteSpace(key); + + // Create Options + var maybePrefixedKey = this.PrefixKeyIfNeeded(key); + + // Remove. + return this.RunOperationAsync( + "DEL", + () => this._database + .KeyDeleteAsync(maybePrefixedKey)); + } + + /// + public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + // Remove records in parallel. + var tasks = keys.Select(key => this.DeleteAsync(key, options, cancellationToken)); + return Task.WhenAll(tasks); + } + + /// + public async Task UpsertAsync(TRecord record, UpsertRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNull(record); + + // Map. + var redisHashSetRecord = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + "HSET", + () => this._mapper.MapFromDataToStorageModel(record)); + + // Upsert. + var maybePrefixedKey = this.PrefixKeyIfNeeded(redisHashSetRecord.Key); + await this.RunOperationAsync( + "HSET", + () => this._database + .HashSetAsync( + maybePrefixedKey, + redisHashSetRecord.HashEntries)).ConfigureAwait(false); + + return redisHashSetRecord.Key; + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(records); + + // Upsert records in parallel. + var tasks = records.Select(x => this.UpsertAsync(x, options, cancellationToken)); + var results = await Task.WhenAll(tasks).ConfigureAwait(false); + foreach (var result in results) + { + if (result is not null) + { + yield return result; + } + } + } + + /// + /// Prefix the key with the collection name if the option is set. + /// + /// The key to prefix. + /// The updated key if updating is required, otherwise the input key. + private string PrefixKeyIfNeeded(string key) + { + if (this._options.PrefixCollectionNameToKeyNames) + { + return $"{this._collectionName}:{key}"; + } + + return key; + } + + /// + /// Run the given operation and wrap any Redis exceptions with ."/> + /// + /// The response type of the operation. + /// The type of database operation being run. + /// The operation to run. + /// The result of the operation. + private async Task RunOperationAsync(string operationName, Func> operation) + { + try + { + return await operation.Invoke().ConfigureAwait(false); + } + catch (RedisConnectionException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this._collectionName, + OperationName = operationName + }; + } + } + + /// + /// Run the given operation and wrap any Redis exceptions with ."/> + /// + /// The type of database operation being run. + /// The operation to run. + /// The result of the operation. + private async Task RunOperationAsync(string operationName, Func operation) + { + try + { + await operation.Invoke().ConfigureAwait(false); + } + catch (RedisConnectionException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this._collectionName, + OperationName = operationName + }; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..7e17859ae0c9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Data; +using StackExchange.Redis; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Options when creating a . +/// +public sealed class RedisHashSetVectorStoreRecordCollectionOptions + where TRecord : class +{ + /// + /// Gets or sets a value indicating whether the collection name should be prefixed to the + /// key names before reading or writing to the Redis store. Default is false. + /// + /// + /// For a record to be indexed by a specific Redis index, the key name must be prefixed with the matching prefix configured on the Redis index. + /// You can either pass in keys that are already prefixed, or set this option to true to have the collection name prefixed to the key names automatically. + /// + public bool PrefixCollectionNameToKeyNames { get; init; } = false; + + /// + /// Gets or sets an optional custom mapper to use when converting between the data model and the Redis record. + /// + public IVectorStoreRecordMapper? HashEntriesCustomMapper { get; init; } = null; + + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordMapper.cs new file mode 100644 index 000000000000..ef31bf09f475 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordMapper.cs @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.SemanticKernel.Data; +using StackExchange.Redis; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Class for mapping between a hashset stored in redis, and the consumer data model. +/// +/// The consumer data model to map to or from. +internal sealed class RedisHashSetVectorStoreRecordMapper : IVectorStoreRecordMapper + where TConsumerDataModel : class +{ + /// A property info object that points at the key property for the current model, allowing easy reading and writing of this property. + private readonly PropertyInfo _keyPropertyInfo; + + /// The name of the temporary json property that the key field will be serialized / parsed from. + private readonly string _keyFieldJsonPropertyName; + + /// A list of property info objects that point at the data properties in the current model, and allows easy reading and writing of these properties. + private readonly IEnumerable _dataPropertiesInfo; + + /// A list of property info objects that point at the vector properties in the current model, and allows easy reading and writing of these properties. + private readonly IEnumerable _vectorPropertiesInfo; + + /// A dictionary that maps from a property name to the configured name that should be used when storing it. + private readonly Dictionary _storagePropertyNames; + + /// A dictionary that maps from a property name to the configured name that should be used when serializing it to json for data and vector properties. + private readonly Dictionary _jsonPropertyNames = new(); + + /// + /// Initializes a new instance of the class. + /// + /// The record definition that defines the schema of the record type. + /// A dictionary that maps from a property name to the configured name that should be used when storing it. + public RedisHashSetVectorStoreRecordMapper( + VectorStoreRecordDefinition vectorStoreRecordDefinition, + Dictionary storagePropertyNames) + { + Verify.NotNull(vectorStoreRecordDefinition); + Verify.NotNull(storagePropertyNames); + + (PropertyInfo keyPropertyInfo, List dataPropertiesInfo, List vectorPropertiesInfo) = VectorStoreRecordPropertyReader.FindProperties(typeof(TConsumerDataModel), vectorStoreRecordDefinition, supportsMultipleVectors: true); + + this._keyPropertyInfo = keyPropertyInfo; + this._dataPropertiesInfo = dataPropertiesInfo; + this._vectorPropertiesInfo = vectorPropertiesInfo; + this._storagePropertyNames = storagePropertyNames; + + this._keyFieldJsonPropertyName = VectorStoreRecordPropertyReader.GetJsonPropertyName(JsonSerializerOptions.Default, keyPropertyInfo); + foreach (var property in dataPropertiesInfo.Concat(vectorPropertiesInfo)) + { + this._jsonPropertyNames[property.Name] = VectorStoreRecordPropertyReader.GetJsonPropertyName(JsonSerializerOptions.Default, property); + } + } + + /// + public (string Key, HashEntry[] HashEntries) MapFromDataToStorageModel(TConsumerDataModel dataModel) + { + var keyValue = this._keyPropertyInfo.GetValue(dataModel) as string ?? throw new VectorStoreRecordMappingException($"Missing key property {this._keyPropertyInfo.Name} on provided record of type {typeof(TConsumerDataModel).FullName}."); + + var hashEntries = new List(); + foreach (var property in this._dataPropertiesInfo) + { + var storageName = this._storagePropertyNames[property.Name]; + var value = property.GetValue(dataModel); + hashEntries.Add(new HashEntry(storageName, RedisValue.Unbox(value))); + } + + foreach (var property in this._vectorPropertiesInfo) + { + var storageName = this._storagePropertyNames[property.Name]; + var value = property.GetValue(dataModel); + if (value is not null) + { + // Convert the vector to a byte array and store it in the hash entry. + // We only support float and double vectors and we do checking in the + // collection constructor to ensure that the model has no other vector types. + if (value is ReadOnlyMemory rom) + { + hashEntries.Add(new HashEntry(storageName, ConvertVectorToBytes(rom))); + } + else if (value is ReadOnlyMemory rod) + { + hashEntries.Add(new HashEntry(storageName, ConvertVectorToBytes(rod))); + } + } + } + + return (keyValue, hashEntries.ToArray()); + } + + /// + public TConsumerDataModel MapFromStorageToDataModel((string Key, HashEntry[] HashEntries) storageModel, StorageToDataModelMapperOptions options) + { + var jsonObject = new JsonObject(); + + foreach (var property in this._dataPropertiesInfo) + { + var storageName = this._storagePropertyNames[property.Name]; + var jsonName = this._jsonPropertyNames[property.Name]; + var hashEntry = storageModel.HashEntries.FirstOrDefault(x => x.Name == storageName); + if (hashEntry.Name.HasValue) + { + var typeOrNullableType = Nullable.GetUnderlyingType(property.PropertyType) ?? property.PropertyType; + var convertedValue = Convert.ChangeType(hashEntry.Value, typeOrNullableType); + jsonObject.Add(jsonName, JsonValue.Create(convertedValue)); + } + } + + if (options.IncludeVectors) + { + foreach (var property in this._vectorPropertiesInfo) + { + var storageName = this._storagePropertyNames[property.Name]; + var jsonName = this._jsonPropertyNames[property.Name]; + + var hashEntry = storageModel.HashEntries.FirstOrDefault(x => x.Name == storageName); + if (hashEntry.Name.HasValue) + { + if (property.PropertyType == typeof(ReadOnlyMemory) || property.PropertyType == typeof(ReadOnlyMemory?)) + { + var array = MemoryMarshal.Cast((byte[])hashEntry.Value!).ToArray(); + jsonObject.Add(jsonName, JsonValue.Create(array)); + } + else if (property.PropertyType == typeof(ReadOnlyMemory) || property.PropertyType == typeof(ReadOnlyMemory?)) + { + var array = MemoryMarshal.Cast((byte[])hashEntry.Value!).ToArray(); + jsonObject.Add(jsonName, JsonValue.Create(array)); + } + else + { + throw new VectorStoreRecordMappingException($"Invalid vector type '{property.PropertyType.Name}' found on property '{property.Name}' on provided record of type '{typeof(TConsumerDataModel).FullName}'. Only float and double vectors are supported."); + } + } + } + } + + // Check that the key field is not already present in the redis value. + if (jsonObject.ContainsKey(this._keyFieldJsonPropertyName)) + { + throw new VectorStoreRecordMappingException($"Invalid data format for document with key '{storageModel.Key}'. Key property '{this._keyFieldJsonPropertyName}' is already present on retrieved object."); + } + + // Since the key is not stored in the redis value, add it back in before deserializing into the data model. + jsonObject.Add(this._keyFieldJsonPropertyName, storageModel.Key); + + return JsonSerializer.Deserialize(jsonObject)!; + } + + private static byte[] ConvertVectorToBytes(ReadOnlyMemory vector) + { + return MemoryMarshal.AsBytes(vector.Span).ToArray(); + } + + private static byte[] ConvertVectorToBytes(ReadOnlyMemory vector) + { + return MemoryMarshal.AsBytes(vector.Span).ToArray(); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..44a6bc41d195 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs @@ -0,0 +1,426 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Data; +using NRedisStack.Json.DataTypes; +using NRedisStack.RedisStackCommands; +using NRedisStack.Search; +using NRedisStack.Search.Literals.Enums; +using StackExchange.Redis; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Service for storing and retrieving vector records, that uses Redis JSON as the underlying storage. +/// +/// The data model to use for adding, updating and retrieving data from storage. +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public sealed class RedisJsonVectorStoreRecordCollection : IVectorStoreRecordCollection +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix + where TRecord : class +{ + /// The name of this database for telemetry purposes. + private const string DatabaseName = "Redis"; + + /// A set of types that a key on the provided model may have. + private static readonly HashSet s_supportedKeyTypes = + [ + typeof(string) + ]; + + /// A set of types that vectors on the provided model may have. + private static readonly HashSet s_supportedVectorTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?), + typeof(ReadOnlyMemory?) + ]; + + /// The Redis database to read/write records from. + private readonly IDatabase _database; + + /// The name of the collection that this will access. + private readonly string _collectionName; + + /// Optional configuration options for this class. + private readonly RedisJsonVectorStoreRecordCollectionOptions _options; + + /// A definition of the current storage model. + private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition; + + /// An array of the storage names of all the data properties that are part of the Redis payload, i.e. all properties except the key and vector properties. + private readonly string[] _dataStoragePropertyNames; + + /// A dictionary that maps from a property name to the storage name that should be used when serializing it to json for data and vector properties. + private readonly Dictionary _storagePropertyNames = new(); + + /// The mapper to use when mapping between the consumer data model and the Redis record. + private readonly IVectorStoreRecordMapper _mapper; + + /// The JSON serializer options to use when converting between the data model and the Redis record. + private readonly JsonSerializerOptions _jsonSerializerOptions; + + /// + /// Initializes a new instance of the class. + /// + /// The Redis database to read/write records from. + /// The name of the collection that this will access. + /// Optional configuration options for this class. + /// Throw when parameters are invalid. + public RedisJsonVectorStoreRecordCollection(IDatabase database, string collectionName, RedisJsonVectorStoreRecordCollectionOptions? options = null) + { + // Verify. + Verify.NotNull(database); + Verify.NotNullOrWhiteSpace(collectionName); + + // Assign. + this._database = database; + this._collectionName = collectionName; + this._options = options ?? new RedisJsonVectorStoreRecordCollectionOptions(); + this._jsonSerializerOptions = this._options.JsonSerializerOptions ?? JsonSerializerOptions.Default; + this._vectorStoreRecordDefinition = this._options.VectorStoreRecordDefinition ?? VectorStoreRecordPropertyReader.CreateVectorStoreRecordDefinitionFromType(typeof(TRecord), true); + + // Validate property types. + var properties = VectorStoreRecordPropertyReader.SplitDefinitionAndVerify(typeof(TRecord).Name, this._vectorStoreRecordDefinition, supportsMultipleVectors: true, requiresAtLeastOneVector: false); + VectorStoreRecordPropertyReader.VerifyPropertyTypes([properties.KeyProperty], s_supportedKeyTypes, "Key"); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(properties.VectorProperties, s_supportedVectorTypes, "Vector"); + + // Lookup json storage property names. + var keyJsonPropertyName = VectorStoreRecordPropertyReader.GetJsonPropertyName(properties.KeyProperty, typeof(TRecord), this._jsonSerializerOptions); + + // Lookup storage property names. + this._storagePropertyNames = VectorStoreRecordPropertyReader.BuildPropertyNameToJsonPropertyNameMap(properties, typeof(TRecord), this._jsonSerializerOptions); + this._dataStoragePropertyNames = properties + .DataProperties + .Select(x => this._storagePropertyNames[x.DataModelPropertyName]) + .ToArray(); + + // Assign Mapper. + if (this._options.JsonNodeCustomMapper is not null) + { + this._mapper = this._options.JsonNodeCustomMapper; + } + else + { + this._mapper = new RedisJsonVectorStoreRecordMapper(keyJsonPropertyName, this._jsonSerializerOptions); + } + } + + /// + public string CollectionName => this._collectionName; + + /// + public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + try + { + await this._database.FT().InfoAsync(this._collectionName).ConfigureAwait(false); + return true; + } + catch (RedisServerException ex) when (ex.Message.Contains("Unknown index name")) + { + return false; + } + catch (RedisConnectionException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this._collectionName, + OperationName = "FT.INFO" + }; + } + } + + /// + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + // Map the record definition to a schema. + var schema = RedisVectorStoreCollectionCreateMapping.MapToSchema(this._vectorStoreRecordDefinition.Properties, this._storagePropertyNames); + + // Create the index creation params. + // Add the collection name and colon as the index prefix, which means that any record where the key is prefixed with this text will be indexed by this index + var createParams = new FTCreateParams() + .AddPrefix($"{this._collectionName}:") + .On(IndexDataType.JSON); + + // Create the index. + return this.RunOperationAsync("FT.CREATE", () => this._database.FT().CreateAsync(this._collectionName, createParams, schema)); + } + + /// + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) + { + await this.CreateCollectionAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + return this.RunOperationAsync("FT.DROPINDEX", () => this._database.FT().DropIndexAsync(this._collectionName)); + } + + /// + public async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNullOrWhiteSpace(key); + + // Create Options + var maybePrefixedKey = this.PrefixKeyIfNeeded(key); + var includeVectors = options?.IncludeVectors ?? false; + + // Get the Redis value. + var redisResult = await this.RunOperationAsync( + "GET", + () => options?.IncludeVectors is true ? + this._database + .JSON() + .GetAsync(maybePrefixedKey) : + this._database + .JSON() + .GetAsync(maybePrefixedKey, this._dataStoragePropertyNames)).ConfigureAwait(false); + + // Check if the key was found before trying to parse the result. + if (redisResult.IsNull || redisResult is null) + { + return null; + } + + // Check if the value contained any JSON text before trying to parse the result. + var redisResultString = redisResult.ToString(); + if (redisResultString is null) + { + throw new VectorStoreRecordMappingException($"Document with key '{key}' does not contain any json."); + } + + // Convert to the caller's data model. + return VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + "GET", + () => + { + var node = JsonSerializer.Deserialize(redisResultString, this._jsonSerializerOptions)!; + return this._mapper.MapFromStorageToDataModel((key, node), new() { IncludeVectors = includeVectors }); + }); + } + + /// + public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + var keysList = keys.ToList(); + + // Create Options + var maybePrefixedKeys = keysList.Select(key => this.PrefixKeyIfNeeded(key)); + var redisKeys = maybePrefixedKeys.Select(x => new RedisKey(x)).ToArray(); + var includeVectors = options?.IncludeVectors ?? false; + + // Get the list of Redis results. + var redisResults = await this.RunOperationAsync( + "MGET", + () => this._database + .JSON() + .MGetAsync(redisKeys, "$")).ConfigureAwait(false); + + // Loop through each key and result and convert to the caller's data model. + for (int i = 0; i < keysList.Count; i++) + { + var key = keysList[i]; + var redisResult = redisResults[i]; + + // Check if the key was found before trying to parse the result. + if (redisResult.IsNull || redisResult is null) + { + continue; + } + + // Check if the value contained any JSON text before trying to parse the result. + var redisResultString = redisResult.ToString(); + if (redisResultString is null) + { + throw new VectorStoreRecordMappingException($"Document with key '{key}' does not contain any json."); + } + + // Convert to the caller's data model. + yield return VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + "MGET", + () => + { + var node = JsonSerializer.Deserialize(redisResultString, this._jsonSerializerOptions)!; + return this._mapper.MapFromStorageToDataModel((key, node), new() { IncludeVectors = includeVectors }); + }); + } + } + + /// + public Task DeleteAsync(string key, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNullOrWhiteSpace(key); + + // Create Options + var maybePrefixedKey = this.PrefixKeyIfNeeded(key); + + // Remove. + return this.RunOperationAsync( + "DEL", + () => this._database + .JSON() + .DelAsync(maybePrefixedKey)); + } + + /// + public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + // Remove records in parallel. + var tasks = keys.Select(key => this.DeleteAsync(key, options, cancellationToken)); + return Task.WhenAll(tasks); + } + + /// + public async Task UpsertAsync(TRecord record, UpsertRecordOptions? options = default, CancellationToken cancellationToken = default) + { + Verify.NotNull(record); + + // Map. + var redisJsonRecord = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + "SET", + () => + { + var mapResult = this._mapper.MapFromDataToStorageModel(record); + var serializedRecord = JsonSerializer.Serialize(mapResult.Node, this._jsonSerializerOptions); + return new { Key = mapResult.Key, SerializedRecord = serializedRecord }; + }); + + // Upsert. + var maybePrefixedKey = this.PrefixKeyIfNeeded(redisJsonRecord.Key); + await this.RunOperationAsync( + "SET", + () => this._database + .JSON() + .SetAsync( + maybePrefixedKey, + "$", + redisJsonRecord.SerializedRecord)).ConfigureAwait(false); + + return redisJsonRecord.Key; + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = default, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(records); + + // Map. + var redisRecords = new List<(string maybePrefixedKey, string originalKey, string serializedRecord)>(); + foreach (var record in records) + { + var redisJsonRecord = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this._collectionName, + "MSET", + () => + { + var mapResult = this._mapper.MapFromDataToStorageModel(record); + var serializedRecord = JsonSerializer.Serialize(mapResult.Node, this._jsonSerializerOptions); + return new { Key = mapResult.Key, SerializedRecord = serializedRecord }; + }); + + var maybePrefixedKey = this.PrefixKeyIfNeeded(redisJsonRecord.Key); + redisRecords.Add((maybePrefixedKey, redisJsonRecord.Key, redisJsonRecord.SerializedRecord)); + } + + // Upsert. + var keyPathValues = redisRecords.Select(x => new KeyPathValue(x.maybePrefixedKey, "$", x.serializedRecord)).ToArray(); + await this.RunOperationAsync( + "MSET", + () => this._database + .JSON() + .MSetAsync(keyPathValues)).ConfigureAwait(false); + + // Return keys of upserted records. + foreach (var record in redisRecords) + { + yield return record.originalKey; + } + } + + /// + /// Prefix the key with the collection name if the option is set. + /// + /// The key to prefix. + /// The updated key if updating is required, otherwise the input key. + private string PrefixKeyIfNeeded(string key) + { + if (this._options.PrefixCollectionNameToKeyNames) + { + return $"{this._collectionName}:{key}"; + } + + return key; + } + + /// + /// Run the given operation and wrap any Redis exceptions with ."/> + /// + /// The type of database operation being run. + /// The operation to run. + /// The result of the operation. + private async Task RunOperationAsync(string operationName, Func operation) + { + try + { + await operation.Invoke().ConfigureAwait(false); + } + catch (RedisException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this._collectionName, + OperationName = operationName + }; + } + } + + /// + /// Run the given operation and wrap any Redis exceptions with ."/> + /// + /// The response type of the operation. + /// The type of database operation being run. + /// The operation to run. + /// The result of the operation. + private async Task RunOperationAsync(string operationName, Func> operation) + { + try + { + return await operation.Invoke().ConfigureAwait(false); + } + catch (RedisException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this._collectionName, + OperationName = operationName + }; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..382484e9cea9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Options when creating a . +/// +public sealed class RedisJsonVectorStoreRecordCollectionOptions + where TRecord : class +{ + /// + /// Gets or sets a value indicating whether the collection name should be prefixed to the + /// key names before reading or writing to the Redis store. Default is true. + /// + /// + /// For a record to be indexed by a specific Redis index, the key name must be prefixed with the matching prefix configured on the Redis index. + /// You can either pass in keys that are already prefixed, or set this option to true to have the collection name prefixed to the key names automatically. + /// + public bool PrefixCollectionNameToKeyNames { get; init; } = true; + + /// + /// Gets or sets an optional custom mapper to use when converting between the data model and the Redis record. + /// + /// + /// If not set, the default built in mapper will be used, which uses record attrigutes or the provided to map the record. + /// + public IVectorStoreRecordMapper? JsonNodeCustomMapper { get; init; } = null; + + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; + + /// + /// Gets or sets the JSON serializer options to use when converting between the data model and the Redis record. + /// + public JsonSerializerOptions? JsonSerializerOptions { get; init; } = null; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordMapper.cs new file mode 100644 index 000000000000..3237c50c992e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordMapper.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.SemanticKernel.Data; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Class for mapping between a json node stored in redis, and the consumer data model. +/// +/// The consumer data model to map to or from. +internal sealed class RedisJsonVectorStoreRecordMapper : IVectorStoreRecordMapper + where TConsumerDataModel : class +{ + /// The name of the temporary json property that the key field will be serialized / parsed from. + private readonly string _keyFieldJsonPropertyName; + + /// The JSON serializer options to use when converting between the data model and the Redis record. + private readonly JsonSerializerOptions _jsonSerializerOptions; + + /// + /// Initializes a new instance of the class. + /// + /// The name of the key field on the model when serialized to json. + /// The JSON serializer options to use when converting between the data model and the Redis record. + public RedisJsonVectorStoreRecordMapper(string keyFieldJsonPropertyName, JsonSerializerOptions jsonSerializerOptions) + { + Verify.NotNullOrWhiteSpace(keyFieldJsonPropertyName); + Verify.NotNull(jsonSerializerOptions); + + this._keyFieldJsonPropertyName = keyFieldJsonPropertyName; + this._jsonSerializerOptions = jsonSerializerOptions; + } + + /// + public (string Key, JsonNode Node) MapFromDataToStorageModel(TConsumerDataModel dataModel) + { + // Convert the provided record into a JsonNode object and try to get the key field for it. + // Since we already checked that the key field is a string in the constructor, and that it exists on the model, + // the only edge case we have to be concerned about is if the key field is null. + var jsonNode = JsonSerializer.SerializeToNode(dataModel, this._jsonSerializerOptions); + if (jsonNode!.AsObject().TryGetPropertyValue(this._keyFieldJsonPropertyName, out var keyField) && keyField is JsonValue jsonValue) + { + // Remove the key field from the JSON object since we don't want to store it in the redis payload. + var keyValue = jsonValue.ToString(); + jsonNode.AsObject().Remove(this._keyFieldJsonPropertyName); + + return (keyValue, jsonNode); + } + + throw new VectorStoreRecordMappingException($"Missing key field {this._keyFieldJsonPropertyName} on provided record of type {typeof(TConsumerDataModel).FullName}."); + } + + /// + public TConsumerDataModel MapFromStorageToDataModel((string Key, JsonNode Node) storageModel, StorageToDataModelMapperOptions options) + { + JsonObject jsonObject; + + // The redis result can be either a single object or an array with a single object in the case where we are doing an MGET. + if (storageModel.Node is JsonObject topLevelJsonObject) + { + jsonObject = topLevelJsonObject; + } + else if (storageModel.Node is JsonArray jsonArray && jsonArray.Count == 1 && jsonArray[0] is JsonObject arrayEntryJsonObject) + { + jsonObject = arrayEntryJsonObject; + } + else + { + throw new VectorStoreRecordMappingException($"Invalid data format for document with key '{storageModel.Key}'"); + } + + // Check that the key field is not already present in the redis value. + if (jsonObject.ContainsKey(this._keyFieldJsonPropertyName)) + { + throw new VectorStoreRecordMappingException($"Invalid data format for document with key '{storageModel.Key}'. Key property '{this._keyFieldJsonPropertyName}' is already present on retrieved object."); + } + + // Since the key is not stored in the redis value, add it back in before deserializing into the data model. + jsonObject.Add(this._keyFieldJsonPropertyName, storageModel.Key); + + return JsonSerializer.Deserialize(jsonObject, this._jsonSerializerOptions)!; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisKernelBuilderExtensions.cs new file mode 100644 index 000000000000..2b20b4d87de2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisKernelBuilderExtensions.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Data; +using StackExchange.Redis; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Extension methods to register Redis instances on the . +/// +public static class RedisKernelBuilderExtensions +{ + /// + /// Register a Redis with the specified service ID and where the Redis is retrieved from the dependency injection container. + /// + /// The builder to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddRedisVectorStore(this IKernelBuilder builder, RedisVectorStoreOptions? options = default, string? serviceId = default) + { + builder.Services.AddRedisVectorStore(options, serviceId); + return builder; + } + + /// + /// Register a Redis with the specified service ID and where the Redis is constructed using the provided . + /// + /// The builder to register the on. + /// The Redis connection configuration string. If not provided, an instance will be requested from the dependency injection container. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddRedisVectorStore(this IKernelBuilder builder, string redisConnectionConfiguration, RedisVectorStoreOptions? options = default, string? serviceId = default) + { + builder.Services.AddRedisVectorStore(redisConnectionConfiguration, options, serviceId); + return builder; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisServiceCollectionExtensions.cs new file mode 100644 index 000000000000..5a55b12f8c39 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisServiceCollectionExtensions.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Data; +using StackExchange.Redis; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Extension methods to register Redis instances on an . +/// +public static class RedisServiceCollectionExtensions +{ + /// + /// Register a Redis with the specified service ID and where the Redis is retrieved from the dependency injection container. + /// + /// The to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddRedisVectorStore(this IServiceCollection services, RedisVectorStoreOptions? options = default, string? serviceId = default) + { + // If we are not constructing the ConnectionMultiplexer, add the IVectorStore as transient, since we + // cannot make assumptions about how IDatabase is being managed. + services.AddKeyedTransient( + serviceId, + (sp, obj) => + { + var database = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService(); + + return new RedisVectorStore( + database, + selectedOptions); + }); + + return services; + } + + /// + /// Register a Redis with the specified service ID and where the Redis is constructed using the provided . + /// + /// The to register the on. + /// The Redis connection configuration string. If not provided, an instance will be requested from the dependency injection container. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddRedisVectorStore(this IServiceCollection services, string redisConnectionConfiguration, RedisVectorStoreOptions? options = default, string? serviceId = default) + { + // If we are constructing the ConnectionMultiplexer, add the IVectorStore as singleton, since we are managing the lifetime + // of the ConnectionMultiplexer, and the recommendation from StackExchange.Redis is to share the ConnectionMultiplexer. + services.AddKeyedSingleton( + serviceId, + (sp, obj) => + { + var database = ConnectionMultiplexer.Connect(redisConnectionConfiguration).GetDatabase(); + var selectedOptions = options ?? sp.GetService(); + + return new RedisVectorStore( + database, + selectedOptions); + }); + + return services; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisStorageType.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisStorageType.cs new file mode 100644 index 000000000000..9360fe448998 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisStorageType.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Indicates the way in which data is stored in redis. +/// +public enum RedisStorageType +{ + /// + /// Data is stored as JSON. + /// + Json, + + /// + /// Data is stored as collections of field-value pairs. + /// + HashSet +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStore.cs new file mode 100644 index 000000000000..51a933d36062 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStore.cs @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using Microsoft.SemanticKernel.Data; +using NRedisStack.RedisStackCommands; +using StackExchange.Redis; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Class for accessing the list of collections in a Redis vector store. +/// +/// +/// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. +/// +public sealed class RedisVectorStore : IVectorStore +{ + /// The name of this database for telemetry purposes. + private const string DatabaseName = "Redis"; + + /// The redis database to read/write indices from. + private readonly IDatabase _database; + + /// Optional configuration options for this class. + private readonly RedisVectorStoreOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// The redis database to read/write indices from. + /// Optional configuration options for this class. + public RedisVectorStore(IDatabase database, RedisVectorStoreOptions? options = default) + { + Verify.NotNull(database); + + this._database = database; + this._options = options ?? new RedisVectorStoreOptions(); + } + + /// + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + where TRecord : class + { + if (typeof(TKey) != typeof(string)) + { + throw new NotSupportedException("Only string keys are supported."); + } + + if (this._options.VectorStoreCollectionFactory is not null) + { + return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(this._database, name, vectorStoreRecordDefinition); + } + + if (this._options.StorageType == RedisStorageType.HashSet) + { + var directlyCreatedStore = new RedisHashSetVectorStoreRecordCollection(this._database, name, new RedisHashSetVectorStoreRecordCollectionOptions() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; + return directlyCreatedStore!; + } + else + { + var directlyCreatedStore = new RedisJsonVectorStoreRecordCollection(this._database, name, new RedisJsonVectorStoreRecordCollectionOptions() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; + return directlyCreatedStore!; + } + } + + /// + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + const string OperationName = ""; + RedisResult[] listResult; + + try + { + listResult = await this._database.FT()._ListAsync().ConfigureAwait(false); + } + catch (RedisException ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + OperationName = OperationName + }; + } + + foreach (var item in listResult) + { + var name = item.ToString(); + if (name != null) + { + yield return name; + } + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionCreateMapping.cs new file mode 100644 index 000000000000..2bdb6a67b5ef --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionCreateMapping.cs @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using Microsoft.SemanticKernel.Data; +using NRedisStack.Search; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Contains mapping helpers to use when creating a redis vector collection. +/// +internal static class RedisVectorStoreCollectionCreateMapping +{ + /// A set of number types that are supported for filtering. + public static readonly HashSet s_supportedFilterableNumericDataTypes = + [ + typeof(short), + typeof(sbyte), + typeof(byte), + typeof(ushort), + typeof(int), + typeof(uint), + typeof(long), + typeof(ulong), + typeof(float), + typeof(double), + typeof(decimal), + + typeof(short?), + typeof(sbyte?), + typeof(byte?), + typeof(ushort?), + typeof(int?), + typeof(uint?), + typeof(long?), + typeof(ulong?), + typeof(float?), + typeof(double?), + typeof(decimal?), + ]; + + /// + /// Map from the given list of items to the Redis . + /// + /// The property definitions to map from. + /// A dictionary that maps from a property name to the storage name that should be used when serializing it to json for data and vector properties. + /// The mapped Redis . + /// Thrown if there are missing required or unsupported configuration options set. + public static Schema MapToSchema(IEnumerable properties, Dictionary storagePropertyNames) + { + var schema = new Schema(); + + // Loop through all properties and create the index fields. + foreach (var property in properties) + { + // Key property. + if (property is VectorStoreRecordKeyProperty keyProperty) + { + // Do nothing, since key is not stored as part of the payload and therefore doesn't have to be added to the index. + continue; + } + + // Data property. + if (property is VectorStoreRecordDataProperty dataProperty && (dataProperty.IsFilterable || dataProperty.IsFullTextSearchable)) + { + var storageName = storagePropertyNames[dataProperty.DataModelPropertyName]; + + if (dataProperty.IsFilterable && dataProperty.IsFullTextSearchable) + { + throw new InvalidOperationException($"Property '{dataProperty.DataModelPropertyName}' has both {nameof(VectorStoreRecordDataProperty.IsFilterable)} and {nameof(VectorStoreRecordDataProperty.IsFullTextSearchable)} set to true, and this is not supported by the Redis VectorStore."); + } + + // Add full text search field index. + if (dataProperty.IsFullTextSearchable) + { + if (dataProperty.PropertyType == typeof(string) || (typeof(IEnumerable).IsAssignableFrom(dataProperty.PropertyType) && GetEnumerableType(dataProperty.PropertyType) == typeof(string))) + { + schema.AddTextField(new FieldName($"$.{storageName}", storageName)); + } + else + { + throw new InvalidOperationException($"Property {nameof(dataProperty.IsFullTextSearchable)} on {nameof(VectorStoreRecordDataProperty)} '{dataProperty.DataModelPropertyName}' is set to true, but the property type is not a string or IEnumerable. The Redis VectorStore supports {nameof(dataProperty.IsFullTextSearchable)} on string or IEnumerable properties only."); + } + } + + // Add filter field index. + if (dataProperty.IsFilterable) + { + if (dataProperty.PropertyType == typeof(string)) + { + schema.AddTagField(new FieldName($"$.{storageName}", storageName)); + } + else if (typeof(IEnumerable).IsAssignableFrom(dataProperty.PropertyType) && GetEnumerableType(dataProperty.PropertyType) == typeof(string)) + { + schema.AddTagField(new FieldName($"$.{storageName}.*", storageName)); + } + else if (RedisVectorStoreCollectionCreateMapping.s_supportedFilterableNumericDataTypes.Contains(dataProperty.PropertyType)) + { + schema.AddNumericField(new FieldName($"$.{storageName}", storageName)); + } + else + { + throw new InvalidOperationException($"Property '{dataProperty.DataModelPropertyName}' is marked as {nameof(VectorStoreRecordDataProperty.IsFilterable)}, but the property type '{dataProperty.PropertyType}' is not supported. Only string, IEnumerable and numeric properties are supported for filtering by the Redis VectorStore."); + } + } + + continue; + } + + // Vector property. + if (property is VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty.Dimensions is not > 0) + { + throw new InvalidOperationException($"Property {nameof(vectorProperty.Dimensions)} on {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' must be set to a positive integer to create a collection."); + } + + var storageName = storagePropertyNames[vectorProperty.DataModelPropertyName]; + var indexKind = GetSDKIndexKind(vectorProperty); + var distanceAlgorithm = GetSDKDistanceAlgorithm(vectorProperty); + var dimensions = vectorProperty.Dimensions.Value.ToString(CultureInfo.InvariantCulture); + schema.AddVectorField(new FieldName($"$.{storageName}", storageName), indexKind, new Dictionary() + { + ["TYPE"] = "FLOAT32", + ["DIM"] = dimensions, + ["DISTANCE_METRIC"] = distanceAlgorithm + }); + } + } + + return schema; + } + + /// + /// Get the configured from the given . + /// If none is configured the default is . + /// + /// The vector property definition. + /// The chosen . + /// Thrown if a index type was chosen that isn't supported by Redis. + public static Schema.VectorField.VectorAlgo GetSDKIndexKind(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty.IndexKind is null) + { + return Schema.VectorField.VectorAlgo.HNSW; + } + + return vectorProperty.IndexKind switch + { + IndexKind.Hnsw => Schema.VectorField.VectorAlgo.HNSW, + IndexKind.Flat => Schema.VectorField.VectorAlgo.FLAT, + _ => throw new InvalidOperationException($"Index kind '{vectorProperty.IndexKind}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Redis VectorStore.") + }; + } + + /// + /// Get the configured distance metric from the given . + /// If none is configured, the default is cosine. + /// + /// The vector property definition. + /// The chosen distance metric. + /// Thrown if a distance function is chosen that isn't supported by Redis. + public static string GetSDKDistanceAlgorithm(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty.DistanceFunction is null) + { + return "COSINE"; + } + + return vectorProperty.DistanceFunction switch + { + DistanceFunction.CosineSimilarity => "COSINE", + DistanceFunction.DotProductSimilarity => "IP", + DistanceFunction.EuclideanDistance => "L2", + _ => throw new InvalidOperationException($"Distance function '{vectorProperty.DistanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorProperty.DataModelPropertyName}' is not supported by the Redis VectorStore.") + }; + } + + /// + /// Gets the type of object stored in the given enumerable type. + /// + /// The enumerable to get the stored type for. + /// The type of object stored in the given enumerable type. + /// Thrown when the given type is not enumerable. + private static Type GetEnumerableType(Type type) + { + if (type is IEnumerable) + { + return typeof(object); + } + else if (type.IsArray) + { + return type.GetElementType()!; + } + + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IEnumerable<>)) + { + return type.GetGenericArguments()[0]; + } + + if (type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)) is Type enumerableInterface) + { + return enumerableInterface.GetGenericArguments()[0]; + } + + throw new InvalidOperationException($"Data type '{type}' for {nameof(VectorStoreRecordDataProperty)} is not supported by the Redis VectorStore."); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreOptions.cs new file mode 100644 index 000000000000..0434b3c633ec --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreOptions.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +/// +/// Options when creating a . +/// +public sealed class RedisVectorStoreOptions +{ + /// + /// An optional factory to use for constructing instances, if custom options are required. + /// + public IRedisVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } + + /// + /// Indicates the way in which data should be stored in redis. Default is . + /// + public RedisStorageType? StorageType { get; init; } = RedisStorageType.Json; +} diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantKernelBuilderExtensionsTests.cs new file mode 100644 index 000000000000..f0b4f327c0f0 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantKernelBuilderExtensionsTests.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Qdrant; +using Microsoft.SemanticKernel.Data; +using Qdrant.Client; +using Xunit; + +namespace SemanticKernel.Connectors.Qdrant.UnitTests; + +/// +/// Tests for the class. +/// +public class QdrantKernelBuilderExtensionsTests +{ + private readonly IKernelBuilder _kernelBuilder; + + public QdrantKernelBuilderExtensionsTests() + { + this._kernelBuilder = Kernel.CreateBuilder(); + } + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange. + using var qdrantClient = new QdrantClient("localhost"); + this._kernelBuilder.Services.AddSingleton(qdrantClient); + + // Act. + this._kernelBuilder.AddQdrantVectorStore(); + + // Assert. + this.AssertVectorStoreCreated(); + } + + [Fact] + public void AddVectorStoreWithHostAndPortAndCredsRegistersClass() + { + // Act. + this._kernelBuilder.AddQdrantVectorStore("localhost", 8080, true, "apikey"); + + // Assert. + this.AssertVectorStoreCreated(); + } + + [Fact] + public void AddVectorStoreWithHostRegistersClass() + { + // Act. + this._kernelBuilder.AddQdrantVectorStore("localhost"); + + // Assert. + this.AssertVectorStoreCreated(); + } + + private void AssertVectorStoreCreated() + { + var kernel = this._kernelBuilder.Build(); + var vectorStore = kernel.Services.GetRequiredService(); + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } +} diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..056b8cfaf9d1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantServiceCollectionExtensionsTests.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Connectors.Qdrant; +using Microsoft.SemanticKernel.Data; +using Qdrant.Client; +using Xunit; + +namespace SemanticKernel.Connectors.Qdrant.UnitTests; + +/// +/// Tests for the class. +/// +public class QdrantServiceCollectionExtensionsTests +{ + private readonly IServiceCollection _serviceCollection; + + public QdrantServiceCollectionExtensionsTests() + { + this._serviceCollection = new ServiceCollection(); + } + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange. + using var qdrantClient = new QdrantClient("localhost"); + this._serviceCollection.AddSingleton(qdrantClient); + + // Act. + this._serviceCollection.AddQdrantVectorStore(); + + // Assert. + this.AssertVectorStoreCreated(); + } + + [Fact] + public void AddVectorStoreWithHostAndPortAndCredsRegistersClass() + { + // Act. + this._serviceCollection.AddQdrantVectorStore("localhost", 8080, true, "apikey"); + + // Assert. + this.AssertVectorStoreCreated(); + } + + [Fact] + public void AddVectorStoreWithHostRegistersClass() + { + // Act. + this._serviceCollection.AddQdrantVectorStore("localhost"); + + // Assert. + this.AssertVectorStoreCreated(); + } + + private void AssertVectorStoreCreated() + { + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + var vectorStore = serviceProvider.GetRequiredService(); + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } +} diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionCreateMappingTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionCreateMappingTests.cs new file mode 100644 index 000000000000..37cd1d8af53f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionCreateMappingTests.cs @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.SemanticKernel.Data; +using Qdrant.Client.Grpc; +using Xunit; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant.UnitTests; + +/// +/// Contains tests for the class. +/// +public class QdrantVectorStoreCollectionCreateMappingTests +{ + [Fact] + public void MapSingleVectorCreatesVectorParams() + { + // Arrange. + var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 4, DistanceFunction = DistanceFunction.DotProductSimilarity }; + + // Act. + var actual = QdrantVectorStoreCollectionCreateMapping.MapSingleVector(vectorProperty); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(Distance.Dot, actual.Distance); + Assert.Equal(4ul, actual.Size); + } + + [Fact] + public void MapSingleVectorDefaultsToCosine() + { + // Arrange. + var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 4 }; + + // Act. + var actual = QdrantVectorStoreCollectionCreateMapping.MapSingleVector(vectorProperty); + + // Assert. + Assert.Equal(Distance.Cosine, actual.Distance); + } + + [Fact] + public void MapSingleVectorThrowsForUnsupportedDistanceFunction() + { + // Arrange. + var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = 4, DistanceFunction = DistanceFunction.CosineDistance }; + + // Act and assert. + Assert.Throws(() => QdrantVectorStoreCollectionCreateMapping.MapSingleVector(vectorProperty)); + } + + [Theory] + [InlineData(null)] + [InlineData(0)] + public void MapSingleVectorThrowsIfDimensionsIsInvalid(int? dimensions) + { + // Arrange. + var vectorProperty = new VectorStoreRecordVectorProperty("testvector", typeof(ReadOnlyMemory)) { Dimensions = dimensions }; + + // Act and assert. + Assert.Throws(() => QdrantVectorStoreCollectionCreateMapping.MapSingleVector(vectorProperty)); + } + + [Fact] + public void MapNamedVectorsCreatesVectorParamsMap() + { + // Arrange. + var vectorProperties = new VectorStoreRecordVectorProperty[] + { + new("testvector1", typeof(ReadOnlyMemory)) { Dimensions = 10, DistanceFunction = DistanceFunction.EuclideanDistance }, + new("testvector2", typeof(ReadOnlyMemory)) { Dimensions = 20 } + }; + + var storagePropertyNames = new Dictionary + { + { "testvector1", "storage_testvector1" }, + { "testvector2", "storage_testvector2" } + }; + + // Act. + var actual = QdrantVectorStoreCollectionCreateMapping.MapNamedVectors(vectorProperties, storagePropertyNames); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(2, actual.Map.Count); + Assert.Equal(10ul, actual.Map["storage_testvector1"].Size); + Assert.Equal(Distance.Euclid, actual.Map["storage_testvector1"].Distance); + Assert.Equal(20ul, actual.Map["storage_testvector2"].Size); + Assert.Equal(Distance.Cosine, actual.Map["storage_testvector2"].Distance); + } +} diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..1889ceef5fef --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs @@ -0,0 +1,757 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Data; +using Moq; +using Qdrant.Client.Grpc; +using Xunit; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant.UnitTests; + +/// +/// Contains tests for the class. +/// +public class QdrantVectorStoreRecordCollectionTests +{ + private const string TestCollectionName = "testcollection"; + private const ulong UlongTestRecordKey1 = 1; + private const ulong UlongTestRecordKey2 = 2; + private static readonly Guid s_guidTestRecordKey1 = Guid.Parse("11111111-1111-1111-1111-111111111111"); + private static readonly Guid s_guidTestRecordKey2 = Guid.Parse("22222222-2222-2222-2222-222222222222"); + + private readonly Mock _qdrantClientMock; + + private readonly CancellationToken _testCancellationToken = new(false); + + public QdrantVectorStoreRecordCollectionTests() + { + this._qdrantClientMock = new Mock(MockBehavior.Strict); + } + + [Theory] + [InlineData(TestCollectionName, true)] + [InlineData("nonexistentcollection", false)] + public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) + { + // Arrange. + var sut = new QdrantVectorStoreRecordCollection>(this._qdrantClientMock.Object, collectionName); + + this._qdrantClientMock + .Setup(x => x.CollectionExistsAsync( + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(expectedExists); + + // Act. + var actual = await sut.CollectionExistsAsync(this._testCancellationToken); + + // Assert. + Assert.Equal(expectedExists, actual); + } + + [Fact] + public async Task CanCreateCollectionAsync() + { + // Arrange. + var sut = new QdrantVectorStoreRecordCollection>(this._qdrantClientMock.Object, TestCollectionName); + + this._qdrantClientMock + .Setup(x => x.CreateCollectionAsync( + It.IsAny(), + It.IsAny(), + this._testCancellationToken)) + .Returns(Task.CompletedTask); + + this._qdrantClientMock + .Setup(x => x.CreatePayloadIndexAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(new UpdateResult()); + + // Act. + await sut.CreateCollectionAsync(this._testCancellationToken); + + // Assert. + this._qdrantClientMock + .Verify( + x => x.CreateCollectionAsync( + TestCollectionName, + It.Is(x => x.Size == 4), + this._testCancellationToken), + Times.Once); + + this._qdrantClientMock + .Verify( + x => x.CreatePayloadIndexAsync( + TestCollectionName, + "OriginalNameData", + PayloadSchemaType.Keyword, + this._testCancellationToken), + Times.Once); + + this._qdrantClientMock + .Verify( + x => x.CreatePayloadIndexAsync( + TestCollectionName, + "OriginalNameData", + PayloadSchemaType.Text, + this._testCancellationToken), + Times.Once); + + this._qdrantClientMock + .Verify( + x => x.CreatePayloadIndexAsync( + TestCollectionName, + "data_storage_name", + PayloadSchemaType.Keyword, + this._testCancellationToken), + Times.Once); + } + + [Fact] + public async Task CanDeleteCollectionAsync() + { + // Arrange. + var sut = new QdrantVectorStoreRecordCollection>(this._qdrantClientMock.Object, TestCollectionName); + + this._qdrantClientMock + .Setup(x => x.DeleteCollectionAsync( + It.IsAny(), + null, + this._testCancellationToken)) + .Returns(Task.CompletedTask); + + // Act. + await sut.DeleteCollectionAsync(this._testCancellationToken); + + // Assert. + this._qdrantClientMock + .Verify( + x => x.DeleteCollectionAsync( + TestCollectionName, + null, + this._testCancellationToken), + Times.Once); + } + + [Theory] + [MemberData(nameof(TestOptions))] + public async Task CanGetRecordWithVectorsAsync(bool useDefinition, bool hasNamedVectors, TKey testRecordKey) + where TKey : notnull + { + var sut = this.CreateRecordCollection(useDefinition, hasNamedVectors); + + // Arrange. + var retrievedPoint = CreateRetrievedPoint(hasNamedVectors, testRecordKey); + this.SetupRetrieveMock([retrievedPoint]); + + // Act. + var actual = await sut.GetAsync( + testRecordKey, + new() { IncludeVectors = true }, + this._testCancellationToken); + + // Assert. + this._qdrantClientMock + .Verify( + x => x.RetrieveAsync( + TestCollectionName, + It.Is>(x => x.Count == 1 && (testRecordKey!.GetType() == typeof(ulong) && x[0].Num == (testRecordKey as ulong?) || testRecordKey!.GetType() == typeof(Guid) && x[0].Uuid == (testRecordKey as Guid?).ToString())), + true, + true, + null, + null, + this._testCancellationToken), + Times.Once); + + Assert.NotNull(actual); + Assert.Equal(testRecordKey, actual.Key); + Assert.Equal("data 1", actual.OriginalNameData); + Assert.Equal("data 1", actual.Data); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector!.Value.ToArray()); + } + + [Theory] + [MemberData(nameof(TestOptions))] + public async Task CanGetRecordWithoutVectorsAsync(bool useDefinition, bool hasNamedVectors, TKey testRecordKey) + where TKey : notnull + { + // Arrange. + var sut = this.CreateRecordCollection(useDefinition, hasNamedVectors); + var retrievedPoint = CreateRetrievedPoint(hasNamedVectors, testRecordKey); + this.SetupRetrieveMock([retrievedPoint]); + + // Act. + var actual = await sut.GetAsync( + testRecordKey, + new() { IncludeVectors = false }, + this._testCancellationToken); + + // Assert. + this._qdrantClientMock + .Verify( + x => x.RetrieveAsync( + TestCollectionName, + It.Is>(x => x.Count == 1 && (testRecordKey!.GetType() == typeof(ulong) && x[0].Num == (testRecordKey as ulong?) || testRecordKey!.GetType() == typeof(Guid) && x[0].Uuid == (testRecordKey as Guid?).ToString())), + true, + false, + null, + null, + this._testCancellationToken), + Times.Once); + + Assert.NotNull(actual); + Assert.Equal(testRecordKey, actual.Key); + Assert.Equal("data 1", actual.OriginalNameData); + Assert.Equal("data 1", actual.Data); + Assert.Null(actual.Vector); + } + + [Theory] + [MemberData(nameof(MultiRecordTestOptions))] + public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition, bool hasNamedVectors, TKey[] testRecordKeys) + where TKey : notnull + { + // Arrange. + var sut = this.CreateRecordCollection(useDefinition, hasNamedVectors); + var retrievedPoint1 = CreateRetrievedPoint(hasNamedVectors, UlongTestRecordKey1); + var retrievedPoint2 = CreateRetrievedPoint(hasNamedVectors, UlongTestRecordKey2); + this.SetupRetrieveMock(testRecordKeys.Select(x => CreateRetrievedPoint(hasNamedVectors, x)).ToList()); + + // Act. + var actual = await sut.GetBatchAsync( + testRecordKeys, + new() { IncludeVectors = true }, + this._testCancellationToken).ToListAsync(); + + // Assert. + this._qdrantClientMock + .Verify( + x => x.RetrieveAsync( + TestCollectionName, + It.Is>(x => + x.Count == 2 && + (testRecordKeys[0]!.GetType() == typeof(ulong) && x[0].Num == (testRecordKeys[0] as ulong?) || testRecordKeys[0]!.GetType() == typeof(Guid) && x[0].Uuid == (testRecordKeys[0] as Guid?).ToString()) && + (testRecordKeys[1]!.GetType() == typeof(ulong) && x[1].Num == (testRecordKeys[1] as ulong?) || testRecordKeys[1]!.GetType() == typeof(Guid) && x[1].Uuid == (testRecordKeys[1] as Guid?).ToString())), + true, + true, + null, + null, + this._testCancellationToken), + Times.Once); + + Assert.NotNull(actual); + Assert.Equal(2, actual.Count); + Assert.Equal(testRecordKeys[0], actual[0].Key); + Assert.Equal(testRecordKeys[1], actual[1].Key); + } + + [Fact] + public async Task CanGetRecordWithCustomMapperAsync() + { + // Arrange. + var retrievedPoint = CreateRetrievedPoint(true, UlongTestRecordKey1); + this.SetupRetrieveMock([retrievedPoint]); + + // Arrange mapper mock from PointStruct to data model. + var mapperMock = new Mock, PointStruct>>(MockBehavior.Strict); + mapperMock.Setup( + x => x.MapFromStorageToDataModel( + It.IsAny(), + It.IsAny())) + .Returns(CreateModel(UlongTestRecordKey1, true)); + + // Arrange target with custom mapper. + var sut = new QdrantVectorStoreRecordCollection>( + this._qdrantClientMock.Object, + TestCollectionName, + new() + { + HasNamedVectors = true, + PointStructCustomMapper = mapperMock.Object + }); + + // Act + var actual = await sut.GetAsync( + UlongTestRecordKey1, + new() { IncludeVectors = true }, + this._testCancellationToken); + + // Assert + Assert.NotNull(actual); + Assert.Equal(UlongTestRecordKey1, actual.Key); + Assert.Equal("data 1", actual.OriginalNameData); + Assert.Equal("data 1", actual.Data); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector!.Value.ToArray()); + + mapperMock + .Verify( + x => x.MapFromStorageToDataModel( + It.Is(x => x.Id.Num == UlongTestRecordKey1), + It.Is(x => x.IncludeVectors)), + Times.Once); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task CanDeleteUlongRecordAsync(bool useDefinition, bool hasNamedVectors) + { + // Arrange + var sut = this.CreateRecordCollection(useDefinition, hasNamedVectors); + this.SetupDeleteMocks(); + + // Act + await sut.DeleteAsync( + UlongTestRecordKey1, + cancellationToken: this._testCancellationToken); + + // Assert + this._qdrantClientMock + .Verify( + x => x.DeleteAsync( + TestCollectionName, + It.Is(x => x == UlongTestRecordKey1), + true, + null, + null, + this._testCancellationToken), + Times.Once); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task CanDeleteGuidRecordAsync(bool useDefinition, bool hasNamedVectors) + { + // Arrange + var sut = this.CreateRecordCollection(useDefinition, hasNamedVectors); + this.SetupDeleteMocks(); + + // Act + await sut.DeleteAsync( + s_guidTestRecordKey1, + cancellationToken: this._testCancellationToken); + + // Assert + this._qdrantClientMock + .Verify( + x => x.DeleteAsync( + TestCollectionName, + It.Is(x => x == s_guidTestRecordKey1), + true, + null, + null, + this._testCancellationToken), + Times.Once); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task CanDeleteManyUlongRecordsAsync(bool useDefinition, bool hasNamedVectors) + { + // Arrange + var sut = this.CreateRecordCollection(useDefinition, hasNamedVectors); + this.SetupDeleteMocks(); + + // Act + await sut.DeleteBatchAsync( + [UlongTestRecordKey1, UlongTestRecordKey2], + cancellationToken: this._testCancellationToken); + + // Assert + this._qdrantClientMock + .Verify( + x => x.DeleteAsync( + TestCollectionName, + It.Is>(x => x.Count == 2 && x.Contains(UlongTestRecordKey1) && x.Contains(UlongTestRecordKey2)), + true, + null, + null, + this._testCancellationToken), + Times.Once); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task CanDeleteManyGuidRecordsAsync(bool useDefinition, bool hasNamedVectors) + { + // Arrange + var sut = this.CreateRecordCollection(useDefinition, hasNamedVectors); + this.SetupDeleteMocks(); + + // Act + await sut.DeleteBatchAsync( + [s_guidTestRecordKey1, s_guidTestRecordKey2], + cancellationToken: this._testCancellationToken); + + // Assert + this._qdrantClientMock + .Verify( + x => x.DeleteAsync( + TestCollectionName, + It.Is>(x => x.Count == 2 && x.Contains(s_guidTestRecordKey1) && x.Contains(s_guidTestRecordKey2)), + true, + null, + null, + this._testCancellationToken), + Times.Once); + } + + [Theory] + [MemberData(nameof(TestOptions))] + public async Task CanUpsertRecordAsync(bool useDefinition, bool hasNamedVectors, TKey testRecordKey) + where TKey : notnull + { + // Arrange + var sut = this.CreateRecordCollection(useDefinition, hasNamedVectors); + this.SetupUpsertMock(); + + // Act + await sut.UpsertAsync( + CreateModel(testRecordKey, true), + cancellationToken: this._testCancellationToken); + + // Assert + this._qdrantClientMock + .Verify( + x => x.UpsertAsync( + TestCollectionName, + It.Is>(x => x.Count == 1 && (testRecordKey!.GetType() == typeof(ulong) && x[0].Id.Num == (testRecordKey as ulong?) || testRecordKey!.GetType() == typeof(Guid) && x[0].Id.Uuid == (testRecordKey as Guid?).ToString())), + true, + null, + null, + this._testCancellationToken), + Times.Once); + } + + [Theory] + [MemberData(nameof(MultiRecordTestOptions))] + public async Task CanUpsertManyRecordsAsync(bool useDefinition, bool hasNamedVectors, TKey[] testRecordKeys) + where TKey : notnull + { + // Arrange + var sut = this.CreateRecordCollection(useDefinition, hasNamedVectors); + this.SetupUpsertMock(); + + var models = testRecordKeys.Select(x => CreateModel(x, true)); + + // Act + var actual = await sut.UpsertBatchAsync( + models, + cancellationToken: this._testCancellationToken).ToListAsync(); + + // Assert + Assert.NotNull(actual); + Assert.Equal(2, actual.Count); + Assert.Equal(testRecordKeys[0], actual[0]); + Assert.Equal(testRecordKeys[1], actual[1]); + + this._qdrantClientMock + .Verify( + x => x.UpsertAsync( + TestCollectionName, + It.Is>(x => + x.Count == 2 && + (testRecordKeys[0]!.GetType() == typeof(ulong) && x[0].Id.Num == (testRecordKeys[0] as ulong?) || testRecordKeys[0]!.GetType() == typeof(Guid) && x[0].Id.Uuid == (testRecordKeys[0] as Guid?).ToString()) && + (testRecordKeys[1]!.GetType() == typeof(ulong) && x[1].Id.Num == (testRecordKeys[1] as ulong?) || testRecordKeys[1]!.GetType() == typeof(Guid) && x[1].Id.Uuid == (testRecordKeys[1] as Guid?).ToString())), + true, + null, + null, + this._testCancellationToken), + Times.Once); + } + + [Fact] + public async Task CanUpsertRecordWithCustomMapperAsync() + { + // Arrange. + this.SetupUpsertMock(); + var pointStruct = new PointStruct + { + Id = new() { Num = UlongTestRecordKey1 }, + Payload = { ["OriginalNameData"] = "data 1", ["data_storage_name"] = "data 1" }, + Vectors = new[] { 1f, 2f, 3f, 4f } + }; + + // Arrange mapper mock from data model to PointStruct. + var mapperMock = new Mock, PointStruct>>(MockBehavior.Strict); + mapperMock + .Setup(x => x.MapFromDataToStorageModel(It.IsAny>())) + .Returns(pointStruct); + + // Arrange target with custom mapper. + var sut = new QdrantVectorStoreRecordCollection>( + this._qdrantClientMock.Object, + TestCollectionName, + new() + { + HasNamedVectors = false, + PointStructCustomMapper = mapperMock.Object + }); + + var model = CreateModel(UlongTestRecordKey1, true); + + // Act + await sut.UpsertAsync( + model, + null, + this._testCancellationToken); + + // Assert + mapperMock + .Verify( + x => x.MapFromDataToStorageModel(It.Is>(x => x == model)), + Times.Once); + } + + /// + /// Tests that the collection can be created even if the definition and the type do not match. + /// In this case, the expectation is that a custom mapper will be provided to map between the + /// schema as defined by the definition and the different data model. + /// + [Fact] + public void CanCreateCollectionWithMismatchedDefinitionAndType() + { + // Arrange. + var definition = new VectorStoreRecordDefinition() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Id", typeof(ulong)), + new VectorStoreRecordDataProperty("Text", typeof(string)), + new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { Dimensions = 4 }, + } + }; + + // Act. + var sut = new QdrantVectorStoreRecordCollection>( + this._qdrantClientMock.Object, + TestCollectionName, + new() { VectorStoreRecordDefinition = definition, PointStructCustomMapper = Mock.Of, PointStruct>>() }); + } + + private void SetupRetrieveMock(List retrievedPoints) + { + this._qdrantClientMock + .Setup(x => x.RetrieveAsync( + It.IsAny(), + It.IsAny>(), + It.IsAny(), // With Payload + It.IsAny(), // With Vectors + It.IsAny(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(retrievedPoints); + } + + private void SetupDeleteMocks() + { + this._qdrantClientMock + .Setup(x => x.DeleteAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), // wait + It.IsAny(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(new UpdateResult()); + + this._qdrantClientMock + .Setup(x => x.DeleteAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), // wait + It.IsAny(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(new UpdateResult()); + + this._qdrantClientMock + .Setup(x => x.DeleteAsync( + It.IsAny(), + It.IsAny>(), + It.IsAny(), // wait + It.IsAny(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(new UpdateResult()); + + this._qdrantClientMock + .Setup(x => x.DeleteAsync( + It.IsAny(), + It.IsAny>(), + It.IsAny(), // wait + It.IsAny(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(new UpdateResult()); + } + + private void SetupUpsertMock() + { + this._qdrantClientMock + .Setup(x => x.UpsertAsync( + It.IsAny(), + It.IsAny>(), + It.IsAny(), // wait + It.IsAny(), + It.IsAny(), + this._testCancellationToken)) + .ReturnsAsync(new UpdateResult()); + } + + private static RetrievedPoint CreateRetrievedPoint(bool hasNamedVectors, TKey recordKey) + { + RetrievedPoint point; + if (hasNamedVectors) + { + var namedVectors = new NamedVectors(); + namedVectors.Vectors.Add("vector_storage_name", new[] { 1f, 2f, 3f, 4f }); + point = new RetrievedPoint() + { + Payload = { ["OriginalNameData"] = "data 1", ["data_storage_name"] = "data 1" }, + Vectors = new Vectors { Vectors_ = namedVectors } + }; + } + else + { + point = new RetrievedPoint() + { + Payload = { ["OriginalNameData"] = "data 1", ["data_storage_name"] = "data 1" }, + Vectors = new[] { 1f, 2f, 3f, 4f } + }; + } + + if (recordKey is ulong ulongKey) + { + point.Id = ulongKey; + } + + if (recordKey is Guid guidKey) + { + point.Id = guidKey; + } + + return point; + } + + private IVectorStoreRecordCollection> CreateRecordCollection(bool useDefinition, bool hasNamedVectors) + where T : notnull + { + var store = new QdrantVectorStoreRecordCollection>( + this._qdrantClientMock.Object, + TestCollectionName, + new() + { + VectorStoreRecordDefinition = useDefinition ? CreateSinglePropsDefinition(typeof(T)) : null, + HasNamedVectors = hasNamedVectors + }) as IVectorStoreRecordCollection>; + return store!; + } + + private static SinglePropsModel CreateModel(T key, bool withVectors) + { + return new SinglePropsModel + { + Key = key, + OriginalNameData = "data 1", + Data = "data 1", + Vector = withVectors ? new float[] { 1, 2, 3, 4 } : null, + NotAnnotated = null, + }; + } + + private static VectorStoreRecordDefinition CreateSinglePropsDefinition(Type keyType) + { + return new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", keyType), + new VectorStoreRecordDataProperty("OriginalNameData", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("Data", typeof(string)) { IsFilterable = true, StoragePropertyName = "data_storage_name" }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { StoragePropertyName = "vector_storage_name" } + ] + }; + } + + public sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public required T Key { get; set; } + + [VectorStoreRecordData(IsFilterable = true, IsFullTextSearchable = true)] + public string OriginalNameData { get; set; } = string.Empty; + + [JsonPropertyName("ignored_data_json_name")] + [VectorStoreRecordData(IsFilterable = true, StoragePropertyName = "data_storage_name")] + public string Data { get; set; } = string.Empty; + + [JsonPropertyName("ignored_vector_json_name")] + [VectorStoreRecordVector(4, StoragePropertyName = "vector_storage_name")] + public ReadOnlyMemory? Vector { get; set; } + + public string? NotAnnotated { get; set; } + } + + public static IEnumerable TestOptions + => GenerateAllCombinations(new object[][] { + new object[] { true, false }, + new object[] { true, false }, + new object[] { UlongTestRecordKey1, s_guidTestRecordKey1 } + }); + + public static IEnumerable MultiRecordTestOptions + => GenerateAllCombinations(new object[][] { + new object[] { true, false }, + new object[] { true, false }, + new object[] { new ulong[] { UlongTestRecordKey1, UlongTestRecordKey2 }, new Guid[] { s_guidTestRecordKey1, s_guidTestRecordKey2 } } + }); + + private static object[][] GenerateAllCombinations(object[][] input) + { + var counterArray = Enumerable.Range(0, input.Length).Select(x => 0).ToArray(); + + // Add each item from the first option set as a separate row. + object[][] currentCombinations = input[0].Select(x => new object[1] { x }).ToArray(); + + // Loop through each additional option set. + for (int currentOptionSetIndex = 1; currentOptionSetIndex < input.Length; currentOptionSetIndex++) + { + var iterationCombinations = new List(); + var currentOptionSet = input[currentOptionSetIndex]; + + // Loop through each row we have already. + foreach (var currentCombination in currentCombinations) + { + // Add each of the values from the new options set to the current row to generate a new row. + for (var currentColumnRow = 0; currentColumnRow < currentOptionSet.Length; currentColumnRow++) + { + iterationCombinations.Add(currentCombination.Append(currentOptionSet[currentColumnRow]).ToArray()); + } + } + + currentCombinations = iterationCombinations.ToArray(); + } + + return currentCombinations; + } +} diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordMapperTests.cs new file mode 100644 index 000000000000..68ff1d46a86b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordMapperTests.cs @@ -0,0 +1,440 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json.Serialization; +using Microsoft.SemanticKernel.Connectors.Qdrant; +using Microsoft.SemanticKernel.Data; +using Qdrant.Client.Grpc; +using Xunit; + +namespace SemanticKernel.Connectors.Qdrant.UnitTests; + +/// +/// Contains tests for the class. +/// +public class QdrantVectorStoreRecordMapperTests +{ + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapsSinglePropsFromDataToStorageModelWithUlong(bool hasNamedVectors) + { + // Arrange. + var definition = CreateSinglePropsVectorStoreRecordDefinition(typeof(ulong)); + var sut = new QdrantVectorStoreRecordMapper>(definition, hasNamedVectors, s_singlePropsModelStorageNamesMap); + + // Act. + var actual = sut.MapFromDataToStorageModel(CreateSinglePropsModel(5ul)); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(5ul, actual.Id.Num); + Assert.Single(actual.Payload); + Assert.Equal("data value", actual.Payload["data"].StringValue); + + if (hasNamedVectors) + { + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vectors.Vectors_.Vectors["vector"].Data.ToArray()); + } + else + { + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vectors.Vector.Data.ToArray()); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapsSinglePropsFromDataToStorageModelWithGuid(bool hasNamedVectors) + { + // Arrange. + var definition = CreateSinglePropsVectorStoreRecordDefinition(typeof(Guid)); + var sut = new QdrantVectorStoreRecordMapper>(definition, hasNamedVectors, s_singlePropsModelStorageNamesMap); + + // Act. + var actual = sut.MapFromDataToStorageModel(CreateSinglePropsModel(Guid.Parse("11111111-1111-1111-1111-111111111111"))); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(Guid.Parse("11111111-1111-1111-1111-111111111111"), Guid.Parse(actual.Id.Uuid)); + Assert.Single(actual.Payload); + Assert.Equal("data value", actual.Payload["data"].StringValue); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public void MapsSinglePropsFromStorageToDataModelWithUlong(bool hasNamedVectors, bool includeVectors) + { + // Arrange. + var definition = CreateSinglePropsVectorStoreRecordDefinition(typeof(ulong)); + var sut = new QdrantVectorStoreRecordMapper>(definition, hasNamedVectors, s_singlePropsModelStorageNamesMap); + + // Act. + var actual = sut.MapFromStorageToDataModel(CreateSinglePropsPointStruct(5, hasNamedVectors), new() { IncludeVectors = includeVectors }); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(5ul, actual.Key); + Assert.Equal("data value", actual.Data); + + if (includeVectors) + { + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector!.Value.ToArray()); + } + else + { + Assert.Null(actual.Vector); + } + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public void MapsSinglePropsFromStorageToDataModelWithGuid(bool hasNamedVectors, bool includeVectors) + { + // Arrange. + var definition = CreateSinglePropsVectorStoreRecordDefinition(typeof(Guid)); + var sut = new QdrantVectorStoreRecordMapper>(definition, hasNamedVectors, s_singlePropsModelStorageNamesMap); + + // Act. + var actual = sut.MapFromStorageToDataModel(CreateSinglePropsPointStruct(Guid.Parse("11111111-1111-1111-1111-111111111111"), hasNamedVectors), new() { IncludeVectors = includeVectors }); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(Guid.Parse("11111111-1111-1111-1111-111111111111"), actual.Key); + Assert.Equal("data value", actual.Data); + + if (includeVectors) + { + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector!.Value.ToArray()); + } + else + { + Assert.Null(actual.Vector); + } + } + + [Fact] + public void MapsMultiPropsFromDataToStorageModelWithUlong() + { + // Arrange. + var definition = CreateMultiPropsVectorStoreRecordDefinition(typeof(ulong)); + var sut = new QdrantVectorStoreRecordMapper>(definition, true, s_multiPropsModelStorageNamesMap); + + // Act. + var actual = sut.MapFromDataToStorageModel(CreateMultiPropsModel(5ul)); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(5ul, actual.Id.Num); + Assert.Equal(7, actual.Payload.Count); + Assert.Equal("data 1", actual.Payload["dataString"].StringValue); + Assert.Equal(5, actual.Payload["dataInt"].IntegerValue); + Assert.Equal(5, actual.Payload["dataLong"].IntegerValue); + Assert.Equal(5.5f, actual.Payload["dataFloat"].DoubleValue); + Assert.Equal(5.5d, actual.Payload["dataDouble"].DoubleValue); + Assert.True(actual.Payload["dataBool"].BoolValue); + Assert.Equal(new int[] { 1, 2, 3, 4 }, actual.Payload["dataArrayInt"].ListValue.Values.Select(x => (int)x.IntegerValue).ToArray()); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vectors.Vectors_.Vectors["vector1"].Data.ToArray()); + Assert.Equal(new float[] { 5, 6, 7, 8 }, actual.Vectors.Vectors_.Vectors["vector2"].Data.ToArray()); + } + + [Fact] + public void MapsMultiPropsFromDataToStorageModelWithGuid() + { + // Arrange. + var definition = CreateMultiPropsVectorStoreRecordDefinition(typeof(Guid)); + var sut = new QdrantVectorStoreRecordMapper>(definition, true, s_multiPropsModelStorageNamesMap); + + // Act. + var actual = sut.MapFromDataToStorageModel(CreateMultiPropsModel(Guid.Parse("11111111-1111-1111-1111-111111111111"))); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(Guid.Parse("11111111-1111-1111-1111-111111111111"), Guid.Parse(actual.Id.Uuid)); + Assert.Equal(7, actual.Payload.Count); + Assert.Equal("data 1", actual.Payload["dataString"].StringValue); + Assert.Equal(5, actual.Payload["dataInt"].IntegerValue); + Assert.Equal(5, actual.Payload["dataLong"].IntegerValue); + Assert.Equal(5.5f, actual.Payload["dataFloat"].DoubleValue); + Assert.Equal(5.5d, actual.Payload["dataDouble"].DoubleValue); + Assert.True(actual.Payload["dataBool"].BoolValue); + Assert.Equal(new int[] { 1, 2, 3, 4 }, actual.Payload["dataArrayInt"].ListValue.Values.Select(x => (int)x.IntegerValue).ToArray()); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vectors.Vectors_.Vectors["vector1"].Data.ToArray()); + Assert.Equal(new float[] { 5, 6, 7, 8 }, actual.Vectors.Vectors_.Vectors["vector2"].Data.ToArray()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapsMultiPropsFromStorageToDataModelWithUlong(bool includeVectors) + { + // Arrange. + var definition = CreateMultiPropsVectorStoreRecordDefinition(typeof(ulong)); + var sut = new QdrantVectorStoreRecordMapper>(definition, true, s_multiPropsModelStorageNamesMap); + + // Act. + var actual = sut.MapFromStorageToDataModel(CreateMultiPropsPointStruct(5), new() { IncludeVectors = includeVectors }); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(5ul, actual.Key); + Assert.Equal("data 1", actual.DataString); + Assert.Equal(5, actual.DataInt); + Assert.Equal(5L, actual.DataLong); + Assert.Equal(5.5f, actual.DataFloat); + Assert.Equal(5.5d, actual.DataDouble); + Assert.True(actual.DataBool); + Assert.Equal(new int[] { 1, 2, 3, 4 }, actual.DataArrayInt); + + if (includeVectors) + { + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector1!.Value.ToArray()); + Assert.Equal(new float[] { 5, 6, 7, 8 }, actual.Vector2!.Value.ToArray()); + } + else + { + Assert.Null(actual.Vector1); + Assert.Null(actual.Vector2); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapsMultiPropsFromStorageToDataModelWithGuid(bool includeVectors) + { + // Arrange. + var definition = CreateMultiPropsVectorStoreRecordDefinition(typeof(Guid)); + var sut = new QdrantVectorStoreRecordMapper>(definition, true, s_multiPropsModelStorageNamesMap); + + // Act. + var actual = sut.MapFromStorageToDataModel(CreateMultiPropsPointStruct(Guid.Parse("11111111-1111-1111-1111-111111111111")), new() { IncludeVectors = includeVectors }); + + // Assert. + Assert.NotNull(actual); + Assert.Equal(Guid.Parse("11111111-1111-1111-1111-111111111111"), actual.Key); + Assert.Equal("data 1", actual.DataString); + Assert.Equal(5, actual.DataInt); + Assert.Equal(5L, actual.DataLong); + Assert.Equal(5.5f, actual.DataFloat); + Assert.Equal(5.5d, actual.DataDouble); + Assert.True(actual.DataBool); + Assert.Equal(new int[] { 1, 2, 3, 4 }, actual.DataArrayInt); + + if (includeVectors) + { + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector1!.Value.ToArray()); + Assert.Equal(new float[] { 5, 6, 7, 8 }, actual.Vector2!.Value.ToArray()); + } + else + { + Assert.Null(actual.Vector1); + Assert.Null(actual.Vector2); + } + } + + private static SinglePropsModel CreateSinglePropsModel(TKey key) + { + return new SinglePropsModel + { + Key = key, + Data = "data value", + Vector = new float[] { 1, 2, 3, 4 }, + NotAnnotated = "notAnnotated", + }; + } + + private static MultiPropsModel CreateMultiPropsModel(TKey key) + { + return new MultiPropsModel + { + Key = key, + DataString = "data 1", + DataInt = 5, + DataLong = 5L, + DataFloat = 5.5f, + DataDouble = 5.5d, + DataBool = true, + DataArrayInt = new List { 1, 2, 3, 4 }, + Vector1 = new float[] { 1, 2, 3, 4 }, + Vector2 = new float[] { 5, 6, 7, 8 }, + NotAnnotated = "notAnnotated", + }; + } + + private static PointStruct CreateSinglePropsPointStruct(ulong id, bool hasNamedVectors) + { + var pointStruct = new PointStruct(); + pointStruct.Id = new PointId() { Num = id }; + AddDataToSinglePropsPointStruct(pointStruct, hasNamedVectors); + return pointStruct; + } + + private static PointStruct CreateSinglePropsPointStruct(Guid id, bool hasNamedVectors) + { + var pointStruct = new PointStruct(); + pointStruct.Id = new PointId() { Uuid = id.ToString() }; + AddDataToSinglePropsPointStruct(pointStruct, hasNamedVectors); + return pointStruct; + } + + private static void AddDataToSinglePropsPointStruct(PointStruct pointStruct, bool hasNamedVectors) + { + pointStruct.Payload.Add("data", "data value"); + + if (hasNamedVectors) + { + var namedVectors = new NamedVectors(); + namedVectors.Vectors.Add("vector", new[] { 1f, 2f, 3f, 4f }); + pointStruct.Vectors = new Vectors() { Vectors_ = namedVectors }; + } + else + { + pointStruct.Vectors = new[] { 1f, 2f, 3f, 4f }; + } + } + + private static PointStruct CreateMultiPropsPointStruct(ulong id) + { + var pointStruct = new PointStruct(); + pointStruct.Id = new PointId() { Num = id }; + AddDataToMultiPropsPointStruct(pointStruct); + return pointStruct; + } + + private static PointStruct CreateMultiPropsPointStruct(Guid id) + { + var pointStruct = new PointStruct(); + pointStruct.Id = new PointId() { Uuid = id.ToString() }; + AddDataToMultiPropsPointStruct(pointStruct); + return pointStruct; + } + + private static void AddDataToMultiPropsPointStruct(PointStruct pointStruct) + { + pointStruct.Payload.Add("dataString", "data 1"); + pointStruct.Payload.Add("dataInt", 5); + pointStruct.Payload.Add("dataLong", 5L); + pointStruct.Payload.Add("dataFloat", 5.5f); + pointStruct.Payload.Add("dataDouble", 5.5d); + pointStruct.Payload.Add("dataBool", true); + + var dataIntArray = new ListValue(); + dataIntArray.Values.Add(1); + dataIntArray.Values.Add(2); + dataIntArray.Values.Add(3); + dataIntArray.Values.Add(4); + pointStruct.Payload.Add("dataArrayInt", new Value { ListValue = dataIntArray }); + + var namedVectors = new NamedVectors(); + namedVectors.Vectors.Add("vector1", new[] { 1f, 2f, 3f, 4f }); + namedVectors.Vectors.Add("vector2", new[] { 5f, 6f, 7f, 8f }); + pointStruct.Vectors = new Vectors() { Vectors_ = namedVectors }; + } + + private static readonly Dictionary s_singlePropsModelStorageNamesMap = new() + { + { "Key", "key" }, + { "Data", "data" }, + { "Vector", "vector" }, + }; + + private static VectorStoreRecordDefinition CreateSinglePropsVectorStoreRecordDefinition(Type keyType) => new() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", keyType), + new VectorStoreRecordDataProperty("Data", typeof(string)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)), + }, + }; + + private sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } = default; + + [VectorStoreRecordData] + public string Data { get; set; } = string.Empty; + + [VectorStoreRecordVector] + public ReadOnlyMemory? Vector { get; set; } + + public string NotAnnotated { get; set; } = string.Empty; + } + + private static readonly Dictionary s_multiPropsModelStorageNamesMap = new() + { + { "Key", "key" }, + { "DataString", "dataString" }, + { "DataInt", "dataInt" }, + { "DataLong", "dataLong" }, + { "DataFloat", "dataFloat" }, + { "DataDouble", "dataDouble" }, + { "DataBool", "dataBool" }, + { "DataArrayInt", "dataArrayInt" }, + { "Vector1", "vector1" }, + { "Vector2", "vector2" }, + }; + + private static VectorStoreRecordDefinition CreateMultiPropsVectorStoreRecordDefinition(Type keyType) => new() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", keyType), + new VectorStoreRecordDataProperty("DataString", typeof(string)), + new VectorStoreRecordDataProperty("DataInt", typeof(int)), + new VectorStoreRecordDataProperty("DataLong", typeof(long)), + new VectorStoreRecordDataProperty("DataFloat", typeof(float)), + new VectorStoreRecordDataProperty("DataDouble", typeof(double)), + new VectorStoreRecordDataProperty("DataBool", typeof(bool)), + new VectorStoreRecordDataProperty("DataArrayInt", typeof(List)), + new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory)), + new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory)), + }, + }; + + private sealed class MultiPropsModel + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } = default; + + [VectorStoreRecordData] + public string DataString { get; set; } = string.Empty; + + [JsonPropertyName("data_int_json")] + [VectorStoreRecordData] + public int DataInt { get; set; } = 0; + + [VectorStoreRecordData] + public long DataLong { get; set; } = 0; + + [VectorStoreRecordData] + public float DataFloat { get; set; } = 0; + + [VectorStoreRecordData] + public double DataDouble { get; set; } = 0; + + [VectorStoreRecordData] + public bool DataBool { get; set; } = false; + + [VectorStoreRecordData] + public List? DataArrayInt { get; set; } + + [VectorStoreRecordVector] + public ReadOnlyMemory? Vector1 { get; set; } + + [VectorStoreRecordVector] + public ReadOnlyMemory? Vector2 { get; set; } + + public string NotAnnotated { get; set; } = string.Empty; + } +} diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreTests.cs new file mode 100644 index 000000000000..2a234f08442a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreTests.cs @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Data; +using Moq; +using Qdrant.Client; +using Xunit; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant.UnitTests; + +/// +/// Contains tests for the class. +/// +public class QdrantVectorStoreTests +{ + private const string TestCollectionName = "testcollection"; + + private readonly Mock _qdrantClientMock; + + private readonly CancellationToken _testCancellationToken = new(false); + + public QdrantVectorStoreTests() + { + this._qdrantClientMock = new Mock(MockBehavior.Strict); + } + + [Fact] + public void GetCollectionReturnsCollection() + { + // Arrange. + var sut = new QdrantVectorStore(this._qdrantClientMock.Object); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.NotNull(actual); + Assert.IsType>>(actual); + } + + [Fact] + public void GetCollectionCallsFactoryIfProvided() + { + // Arrange. + var factoryMock = new Mock(MockBehavior.Strict); + var collectionMock = new Mock>>(MockBehavior.Strict); + factoryMock + .Setup(x => x.CreateVectorStoreRecordCollection>(It.IsAny(), TestCollectionName, null)) + .Returns(collectionMock.Object); + var sut = new QdrantVectorStore(this._qdrantClientMock.Object, new() { VectorStoreCollectionFactory = factoryMock.Object }); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.Equal(collectionMock.Object, actual); + factoryMock.Verify(x => x.CreateVectorStoreRecordCollection>(It.IsAny(), TestCollectionName, null), Times.Once); + } + + [Fact] + public void GetCollectionThrowsForInvalidKeyType() + { + // Arrange. + var sut = new QdrantVectorStore(this._qdrantClientMock.Object); + + // Act & Assert. + Assert.Throws(() => sut.GetCollection>(TestCollectionName)); + } + + [Fact] + public async Task ListCollectionNamesCallsSDKAsync() + { + // Arrange. + this._qdrantClientMock + .Setup(x => x.ListCollectionsAsync(It.IsAny())) + .ReturnsAsync(new[] { "collection1", "collection2" }); + var sut = new QdrantVectorStore(this._qdrantClientMock.Object); + + // Act. + var collectionNames = sut.ListCollectionNamesAsync(this._testCancellationToken); + + // Assert. + var collectionNamesList = await collectionNames.ToListAsync(); + Assert.Equal(new[] { "collection1", "collection2" }, collectionNamesList); + } + + public sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public required TKey Key { get; set; } + + [VectorStoreRecordData] + public string Data { get; set; } = string.Empty; + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? Vector { get; set; } + + public string? NotAnnotated { get; set; } + } +} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..a95179e86346 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs @@ -0,0 +1,534 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text.Json.Serialization; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Data; +using Moq; +using NRedisStack; +using StackExchange.Redis; +using Xunit; + +namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; + +/// +/// Contains tests for the class. +/// +public class RedisHashSetVectorStoreRecordCollectionTests +{ + private const string TestCollectionName = "testcollection"; + private const string TestRecordKey1 = "testid1"; + private const string TestRecordKey2 = "testid2"; + + private readonly Mock _redisDatabaseMock; + + public RedisHashSetVectorStoreRecordCollectionTests() + { + this._redisDatabaseMock = new Mock(MockBehavior.Strict); + + var batchMock = new Mock(); + this._redisDatabaseMock.Setup(x => x.CreateBatch(It.IsAny())).Returns(batchMock.Object); + } + + [Theory] + [InlineData(TestCollectionName, true)] + [InlineData("nonexistentcollection", false)] + public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) + { + // Arrange + if (expectedExists) + { + SetupExecuteMock(this._redisDatabaseMock, ["index_name", collectionName]); + } + else + { + SetupExecuteMock(this._redisDatabaseMock, new RedisServerException("Unknown index name")); + } + var sut = new RedisHashSetVectorStoreRecordCollection( + this._redisDatabaseMock.Object, + collectionName); + + // Act + var actual = await sut.CollectionExistsAsync(); + + // Assert + var expectedArgs = new object[] { collectionName }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "FT.INFO", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + Assert.Equal(expectedExists, actual); + } + + [Fact] + public async Task CanCreateCollectionAsync() + { + // Arrange. + SetupExecuteMock(this._redisDatabaseMock, string.Empty); + var sut = new RedisHashSetVectorStoreRecordCollection(this._redisDatabaseMock.Object, TestCollectionName); + + // Act. + await sut.CreateCollectionAsync(); + + // Assert. + var expectedArgs = new object[] { + "testcollection", + "PREFIX", + 1, + "testcollection:", + "SCHEMA", + "$.OriginalNameData", + "AS", + "OriginalNameData", + "TAG", + "$.data_storage_name", + "AS", + "data_storage_name", + "TAG", + "$.vector_storage_name", + "AS", + "vector_storage_name", + "VECTOR", + "HNSW", + 6, + "TYPE", + "FLOAT32", + "DIM", + "4", + "DISTANCE_METRIC", + "COSINE" }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "FT.CREATE", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + } + + [Fact] + public async Task CanDeleteCollectionAsync() + { + // Arrange + SetupExecuteMock(this._redisDatabaseMock, string.Empty); + var sut = this.CreateRecordCollection(false); + + // Act + await sut.DeleteCollectionAsync(); + + // Assert + var expectedArgs = new object[] { TestCollectionName }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "FT.DROPINDEX", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanGetRecordWithVectorsAsync(bool useDefinition) + { + // Arrange + var hashEntries = new HashEntry[] + { + new("OriginalNameData", "data 1"), + new("data_storage_name", "data 1"), + new("vector_storage_name", MemoryMarshal.AsBytes(new ReadOnlySpan(new float[] { 1, 2, 3, 4 })).ToArray()) + }; + this._redisDatabaseMock.Setup(x => x.HashGetAllAsync(It.IsAny(), CommandFlags.None)).ReturnsAsync(hashEntries); + var sut = this.CreateRecordCollection(useDefinition); + + // Act + var actual = await sut.GetAsync( + TestRecordKey1, + new() { IncludeVectors = true }); + + // Assert + this._redisDatabaseMock.Verify(x => x.HashGetAllAsync(TestRecordKey1, CommandFlags.None), Times.Once); + + Assert.NotNull(actual); + Assert.Equal(TestRecordKey1, actual.Key); + Assert.Equal("data 1", actual.OriginalNameData); + Assert.Equal("data 1", actual.Data); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector!.Value.ToArray()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanGetRecordWithoutVectorsAsync(bool useDefinition) + { + // Arrange + var redisValues = new RedisValue[] { new("data 1"), new("data 1") }; + this._redisDatabaseMock.Setup(x => x.HashGetAsync(It.IsAny(), It.IsAny(), CommandFlags.None)).ReturnsAsync(redisValues); + var sut = this.CreateRecordCollection(useDefinition); + + // Act + var actual = await sut.GetAsync( + TestRecordKey1, + new() { IncludeVectors = false }); + + // Assert + var fieldNames = new RedisValue[] { "OriginalNameData", "data_storage_name" }; + this._redisDatabaseMock.Verify(x => x.HashGetAsync(TestRecordKey1, fieldNames, CommandFlags.None), Times.Once); + + Assert.NotNull(actual); + Assert.Equal(TestRecordKey1, actual.Key); + Assert.Equal("data 1", actual.OriginalNameData); + Assert.Equal("data 1", actual.Data); + Assert.False(actual.Vector.HasValue); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition) + { + // Arrange + var hashEntries1 = new HashEntry[] + { + new("OriginalNameData", "data 1"), + new("data_storage_name", "data 1"), + new("vector_storage_name", MemoryMarshal.AsBytes(new ReadOnlySpan(new float[] { 1, 2, 3, 4 })).ToArray()) + }; + var hashEntries2 = new HashEntry[] + { + new("OriginalNameData", "data 2"), + new("data_storage_name", "data 2"), + new("vector_storage_name", MemoryMarshal.AsBytes(new ReadOnlySpan(new float[] { 5, 6, 7, 8 })).ToArray()) + }; + this._redisDatabaseMock.Setup(x => x.HashGetAllAsync(It.IsAny(), CommandFlags.None)).Returns((RedisKey key, CommandFlags flags) => + { + return key switch + { + RedisKey k when k == TestRecordKey1 => Task.FromResult(hashEntries1), + RedisKey k when k == TestRecordKey2 => Task.FromResult(hashEntries2), + _ => throw new ArgumentException("Unexpected key."), + }; + }); + var sut = this.CreateRecordCollection(useDefinition); + + // Act + var actual = await sut.GetBatchAsync( + [TestRecordKey1, TestRecordKey2], + new() { IncludeVectors = true }).ToListAsync(); + + // Assert + this._redisDatabaseMock.Verify(x => x.HashGetAllAsync(TestRecordKey1, CommandFlags.None), Times.Once); + this._redisDatabaseMock.Verify(x => x.HashGetAllAsync(TestRecordKey2, CommandFlags.None), Times.Once); + + Assert.NotNull(actual); + Assert.Equal(2, actual.Count); + Assert.Equal(TestRecordKey1, actual[0].Key); + Assert.Equal("data 1", actual[0].OriginalNameData); + Assert.Equal("data 1", actual[0].Data); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual[0].Vector!.Value.ToArray()); + Assert.Equal(TestRecordKey2, actual[1].Key); + Assert.Equal("data 2", actual[1].OriginalNameData); + Assert.Equal("data 2", actual[1].Data); + Assert.Equal(new float[] { 5, 6, 7, 8 }, actual[1].Vector!.Value.ToArray()); + } + + [Fact] + public async Task CanGetRecordWithCustomMapperAsync() + { + // Arrange. + var hashEntries = new HashEntry[] + { + new("OriginalNameData", "data 1"), + new("data_storage_name", "data 1"), + new("vector_storage_name", MemoryMarshal.AsBytes(new ReadOnlySpan(new float[] { 1, 2, 3, 4 })).ToArray()) + }; + this._redisDatabaseMock.Setup(x => x.HashGetAllAsync(It.IsAny(), CommandFlags.None)).ReturnsAsync(hashEntries); + + // Arrange mapper mock from JsonNode to data model. + var mapperMock = new Mock>(MockBehavior.Strict); + mapperMock.Setup( + x => x.MapFromStorageToDataModel( + It.IsAny<(string key, HashEntry[] hashEntries)>(), + It.IsAny())) + .Returns(CreateModel(TestRecordKey1, true)); + + // Arrange target with custom mapper. + var sut = new RedisHashSetVectorStoreRecordCollection( + this._redisDatabaseMock.Object, + TestCollectionName, + new() + { + HashEntriesCustomMapper = mapperMock.Object + }); + + // Act + var actual = await sut.GetAsync( + TestRecordKey1, + new() { IncludeVectors = true }); + + // Assert + Assert.NotNull(actual); + Assert.Equal(TestRecordKey1, actual.Key); + Assert.Equal("data 1", actual.OriginalNameData); + Assert.Equal("data 1", actual.Data); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector!.Value.ToArray()); + + mapperMock + .Verify( + x => x.MapFromStorageToDataModel( + It.Is<(string key, HashEntry[] hashEntries)>(x => x.key == TestRecordKey1), + It.Is(x => x.IncludeVectors)), + Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanDeleteRecordAsync(bool useDefinition) + { + // Arrange + this._redisDatabaseMock.Setup(x => x.KeyDeleteAsync(It.IsAny(), CommandFlags.None)).ReturnsAsync(true); + var sut = this.CreateRecordCollection(useDefinition); + + // Act + await sut.DeleteAsync(TestRecordKey1); + + // Assert + this._redisDatabaseMock.Verify(x => x.KeyDeleteAsync(TestRecordKey1, CommandFlags.None), Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanDeleteManyRecordsWithVectorsAsync(bool useDefinition) + { + // Arrange + this._redisDatabaseMock.Setup(x => x.KeyDeleteAsync(It.IsAny(), CommandFlags.None)).ReturnsAsync(true); + var sut = this.CreateRecordCollection(useDefinition); + + // Act + await sut.DeleteBatchAsync([TestRecordKey1, TestRecordKey2]); + + // Assert + this._redisDatabaseMock.Verify(x => x.KeyDeleteAsync(TestRecordKey1, CommandFlags.None), Times.Once); + this._redisDatabaseMock.Verify(x => x.KeyDeleteAsync(TestRecordKey2, CommandFlags.None), Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanUpsertRecordAsync(bool useDefinition) + { + // Arrange + this._redisDatabaseMock.Setup(x => x.HashSetAsync(It.IsAny(), It.IsAny(), CommandFlags.None)).Returns(Task.CompletedTask); + var sut = this.CreateRecordCollection(useDefinition); + var model = CreateModel(TestRecordKey1, true); + + // Act + await sut.UpsertAsync(model); + + // Assert + this._redisDatabaseMock.Verify( + x => x.HashSetAsync( + TestRecordKey1, + It.Is(x => x.Length == 3 && x[0].Name == "OriginalNameData" && x[1].Name == "data_storage_name" && x[2].Name == "vector_storage_name"), + CommandFlags.None), + Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanUpsertManyRecordsAsync(bool useDefinition) + { + // Arrange + this._redisDatabaseMock.Setup(x => x.HashSetAsync(It.IsAny(), It.IsAny(), CommandFlags.None)).Returns(Task.CompletedTask); + var sut = this.CreateRecordCollection(useDefinition); + + var model1 = CreateModel(TestRecordKey1, true); + var model2 = CreateModel(TestRecordKey2, true); + + // Act + var actual = await sut.UpsertBatchAsync([model1, model2]).ToListAsync(); + + // Assert + Assert.NotNull(actual); + Assert.Equal(2, actual.Count); + Assert.Equal(TestRecordKey1, actual[0]); + Assert.Equal(TestRecordKey2, actual[1]); + + this._redisDatabaseMock.Verify( + x => x.HashSetAsync( + TestRecordKey1, + It.Is(x => x.Length == 3 && x[0].Name == "OriginalNameData" && x[1].Name == "data_storage_name" && x[2].Name == "vector_storage_name"), + CommandFlags.None), + Times.Once); + this._redisDatabaseMock.Verify( + x => x.HashSetAsync( + TestRecordKey2, + It.Is(x => x.Length == 3 && x[0].Name == "OriginalNameData" && x[1].Name == "data_storage_name" && x[2].Name == "vector_storage_name"), + CommandFlags.None), + Times.Once); + } + + [Fact] + public async Task CanUpsertRecordWithCustomMapperAsync() + { + // Arrange. + this._redisDatabaseMock.Setup(x => x.HashSetAsync(It.IsAny(), It.IsAny(), CommandFlags.None)).Returns(Task.CompletedTask); + + // Arrange mapper mock from data model to JsonNode. + var mapperMock = new Mock>(MockBehavior.Strict); + var hashEntries = new HashEntry[] + { + new("OriginalNameData", "data 1"), + new("data_storage_name", "data 1"), + new("vector_storage_name", "[1,2,3,4]"), + new("NotAnnotated", RedisValue.Null) + }; + mapperMock + .Setup(x => x.MapFromDataToStorageModel(It.IsAny())) + .Returns((TestRecordKey1, hashEntries)); + + // Arrange target with custom mapper. + var sut = new RedisHashSetVectorStoreRecordCollection( + this._redisDatabaseMock.Object, + TestCollectionName, + new() + { + HashEntriesCustomMapper = mapperMock.Object + }); + + var model = CreateModel(TestRecordKey1, true); + + // Act + await sut.UpsertAsync(model); + + // Assert + mapperMock + .Verify( + x => x.MapFromDataToStorageModel(It.Is(x => x == model)), + Times.Once); + } + + /// + /// Tests that the collection can be created even if the definition and the type do not match. + /// In this case, the expectation is that a custom mapper will be provided to map between the + /// schema as defined by the definition and the different data model. + /// + [Fact] + public void CanCreateCollectionWithMismatchedDefinitionAndType() + { + // Arrange. + var definition = new VectorStoreRecordDefinition() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Id", typeof(string)), + new VectorStoreRecordDataProperty("Text", typeof(string)), + new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { Dimensions = 4 }, + } + }; + + // Act. + var sut = new RedisHashSetVectorStoreRecordCollection( + this._redisDatabaseMock.Object, + TestCollectionName, + new() { VectorStoreRecordDefinition = definition, HashEntriesCustomMapper = Mock.Of>() }); + } + + private RedisHashSetVectorStoreRecordCollection CreateRecordCollection(bool useDefinition) + { + return new RedisHashSetVectorStoreRecordCollection( + this._redisDatabaseMock.Object, + TestCollectionName, + new() + { + PrefixCollectionNameToKeyNames = false, + VectorStoreRecordDefinition = useDefinition ? this._singlePropsDefinition : null + }); + } + + private static void SetupExecuteMock(Mock redisDatabaseMock, Exception exception) + { + redisDatabaseMock + .Setup( + x => x.ExecuteAsync( + It.IsAny(), + It.IsAny())) + .ThrowsAsync(exception); + } + + private static void SetupExecuteMock(Mock redisDatabaseMock, IEnumerable redisResultStrings) + { + var results = redisResultStrings + .Select(x => RedisResult.Create(new RedisValue(x))) + .ToArray(); + redisDatabaseMock + .Setup( + x => x.ExecuteAsync( + It.IsAny(), + It.IsAny())) + .ReturnsAsync(RedisResult.Create(results)); + } + + private static void SetupExecuteMock(Mock redisDatabaseMock, string redisResultString) + { + redisDatabaseMock + .Setup( + x => x.ExecuteAsync( + It.IsAny(), + It.IsAny())) + .Callback((string command, object[] args) => + { + Console.WriteLine(args); + }) + .ReturnsAsync(RedisResult.Create(new RedisValue(redisResultString))); + } + + private static SinglePropsModel CreateModel(string key, bool withVectors) + { + return new SinglePropsModel + { + Key = key, + OriginalNameData = "data 1", + Data = "data 1", + Vector = withVectors ? new float[] { 1, 2, 3, 4 } : null, + NotAnnotated = null, + }; + } + + private readonly VectorStoreRecordDefinition _singlePropsDefinition = new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("OriginalNameData", typeof(string)), + new VectorStoreRecordDataProperty("Data", typeof(string)) { StoragePropertyName = "data_storage_name" }, + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) { StoragePropertyName = "vector_storage_name" } + ] + }; + + public sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public string Key { get; set; } = string.Empty; + + [VectorStoreRecordData(IsFilterable = true)] + public string OriginalNameData { get; set; } = string.Empty; + + [JsonPropertyName("ignored_data_json_name")] + [VectorStoreRecordData(IsFilterable = true, StoragePropertyName = "data_storage_name")] + public string Data { get; set; } = string.Empty; + + [JsonPropertyName("ignored_vector_json_name")] + [VectorStoreRecordVector(4, StoragePropertyName = "vector_storage_name")] + public ReadOnlyMemory? Vector { get; set; } + + public string? NotAnnotated { get; set; } + } +} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordMapperTests.cs new file mode 100644 index 000000000000..fd7a56d8765c --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordMapperTests.cs @@ -0,0 +1,268 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using Microsoft.SemanticKernel.Connectors.Redis; +using Microsoft.SemanticKernel.Data; +using StackExchange.Redis; +using Xunit; + +namespace SemanticKernel.Connectors.Redis.UnitTests; + +/// +/// Contains tests for the class. +/// +public sealed class RedisHashSetVectorStoreRecordMapperTests +{ + [Fact] + public void MapsAllFieldsFromDataToStorageModel() + { + // Arrange. + var sut = new RedisHashSetVectorStoreRecordMapper(s_vectorStoreRecordDefinition, s_storagePropertyNames); + + // Act. + var actual = sut.MapFromDataToStorageModel(CreateModel("test key")); + + // Assert. + Assert.NotNull(actual.HashEntries); + Assert.Equal("test key", actual.Key); + + Assert.Equal("storage_string_data", actual.HashEntries[0].Name.ToString()); + Assert.Equal("data 1", actual.HashEntries[0].Value.ToString()); + + Assert.Equal("IntData", actual.HashEntries[1].Name.ToString()); + Assert.Equal(1, (int)actual.HashEntries[1].Value); + + Assert.Equal("UIntData", actual.HashEntries[2].Name.ToString()); + Assert.Equal(2u, (uint)actual.HashEntries[2].Value); + + Assert.Equal("LongData", actual.HashEntries[3].Name.ToString()); + Assert.Equal(3, (long)actual.HashEntries[3].Value); + + Assert.Equal("ULongData", actual.HashEntries[4].Name.ToString()); + Assert.Equal(4ul, (ulong)actual.HashEntries[4].Value); + + Assert.Equal("DoubleData", actual.HashEntries[5].Name.ToString()); + Assert.Equal(5.5d, (double)actual.HashEntries[5].Value); + + Assert.Equal("FloatData", actual.HashEntries[6].Name.ToString()); + Assert.Equal(6.6f, (float)actual.HashEntries[6].Value); + + Assert.Equal("BoolData", actual.HashEntries[7].Name.ToString()); + Assert.True((bool)actual.HashEntries[7].Value); + + Assert.Equal("NullableIntData", actual.HashEntries[8].Name.ToString()); + Assert.Equal(7, (int)actual.HashEntries[8].Value); + + Assert.Equal("NullableUIntData", actual.HashEntries[9].Name.ToString()); + Assert.Equal(8u, (uint)actual.HashEntries[9].Value); + + Assert.Equal("NullableLongData", actual.HashEntries[10].Name.ToString()); + Assert.Equal(9, (long)actual.HashEntries[10].Value); + + Assert.Equal("NullableULongData", actual.HashEntries[11].Name.ToString()); + Assert.Equal(10ul, (ulong)actual.HashEntries[11].Value); + + Assert.Equal("NullableDoubleData", actual.HashEntries[12].Name.ToString()); + Assert.Equal(11.1d, (double)actual.HashEntries[12].Value); + + Assert.Equal("NullableFloatData", actual.HashEntries[13].Name.ToString()); + Assert.Equal(12.2f, (float)actual.HashEntries[13].Value); + + Assert.Equal("NullableBoolData", actual.HashEntries[14].Name.ToString()); + Assert.False((bool)actual.HashEntries[14].Value); + + Assert.Equal("FloatVector", actual.HashEntries[15].Name.ToString()); + Assert.Equal(new float[] { 1, 2, 3, 4 }, MemoryMarshal.Cast((byte[])actual.HashEntries[15].Value!).ToArray()); + + Assert.Equal("DoubleVector", actual.HashEntries[16].Name.ToString()); + Assert.Equal(new double[] { 5, 6, 7, 8 }, MemoryMarshal.Cast((byte[])actual.HashEntries[16].Value!).ToArray()); + } + + [Fact] + public void MapsAllFieldsFromStorageToDataModel() + { + // Arrange. + var sut = new RedisHashSetVectorStoreRecordMapper(s_vectorStoreRecordDefinition, s_storagePropertyNames); + + // Act. + var actual = sut.MapFromStorageToDataModel(("test key", CreateHashSet()), new() { IncludeVectors = true }); + + // Assert. + Assert.NotNull(actual); + Assert.Equal("test key", actual.Key); + Assert.Equal("data 1", actual.StringData); + Assert.Equal(1, actual.IntData); + Assert.Equal(2u, actual.UIntData); + Assert.Equal(3, actual.LongData); + Assert.Equal(4ul, actual.ULongData); + Assert.Equal(5.5d, actual.DoubleData); + Assert.Equal(6.6f, actual.FloatData); + Assert.True(actual.BoolData); + Assert.Equal(7, actual.NullableIntData); + Assert.Equal(8u, actual.NullableUIntData); + Assert.Equal(9, actual.NullableLongData); + Assert.Equal(10ul, actual.NullableULongData); + Assert.Equal(11.1d, actual.NullableDoubleData); + Assert.Equal(12.2f, actual.NullableFloatData); + Assert.False(actual.NullableBoolData); + + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.FloatVector!.Value.ToArray()); + Assert.Equal(new double[] { 5, 6, 7, 8 }, actual.DoubleVector!.Value.ToArray()); + } + + private static AllTypesModel CreateModel(string key) + { + return new AllTypesModel + { + Key = key, + StringData = "data 1", + IntData = 1, + UIntData = 2, + LongData = 3, + ULongData = 4, + DoubleData = 5.5d, + FloatData = 6.6f, + BoolData = true, + NullableIntData = 7, + NullableUIntData = 8, + NullableLongData = 9, + NullableULongData = 10, + NullableDoubleData = 11.1d, + NullableFloatData = 12.2f, + NullableBoolData = false, + FloatVector = new float[] { 1, 2, 3, 4 }, + DoubleVector = new double[] { 5, 6, 7, 8 }, + NotAnnotated = "notAnnotated", + }; + } + + private static HashEntry[] CreateHashSet() + { + var hashSet = new HashEntry[17]; + hashSet[0] = new HashEntry("storage_string_data", "data 1"); + hashSet[1] = new HashEntry("IntData", 1); + hashSet[2] = new HashEntry("UIntData", 2); + hashSet[3] = new HashEntry("LongData", 3); + hashSet[4] = new HashEntry("ULongData", 4); + hashSet[5] = new HashEntry("DoubleData", 5.5); + hashSet[6] = new HashEntry("FloatData", 6.6); + hashSet[7] = new HashEntry("BoolData", true); + hashSet[8] = new HashEntry("NullableIntData", 7); + hashSet[9] = new HashEntry("NullableUIntData", 8); + hashSet[10] = new HashEntry("NullableLongData", 9); + hashSet[11] = new HashEntry("NullableULongData", 10); + hashSet[12] = new HashEntry("NullableDoubleData", 11.1); + hashSet[13] = new HashEntry("NullableFloatData", 12.2); + hashSet[14] = new HashEntry("NullableBoolData", false); + hashSet[15] = new HashEntry("FloatVector", MemoryMarshal.AsBytes(new ReadOnlySpan(new float[] { 1, 2, 3, 4 })).ToArray()); + hashSet[16] = new HashEntry("DoubleVector", MemoryMarshal.AsBytes(new ReadOnlySpan(new double[] { 5, 6, 7, 8 })).ToArray()); + return hashSet; + } + + private static readonly Dictionary s_storagePropertyNames = new() + { + ["StringData"] = "storage_string_data", + ["IntData"] = "IntData", + ["UIntData"] = "UIntData", + ["LongData"] = "LongData", + ["ULongData"] = "ULongData", + ["DoubleData"] = "DoubleData", + ["FloatData"] = "FloatData", + ["BoolData"] = "BoolData", + ["NullableIntData"] = "NullableIntData", + ["NullableUIntData"] = "NullableUIntData", + ["NullableLongData"] = "NullableLongData", + ["NullableULongData"] = "NullableULongData", + ["NullableDoubleData"] = "NullableDoubleData", + ["NullableFloatData"] = "NullableFloatData", + ["NullableBoolData"] = "NullableBoolData", + ["FloatVector"] = "FloatVector", + ["DoubleVector"] = "DoubleVector", + }; + + private static readonly VectorStoreRecordDefinition s_vectorStoreRecordDefinition = new() + { + Properties = new List() + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringData", typeof(string)), + new VectorStoreRecordDataProperty("IntData", typeof(int)), + new VectorStoreRecordDataProperty("UIntData", typeof(uint)), + new VectorStoreRecordDataProperty("LongData", typeof(long)), + new VectorStoreRecordDataProperty("ULongData", typeof(ulong)), + new VectorStoreRecordDataProperty("DoubleData", typeof(double)), + new VectorStoreRecordDataProperty("FloatData", typeof(float)), + new VectorStoreRecordDataProperty("BoolData", typeof(bool)), + new VectorStoreRecordDataProperty("NullableIntData", typeof(int?)), + new VectorStoreRecordDataProperty("NullableUIntData", typeof(uint?)), + new VectorStoreRecordDataProperty("NullableLongData", typeof(long?)), + new VectorStoreRecordDataProperty("NullableULongData", typeof(ulong?)), + new VectorStoreRecordDataProperty("NullableDoubleData", typeof(double?)), + new VectorStoreRecordDataProperty("NullableFloatData", typeof(float?)), + new VectorStoreRecordDataProperty("NullableBoolData", typeof(bool?)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(float)), + new VectorStoreRecordVectorProperty("DoubleVector", typeof(double)), + } + }; + + private sealed class AllTypesModel + { + [VectorStoreRecordKey] + public string Key { get; set; } = string.Empty; + + [VectorStoreRecordData] + public string StringData { get; set; } = string.Empty; + + [VectorStoreRecordData] + public int IntData { get; set; } + + [VectorStoreRecordData] + public uint UIntData { get; set; } + + [VectorStoreRecordData] + public long LongData { get; set; } + + [VectorStoreRecordData] + public ulong ULongData { get; set; } + + [VectorStoreRecordData] + public double DoubleData { get; set; } + + [VectorStoreRecordData] + public float FloatData { get; set; } + + [VectorStoreRecordData] + public bool BoolData { get; set; } + + [VectorStoreRecordData] + public int? NullableIntData { get; set; } + + [VectorStoreRecordData] + public uint? NullableUIntData { get; set; } + + [VectorStoreRecordData] + public long? NullableLongData { get; set; } + + [VectorStoreRecordData] + public ulong? NullableULongData { get; set; } + + [VectorStoreRecordData] + public double? NullableDoubleData { get; set; } + + [VectorStoreRecordData] + public float? NullableFloatData { get; set; } + + [VectorStoreRecordData] + public bool? NullableBoolData { get; set; } + + [VectorStoreRecordVector] + public ReadOnlyMemory? FloatVector { get; set; } + + [VectorStoreRecordVector] + public ReadOnlyMemory? DoubleVector { get; set; } + + public string NotAnnotated { get; set; } = string.Empty; + } +} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..58cda992db4d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs @@ -0,0 +1,568 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Data; +using Moq; +using NRedisStack; +using StackExchange.Redis; +using Xunit; + +namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; + +/// +/// Contains tests for the class. +/// +public class RedisJsonVectorStoreRecordCollectionTests +{ + private const string TestCollectionName = "testcollection"; + private const string TestRecordKey1 = "testid1"; + private const string TestRecordKey2 = "testid2"; + + private readonly Mock _redisDatabaseMock; + + public RedisJsonVectorStoreRecordCollectionTests() + { + this._redisDatabaseMock = new Mock(MockBehavior.Strict); + + var batchMock = new Mock(); + this._redisDatabaseMock.Setup(x => x.CreateBatch(It.IsAny())).Returns(batchMock.Object); + } + + [Theory] + [InlineData(TestCollectionName, true)] + [InlineData("nonexistentcollection", false)] + public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) + { + // Arrange + if (expectedExists) + { + SetupExecuteMock(this._redisDatabaseMock, ["index_name", collectionName]); + } + else + { + SetupExecuteMock(this._redisDatabaseMock, new RedisServerException("Unknown index name")); + } + var sut = new RedisJsonVectorStoreRecordCollection( + this._redisDatabaseMock.Object, + collectionName); + + // Act + var actual = await sut.CollectionExistsAsync(); + + // Assert + var expectedArgs = new object[] { collectionName }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "FT.INFO", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + Assert.Equal(expectedExists, actual); + } + + [Theory] + [InlineData(true, true, "data2", "vector2")] + [InlineData(true, false, "Data2", "Vector2")] + [InlineData(false, true, "data2", "vector2")] + [InlineData(false, false, "Data2", "Vector2")] + public async Task CanCreateCollectionAsync(bool useDefinition, bool useCustomJsonSerializerOptions, string expectedData2Name, string expectedVector2Name) + { + // Arrange. + SetupExecuteMock(this._redisDatabaseMock, string.Empty); + var sut = this.CreateRecordCollection(useDefinition, useCustomJsonSerializerOptions); + + // Act. + await sut.CreateCollectionAsync(); + + // Assert. + var expectedArgs = new object[] { + "testcollection", + "ON", + "JSON", + "PREFIX", + 1, + "testcollection:", + "SCHEMA", + "$.data1_json_name", + "AS", + "data1_json_name", + "TAG", + $"$.{expectedData2Name}", + "AS", + expectedData2Name, + "TAG", + "$.vector1_json_name", + "AS", + "vector1_json_name", + "VECTOR", + "HNSW", + 6, + "TYPE", + "FLOAT32", + "DIM", + "4", + "DISTANCE_METRIC", + "COSINE", + $"$.{expectedVector2Name}", + "AS", + expectedVector2Name, + "VECTOR", + "HNSW", + 6, + "TYPE", + "FLOAT32", + "DIM", + "4", + "DISTANCE_METRIC", + "COSINE" }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "FT.CREATE", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + } + + [Fact] + public async Task CanDeleteCollectionAsync() + { + // Arrange + SetupExecuteMock(this._redisDatabaseMock, string.Empty); + var sut = this.CreateRecordCollection(false); + + // Act + await sut.DeleteCollectionAsync(); + + // Assert + var expectedArgs = new object[] { TestCollectionName }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "FT.DROPINDEX", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + } + + [Theory] + [InlineData(true, true, """{ "data1_json_name": "data 1", "data2": "data 2", "vector1_json_name": [1, 2, 3, 4], "vector2": [1, 2, 3, 4] }""")] + [InlineData(true, false, """{ "data1_json_name": "data 1", "Data2": "data 2", "vector1_json_name": [1, 2, 3, 4], "Vector2": [1, 2, 3, 4] }""")] + [InlineData(false, true, """{ "data1_json_name": "data 1", "data2": "data 2", "vector1_json_name": [1, 2, 3, 4], "vector2": [1, 2, 3, 4] }""")] + [InlineData(false, false, """{ "data1_json_name": "data 1", "Data2": "data 2", "vector1_json_name": [1, 2, 3, 4], "Vector2": [1, 2, 3, 4] }""")] + public async Task CanGetRecordWithVectorsAsync(bool useDefinition, bool useCustomJsonSerializerOptions, string redisResultString) + { + // Arrange + SetupExecuteMock(this._redisDatabaseMock, redisResultString); + var sut = this.CreateRecordCollection(useDefinition, useCustomJsonSerializerOptions); + + // Act + var actual = await sut.GetAsync( + TestRecordKey1, + new() { IncludeVectors = true }); + + // Assert + var expectedArgs = new object[] { TestRecordKey1 }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "JSON.GET", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + + Assert.NotNull(actual); + Assert.Equal(TestRecordKey1, actual.Key); + Assert.Equal("data 1", actual.Data1); + Assert.Equal("data 2", actual.Data2); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector1!.Value.ToArray()); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector2!.Value.ToArray()); + } + + [Theory] + [InlineData(true, true, """{ "data1_json_name": "data 1", "data2": "data 2" }""", "data2")] + [InlineData(true, false, """{ "data1_json_name": "data 1", "Data2": "data 2" }""", "Data2")] + [InlineData(false, true, """{ "data1_json_name": "data 1", "data2": "data 2" }""", "data2")] + [InlineData(false, false, """{ "data1_json_name": "data 1", "Data2": "data 2" }""", "Data2")] + public async Task CanGetRecordWithoutVectorsAsync(bool useDefinition, bool useCustomJsonSerializerOptions, string redisResultString, string expectedData2Name) + { + // Arrange + SetupExecuteMock(this._redisDatabaseMock, redisResultString); + var sut = this.CreateRecordCollection(useDefinition, useCustomJsonSerializerOptions); + + // Act + var actual = await sut.GetAsync( + TestRecordKey1, + new() { IncludeVectors = false }); + + // Assert + var expectedArgs = new object[] { TestRecordKey1, "data1_json_name", expectedData2Name }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "JSON.GET", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + + Assert.NotNull(actual); + Assert.Equal(TestRecordKey1, actual.Key); + Assert.Equal("data 1", actual.Data1); + Assert.Equal("data 2", actual.Data2); + Assert.False(actual.Vector1.HasValue); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition) + { + // Arrange + var redisResultString1 = """{ "data1_json_name": "data 1", "Data2": "data 2", "vector1_json_name": [1, 2, 3, 4], "Vector2": [1, 2, 3, 4] }"""; + var redisResultString2 = """{ "data1_json_name": "data 1", "Data2": "data 2", "vector1_json_name": [5, 6, 7, 8], "Vector2": [1, 2, 3, 4] }"""; + SetupExecuteMock(this._redisDatabaseMock, [redisResultString1, redisResultString2]); + var sut = this.CreateRecordCollection(useDefinition); + + // Act + var actual = await sut.GetBatchAsync( + [TestRecordKey1, TestRecordKey2], + new() { IncludeVectors = true }).ToListAsync(); + + // Assert + var expectedArgs = new object[] { TestRecordKey1, TestRecordKey2, "$" }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "JSON.MGET", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + + Assert.NotNull(actual); + Assert.Equal(2, actual.Count); + Assert.Equal(TestRecordKey1, actual[0].Key); + Assert.Equal("data 1", actual[0].Data1); + Assert.Equal("data 2", actual[0].Data2); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual[0].Vector1!.Value.ToArray()); + Assert.Equal(TestRecordKey2, actual[1].Key); + Assert.Equal("data 1", actual[1].Data1); + Assert.Equal("data 2", actual[1].Data2); + Assert.Equal(new float[] { 5, 6, 7, 8 }, actual[1].Vector1!.Value.ToArray()); + } + + [Fact] + public async Task CanGetRecordWithCustomMapperAsync() + { + // Arrange. + var redisResultString = """{ "data1_json_name": "data 1", "Data2": "data 2", "vector1_json_name": [1, 2, 3, 4], "Vector2": [1, 2, 3, 4] }"""; + SetupExecuteMock(this._redisDatabaseMock, redisResultString); + + // Arrange mapper mock from JsonNode to data model. + var mapperMock = new Mock>(MockBehavior.Strict); + mapperMock.Setup( + x => x.MapFromStorageToDataModel( + It.IsAny<(string key, JsonNode node)>(), + It.IsAny())) + .Returns(CreateModel(TestRecordKey1, true)); + + // Arrange target with custom mapper. + var sut = new RedisJsonVectorStoreRecordCollection( + this._redisDatabaseMock.Object, + TestCollectionName, + new() + { + JsonNodeCustomMapper = mapperMock.Object + }); + + // Act + var actual = await sut.GetAsync( + TestRecordKey1, + new() { IncludeVectors = true }); + + // Assert + Assert.NotNull(actual); + Assert.Equal(TestRecordKey1, actual.Key); + Assert.Equal("data 1", actual.Data1); + Assert.Equal("data 2", actual.Data2); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector1!.Value.ToArray()); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector2!.Value.ToArray()); + + mapperMock + .Verify( + x => x.MapFromStorageToDataModel( + It.Is<(string key, JsonNode node)>(x => x.key == TestRecordKey1), + It.Is(x => x.IncludeVectors)), + Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanDeleteRecordAsync(bool useDefinition) + { + // Arrange + SetupExecuteMock(this._redisDatabaseMock, "200"); + var sut = this.CreateRecordCollection(useDefinition); + + // Act + await sut.DeleteAsync(TestRecordKey1); + + // Assert + var expectedArgs = new object[] { TestRecordKey1 }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "JSON.DEL", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanDeleteManyRecordsWithVectorsAsync(bool useDefinition) + { + // Arrange + SetupExecuteMock(this._redisDatabaseMock, "200"); + var sut = this.CreateRecordCollection(useDefinition); + + // Act + await sut.DeleteBatchAsync([TestRecordKey1, TestRecordKey2]); + + // Assert + var expectedArgs1 = new object[] { TestRecordKey1 }; + var expectedArgs2 = new object[] { TestRecordKey2 }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "JSON.DEL", + It.Is(x => x.SequenceEqual(expectedArgs1))), + Times.Once); + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "JSON.DEL", + It.Is(x => x.SequenceEqual(expectedArgs2))), + Times.Once); + } + + [Theory] + [InlineData(true, true, """{"data1_json_name":"data 1","data2":"data 2","vector1_json_name":[1,2,3,4],"vector2":[1,2,3,4],"notAnnotated":null}""")] + [InlineData(true, false, """{"data1_json_name":"data 1","Data2":"data 2","vector1_json_name":[1,2,3,4],"Vector2":[1,2,3,4],"NotAnnotated":null}""")] + [InlineData(false, true, """{"data1_json_name":"data 1","data2":"data 2","vector1_json_name":[1,2,3,4],"vector2":[1,2,3,4],"notAnnotated":null}""")] + [InlineData(false, false, """{"data1_json_name":"data 1","Data2":"data 2","vector1_json_name":[1,2,3,4],"Vector2":[1,2,3,4],"NotAnnotated":null}""")] + public async Task CanUpsertRecordAsync(bool useDefinition, bool useCustomJsonSerializerOptions, string expectedUpsertedJson) + { + // Arrange + SetupExecuteMock(this._redisDatabaseMock, "OK"); + var sut = this.CreateRecordCollection(useDefinition, useCustomJsonSerializerOptions); + var model = CreateModel(TestRecordKey1, true); + + // Act + await sut.UpsertAsync(model); + + // Assert + // TODO: Fix issue where NotAnnotated is being included in the JSON. + var expectedArgs = new object[] { TestRecordKey1, "$", expectedUpsertedJson }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "JSON.SET", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CanUpsertManyRecordsAsync(bool useDefinition) + { + // Arrange + SetupExecuteMock(this._redisDatabaseMock, "OK"); + var sut = this.CreateRecordCollection(useDefinition); + + var model1 = CreateModel(TestRecordKey1, true); + var model2 = CreateModel(TestRecordKey2, true); + + // Act + var actual = await sut.UpsertBatchAsync([model1, model2]).ToListAsync(); + + // Assert + Assert.NotNull(actual); + Assert.Equal(2, actual.Count); + Assert.Equal(TestRecordKey1, actual[0]); + Assert.Equal(TestRecordKey2, actual[1]); + + // TODO: Fix issue where NotAnnotated is being included in the JSON. + var expectedArgs = new object[] { TestRecordKey1, "$", """{"data1_json_name":"data 1","Data2":"data 2","vector1_json_name":[1,2,3,4],"Vector2":[1,2,3,4],"NotAnnotated":null}""", TestRecordKey2, "$", """{"data1_json_name":"data 1","Data2":"data 2","vector1_json_name":[1,2,3,4],"Vector2":[1,2,3,4],"NotAnnotated":null}""" }; + this._redisDatabaseMock + .Verify( + x => x.ExecuteAsync( + "JSON.MSET", + It.Is(x => x.SequenceEqual(expectedArgs))), + Times.Once); + } + + [Fact] + public async Task CanUpsertRecordWithCustomMapperAsync() + { + // Arrange. + SetupExecuteMock(this._redisDatabaseMock, "OK"); + + // Arrange mapper mock from data model to JsonNode. + var mapperMock = new Mock>(MockBehavior.Strict); + var jsonNode = """{"data1_json_name":"data 1","Data2": "data 2","vector1_json_name":[1,2,3,4],"Vector2":[1,2,3,4],"NotAnnotated":null}"""; + mapperMock + .Setup(x => x.MapFromDataToStorageModel(It.IsAny())) + .Returns((TestRecordKey1, JsonNode.Parse(jsonNode)!)); + + // Arrange target with custom mapper. + var sut = new RedisJsonVectorStoreRecordCollection( + this._redisDatabaseMock.Object, + TestCollectionName, + new() + { + JsonNodeCustomMapper = mapperMock.Object + }); + + var model = CreateModel(TestRecordKey1, true); + + // Act + await sut.UpsertAsync(model); + + // Assert + mapperMock + .Verify( + x => x.MapFromDataToStorageModel(It.Is(x => x == model)), + Times.Once); + } + + /// + /// Tests that the collection can be created even if the definition and the type do not match. + /// In this case, the expectation is that a custom mapper will be provided to map between the + /// schema as defined by the definition and the different data model. + /// + [Fact] + public void CanCreateCollectionWithMismatchedDefinitionAndType() + { + // Arrange. + var definition = new VectorStoreRecordDefinition() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Id", typeof(string)), + new VectorStoreRecordDataProperty("Text", typeof(string)), + new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { Dimensions = 4 }, + } + }; + + // Act. + var sut = new RedisJsonVectorStoreRecordCollection( + this._redisDatabaseMock.Object, + TestCollectionName, + new() { VectorStoreRecordDefinition = definition, JsonNodeCustomMapper = Mock.Of>() }); + } + + private RedisJsonVectorStoreRecordCollection CreateRecordCollection(bool useDefinition, bool useCustomJsonSerializerOptions = false) + { + return new RedisJsonVectorStoreRecordCollection( + this._redisDatabaseMock.Object, + TestCollectionName, + new() + { + PrefixCollectionNameToKeyNames = false, + VectorStoreRecordDefinition = useDefinition ? this._multiPropsDefinition : null, + JsonSerializerOptions = useCustomJsonSerializerOptions ? this._customJsonSerializerOptions : null + }); + } + + private static void SetupExecuteMock(Mock redisDatabaseMock, Exception exception) + { + redisDatabaseMock + .Setup( + x => x.ExecuteAsync( + It.IsAny(), + It.IsAny())) + .ThrowsAsync(exception); + } + + private static void SetupExecuteMock(Mock redisDatabaseMock, IEnumerable redisResultStrings) + { + var results = redisResultStrings + .Select(x => RedisResult.Create(new RedisValue(x))) + .ToArray(); + redisDatabaseMock + .Setup( + x => x.ExecuteAsync( + It.IsAny(), + It.IsAny())) + .ReturnsAsync(RedisResult.Create(results)); + } + + private static void SetupExecuteMock(Mock redisDatabaseMock, string redisResultString) + { + redisDatabaseMock + .Setup( + x => x.ExecuteAsync( + It.IsAny(), + It.IsAny())) + .Callback((string command, object[] args) => + { + Console.WriteLine(args); + }) + .ReturnsAsync(RedisResult.Create(new RedisValue(redisResultString))); + } + + private static MultiPropsModel CreateModel(string key, bool withVectors) + { + return new MultiPropsModel + { + Key = key, + Data1 = "data 1", + Data2 = "data 2", + Vector1 = withVectors ? new float[] { 1, 2, 3, 4 } : null, + Vector2 = withVectors ? new float[] { 1, 2, 3, 4 } : null, + NotAnnotated = null, + }; + } + + private readonly JsonSerializerOptions _customJsonSerializerOptions = new() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }; + + private readonly VectorStoreRecordDefinition _multiPropsDefinition = new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("Data1", typeof(string)) { IsFilterable = true, StoragePropertyName = "ignored_data1_storage_name" }, + new VectorStoreRecordDataProperty("Data2", typeof(string)) { IsFilterable = true }, + new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory)) { Dimensions = 4, StoragePropertyName = "ignored_vector1_storage_name" }, + new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory)) { Dimensions = 4 } + ] + }; + + public sealed class MultiPropsModel + { + [VectorStoreRecordKey] + public string Key { get; set; } = string.Empty; + + [JsonPropertyName("data1_json_name")] + [VectorStoreRecordData(IsFilterable = true, StoragePropertyName = "ignored_data1_storage_name")] + public string Data1 { get; set; } = string.Empty; + + [VectorStoreRecordData(IsFilterable = true)] + public string Data2 { get; set; } = string.Empty; + + [JsonPropertyName("vector1_json_name")] + [VectorStoreRecordVector(4, StoragePropertyName = "ignored_vector1_storage_name")] + public ReadOnlyMemory? Vector1 { get; set; } + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? Vector2 { get; set; } + + public string? NotAnnotated { get; set; } + } +} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordMapperTests.cs new file mode 100644 index 000000000000..a7ae97c06355 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordMapperTests.cs @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.SemanticKernel.Connectors.Redis; +using Microsoft.SemanticKernel.Data; +using Xunit; + +namespace SemanticKernel.Connectors.Redis.UnitTests; + +/// +/// Contains tests for the class. +/// +public sealed class RedisJsonVectorStoreRecordMapperTests +{ + [Fact] + public void MapsAllFieldsFromDataToStorageModel() + { + // Arrange. + var sut = new RedisJsonVectorStoreRecordMapper("Key", JsonSerializerOptions.Default); + + // Act. + var actual = sut.MapFromDataToStorageModel(CreateModel("test key")); + + // Assert. + Assert.NotNull(actual.Node); + Assert.Equal("test key", actual.Key); + var jsonObject = actual.Node.AsObject(); + Assert.Equal("data 1", jsonObject?["Data1"]?.ToString()); + Assert.Equal("data 2", jsonObject?["Data2"]?.ToString()); + Assert.Equal(new float[] { 1, 2, 3, 4 }, jsonObject?["Vector1"]?.AsArray().GetValues().ToArray()); + Assert.Equal(new float[] { 5, 6, 7, 8 }, jsonObject?["Vector2"]?.AsArray().GetValues().ToArray()); + } + + [Fact] + public void MapsAllFieldsFromDataToStorageModelWithCustomSerializerOptions() + { + // Arrange. + var sut = new RedisJsonVectorStoreRecordMapper("key", new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + + // Act. + var actual = sut.MapFromDataToStorageModel(CreateModel("test key")); + + // Assert. + Assert.NotNull(actual.Node); + Assert.Equal("test key", actual.Key); + var jsonObject = actual.Node.AsObject(); + Assert.Equal("data 1", jsonObject?["data1"]?.ToString()); + Assert.Equal("data 2", jsonObject?["data2"]?.ToString()); + Assert.Equal(new float[] { 1, 2, 3, 4 }, jsonObject?["vector1"]?.AsArray().GetValues().ToArray()); + Assert.Equal(new float[] { 5, 6, 7, 8 }, jsonObject?["vector2"]?.AsArray().GetValues().ToArray()); + } + + [Fact] + public void MapsAllFieldsFromStorageToDataModel() + { + // Arrange. + var sut = new RedisJsonVectorStoreRecordMapper("Key", JsonSerializerOptions.Default); + + // Act. + var jsonObject = new JsonObject(); + jsonObject.Add("Data1", "data 1"); + jsonObject.Add("Data2", "data 2"); + jsonObject.Add("Vector1", new JsonArray(new[] { 1, 2, 3, 4 }.Select(x => JsonValue.Create(x)).ToArray())); + jsonObject.Add("Vector2", new JsonArray(new[] { 5, 6, 7, 8 }.Select(x => JsonValue.Create(x)).ToArray())); + var actual = sut.MapFromStorageToDataModel(("test key", jsonObject), new()); + + // Assert. + Assert.NotNull(actual); + Assert.Equal("test key", actual.Key); + Assert.Equal("data 1", actual.Data1); + Assert.Equal("data 2", actual.Data2); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector1!.Value.ToArray()); + Assert.Equal(new float[] { 5, 6, 7, 8 }, actual.Vector2!.Value.ToArray()); + } + + [Fact] + public void MapsAllFieldsFromStorageToDataModelWithCustomSerializerOptions() + { + // Arrange. + var sut = new RedisJsonVectorStoreRecordMapper("key", new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + + // Act. + var jsonObject = new JsonObject(); + jsonObject.Add("data1", "data 1"); + jsonObject.Add("data2", "data 2"); + jsonObject.Add("vector1", new JsonArray(new[] { 1, 2, 3, 4 }.Select(x => JsonValue.Create(x)).ToArray())); + jsonObject.Add("vector2", new JsonArray(new[] { 5, 6, 7, 8 }.Select(x => JsonValue.Create(x)).ToArray())); + var actual = sut.MapFromStorageToDataModel(("test key", jsonObject), new()); + + // Assert. + Assert.NotNull(actual); + Assert.Equal("test key", actual.Key); + Assert.Equal("data 1", actual.Data1); + Assert.Equal("data 2", actual.Data2); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector1!.Value.ToArray()); + Assert.Equal(new float[] { 5, 6, 7, 8 }, actual.Vector2!.Value.ToArray()); + } + + private static MultiPropsModel CreateModel(string key) + { + return new MultiPropsModel + { + Key = key, + Data1 = "data 1", + Data2 = "data 2", + Vector1 = new float[] { 1, 2, 3, 4 }, + Vector2 = new float[] { 5, 6, 7, 8 }, + NotAnnotated = "notAnnotated", + }; + } + + private sealed class MultiPropsModel + { + [VectorStoreRecordKey] + public string Key { get; set; } = string.Empty; + + [VectorStoreRecordData] + public string Data1 { get; set; } = string.Empty; + + [VectorStoreRecordData] + public string Data2 { get; set; } = string.Empty; + + [VectorStoreRecordVector] + public ReadOnlyMemory? Vector1 { get; set; } + + [VectorStoreRecordVector] + public ReadOnlyMemory? Vector2 { get; set; } + + public string NotAnnotated { get; set; } = string.Empty; + } +} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisKernelBuilderExtensionsTests.cs new file mode 100644 index 000000000000..dcb8383b1525 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisKernelBuilderExtensionsTests.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Redis; +using Microsoft.SemanticKernel.Data; +using Moq; +using StackExchange.Redis; +using Xunit; + +namespace SemanticKernel.Connectors.Redis.UnitTests; + +/// +/// Tests for the class. +/// +public class RedisKernelBuilderExtensionsTests +{ + private readonly IKernelBuilder _kernelBuilder; + + public RedisKernelBuilderExtensionsTests() + { + this._kernelBuilder = Kernel.CreateBuilder(); + } + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange. + this._kernelBuilder.Services.AddSingleton(Mock.Of()); + + // Act. + this._kernelBuilder.AddRedisVectorStore(); + + // Assert. + this.AssertVectorStoreCreated(); + } + + private void AssertVectorStoreCreated() + { + var kernel = this._kernelBuilder.Build(); + var vectorStore = kernel.Services.GetRequiredService(); + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } +} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..fe08b6d568b6 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisServiceCollectionExtensionsTests.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Connectors.Redis; +using Microsoft.SemanticKernel.Data; +using Moq; +using StackExchange.Redis; +using Xunit; + +namespace SemanticKernel.Connectors.Redis.UnitTests; + +/// +/// Tests for the class. +/// +public class RedisServiceCollectionExtensionsTests +{ + private readonly IServiceCollection _serviceCollection; + + public RedisServiceCollectionExtensionsTests() + { + this._serviceCollection = new ServiceCollection(); + } + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange. + this._serviceCollection.AddSingleton(Mock.Of()); + + // Act. + this._serviceCollection.AddRedisVectorStore(); + + // Assert. + this.AssertVectorStoreCreated(); + } + + private void AssertVectorStoreCreated() + { + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + var vectorStore = serviceProvider.GetRequiredService(); + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } +} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionCreateMappingTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionCreateMappingTests.cs new file mode 100644 index 000000000000..c5bb3b12b2c5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionCreateMappingTests.cs @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.SemanticKernel.Data; +using NRedisStack.Search; +using Xunit; +using static NRedisStack.Search.Schema; + +namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; + +/// +/// Contains tests for the class. +/// +public class RedisVectorStoreCollectionCreateMappingTests +{ + [Fact] + public void MapToSchemaCreatesSchema() + { + // Arrange. + var properties = new VectorStoreRecordProperty[] + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + + new VectorStoreRecordDataProperty("FilterableString", typeof(string)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("FullTextSearchableString", typeof(string)) { IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("FilterableStringEnumerable", typeof(string[])) { IsFilterable = true }, + new VectorStoreRecordDataProperty("FullTextSearchableStringEnumerable", typeof(string[])) { IsFullTextSearchable = true }, + + new VectorStoreRecordDataProperty("FilterableInt", typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("FilterableNullableInt", typeof(int)) { IsFilterable = true }, + + new VectorStoreRecordDataProperty("NonFilterableString", typeof(string)), + + new VectorStoreRecordVectorProperty("VectorDefaultIndexingOptions", typeof(ReadOnlyMemory)) { Dimensions = 10 }, + new VectorStoreRecordVectorProperty("VectorSpecificIndexingOptions", typeof(ReadOnlyMemory)) { Dimensions = 20, IndexKind = IndexKind.Flat, DistanceFunction = DistanceFunction.EuclideanDistance }, + }; + + var storagePropertyNames = new Dictionary() + { + { "FilterableString", "FilterableString" }, + { "FullTextSearchableString", "FullTextSearchableString" }, + { "FilterableStringEnumerable", "FilterableStringEnumerable" }, + { "FullTextSearchableStringEnumerable", "FullTextSearchableStringEnumerable" }, + { "FilterableInt", "FilterableInt" }, + { "FilterableNullableInt", "FilterableNullableInt" }, + { "NonFilterableString", "NonFilterableString" }, + { "VectorDefaultIndexingOptions", "VectorDefaultIndexingOptions" }, + { "VectorSpecificIndexingOptions", "vector_specific_indexing_options" }, + }; + + // Act. + var schema = RedisVectorStoreCollectionCreateMapping.MapToSchema(properties, storagePropertyNames); + + // Assert. + Assert.NotNull(schema); + Assert.Equal(8, schema.Fields.Count); + + Assert.IsType(schema.Fields[0]); + Assert.IsType(schema.Fields[1]); + Assert.IsType(schema.Fields[2]); + Assert.IsType(schema.Fields[3]); + Assert.IsType(schema.Fields[4]); + Assert.IsType(schema.Fields[5]); + Assert.IsType(schema.Fields[6]); + Assert.IsType(schema.Fields[7]); + + VerifyFieldName(schema.Fields[0].FieldName, new List { "$.FilterableString", "AS", "FilterableString" }); + VerifyFieldName(schema.Fields[1].FieldName, new List { "$.FullTextSearchableString", "AS", "FullTextSearchableString" }); + VerifyFieldName(schema.Fields[2].FieldName, new List { "$.FilterableStringEnumerable.*", "AS", "FilterableStringEnumerable" }); + VerifyFieldName(schema.Fields[3].FieldName, new List { "$.FullTextSearchableStringEnumerable", "AS", "FullTextSearchableStringEnumerable" }); + + VerifyFieldName(schema.Fields[4].FieldName, new List { "$.FilterableInt", "AS", "FilterableInt" }); + VerifyFieldName(schema.Fields[5].FieldName, new List { "$.FilterableNullableInt", "AS", "FilterableNullableInt" }); + + VerifyFieldName(schema.Fields[6].FieldName, new List { "$.VectorDefaultIndexingOptions", "AS", "VectorDefaultIndexingOptions" }); + VerifyFieldName(schema.Fields[7].FieldName, new List { "$.vector_specific_indexing_options", "AS", "vector_specific_indexing_options" }); + + Assert.Equal("10", ((VectorField)schema.Fields[6]).Attributes!["DIM"]); + Assert.Equal("FLOAT32", ((VectorField)schema.Fields[6]).Attributes!["TYPE"]); + Assert.Equal("COSINE", ((VectorField)schema.Fields[6]).Attributes!["DISTANCE_METRIC"]); + + Assert.Equal("20", ((VectorField)schema.Fields[7]).Attributes!["DIM"]); + Assert.Equal("FLOAT32", ((VectorField)schema.Fields[7]).Attributes!["TYPE"]); + Assert.Equal("L2", ((VectorField)schema.Fields[7]).Attributes!["DISTANCE_METRIC"]); + } + + [Theory] + [InlineData(null)] + [InlineData(0)] + public void MapToSchemaThrowsOnInvalidVectorDimensions(int? dimensions) + { + // Arrange. + var properties = new VectorStoreRecordProperty[] { new VectorStoreRecordVectorProperty("VectorProperty", typeof(ReadOnlyMemory)) { Dimensions = dimensions } }; + var storagePropertyNames = new Dictionary() { { "VectorProperty", "VectorProperty" } }; + + // Act and assert. + Assert.Throws(() => RedisVectorStoreCollectionCreateMapping.MapToSchema(properties, storagePropertyNames)); + } + + [Fact] + public void GetSDKIndexKindThrowsOnUnsupportedIndexKind() + { + // Arrange. + var vectorProperty = new VectorStoreRecordVectorProperty("VectorProperty", typeof(ReadOnlyMemory)) { IndexKind = "Unsupported" }; + + // Act and assert. + Assert.Throws(() => RedisVectorStoreCollectionCreateMapping.GetSDKIndexKind(vectorProperty)); + } + + [Fact] + public void GetSDKDistanceAlgorithmThrowsOnUnsupportedDistanceFunction() + { + // Arrange. + var vectorProperty = new VectorStoreRecordVectorProperty("VectorProperty", typeof(ReadOnlyMemory)) { DistanceFunction = "Unsupported" }; + + // Act and assert. + Assert.Throws(() => RedisVectorStoreCollectionCreateMapping.GetSDKDistanceAlgorithm(vectorProperty)); + } + + private static void VerifyFieldName(FieldName fieldName, List expected) + { + var args = new List(); + fieldName.AddCommandArguments(args); + Assert.Equal(expected, args); + } +} diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreTests.cs new file mode 100644 index 000000000000..28f8f6cc5bcb --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreTests.cs @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Data; +using Moq; +using StackExchange.Redis; +using Xunit; + +namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; + +/// +/// Contains tests for the class. +/// +public class RedisVectorStoreTests +{ + private const string TestCollectionName = "testcollection"; + + private readonly Mock _redisDatabaseMock; + + public RedisVectorStoreTests() + { + this._redisDatabaseMock = new Mock(MockBehavior.Strict); + + var batchMock = new Mock(); + this._redisDatabaseMock.Setup(x => x.CreateBatch(It.IsAny())).Returns(batchMock.Object); + } + + [Fact] + public void GetCollectionReturnsJsonCollection() + { + // Arrange. + var sut = new RedisVectorStore(this._redisDatabaseMock.Object); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.NotNull(actual); + Assert.IsType>>(actual); + } + + [Fact] + public void GetCollectionReturnsHashSetCollection() + { + // Arrange. + var sut = new RedisVectorStore(this._redisDatabaseMock.Object, new() { StorageType = RedisStorageType.HashSet }); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.NotNull(actual); + Assert.IsType>>(actual); + } + + [Fact] + public void GetCollectionCallsFactoryIfProvided() + { + // Arrange. + var factoryMock = new Mock(MockBehavior.Strict); + var collectionMock = new Mock>>(MockBehavior.Strict); + factoryMock + .Setup(x => x.CreateVectorStoreRecordCollection>(It.IsAny(), TestCollectionName, null)) + .Returns(collectionMock.Object); + var sut = new RedisVectorStore(this._redisDatabaseMock.Object, new() { VectorStoreCollectionFactory = factoryMock.Object }); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.Equal(collectionMock.Object, actual); + factoryMock.Verify(x => x.CreateVectorStoreRecordCollection>(It.IsAny(), TestCollectionName, null), Times.Once); + } + + [Fact] + public void GetCollectionThrowsForInvalidKeyType() + { + // Arrange. + var sut = new RedisVectorStore(this._redisDatabaseMock.Object); + + // Act & Assert. + Assert.Throws(() => sut.GetCollection>(TestCollectionName)); + } + + [Fact] + public async Task ListCollectionNamesCallsSDKAsync() + { + // Arrange. + var redisResultStrings = new string[] { "collection1", "collection2" }; + var results = redisResultStrings + .Select(x => RedisResult.Create(new RedisValue(x))) + .ToArray(); + this._redisDatabaseMock + .Setup( + x => x.ExecuteAsync( + It.IsAny(), + It.IsAny())) + .ReturnsAsync(RedisResult.Create(results)); + var sut = new RedisVectorStore(this._redisDatabaseMock.Object); + + // Act. + var collectionNames = sut.ListCollectionNamesAsync(); + + // Assert. + var collectionNamesList = await collectionNames.ToListAsync(); + Assert.Equal(new[] { "collection1", "collection2" }, collectionNamesList); + } + + public sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public required TKey Key { get; set; } + + [VectorStoreRecordData] + public string Data { get; set; } = string.Empty; + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? Vector { get; set; } + + public string? NotAnnotated { get; set; } + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Pinecone/PineconeKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Pinecone/PineconeKernelBuilderExtensionsTests.cs new file mode 100644 index 000000000000..67cd1588e0dd --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Pinecone/PineconeKernelBuilderExtensionsTests.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Pinecone; +using Microsoft.SemanticKernel.Data; +using Xunit; +using Sdk = Pinecone; + +namespace SemanticKernel.Connectors.UnitTests.Pinecone; + +/// +/// Tests for the class. +/// +public class PineconeKernelBuilderExtensionsTests +{ + private readonly IKernelBuilder _kernelBuilder; + + public PineconeKernelBuilderExtensionsTests() + { + this._kernelBuilder = Kernel.CreateBuilder(); + } + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange. + using var client = new Sdk.PineconeClient("fake api key"); + this._kernelBuilder.Services.AddSingleton(client); + + // Act. + this._kernelBuilder.AddPineconeVectorStore(); + + // Assert. + this.AssertVectorStoreCreated(); + } + + [Fact] + public void AddVectorStoreWithApiKeyRegistersClass() + { + // Act. + this._kernelBuilder.AddPineconeVectorStore("fake api key"); + + // Assert. + this.AssertVectorStoreCreated(); + } + + private void AssertVectorStoreCreated() + { + var kernel = this._kernelBuilder.Build(); + var vectorStore = kernel.Services.GetRequiredService(); + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Pinecone/PineconeServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Pinecone/PineconeServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..3894e3b65dc5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Pinecone/PineconeServiceCollectionExtensionsTests.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Connectors.Pinecone; +using Microsoft.SemanticKernel.Data; +using Xunit; +using Sdk = Pinecone; + +namespace SemanticKernel.Connectors.UnitTests.Pinecone; + +/// +/// Tests for the class. +/// +public class PineconeServiceCollectionExtensionsTests +{ + private readonly IServiceCollection _serviceCollection; + + public PineconeServiceCollectionExtensionsTests() + { + this._serviceCollection = new ServiceCollection(); + } + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange. + using var client = new Sdk.PineconeClient("fake api key"); + this._serviceCollection.AddSingleton(client); + + // Act. + this._serviceCollection.AddPineconeVectorStore(); + + // Assert. + this.AssertVectorStoreCreated(); + } + + [Fact] + public void AddVectorStoreWithApiKeyRegistersClass() + { + // Act. + this._serviceCollection.AddPineconeVectorStore("fake api key"); + + // Assert. + this.AssertVectorStoreCreated(); + } + + private void AssertVectorStoreCreated() + { + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + var vectorStore = serviceProvider.GetRequiredService(); + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..d8e10c71491d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.SemanticKernel.Connectors.Pinecone; +using Microsoft.SemanticKernel.Data; +using Moq; +using Xunit; +using Sdk = Pinecone; + +namespace SemanticKernel.Connectors.UnitTests.Pinecone; + +/// +/// Contains tests for the class. +/// +public class PineconeVectorStoreRecordCollectionTests +{ + private const string TestCollectionName = "testcollection"; + + /// + /// Tests that the collection can be created even if the definition and the type do not match. + /// In this case, the expectation is that a custom mapper will be provided to map between the + /// schema as defined by the definition and the different data model. + /// + [Fact] + public void CanCreateCollectionWithMismatchedDefinitionAndType() + { + // Arrange. + var definition = new VectorStoreRecordDefinition() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Id", typeof(string)), + new VectorStoreRecordDataProperty("Text", typeof(string)), + new VectorStoreRecordVectorProperty("Embedding", typeof(ReadOnlyMemory)) { Dimensions = 4 }, + } + }; + using var pineconeClient = new Sdk.PineconeClient("fake api key"); + + // Act. + var sut = new PineconeVectorStoreRecordCollection( + pineconeClient, + TestCollectionName, + new() { VectorStoreRecordDefinition = definition, VectorCustomMapper = Mock.Of>() }); + } + + public sealed class SinglePropsModel + { + public string Key { get; set; } = string.Empty; + + public string OriginalNameData { get; set; } = string.Empty; + + public string Data { get; set; } = string.Empty; + + public ReadOnlyMemory? Vector { get; set; } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreCollectionFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreCollectionFixture.cs new file mode 100644 index 000000000000..6c9870cf0327 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreCollectionFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureAISearch; + +[CollectionDefinition("AzureAISearchVectorStoreCollection")] +public class AzureAISearchVectorStoreCollectionFixture : ICollectionFixture +{ +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreFixture.cs new file mode 100644 index 000000000000..19158ce56e4f --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreFixture.cs @@ -0,0 +1,245 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json.Serialization; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using Azure; +using Azure.Search.Documents; +using Azure.Search.Documents.Indexes; +using Azure.Search.Documents.Indexes.Models; +using Azure.Search.Documents.Models; +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel.Data; +using SemanticKernel.IntegrationTests.TestSettings.Memory; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureAISearch; + +/// +/// Helper class for setting up and tearing down Azure AI Search indexes for testing purposes. +/// +public class AzureAISearchVectorStoreFixture : IAsyncLifetime +{ + /// + /// Test index name which consists out of "hotels-" and the machine name with any non-alphanumeric characters removed. + /// +#pragma warning disable CA1308 // Normalize strings to uppercase + private readonly string _testIndexName = "hotels-" + new Regex("[^a-zA-Z0-9]").Replace(Environment.MachineName.ToLowerInvariant(), ""); +#pragma warning restore CA1308 // Normalize strings to uppercase + + /// + /// Test Configuration setup. + /// + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + /// + /// Initializes a new instance of the class. + /// + public AzureAISearchVectorStoreFixture() + { + var config = this._configuration.GetRequiredSection("AzureAISearch").Get(); + Assert.NotNull(config); + this.Config = config; + this.SearchIndexClient = new SearchIndexClient(new Uri(config.ServiceUrl), new AzureKeyCredential(config.ApiKey)); + this.VectorStoreRecordDefinition = new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("HotelId", typeof(string)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4 }, + new VectorStoreRecordDataProperty("Tags", typeof(string[])) { IsFilterable = true }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool?)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("LastRenovationDate", typeof(DateTimeOffset?)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Rating", typeof(float?)) + } + }; + } + + /// + /// Gets the Search Index Client to use for connecting to the Azure AI Search service. + /// + public SearchIndexClient SearchIndexClient { get; private set; } + + /// + /// Gets the name of the index that this fixture sets up and tears down. + /// + public string TestIndexName { get => this._testIndexName; } + + /// + /// Gets the manually created vector store record definition for our test model. + /// + public VectorStoreRecordDefinition VectorStoreRecordDefinition { get; private set; } + + /// + /// Gets the configuration for the Azure AI Search service. + /// + public AzureAISearchConfiguration Config { get; private set; } + + /// + /// Create / Recreate index and upload documents before test run. + /// + /// An async task. + public async Task InitializeAsync() + { + await AzureAISearchVectorStoreFixture.DeleteIndexIfExistsAsync(this._testIndexName, this.SearchIndexClient); + await AzureAISearchVectorStoreFixture.CreateIndexAsync(this._testIndexName, this.SearchIndexClient); + AzureAISearchVectorStoreFixture.UploadDocuments(this.SearchIndexClient.GetSearchClient(this._testIndexName)); + } + + /// + /// Delete the index after the test run. + /// + /// An async task. + public async Task DisposeAsync() + { + await AzureAISearchVectorStoreFixture.DeleteIndexIfExistsAsync(this._testIndexName, this.SearchIndexClient); + } + + /// + /// Delete the index if it exists. + /// + /// The name of the index to delete. + /// The search index client to use for deleting the index. + /// An async task. + public static async Task DeleteIndexIfExistsAsync(string indexName, SearchIndexClient adminClient) + { + adminClient.GetIndexNames(); + { + await adminClient.DeleteIndexAsync(indexName); + } + } + + /// + /// Create an index with the given name. + /// + /// The name of the index to create. + /// The search index client to use for creating the index. + /// An async task. + public static async Task CreateIndexAsync(string indexName, SearchIndexClient adminClient) + { + FieldBuilder fieldBuilder = new(); + var searchFields = fieldBuilder.Build(typeof(Hotel)); + var embeddingfield = searchFields.First(x => x.Name == "DescriptionEmbedding"); + searchFields.Remove(embeddingfield); + searchFields.Add(new VectorSearchField("DescriptionEmbedding", 4, "my-vector-profile")); + + var definition = new SearchIndex(indexName, searchFields); + definition.VectorSearch = new VectorSearch(); + definition.VectorSearch.Algorithms.Add(new HnswAlgorithmConfiguration("my-hnsw-vector-config-1") { Parameters = new HnswParameters { Metric = VectorSearchAlgorithmMetric.Cosine } }); + definition.VectorSearch.Profiles.Add(new VectorSearchProfile("my-vector-profile", "my-hnsw-vector-config-1")); + + var suggester = new SearchSuggester("sg", new[] { "HotelName" }); + definition.Suggesters.Add(suggester); + + await adminClient.CreateOrUpdateIndexAsync(definition); + } + + /// + /// Upload test documents to the index. + /// + /// The client to use for uploading the documents. + public static void UploadDocuments(SearchClient searchClient) + { + IndexDocumentsBatch batch = IndexDocumentsBatch.Create( + IndexDocumentsAction.Upload( + new Hotel() + { + HotelId = "BaseSet-1", + HotelName = "Hotel 1", + Description = "This is a great hotel", + DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f }, + Tags = new[] { "pool", "air conditioning", "concierge" }, + ParkingIncluded = false, + LastRenovationDate = new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), + Rating = 3.6 + }), + IndexDocumentsAction.Upload( + new Hotel() + { + HotelId = "BaseSet-2", + HotelName = "Hotel 2", + Description = "This is a great hotel", + DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f }, + Tags = new[] { "pool", "free wifi", "concierge" }, + ParkingIncluded = false, + LastRenovationDate = new DateTimeOffset(1979, 2, 18, 0, 0, 0, TimeSpan.Zero), + Rating = 3.60 + }), + IndexDocumentsAction.Upload( + new Hotel() + { + HotelId = "BaseSet-3", + HotelName = "Hotel 3", + Description = "This is a great hotel", + DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f }, + Tags = new[] { "air conditioning", "bar", "continental breakfast" }, + ParkingIncluded = true, + LastRenovationDate = new DateTimeOffset(2015, 9, 20, 0, 0, 0, TimeSpan.Zero), + Rating = 4.80 + }), + IndexDocumentsAction.Upload( + new Hotel() + { + HotelId = "BaseSet-4", + HotelName = "Hotel 4", + Description = "This is a great hotel", + DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f }, + Tags = new[] { "concierge", "view", "24-hour front desk service" }, + ParkingIncluded = true, + LastRenovationDate = new DateTimeOffset(1960, 2, 06, 0, 0, 0, TimeSpan.Zero), + Rating = 4.60 + }) + ); + + searchClient.IndexDocuments(batch); + } + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + public class Hotel + { + [SimpleField(IsKey = true, IsFilterable = true)] + [VectorStoreRecordKey] + public string HotelId { get; set; } + + [SearchableField(IsSortable = true)] + [VectorStoreRecordData(IsFilterable = true, IsFullTextSearchable = true)] + public string HotelName { get; set; } + + [SearchableField(AnalyzerName = LexicalAnalyzerName.Values.EnLucene)] + [VectorStoreRecordData] + public string Description { get; set; } + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } + + [SearchableField(IsFilterable = true, IsFacetable = true)] + [VectorStoreRecordData(IsFilterable = true)] +#pragma warning disable CA1819 // Properties should not return arrays + public string[] Tags { get; set; } +#pragma warning restore CA1819 // Properties should not return arrays + + [JsonPropertyName("parking_is_included")] + [SimpleField(IsFilterable = true, IsSortable = true, IsFacetable = true)] + [VectorStoreRecordData(IsFilterable = true)] + public bool? ParkingIncluded { get; set; } + + [SimpleField(IsFilterable = true, IsSortable = true, IsFacetable = true)] + [VectorStoreRecordData(IsFilterable = true)] + public DateTimeOffset? LastRenovationDate { get; set; } + + [SimpleField(IsFilterable = true, IsSortable = true, IsFacetable = true)] + [VectorStoreRecordData] + public double? Rating { get; set; } + } +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..7f810dc87fbd --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs @@ -0,0 +1,335 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Text.Json.Nodes; +using System.Threading.Tasks; +using Azure; +using Azure.Search.Documents.Indexes; +using Microsoft.SemanticKernel.Connectors.AzureAISearch; +using Microsoft.SemanticKernel.Data; +using Xunit; +using Xunit.Abstractions; +using static SemanticKernel.IntegrationTests.Connectors.Memory.AzureAISearch.AzureAISearchVectorStoreFixture; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureAISearch; + +/// +/// Integration tests for class. +/// Tests work with an Azure AI Search Instance. +/// +[Collection("AzureAISearchVectorStoreCollection")] +public sealed class AzureAISearchVectorStoreRecordCollectionTests(ITestOutputHelper output, AzureAISearchVectorStoreFixture fixture) +{ + // If null, all tests will be enabled + private const string SkipReason = "Requires Azure AI Search Service instance up and running"; + + [Theory(Skip = SkipReason)] + [InlineData(true)] + [InlineData(false)] + public async Task CollectionExistsReturnsCollectionStateAsync(bool expectedExists) + { + // Arrange. + var collectionName = expectedExists ? fixture.TestIndexName : "nonexistentcollection"; + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, collectionName); + + // Act. + var actual = await sut.CollectionExistsAsync(); + + // Assert. + Assert.Equal(expectedExists, actual); + } + + [Theory(Skip = SkipReason)] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanCreateACollectionUpsertAndGetAsync(bool useRecordDefinition) + { + // Arrange + var hotel = CreateTestHotel("Upsert-1"); + var testCollectionName = $"{fixture.TestIndexName}-createtest"; + var options = new AzureAISearchVectorStoreRecordCollectionOptions + { + VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null + }; + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, testCollectionName, options); + + await sut.DeleteCollectionAsync(); + + // Act + await sut.CreateCollectionAsync(); + var upsertResult = await sut.UpsertAsync(hotel); + var getResult = await sut.GetAsync("Upsert-1"); + + // Assert + var collectionExistResult = await sut.CollectionExistsAsync(); + Assert.True(collectionExistResult); + await sut.DeleteCollectionAsync(); + + Assert.NotNull(upsertResult); + Assert.Equal("Upsert-1", upsertResult); + + Assert.NotNull(getResult); + Assert.Equal(hotel.HotelName, getResult.HotelName); + Assert.Equal(hotel.Description, getResult.Description); + Assert.NotNull(getResult.DescriptionEmbedding); + Assert.Equal(hotel.DescriptionEmbedding?.ToArray(), getResult.DescriptionEmbedding?.ToArray()); + Assert.Equal(hotel.Tags, getResult.Tags); + Assert.Equal(hotel.ParkingIncluded, getResult.ParkingIncluded); + Assert.Equal(hotel.LastRenovationDate, getResult.LastRenovationDate); + Assert.Equal(hotel.Rating, getResult.Rating); + + // Output + output.WriteLine(collectionExistResult.ToString()); + output.WriteLine(upsertResult); + output.WriteLine(getResult.ToString()); + } + + [Fact(Skip = SkipReason)] + public async Task ItCanDeleteCollectionAsync() + { + // Arrange + var tempCollectionName = fixture.TestIndexName + "-delete"; + await AzureAISearchVectorStoreFixture.CreateIndexAsync(tempCollectionName, fixture.SearchIndexClient); + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, tempCollectionName); + + // Act + await sut.DeleteCollectionAsync(); + + // Assert + Assert.False(await sut.CollectionExistsAsync()); + } + + [Theory(Skip = SkipReason)] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanUpsertDocumentToVectorStoreAsync(bool useRecordDefinition) + { + // Arrange + var options = new AzureAISearchVectorStoreRecordCollectionOptions + { + VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null + }; + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName, options); + + // Act + var hotel = CreateTestHotel("Upsert-1"); + var upsertResult = await sut.UpsertAsync(hotel); + var getResult = await sut.GetAsync("Upsert-1"); + + // Assert + Assert.NotNull(upsertResult); + Assert.Equal("Upsert-1", upsertResult); + + Assert.NotNull(getResult); + Assert.Equal(hotel.HotelName, getResult.HotelName); + Assert.Equal(hotel.Description, getResult.Description); + Assert.NotNull(getResult.DescriptionEmbedding); + Assert.Equal(hotel.DescriptionEmbedding?.ToArray(), getResult.DescriptionEmbedding?.ToArray()); + Assert.Equal(hotel.Tags, getResult.Tags); + Assert.Equal(hotel.ParkingIncluded, getResult.ParkingIncluded); + Assert.Equal(hotel.LastRenovationDate, getResult.LastRenovationDate); + Assert.Equal(hotel.Rating, getResult.Rating); + + // Output + output.WriteLine(upsertResult); + output.WriteLine(getResult.ToString()); + } + + [Fact(Skip = SkipReason)] + public async Task ItCanUpsertManyDocumentsToVectorStoreAsync() + { + // Arrange + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + + // Act + var results = sut.UpsertBatchAsync( + [ + CreateTestHotel("UpsertMany-1"), + CreateTestHotel("UpsertMany-2"), + CreateTestHotel("UpsertMany-3"), + ]); + + // Assert + Assert.NotNull(results); + var resultsList = await results.ToListAsync(); + + Assert.Equal(3, resultsList.Count); + Assert.Contains("UpsertMany-1", resultsList); + Assert.Contains("UpsertMany-2", resultsList); + Assert.Contains("UpsertMany-3", resultsList); + + // Output + foreach (var result in resultsList) + { + output.WriteLine(result); + } + } + + [Theory(Skip = SkipReason)] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task ItCanGetDocumentFromVectorStoreAsync(bool includeVectors, bool useRecordDefinition) + { + // Arrange + var options = new AzureAISearchVectorStoreRecordCollectionOptions + { + VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null + }; + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName, options); + + // Act + var getResult = await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = includeVectors }); + + // Assert + Assert.NotNull(getResult); + + Assert.Equal("Hotel 1", getResult.HotelName); + Assert.Equal("This is a great hotel", getResult.Description); + Assert.Equal(includeVectors, getResult.DescriptionEmbedding != null); + if (includeVectors) + { + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, getResult.DescriptionEmbedding!.Value.ToArray()); + } + Assert.Equal(new[] { "pool", "air conditioning", "concierge" }, getResult.Tags); + Assert.False(getResult.ParkingIncluded); + Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), getResult.LastRenovationDate); + Assert.Equal(3.6, getResult.Rating); + + // Output + output.WriteLine(getResult.ToString()); + } + + [Fact(Skip = SkipReason)] + public async Task ItCanGetManyDocumentsFromVectorStoreAsync() + { + // Arrange + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + + // Act + // Also include one non-existing key to test that the operation does not fail for these and returns only the found ones. + var hotels = sut.GetBatchAsync(["BaseSet-1", "BaseSet-2", "BaseSet-3", "BaseSet-5", "BaseSet-4"], new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(hotels); + var hotelsList = await hotels.ToListAsync(); + Assert.Equal(4, hotelsList.Count); + + // Output + foreach (var hotel in hotelsList) + { + output.WriteLine(hotel.ToString()); + } + } + + [Theory(Skip = SkipReason)] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanRemoveDocumentFromVectorStoreAsync(bool useRecordDefinition) + { + // Arrange + var options = new AzureAISearchVectorStoreRecordCollectionOptions + { + VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null + }; + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + await sut.UpsertAsync(CreateTestHotel("Remove-1")); + + // Act + await sut.DeleteAsync("Remove-1"); + // Also delete a non-existing key to test that the operation does not fail for these. + await sut.DeleteAsync("Remove-2"); + + // Assert + Assert.Null(await sut.GetAsync("Remove-1", new GetRecordOptions { IncludeVectors = true })); + } + + [Fact(Skip = SkipReason)] + public async Task ItCanRemoveManyDocumentsFromVectorStoreAsync() + { + // Arrange + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + await sut.UpsertAsync(CreateTestHotel("RemoveMany-1")); + await sut.UpsertAsync(CreateTestHotel("RemoveMany-2")); + await sut.UpsertAsync(CreateTestHotel("RemoveMany-3")); + + // Act + // Also include a non-existing key to test that the operation does not fail for these. + await sut.DeleteBatchAsync(["RemoveMany-1", "RemoveMany-2", "RemoveMany-3", "RemoveMany-4"]); + + // Assert + Assert.Null(await sut.GetAsync("RemoveMany-1", new GetRecordOptions { IncludeVectors = true })); + Assert.Null(await sut.GetAsync("RemoveMany-2", new GetRecordOptions { IncludeVectors = true })); + Assert.Null(await sut.GetAsync("RemoveMany-3", new GetRecordOptions { IncludeVectors = true })); + } + + [Fact(Skip = SkipReason)] + public async Task ItReturnsNullWhenGettingNonExistentRecordAsync() + { + // Arrange + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName); + + // Act & Assert + Assert.Null(await sut.GetAsync("BaseSet-5", new GetRecordOptions { IncludeVectors = true })); + } + + [Fact(Skip = SkipReason)] + public async Task ItThrowsOperationExceptionForFailedConnectionAsync() + { + // Arrange + var searchIndexClient = new SearchIndexClient(new Uri("https://localhost:12345"), new AzureKeyCredential("12345")); + var sut = new AzureAISearchVectorStoreRecordCollection(searchIndexClient, fixture.TestIndexName); + + // Act & Assert + await Assert.ThrowsAsync(async () => await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = true })); + } + + [Fact(Skip = SkipReason)] + public async Task ItThrowsOperationExceptionForFailedAuthenticationAsync() + { + // Arrange + var searchIndexClient = new SearchIndexClient(new Uri(fixture.Config.ServiceUrl), new AzureKeyCredential("12345")); + var sut = new AzureAISearchVectorStoreRecordCollection(searchIndexClient, fixture.TestIndexName); + + // Act & Assert + await Assert.ThrowsAsync(async () => await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = true })); + } + + [Fact(Skip = SkipReason)] + public async Task ItThrowsMappingExceptionForFailedMapperAsync() + { + // Arrange + var options = new AzureAISearchVectorStoreRecordCollectionOptions { JsonObjectCustomMapper = new FailingMapper() }; + var sut = new AzureAISearchVectorStoreRecordCollection(fixture.SearchIndexClient, fixture.TestIndexName, options); + + // Act & Assert + await Assert.ThrowsAsync(async () => await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = true })); + } + + private static Hotel CreateTestHotel(string hotelId) => new() + { + HotelId = hotelId, + HotelName = $"MyHotel {hotelId}", + Description = "My Hotel is great.", + DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f }, + Tags = ["pool", "air conditioning", "concierge"], + ParkingIncluded = true, + LastRenovationDate = new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), + Rating = 3.6 + }; + + private sealed class FailingMapper : IVectorStoreRecordMapper + { + public JsonObject MapFromDataToStorageModel(Hotel dataModel) + { + throw new NotImplementedException(); + } + + public Hotel MapFromStorageToDataModel(JsonObject storageModel, StorageToDataModelMapperOptions options) + { + throw new NotImplementedException(); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreTests.cs new file mode 100644 index 000000000000..7bda8cb0fff9 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreTests.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.AzureAISearch; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureAISearch; + +/// +/// Contains integration tests for the class. +/// Tests work with an Azure AI Search Instance. +/// +[Collection("AzureAISearchVectorStoreCollection")] +public class AzureAISearchVectorStoreTests(ITestOutputHelper output, AzureAISearchVectorStoreFixture fixture) +{ + // If null, all tests will be enabled + private const string SkipReason = "Requires Azure AI Search Service instance up and running"; + + [Fact(Skip = SkipReason)] + public async Task ItCanGetAListOfExistingCollectionNamesAsync() + { + // Arrange + var additionalCollectionName = fixture.TestIndexName + "-listnames"; + await AzureAISearchVectorStoreFixture.DeleteIndexIfExistsAsync(additionalCollectionName, fixture.SearchIndexClient); + await AzureAISearchVectorStoreFixture.CreateIndexAsync(additionalCollectionName, fixture.SearchIndexClient); + var sut = new AzureAISearchVectorStore(fixture.SearchIndexClient); + + // Act + var collectionNames = await sut.ListCollectionNamesAsync().ToListAsync(); + + // Assert + Assert.Equal(2, collectionNames.Where(x => x.StartsWith(fixture.TestIndexName, StringComparison.InvariantCultureIgnoreCase)).Count()); + Assert.Contains(fixture.TestIndexName, collectionNames); + Assert.Contains(additionalCollectionName, collectionNames); + + // Output + output.WriteLine(string.Join(",", collectionNames)); + + // Cleanup + await AzureAISearchVectorStoreFixture.DeleteIndexIfExistsAsync(additionalCollectionName, fixture.SearchIndexClient); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeAllTypes.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeAllTypes.cs new file mode 100644 index 000000000000..63216da7046f --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeAllTypes.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.SemanticKernel.Data; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. +public record PineconeAllTypes() +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. +{ + [VectorStoreRecordKey] + public string Id { get; init; } + + [VectorStoreRecordData] + public bool BoolProperty { get; set; } + [VectorStoreRecordData] + public bool? NullableBoolProperty { get; set; } + [VectorStoreRecordData] + public string StringProperty { get; set; } + [VectorStoreRecordData] + public string? NullableStringProperty { get; set; } + [VectorStoreRecordData] + public int IntProperty { get; set; } + [VectorStoreRecordData] + public int? NullableIntProperty { get; set; } + [VectorStoreRecordData] + public long LongProperty { get; set; } + [VectorStoreRecordData] + public long? NullableLongProperty { get; set; } + [VectorStoreRecordData] + public float FloatProperty { get; set; } + [VectorStoreRecordData] + public float? NullableFloatProperty { get; set; } + [VectorStoreRecordData] + public double DoubleProperty { get; set; } + [VectorStoreRecordData] + public double? NullableDoubleProperty { get; set; } + [VectorStoreRecordData] + public decimal DecimalProperty { get; set; } + [VectorStoreRecordData] + public decimal? NullableDecimalProperty { get; set; } + +#pragma warning disable CA1819 // Properties should not return arrays + [VectorStoreRecordData] + public string[] StringArray { get; set; } + [VectorStoreRecordData] + public string[]? NullableStringArray { get; set; } +#pragma warning restore CA1819 // Properties should not return arrays + + [VectorStoreRecordData] + public List StringList { get; set; } + [VectorStoreRecordData] + public List? NullableStringList { get; set; } + + [VectorStoreRecordData] + public IReadOnlyCollection Collection { get; set; } + [VectorStoreRecordData] + public IEnumerable Enumerable { get; set; } + + [VectorStoreRecordVector(Dimensions: 8, IndexKind: null, DistanceFunction: DistanceFunction.DotProductSimilarity)] + public ReadOnlyMemory? Embedding { get; set; } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeHotel.cs new file mode 100644 index 000000000000..c648b10f2c62 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeHotel.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text.Json.Serialization; +using Microsoft.SemanticKernel.Data; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. +public record PineconeHotel() +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. +{ + [VectorStoreRecordKey] + public string HotelId { get; init; } + + [VectorStoreRecordData] + public string HotelName { get; set; } + + [JsonPropertyName("code_of_the_hotel")] + [VectorStoreRecordData] + public int HotelCode { get; set; } + + [VectorStoreRecordData] + public float HotelRating { get; set; } + + [JsonPropertyName("json_parking")] + [VectorStoreRecordData(StoragePropertyName = "parking_is_included")] + public bool ParkingIncluded { get; set; } + + [VectorStoreRecordData] + public List Tags { get; set; } = []; + + [VectorStoreRecordData] + public string Description { get; set; } + + [VectorStoreRecordVector(Dimensions: 8, IndexKind: null, DistanceFunction: DistanceFunction.DotProductSimilarity)] + public ReadOnlyMemory DescriptionEmbedding { get; set; } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeUserSecretsExtensions.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeUserSecretsExtensions.cs new file mode 100644 index 000000000000..1644b7427e99 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeUserSecretsExtensions.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.IO; +using System.Reflection; +using System.Text.Json; +using Microsoft.Extensions.Configuration.UserSecrets; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone; +public static class PineconeUserSecretsExtensions +{ + public const string PineconeApiKeyUserSecretEntry = "PineconeApiKey"; + + public static string ReadPineconeApiKey() + => JsonSerializer.Deserialize>( + File.ReadAllText(PathHelper.GetSecretsPathFromSecretsId( + typeof(PineconeUserSecretsExtensions).Assembly.GetCustomAttribute()! + .UserSecretsId)))![PineconeApiKeyUserSecretEntry].Trim(); + + public static bool ContainsPineconeApiKey() + { + var userSecretsIdAttribute = typeof(PineconeUserSecretsExtensions).Assembly.GetCustomAttribute(); + if (userSecretsIdAttribute == null) + { + return false; + } + + var path = PathHelper.GetSecretsPathFromSecretsId(userSecretsIdAttribute.UserSecretsId); + if (!File.Exists(path)) + { + return false; + } + + return JsonSerializer.Deserialize>( + File.ReadAllText(path))!.ContainsKey(PineconeApiKeyUserSecretEntry); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreFixture.cs new file mode 100644 index 000000000000..28559cb0d19f --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreFixture.cs @@ -0,0 +1,345 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Pinecone; +using Microsoft.SemanticKernel.Data; +using Pinecone.Grpc; +using Xunit; +using Sdk = Pinecone; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone; + +public class PineconeVectorStoreFixture : IAsyncLifetime +{ + private const int MaxAttemptCount = 100; + private const int DelayInterval = 300; + + public string IndexName { get; } = "sk-index" +#pragma warning disable CA1308 // Normalize strings to uppercase + + new Regex("[^a-zA-Z0-9]", RegexOptions.None, matchTimeout: new TimeSpan(0, 0, 10)).Replace(Environment.MachineName.ToLowerInvariant(), ""); +#pragma warning restore CA1308 // Normalize strings to uppercase + + public Sdk.PineconeClient Client { get; private set; } = null!; + public PineconeVectorStore VectorStore { get; private set; } = null!; + public PineconeVectorStoreRecordCollection HotelRecordCollection { get; set; } = null!; + public PineconeVectorStoreRecordCollection AllTypesRecordCollection { get; set; } = null!; + public PineconeVectorStoreRecordCollection HotelRecordCollectionWithCustomNamespace { get; set; } = null!; + public IVectorStoreRecordCollection HotelRecordCollectionFromVectorStore { get; set; } = null!; + + public virtual Sdk.Index Index { get; set; } = null!; + + public virtual async Task InitializeAsync() + { + this.Client = new Sdk.PineconeClient(PineconeUserSecretsExtensions.ReadPineconeApiKey()); + this.VectorStore = new PineconeVectorStore(this.Client); + + var hotelRecordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(PineconeHotel.HotelId), typeof(string)), + new VectorStoreRecordDataProperty(nameof(PineconeHotel.HotelName), typeof(string)), + new VectorStoreRecordDataProperty(nameof(PineconeHotel.HotelCode), typeof(int)), + new VectorStoreRecordDataProperty(nameof(PineconeHotel.ParkingIncluded), typeof(bool)) { StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty(nameof(PineconeHotel.HotelRating), typeof(float)), + new VectorStoreRecordDataProperty(nameof(PineconeHotel.Tags), typeof(List)), + new VectorStoreRecordDataProperty(nameof(PineconeHotel.Description), typeof(string)), + new VectorStoreRecordVectorProperty(nameof(PineconeHotel.DescriptionEmbedding), typeof(ReadOnlyMemory)) { Dimensions = 8, DistanceFunction = DistanceFunction.DotProductSimilarity } + ] + }; + + var allTypesRecordDefinition = new VectorStoreRecordDefinition + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(PineconeAllTypes.Id), typeof(string)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.BoolProperty), typeof(bool)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.NullableBoolProperty), typeof(bool?)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.StringProperty), typeof(string)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.NullableStringProperty), typeof(string)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.IntProperty), typeof(int)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.NullableIntProperty), typeof(int?)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.LongProperty), typeof(long)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.NullableLongProperty), typeof(long?)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.FloatProperty), typeof(float)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.NullableFloatProperty), typeof(float?)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.DoubleProperty), typeof(double)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.NullableDoubleProperty), typeof(double?)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.DecimalProperty), typeof(decimal)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.NullableDecimalProperty), typeof(decimal?)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.StringArray), typeof(string[])), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.NullableStringArray), typeof(string[])), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.StringList), typeof(List)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.NullableStringList), typeof(List)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.Collection), typeof(IReadOnlyCollection)), + new VectorStoreRecordDataProperty(nameof(PineconeAllTypes.Enumerable), typeof(IEnumerable)), + new VectorStoreRecordVectorProperty(nameof(PineconeAllTypes.Embedding), typeof(ReadOnlyMemory?)) { Dimensions = 8, DistanceFunction = DistanceFunction.DotProductSimilarity } + ] + }; + + this.HotelRecordCollection = new PineconeVectorStoreRecordCollection( + this.Client, + this.IndexName, + new PineconeVectorStoreRecordCollectionOptions + { + VectorStoreRecordDefinition = hotelRecordDefinition + }); + + this.AllTypesRecordCollection = new PineconeVectorStoreRecordCollection( + this.Client, + this.IndexName, + new PineconeVectorStoreRecordCollectionOptions + { + VectorStoreRecordDefinition = allTypesRecordDefinition + }); + + this.HotelRecordCollectionWithCustomNamespace = new PineconeVectorStoreRecordCollection( + this.Client, + this.IndexName, + new PineconeVectorStoreRecordCollectionOptions + { + VectorStoreRecordDefinition = hotelRecordDefinition, + IndexNamespace = "my-namespace" + }); + + this.HotelRecordCollectionFromVectorStore = this.VectorStore.GetCollection( + this.IndexName, + hotelRecordDefinition); + + await this.ClearIndexesAsync(); + await this.CreateIndexAndWaitAsync(); + await this.AddSampleDataAsync(); + } + + private async Task CreateIndexAndWaitAsync() + { + var attemptCount = 0; + + await this.HotelRecordCollection.CreateCollectionAsync(); + + do + { + await Task.Delay(DelayInterval); + attemptCount++; + this.Index = await this.Client.GetIndex(this.IndexName); + } while (!this.Index.Status.IsReady && attemptCount <= MaxAttemptCount); + + if (!this.Index.Status.IsReady) + { + throw new InvalidOperationException("'Create index' operation didn't complete in time. Index name: " + this.IndexName); + } + } + + public async Task DisposeAsync() + { + if (this.Client is not null) + { + await this.ClearIndexesAsync(); + this.Client.Dispose(); + } + } + + private async Task AddSampleDataAsync() + { + var fiveSeasons = new PineconeHotel + { + HotelId = "five-seasons", + HotelName = "Five Seasons Hotel", + Description = "Great service any season.", + HotelCode = 7, + HotelRating = 4.5f, + ParkingIncluded = true, + DescriptionEmbedding = new ReadOnlyMemory([7.5f, 71.0f, 71.5f, 72.0f, 72.5f, 73.0f, 73.5f, 74.0f]), + Tags = ["wi-fi", "sauna", "gym", "pool"] + }; + + var vacationInn = new PineconeHotel + { + HotelId = "vacation-inn", + HotelName = "Vacation Inn Hotel", + Description = "On vacation? Stay with us.", + HotelCode = 11, + HotelRating = 4.3f, + ParkingIncluded = true, + DescriptionEmbedding = new ReadOnlyMemory([17.5f, 721.0f, 731.5f, 742.0f, 762.5f, 783.0f, 793.5f, 704.0f]), + Tags = ["wi-fi", "breakfast", "gym"] + }; + + var bestEastern = new PineconeHotel + { + HotelId = "best-eastern", + HotelName = "Best Eastern Hotel", + Description = "Best hotel east of New York.", + HotelCode = 42, + HotelRating = 4.7f, + ParkingIncluded = true, + DescriptionEmbedding = new ReadOnlyMemory([47.5f, 421.0f, 741.5f, 744.0f, 742.5f, 483.0f, 743.5f, 744.0f]), + Tags = ["wi-fi", "breakfast", "gym"] + }; + + var stats = await this.Index.DescribeStats(); + var vectorCountBefore = stats.TotalVectorCount; + + // use both Upsert and BatchUpsert methods and also use record collections created directly and using vector store + await this.HotelRecordCollection.UpsertAsync(fiveSeasons); + vectorCountBefore = await this.VerifyVectorCountModifiedAsync(vectorCountBefore, delta: 1); + + await this.HotelRecordCollectionFromVectorStore.UpsertBatchAsync([vacationInn, bestEastern]).ToListAsync(); + vectorCountBefore = await this.VerifyVectorCountModifiedAsync(vectorCountBefore, delta: 2); + + var allTypes1 = new PineconeAllTypes + { + Id = "all-types-1", + BoolProperty = true, + NullableBoolProperty = false, + StringProperty = "string prop 1", + NullableStringProperty = "nullable prop 1", + IntProperty = 1, + NullableIntProperty = 10, + LongProperty = 100L, + NullableLongProperty = 1000L, + FloatProperty = 10.5f, + NullableFloatProperty = 100.5f, + DoubleProperty = 23.75d, + NullableDoubleProperty = 233.75d, + DecimalProperty = 50.75m, + NullableDecimalProperty = 500.75m, + StringArray = ["one", "two"], + NullableStringArray = ["five", "six"], + StringList = ["eleven", "twelve"], + NullableStringList = ["fifteen", "sixteen"], + Collection = ["Foo", "Bar"], + Enumerable = ["another", "and another"], + Embedding = new ReadOnlyMemory([1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f]) + }; + + var allTypes2 = new PineconeAllTypes + { + Id = "all-types-2", + BoolProperty = false, + NullableBoolProperty = null, + StringProperty = "string prop 2", + NullableStringProperty = null, + IntProperty = 2, + NullableIntProperty = null, + LongProperty = 200L, + NullableLongProperty = null, + FloatProperty = 20.5f, + NullableFloatProperty = null, + DoubleProperty = 43.75, + NullableDoubleProperty = null, + DecimalProperty = 250.75M, + NullableDecimalProperty = null, + StringArray = [], + NullableStringArray = null, + StringList = [], + NullableStringList = null, + Collection = [], + Enumerable = [], + Embedding = new ReadOnlyMemory([10.5f, 20.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 80.5f]) + }; + + await this.AllTypesRecordCollection.UpsertBatchAsync([allTypes1, allTypes2]).ToListAsync(); + vectorCountBefore = await this.VerifyVectorCountModifiedAsync(vectorCountBefore, delta: 2); + + var custom = new PineconeHotel + { + HotelId = "custom-hotel", + HotelName = "Custom Hotel", + Description = "Everything customizable!", + HotelCode = 17, + HotelRating = 4.25f, + ParkingIncluded = true, + DescriptionEmbedding = new ReadOnlyMemory([147.5f, 1421.0f, 1741.5f, 1744.0f, 1742.5f, 1483.0f, 1743.5f, 1744.0f]), + }; + + await this.HotelRecordCollectionWithCustomNamespace.UpsertAsync(custom); + vectorCountBefore = await this.VerifyVectorCountModifiedAsync(vectorCountBefore, delta: 1); + } + + public async Task VerifyVectorCountModifiedAsync(uint vectorCountBefore, int delta) + { + var attemptCount = 0; + Sdk.IndexStats stats; + + do + { + await Task.Delay(DelayInterval); + attemptCount++; + stats = await this.Index.DescribeStats(); + } while (stats.TotalVectorCount != vectorCountBefore + delta && attemptCount <= MaxAttemptCount); + + if (stats.TotalVectorCount != vectorCountBefore + delta) + { + throw new InvalidOperationException("'Upsert'/'Delete' operation didn't complete in time."); + } + + return stats.TotalVectorCount; + } + + public async Task DeleteAndWaitAsync(IEnumerable ids, string? indexNamespace = null) + { + var stats = await this.Index.DescribeStats(); + var vectorCountBefore = stats.Namespaces.Single(x => x.Name == (indexNamespace ?? "")).VectorCount; + var idCount = ids.Count(); + + var attemptCount = 0; + await this.Index.Delete(ids, indexNamespace); + long vectorCount; + do + { + await Task.Delay(DelayInterval); + attemptCount++; + stats = await this.Index.DescribeStats(); + vectorCount = stats.Namespaces.Single(x => x.Name == (indexNamespace ?? "")).VectorCount; + } while (vectorCount > vectorCountBefore - idCount && attemptCount <= MaxAttemptCount); + + if (vectorCount > vectorCountBefore - idCount) + { + throw new InvalidOperationException("'Delete' operation didn't complete in time."); + } + } + + private async Task ClearIndexesAsync() + { + var indexes = await this.Client.ListIndexes(); + var deletions = indexes.Select(x => this.DeleteExistingIndexAndWaitAsync(x.Name)); + + await Task.WhenAll(deletions); + } + + private async Task DeleteExistingIndexAndWaitAsync(string indexName) + { + var exists = true; + try + { + var attemptCount = 0; + await this.Client.DeleteIndex(indexName); + + do + { + await Task.Delay(DelayInterval); + var indexes = (await this.Client.ListIndexes()).Select(x => x.Name).ToArray(); + if (indexes.Length == 0 || !indexes.Contains(indexName)) + { + exists = false; + } + } while (exists && attemptCount <= MaxAttemptCount); + } + catch (HttpRequestException ex) when (ex.Message.Contains("NOT_FOUND")) + { + // index was already deleted + exists = false; + } + + if (exists) + { + throw new InvalidOperationException("'Delete index' operation didn't complete in time. Index name: " + indexName); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..411225101ffc --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs @@ -0,0 +1,564 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; +using Grpc.Core; +using Microsoft.SemanticKernel.Connectors.Pinecone; +using Microsoft.SemanticKernel.Data; +using Pinecone; +using SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone; + +[Collection("PineconeVectorStoreTests")] +[PineconeApiKeySetCondition] +public class PineconeVectorStoreRecordCollectionTests(PineconeVectorStoreFixture fixture) : IClassFixture +{ + private PineconeVectorStoreFixture Fixture { get; } = fixture; + + [PineconeFact] + public async Task TryCreateExistingIndexIsNoopAsync() + { + await this.Fixture.HotelRecordCollection.CreateCollectionIfNotExistsAsync(); + } + + [PineconeFact] + public async Task CollectionExistsReturnsTrueForExistingCollectionAsync() + { + var result = await this.Fixture.HotelRecordCollection.CollectionExistsAsync(); + + Assert.True(result); + } + + [PineconeTheory] + [InlineData(true)] + [InlineData(false)] + public async Task BasicGetAsync(bool includeVectors) + { + var fiveSeasons = await this.Fixture.HotelRecordCollection.GetAsync("five-seasons", new GetRecordOptions { IncludeVectors = includeVectors }); + + Assert.NotNull(fiveSeasons); + Assert.Equal("five-seasons", fiveSeasons.HotelId); + Assert.Equal("Five Seasons Hotel", fiveSeasons.HotelName); + Assert.Equal("Great service any season.", fiveSeasons.Description); + Assert.Equal(7, fiveSeasons.HotelCode); + Assert.Equal(4.5f, fiveSeasons.HotelRating); + Assert.True(fiveSeasons.ParkingIncluded); + Assert.Contains("wi-fi", fiveSeasons.Tags); + Assert.Contains("sauna", fiveSeasons.Tags); + Assert.Contains("gym", fiveSeasons.Tags); + Assert.Contains("pool", fiveSeasons.Tags); + + if (includeVectors) + { + Assert.Equal(new ReadOnlyMemory([7.5f, 71.0f, 71.5f, 72.0f, 72.5f, 73.0f, 73.5f, 74.0f]), fiveSeasons.DescriptionEmbedding); + } + else + { + Assert.Equal(new ReadOnlyMemory([]), fiveSeasons.DescriptionEmbedding); + } + } + + [PineconeTheory] + [InlineData(true)] + [InlineData(false)] + public async Task BatchGetAsync(bool collectionFromVectorStore) + { + var hotelsCollection = collectionFromVectorStore + ? this.Fixture.HotelRecordCollection + : this.Fixture.HotelRecordCollectionFromVectorStore; + + var hotels = await hotelsCollection.GetBatchAsync(["five-seasons", "vacation-inn", "best-eastern"]).ToListAsync(); + + var fiveSeasons = hotels.Single(x => x.HotelId == "five-seasons"); + var vacationInn = hotels.Single(x => x.HotelId == "vacation-inn"); + var bestEastern = hotels.Single(x => x.HotelId == "best-eastern"); + + Assert.Equal("Five Seasons Hotel", fiveSeasons.HotelName); + Assert.Equal("Great service any season.", fiveSeasons.Description); + Assert.Equal(7, fiveSeasons.HotelCode); + Assert.Equal(4.5f, fiveSeasons.HotelRating); + Assert.True(fiveSeasons.ParkingIncluded); + Assert.Contains("wi-fi", fiveSeasons.Tags); + Assert.Contains("sauna", fiveSeasons.Tags); + Assert.Contains("gym", fiveSeasons.Tags); + Assert.Contains("pool", fiveSeasons.Tags); + + Assert.Equal("Vacation Inn Hotel", vacationInn.HotelName); + Assert.Equal("On vacation? Stay with us.", vacationInn.Description); + Assert.Equal(11, vacationInn.HotelCode); + Assert.Equal(4.3f, vacationInn.HotelRating); + Assert.True(vacationInn.ParkingIncluded); + Assert.Contains("wi-fi", vacationInn.Tags); + Assert.Contains("breakfast", vacationInn.Tags); + Assert.Contains("gym", vacationInn.Tags); + + Assert.Equal("Best Eastern Hotel", bestEastern.HotelName); + Assert.Equal("Best hotel east of New York.", bestEastern.Description); + Assert.Equal(42, bestEastern.HotelCode); + Assert.Equal(4.7f, bestEastern.HotelRating); + Assert.True(bestEastern.ParkingIncluded); + Assert.Contains("wi-fi", bestEastern.Tags); + Assert.Contains("breakfast", bestEastern.Tags); + Assert.Contains("gym", bestEastern.Tags); + } + + [PineconeTheory] + [InlineData(true)] + [InlineData(false)] + public async Task AllTypesBatchGetAsync(bool includeVectors) + { + var allTypes = await this.Fixture.AllTypesRecordCollection.GetBatchAsync(["all-types-1", "all-types-2"], new GetRecordOptions { IncludeVectors = includeVectors }).ToListAsync(); + + var allTypes1 = allTypes.Single(x => x.Id == "all-types-1"); + var allTypes2 = allTypes.Single(x => x.Id == "all-types-2"); + + Assert.True(allTypes1.BoolProperty); + Assert.Equal("string prop 1", allTypes1.StringProperty); + Assert.Equal(1, allTypes1.IntProperty); + Assert.Equal(100L, allTypes1.LongProperty); + Assert.Equal(10.5f, allTypes1.FloatProperty); + Assert.Equal(23.75d, allTypes1.DoubleProperty); + Assert.Equal(50.75m, allTypes1.DecimalProperty); + Assert.Contains("one", allTypes1.StringArray); + Assert.Contains("two", allTypes1.StringArray); + Assert.Contains("eleven", allTypes1.StringList); + Assert.Contains("twelve", allTypes1.StringList); + Assert.Contains("Foo", allTypes1.Collection); + Assert.Contains("Bar", allTypes1.Collection); + Assert.Contains("another", allTypes1.Enumerable); + Assert.Contains("and another", allTypes1.Enumerable); + + Assert.False(allTypes2.BoolProperty); + Assert.Equal("string prop 2", allTypes2.StringProperty); + Assert.Equal(2, allTypes2.IntProperty); + Assert.Equal(200L, allTypes2.LongProperty); + Assert.Equal(20.5f, allTypes2.FloatProperty); + Assert.Equal(43.75d, allTypes2.DoubleProperty); + Assert.Equal(250.75m, allTypes2.DecimalProperty); + Assert.Empty(allTypes2.StringArray); + Assert.Empty(allTypes2.StringList); + Assert.Empty(allTypes2.Collection); + Assert.Empty(allTypes2.Enumerable); + + if (includeVectors) + { + Assert.True(allTypes1.Embedding.HasValue); + Assert.Equal(new ReadOnlyMemory([1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f]), allTypes1.Embedding.Value); + + Assert.True(allTypes2.Embedding.HasValue); + Assert.Equal(new ReadOnlyMemory([10.5f, 20.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 80.5f]), allTypes2.Embedding.Value); + } + else + { + Assert.Null(allTypes1.Embedding); + Assert.Null(allTypes2.Embedding); + } + } + + [PineconeFact] + public async Task BatchGetIncludingNonExistingRecordAsync() + { + var hotels = await this.Fixture.HotelRecordCollection.GetBatchAsync(["vacation-inn", "non-existing"]).ToListAsync(); + + Assert.Single(hotels); + var vacationInn = hotels.Single(x => x.HotelId == "vacation-inn"); + + Assert.Equal("Vacation Inn Hotel", vacationInn.HotelName); + Assert.Equal("On vacation? Stay with us.", vacationInn.Description); + Assert.Equal(11, vacationInn.HotelCode); + Assert.Equal(4.3f, vacationInn.HotelRating); + Assert.True(vacationInn.ParkingIncluded); + Assert.Contains("wi-fi", vacationInn.Tags); + Assert.Contains("breakfast", vacationInn.Tags); + Assert.Contains("gym", vacationInn.Tags); + } + + [PineconeFact] + public async Task GetNonExistingRecordAsync() + { + var result = await this.Fixture.HotelRecordCollection.GetAsync("non-existing"); + Assert.Null(result); + } + + [PineconeTheory] + [InlineData(true)] + [InlineData(false)] + public async Task GetFromCustomNamespaceAsync(bool includeVectors) + { + var custom = await this.Fixture.HotelRecordCollectionWithCustomNamespace.GetAsync("custom-hotel", new GetRecordOptions { IncludeVectors = includeVectors }); + + Assert.NotNull(custom); + Assert.Equal("custom-hotel", custom.HotelId); + Assert.Equal("Custom Hotel", custom.HotelName); + if (includeVectors) + { + Assert.Equal(new ReadOnlyMemory([147.5f, 1421.0f, 1741.5f, 1744.0f, 1742.5f, 1483.0f, 1743.5f, 1744.0f]), custom.DescriptionEmbedding); + } + else + { + Assert.Equal(new ReadOnlyMemory([]), custom.DescriptionEmbedding); + } + } + + [PineconeFact] + public async Task TryGetVectorLocatedInDefaultNamespaceButLookInCustomNamespaceAsync() + { + var badFiveSeasons = await this.Fixture.HotelRecordCollectionWithCustomNamespace.GetAsync("five-seasons"); + + Assert.Null(badFiveSeasons); + } + + [PineconeFact] + public async Task TryGetVectorLocatedInCustomNamespaceButLookInDefaultNamespaceAsync() + { + var badCustomHotel = await this.Fixture.HotelRecordCollection.GetAsync("custom-hotel"); + + Assert.Null(badCustomHotel); + } + + [PineconeFact] + public async Task DeleteNonExistingRecordAsync() + { + await this.Fixture.HotelRecordCollection.DeleteAsync("non-existing"); + } + + [PineconeFact] + public async Task TryDeleteExistingVectorLocatedInDefaultNamespaceButUseCustomNamespaceDoesNotDoAnythingAsync() + { + await this.Fixture.HotelRecordCollectionWithCustomNamespace.DeleteAsync("five-seasons"); + + var stillThere = await this.Fixture.HotelRecordCollection.GetAsync("five-seasons"); + Assert.NotNull(stillThere); + Assert.Equal("five-seasons", stillThere.HotelId); + } + + [PineconeFact] + public async Task TryDeleteExistingVectorLocatedInCustomNamespaceButUseDefaultNamespaceDoesNotDoAnythingAsync() + { + await this.Fixture.HotelRecordCollection.DeleteAsync("custom-hotel"); + + var stillThere = await this.Fixture.HotelRecordCollectionWithCustomNamespace.GetAsync("custom-hotel"); + Assert.NotNull(stillThere); + Assert.Equal("custom-hotel", stillThere.HotelId); + } + + [PineconeTheory] + [InlineData(true)] + [InlineData(false)] + public async Task InsertGetModifyDeleteVectorAsync(bool collectionFromVectorStore) + { + var langriSha = new PineconeHotel + { + HotelId = "langri-sha", + HotelName = "Langri-Sha Hotel", + Description = "Lorem ipsum", + HotelCode = 100, + HotelRating = 4.2f, + ParkingIncluded = false, + DescriptionEmbedding = new ReadOnlyMemory([1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f]) + }; + + var stats = await this.Fixture.Index.DescribeStats(); + var vectorCountBefore = stats.TotalVectorCount; + + var hotelRecordCollection = collectionFromVectorStore + ? this.Fixture.HotelRecordCollectionFromVectorStore + : this.Fixture.HotelRecordCollection; + + // insert + await hotelRecordCollection.UpsertAsync(langriSha); + + vectorCountBefore = await this.Fixture.VerifyVectorCountModifiedAsync(vectorCountBefore, delta: 1); + + var inserted = await hotelRecordCollection.GetAsync("langri-sha", new GetRecordOptions { IncludeVectors = true }); + + Assert.NotNull(inserted); + Assert.Equal(langriSha.HotelName, inserted.HotelName); + Assert.Equal(langriSha.Description, inserted.Description); + Assert.Equal(langriSha.HotelCode, inserted.HotelCode); + Assert.Equal(langriSha.HotelRating, inserted.HotelRating); + Assert.Equal(langriSha.ParkingIncluded, inserted.ParkingIncluded); + Assert.Equal(langriSha.DescriptionEmbedding, inserted.DescriptionEmbedding); + + langriSha.Description += " dolor sit amet"; + langriSha.ParkingIncluded = true; + langriSha.DescriptionEmbedding = new ReadOnlyMemory([11f, 12f, 13f, 14f, 15f, 16f, 17f, 18f]); + + // update + await hotelRecordCollection.UpsertAsync(langriSha); + + // this is not great but no vectors are added so we can't query status for number of vectors like we do for insert/delete + await Task.Delay(2000); + + var updated = await hotelRecordCollection.GetAsync("langri-sha", new GetRecordOptions { IncludeVectors = true }); + + Assert.NotNull(updated); + Assert.Equal(langriSha.HotelName, updated.HotelName); + Assert.Equal(langriSha.Description, updated.Description); + Assert.Equal(langriSha.HotelCode, updated.HotelCode); + Assert.Equal(langriSha.HotelRating, updated.HotelRating); + Assert.Equal(langriSha.ParkingIncluded, updated.ParkingIncluded); + Assert.Equal(langriSha.DescriptionEmbedding, updated.DescriptionEmbedding); + + // delete + await hotelRecordCollection.DeleteAsync("langri-sha"); + + await this.Fixture.VerifyVectorCountModifiedAsync(vectorCountBefore, delta: -1); + } + + [PineconeFact] + public async Task UseCollectionExistsOnNonExistingStoreReturnsFalseAsync() + { + var incorrectRecordStore = new PineconeVectorStoreRecordCollection( + this.Fixture.Client, + "incorrect"); + + var result = await incorrectRecordStore.CollectionExistsAsync(); + + Assert.False(result); + } + + [PineconeFact] + public async Task UseNonExistingIndexThrowsAsync() + { + var incorrectRecordStore = new PineconeVectorStoreRecordCollection( + this.Fixture.Client, + "incorrect"); + + var statusCode = (await Assert.ThrowsAsync( + () => incorrectRecordStore.GetAsync("best-eastern"))).StatusCode; + + Assert.Equal(HttpStatusCode.NotFound, statusCode); + } + + [PineconeFact] + public async Task UseRecordStoreWithCustomMapperAsync() + { + var recordStore = new PineconeVectorStoreRecordCollection( + this.Fixture.Client, + this.Fixture.IndexName, + new PineconeVectorStoreRecordCollectionOptions { VectorCustomMapper = new CustomHotelRecordMapper() }); + + var vacationInn = await recordStore.GetAsync("vacation-inn", new GetRecordOptions { IncludeVectors = true }); + + Assert.NotNull(vacationInn); + Assert.Equal("Custom Vacation Inn Hotel", vacationInn.HotelName); + Assert.Equal("On vacation? Stay with us.", vacationInn.Description); + Assert.Equal(11, vacationInn.HotelCode); + Assert.Equal(4.3f, vacationInn.HotelRating); + Assert.True(vacationInn.ParkingIncluded); + Assert.Contains("wi-fi", vacationInn.Tags); + Assert.Contains("breakfast", vacationInn.Tags); + Assert.Contains("gym", vacationInn.Tags); + } + + private sealed class CustomHotelRecordMapper : IVectorStoreRecordMapper + { + public Vector MapFromDataToStorageModel(PineconeHotel dataModel) + { + var metadata = new MetadataMap + { + [nameof(PineconeHotel.HotelName)] = dataModel.HotelName, + [nameof(PineconeHotel.Description)] = dataModel.Description, + [nameof(PineconeHotel.HotelCode)] = dataModel.HotelCode, + [nameof(PineconeHotel.HotelRating)] = dataModel.HotelRating, + ["parking_is_included"] = dataModel.ParkingIncluded, + [nameof(PineconeHotel.Tags)] = dataModel.Tags.ToArray(), + }; + + return new Vector + { + Id = dataModel.HotelId, + Values = dataModel.DescriptionEmbedding.ToArray(), + Metadata = metadata, + }; + } + + public PineconeHotel MapFromStorageToDataModel(Vector storageModel, StorageToDataModelMapperOptions options) + { + if (storageModel.Metadata == null) + { + throw new InvalidOperationException("Missing metadata."); + } + + return new PineconeHotel + { + HotelId = storageModel.Id, + HotelName = "Custom " + (string)storageModel.Metadata[nameof(PineconeHotel.HotelName)].Inner!, + Description = (string)storageModel.Metadata[nameof(PineconeHotel.Description)].Inner!, + HotelCode = (int)(double)storageModel.Metadata[nameof(PineconeHotel.HotelCode)].Inner!, + HotelRating = (float)(double)storageModel.Metadata[nameof(PineconeHotel.HotelRating)].Inner!, + ParkingIncluded = (bool)storageModel.Metadata["parking_is_included"].Inner!, + Tags = ((MetadataValue[])storageModel.Metadata[nameof(PineconeHotel.Tags)].Inner!)!.Select(x => (string)x.Inner!).ToList(), + }; + } + } + + #region Negative + + [PineconeFact] + public void UseRecordWithNoEmbeddingThrows() + { + var exception = Assert.Throws( + () => new PineconeVectorStoreRecordCollection( + this.Fixture.Client, + "Whatever")); + + Assert.Equal( + $"No vector property found on type {typeof(PineconeRecordNoEmbedding).FullName}.", + exception.Message); + } + +#pragma warning disable CA1812 + private sealed record PineconeRecordNoEmbedding + { + [VectorStoreRecordKey] + public int Id { get; set; } + + [VectorStoreRecordData] + public string? Name { get; set; } + } +#pragma warning restore CA1812 + + [PineconeFact] + public void UseRecordWithMultipleEmbeddingsThrows() + { + var exception = Assert.Throws( + () => new PineconeVectorStoreRecordCollection( + this.Fixture.Client, + "Whatever")); + + Assert.Equal( + $"Multiple vector properties found on type {typeof(PineconeRecordMultipleEmbeddings).FullName} while only one is supported.", + exception.Message); + } + +#pragma warning disable CA1812 + private sealed record PineconeRecordMultipleEmbeddings + { + [VectorStoreRecordKey] + public string Id { get; set; } = null!; + + [VectorStoreRecordVector] + public ReadOnlyMemory Embedding1 { get; set; } + + [VectorStoreRecordVector] + public ReadOnlyMemory Embedding2 { get; set; } + } +#pragma warning restore CA1812 + + [PineconeFact] + public void UseRecordWithUnsupportedKeyTypeThrows() + { + var message = Assert.Throws( + () => new PineconeVectorStoreRecordCollection( + this.Fixture.Client, + "Whatever")).Message; + + Assert.Equal( + $"Key properties must be one of the supported types: {typeof(string).FullName}. Type of the property '{nameof(PineconeRecordUnsupportedKeyType.Id)}' is {typeof(int).FullName}.", + message); + } + +#pragma warning disable CA1812 + private sealed record PineconeRecordUnsupportedKeyType + { + [VectorStoreRecordKey] + public int Id { get; set; } + + [VectorStoreRecordData] + public string? Name { get; set; } + + [VectorStoreRecordVector] + public ReadOnlyMemory Embedding { get; set; } + } +#pragma warning restore CA1812 + + [PineconeFact] + public async Task TryAddingVectorWithUnsupportedValuesAsync() + { + var badAllTypes = new PineconeAllTypes + { + Id = "bad", + BoolProperty = true, + DecimalProperty = 1m, + DoubleProperty = 1.5d, + FloatProperty = 2.5f, + IntProperty = 1, + LongProperty = 11L, + NullableStringArray = ["foo", null!, "bar",], + Embedding = new ReadOnlyMemory([1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f]) + }; + + var exception = await Assert.ThrowsAsync( + () => this.Fixture.AllTypesRecordCollection.UpsertAsync(badAllTypes)); + + Assert.Equal("Microsoft.SemanticKernel.Connectors.Pinecone", exception.Source); + Assert.Equal("Pinecone", exception.VectorStoreType); + Assert.Equal("Upsert", exception.OperationName); + Assert.Equal(this.Fixture.IndexName, exception.CollectionName); + + var inner = exception.InnerException as RpcException; + Assert.NotNull(inner); + Assert.Equal(StatusCode.InvalidArgument, inner.StatusCode); + } + + [PineconeFact] + public async Task TryCreateIndexWithIncorrectDimensionFailsAsync() + { + var recordCollection = new PineconeVectorStoreRecordCollection( + this.Fixture.Client, + "negative-dimension"); + + var message = (await Assert.ThrowsAsync(() => recordCollection.CreateCollectionAsync())).Message; + + Assert.Equal("Property Dimensions on VectorStoreRecordVectorProperty 'Embedding' must be set to a positive integer to create a collection.", message); + } + +#pragma warning disable CA1812 + private sealed record PineconeRecordWithIncorrectDimension + { + [VectorStoreRecordKey] + public string Id { get; set; } = null!; + + [VectorStoreRecordData] + public string? Name { get; set; } + + [VectorStoreRecordVector(Dimensions: -7)] + public ReadOnlyMemory Embedding { get; set; } + } +#pragma warning restore CA1812 + + [PineconeFact] + public async Task TryCreateIndexWithUnsSupportedMetricFailsAsync() + { + var recordCollection = new PineconeVectorStoreRecordCollection( + this.Fixture.Client, + "bad-metric"); + + var message = (await Assert.ThrowsAsync(() => recordCollection.CreateCollectionAsync())).Message; + + Assert.Equal("Distance function 'just eyeball it' for VectorStoreRecordVectorProperty 'Embedding' is not supported by the Pinecone VectorStore.", message); + } + +#pragma warning disable CA1812 + private sealed record PineconeRecordWithUnsupportedMetric + { + [VectorStoreRecordKey] + public string Id { get; set; } = null!; + + [VectorStoreRecordData] + public string? Name { get; set; } + + [VectorStoreRecordVector(Dimensions: 5, IndexKind: null, DistanceFunction: "just eyeball it")] + public ReadOnlyMemory Embedding { get; set; } + } +#pragma warning restore CA1812 + + #endregion +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreTests.cs new file mode 100644 index 000000000000..8aa50e6fa2fa --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreTests.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Pinecone; +using Microsoft.SemanticKernel.Data; +using SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit; +using Xunit; +using Sdk = Pinecone; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone; + +[Collection("PineconeVectorStoreTests")] +[PineconeApiKeySetCondition] +public class PineconeVectorStoreTests(PineconeVectorStoreFixture fixture) : IClassFixture +{ + private PineconeVectorStoreFixture Fixture { get; } = fixture; + + [PineconeFact] + public async Task ListCollectionNamesAsync() + { + var collectionNames = await this.Fixture.VectorStore.ListCollectionNamesAsync().ToListAsync(); + + Assert.Equal([this.Fixture.IndexName], collectionNames); + } + + [PineconeFact] + public void CreateCollectionUsingFactory() + { + var vectorStore = new PineconeVectorStore( + this.Fixture.Client, + new PineconeVectorStoreOptions + { + VectorStoreCollectionFactory = new MyVectorStoreRecordCollectionFactory() + }); + + var factoryCollection = vectorStore.GetCollection(this.Fixture.IndexName); + + Assert.NotNull(factoryCollection); + Assert.Equal("factory" + this.Fixture.IndexName, factoryCollection.CollectionName); + } + + private sealed class MyVectorStoreRecordCollectionFactory : IPineconeVectorStoreRecordCollectionFactory + { + public IVectorStoreRecordCollection CreateVectorStoreRecordCollection( + Sdk.PineconeClient pineconeClient, + string name, + VectorStoreRecordDefinition? vectorStoreRecordDefinition) + where TKey : notnull + where TRecord : class + { + if (typeof(TKey) != typeof(string)) + { + throw new InvalidOperationException("Only string keys are supported."); + } + + return (new PineconeVectorStoreRecordCollection(pineconeClient, "factory" + name) as IVectorStoreRecordCollection)!; + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/ITestCondition.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/ITestCondition.cs new file mode 100644 index 000000000000..361e13d60cd0 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/ITestCondition.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit; + +public interface ITestCondition +{ + ValueTask IsMetAsync(); + + string SkipReason { get; } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeApiKeySetConditionAttribute.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeApiKeySetConditionAttribute.cs new file mode 100644 index 000000000000..ef144699fb7c --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeApiKeySetConditionAttribute.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading.Tasks; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit; + +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class)] +public sealed class PineconeApiKeySetConditionAttribute : Attribute, ITestCondition +{ + public ValueTask IsMetAsync() + { + var isMet = PineconeUserSecretsExtensions.ContainsPineconeApiKey(); + + return ValueTask.FromResult(isMet); + } + + public string SkipReason + => $"Pinecone API key was not specified in user secrets. Use the following command to set it: dotnet user-secrets set \"{PineconeUserSecretsExtensions.PineconeApiKeyUserSecretEntry}\" \"your_Pinecone_API_key\""; +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeFactAttribute.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeFactAttribute.cs new file mode 100644 index 000000000000..d4ebff8869e0 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeFactAttribute.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Xunit; +using Xunit.Sdk; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit; + +[AttributeUsage(AttributeTargets.Method)] +[XunitTestCaseDiscoverer("SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit.PineconeFactDiscoverer", "IntegrationTests")] +public sealed class PineconeFactAttribute : FactAttribute; diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeFactDiscoverer.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeFactDiscoverer.cs new file mode 100644 index 000000000000..c1923ad72a2e --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeFactDiscoverer.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit; + +public class PineconeFactDiscoverer(IMessageSink messageSink) : FactDiscoverer(messageSink) +{ + protected override IXunitTestCase CreateTestCase( + ITestFrameworkDiscoveryOptions discoveryOptions, + ITestMethod testMethod, + IAttributeInfo factAttribute) + => new PineconeFactTestCase( + this.DiagnosticMessageSink, + discoveryOptions.MethodDisplayOrDefault(), + discoveryOptions.MethodDisplayOptionsOrDefault(), + testMethod); +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeFactTestCase.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeFactTestCase.cs new file mode 100644 index 000000000000..4a27031ff45b --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeFactTestCase.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit; + +public sealed class PineconeFactTestCase : XunitTestCase +{ + [Obsolete("Called by the de-serializer; should only be called by deriving classes for de-serialization purposes")] + public PineconeFactTestCase() + { + } + + public PineconeFactTestCase( + IMessageSink diagnosticMessageSink, + TestMethodDisplay defaultMethodDisplay, + TestMethodDisplayOptions defaultMethodDisplayOptions, + ITestMethod testMethod, + object[]? testMethodArguments = null) + : base(diagnosticMessageSink, defaultMethodDisplay, defaultMethodDisplayOptions, testMethod, testMethodArguments) + { + } + + public override async Task RunAsync( + IMessageSink diagnosticMessageSink, + IMessageBus messageBus, + object[] constructorArguments, + ExceptionAggregator aggregator, + CancellationTokenSource cancellationTokenSource) + => await XunitTestCaseExtensions.TrySkipAsync(this, messageBus) + ? new RunSummary { Total = 1, Skipped = 1 } + : await base.RunAsync( + diagnosticMessageSink, + messageBus, + constructorArguments, + aggregator, + cancellationTokenSource); +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeTheoryAttribute.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeTheoryAttribute.cs new file mode 100644 index 000000000000..bff77c952c24 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeTheoryAttribute.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Xunit; +using Xunit.Sdk; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit; + +[AttributeUsage(AttributeTargets.Method)] +[XunitTestCaseDiscoverer("SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit.PineconeTheoryDiscoverer", "IntegrationTests")] +public sealed class PineconeTheoryAttribute : TheoryAttribute; diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeTheoryDiscoverer.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeTheoryDiscoverer.cs new file mode 100644 index 000000000000..79a60afd69b8 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeTheoryDiscoverer.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit; + +public class PineconeTheoryDiscoverer(IMessageSink messageSink) : TheoryDiscoverer(messageSink) +{ + protected override IEnumerable CreateTestCasesForTheory( + ITestFrameworkDiscoveryOptions discoveryOptions, + ITestMethod testMethod, + IAttributeInfo theoryAttribute) + { + yield return new PineconeTheoryTestCase( + this.DiagnosticMessageSink, + discoveryOptions.MethodDisplayOrDefault(), + discoveryOptions.MethodDisplayOptionsOrDefault(), + testMethod); + } + + protected override IEnumerable CreateTestCasesForDataRow( + ITestFrameworkDiscoveryOptions discoveryOptions, + ITestMethod testMethod, + IAttributeInfo theoryAttribute, + object[] dataRow) + { + yield return new PineconeFactTestCase( + this.DiagnosticMessageSink, + discoveryOptions.MethodDisplayOrDefault(), + discoveryOptions.MethodDisplayOptionsOrDefault(), + testMethod, + dataRow); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeTheoryTestCase.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeTheoryTestCase.cs new file mode 100644 index 000000000000..1a9ebff92e1f --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/PineconeTheoryTestCase.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit; + +public sealed class PineconeTheoryTestCase : XunitTheoryTestCase +{ + [Obsolete("Called by the de-serializer; should only be called by deriving classes for de-serialization purposes")] + public PineconeTheoryTestCase() + { + } + + public PineconeTheoryTestCase( + IMessageSink diagnosticMessageSink, + TestMethodDisplay defaultMethodDisplay, + TestMethodDisplayOptions defaultMethodDisplayOptions, + ITestMethod testMethod) + : base(diagnosticMessageSink, defaultMethodDisplay, defaultMethodDisplayOptions, testMethod) + { + } + + public override async Task RunAsync( + IMessageSink diagnosticMessageSink, + IMessageBus messageBus, + object[] constructorArguments, + ExceptionAggregator aggregator, + CancellationTokenSource cancellationTokenSource) + => await XunitTestCaseExtensions.TrySkipAsync(this, messageBus) + ? new RunSummary { Total = 1, Skipped = 1 } + : await base.RunAsync( + diagnosticMessageSink, + messageBus, + constructorArguments, + aggregator, + cancellationTokenSource); +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/XunitTestCaseExtensions.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/XunitTestCaseExtensions.cs new file mode 100644 index 000000000000..75d22e4e5ae9 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/Xunit/XunitTestCaseExtensions.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone.Xunit; + +public static class XunitTestCaseExtensions +{ + private static readonly ConcurrentDictionary> s_typeAttributes = new(); + private static readonly ConcurrentDictionary> s_assemblyAttributes = new(); + + public static async ValueTask TrySkipAsync(XunitTestCase testCase, IMessageBus messageBus) + { + var method = testCase.Method; + var type = testCase.TestMethod.TestClass.Class; + var assembly = type.Assembly; + + var skipReasons = new List(); + var attributes = + s_assemblyAttributes.GetOrAdd( + assembly.Name, + a => assembly.GetCustomAttributes(typeof(ITestCondition)).ToList()) + .Concat( + s_typeAttributes.GetOrAdd( + type.Name, + t => type.GetCustomAttributes(typeof(ITestCondition)).ToList())) + .Concat(method.GetCustomAttributes(typeof(ITestCondition))) + .OfType() + .Select(attributeInfo => (ITestCondition)attributeInfo.Attribute); + + foreach (var attribute in attributes) + { + if (!await attribute.IsMetAsync()) + { + skipReasons.Add(attribute.SkipReason); + } + } + + if (skipReasons.Count > 0) + { + messageBus.QueueMessage( + new TestSkipped(new XunitTest(testCase, testCase.DisplayName), string.Join(Environment.NewLine, skipReasons))); + + return true; + } + + return false; + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreCollectionFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreCollectionFixture.cs new file mode 100644 index 000000000000..a7b565d71c2d --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreCollectionFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Qdrant; + +[CollectionDefinition("QdrantVectorStoreCollection")] +public class QdrantVectorStoreCollectionFixture : ICollectionFixture +{ +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreFixture.cs new file mode 100644 index 000000000000..d1a314829547 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreFixture.cs @@ -0,0 +1,325 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Docker.DotNet; +using Docker.DotNet.Models; +using Grpc.Core; +using Microsoft.SemanticKernel.Data; +using Qdrant.Client; +using Qdrant.Client.Grpc; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Qdrant; + +public class QdrantVectorStoreFixture : IAsyncLifetime +{ + /// The docker client we are using to create a qdrant container with. + private readonly DockerClient _client; + + /// The id of the qdrant container that we are testing with. + private string? _containerId = null; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + + /// + /// Initializes a new instance of the class. + /// + public QdrantVectorStoreFixture() + { + using var dockerClientConfiguration = new DockerClientConfiguration(); + this._client = dockerClientConfiguration.CreateClient(); + this.HotelVectorStoreRecordDefinition = new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("HotelId", typeof(ulong)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, DistanceFunction = DistanceFunction.ManhattanDistance } + } + }; + this.HotelWithGuidIdVectorStoreRecordDefinition = new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("HotelId", typeof(Guid)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, DistanceFunction = DistanceFunction.ManhattanDistance } + } + }; + } + +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + + /// Gets the qdrant client connection to use for tests. + public QdrantClient QdrantClient { get; private set; } + + /// Gets the manually created vector store record definition for our test model. + public VectorStoreRecordDefinition HotelVectorStoreRecordDefinition { get; private set; } + + /// Gets the manually created vector store record definition for our test model. + public VectorStoreRecordDefinition HotelWithGuidIdVectorStoreRecordDefinition { get; private set; } + + /// + /// Create / Recreate qdrant docker container and run it. + /// + /// An async task. + public async Task InitializeAsync() + { + this._containerId = await SetupQdrantContainerAsync(this._client); + + // Connect to qdrant. + this.QdrantClient = new QdrantClient("localhost"); + + // Create schemas for the vector store. + var vectorParamsMap = new VectorParamsMap(); + vectorParamsMap.Map.Add("DescriptionEmbedding", new VectorParams { Size = 4, Distance = Distance.Cosine }); + + // Wait for the qdrant container to be ready. + var retryCount = 0; + while (retryCount++ < 5) + { + try + { + await this.QdrantClient.ListCollectionsAsync(); + } + catch (RpcException e) + { + if (e.StatusCode != Grpc.Core.StatusCode.Unavailable) + { + throw; + } + + await Task.Delay(1000); + } + } + + await this.QdrantClient.CreateCollectionAsync( + "namedVectorsHotels", + vectorParamsMap); + + await this.QdrantClient.CreateCollectionAsync( + "singleVectorHotels", + new VectorParams { Size = 4, Distance = Distance.Cosine }); + + await this.QdrantClient.CreateCollectionAsync( + "singleVectorGuidIdHotels", + new VectorParams { Size = 4, Distance = Distance.Cosine }); + + // Create test data common to both named and unnamed vectors. + var tags = new ListValue(); + tags.Values.Add("t1"); + tags.Values.Add("t2"); + var tagsValue = new Value(); + tagsValue.ListValue = tags; + + // Create some test data using named vectors. + var embedding = new[] { 30f, 31f, 32f, 33f }; + + var namedVectors1 = new NamedVectors(); + var namedVectors2 = new NamedVectors(); + var namedVectors3 = new NamedVectors(); + + namedVectors1.Vectors.Add("DescriptionEmbedding", embedding); + namedVectors2.Vectors.Add("DescriptionEmbedding", embedding); + namedVectors3.Vectors.Add("DescriptionEmbedding", embedding); + + List namedVectorPoints = + [ + new PointStruct + { + Id = 11, + Vectors = new Vectors { Vectors_ = namedVectors1 }, + Payload = { ["HotelName"] = "My Hotel 11", ["HotelCode"] = 11, ["parking_is_included"] = true, ["Tags"] = tagsValue, ["HotelRating"] = 4.5f, ["Description"] = "This is a great hotel." } + }, + new PointStruct + { + Id = 12, + Vectors = new Vectors { Vectors_ = namedVectors2 }, + Payload = { ["HotelName"] = "My Hotel 12", ["HotelCode"] = 12, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } + }, + new PointStruct + { + Id = 13, + Vectors = new Vectors { Vectors_ = namedVectors3 }, + Payload = { ["HotelName"] = "My Hotel 13", ["HotelCode"] = 13, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } + }, + ]; + + await this.QdrantClient.UpsertAsync("namedVectorsHotels", namedVectorPoints); + + // Create some test data using a single unnamed vector. + List unnamedVectorPoints = + [ + new PointStruct + { + Id = 11, + Vectors = embedding, + Payload = { ["HotelName"] = "My Hotel 11", ["HotelCode"] = 11, ["parking_is_included"] = true, ["Tags"] = tagsValue, ["HotelRating"] = 4.5f, ["Description"] = "This is a great hotel." } + }, + new PointStruct + { + Id = 12, + Vectors = embedding, + Payload = { ["HotelName"] = "My Hotel 12", ["HotelCode"] = 12, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } + }, + new PointStruct + { + Id = 13, + Vectors = embedding, + Payload = { ["HotelName"] = "My Hotel 13", ["HotelCode"] = 13, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } + }, + ]; + + await this.QdrantClient.UpsertAsync("singleVectorHotels", unnamedVectorPoints); + + // Create some test data using a single unnamed vector and a guid id. + List unnamedVectorGuidIdPoints = + [ + new PointStruct + { + Id = Guid.Parse("11111111-1111-1111-1111-111111111111"), + Vectors = embedding, + Payload = { ["HotelName"] = "My Hotel 11", ["Description"] = "This is a great hotel." } + }, + new PointStruct + { + Id = Guid.Parse("22222222-2222-2222-2222-222222222222"), + Vectors = embedding, + Payload = { ["HotelName"] = "My Hotel 12", ["Description"] = "This is a great hotel." } + }, + new PointStruct + { + Id = Guid.Parse("33333333-3333-3333-3333-333333333333"), + Vectors = embedding, + Payload = { ["HotelName"] = "My Hotel 13", ["Description"] = "This is a great hotel." } + }, + ]; + + await this.QdrantClient.UpsertAsync("singleVectorGuidIdHotels", unnamedVectorGuidIdPoints); + } + + /// + /// Delete the docker container after the test run. + /// + /// An async task. + public async Task DisposeAsync() + { + if (this._containerId != null) + { + await this._client.Containers.StopContainerAsync(this._containerId, new ContainerStopParameters()); + await this._client.Containers.RemoveContainerAsync(this._containerId, new ContainerRemoveParameters()); + } + } + + /// + /// Setup the qdrant container by pulling the image and running it. + /// + /// The docker client to create the container with. + /// The id of the container. + private static async Task SetupQdrantContainerAsync(DockerClient client) + { + await client.Images.CreateImageAsync( + new ImagesCreateParameters + { + FromImage = "qdrant/qdrant", + Tag = "latest", + }, + null, + new Progress()); + + var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() + { + Image = "qdrant/qdrant", + HostConfig = new HostConfig() + { + PortBindings = new Dictionary> + { + {"6333", new List {new() {HostPort = "6333" } }}, + {"6334", new List {new() {HostPort = "6334" } }} + }, + PublishAllPorts = true + }, + ExposedPorts = new Dictionary + { + { "6333", default }, + { "6334", default } + }, + }); + + await client.Containers.StartContainerAsync( + container.ID, + new ContainerStartParameters()); + + return container.ID; + } + + /// + /// A test model for the qdrant vector store. + /// +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + public record HotelInfo() + { + /// The key of the record. + [VectorStoreRecordKey] + public ulong HotelId { get; init; } + + /// A string metadata field. + [VectorStoreRecordData(IsFilterable = true, IsFullTextSearchable = true)] + public string? HotelName { get; set; } + + /// An int metadata field. + [VectorStoreRecordData(IsFilterable = true)] + public int HotelCode { get; set; } + + /// A float metadata field. + [VectorStoreRecordData(IsFilterable = true)] + public float? HotelRating { get; set; } + + /// A bool metadata field. + [VectorStoreRecordData(IsFilterable = true, StoragePropertyName = "parking_is_included")] + public bool ParkingIncluded { get; set; } + + [VectorStoreRecordData] + public List Tags { get; set; } = new List(); + + /// A data field. + [VectorStoreRecordData] + public string Description { get; set; } + + /// A vector field. + [VectorStoreRecordVector(4, IndexKind.Hnsw, DistanceFunction.ManhattanDistance)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } + } + + /// + /// A test model for the qdrant vector store. + /// +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + public record HotelInfoWithGuidId() + { + /// The key of the record. + [VectorStoreRecordKey] + public Guid HotelId { get; init; } + + /// A string metadata field. + [VectorStoreRecordData(IsFilterable = true, IsFullTextSearchable = true)] + public string? HotelName { get; set; } + + /// A data field. + [VectorStoreRecordData] + public string Description { get; set; } + + /// A vector field. + [VectorStoreRecordVector(4, IndexKind.Hnsw, DistanceFunction.ManhattanDistance)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } + } +} +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..7e2e9b1f7d78 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs @@ -0,0 +1,381 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Globalization; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Qdrant; +using Microsoft.SemanticKernel.Data; +using Qdrant.Client.Grpc; +using Xunit; +using Xunit.Abstractions; +using static SemanticKernel.IntegrationTests.Connectors.Memory.Qdrant.QdrantVectorStoreFixture; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Qdrant; + +/// +/// Contains tests for the class. +/// +/// Used for logging. +/// Qdrant setup and teardown. +[Collection("QdrantVectorStoreCollection")] +public sealed class QdrantVectorStoreRecordCollectionTests(ITestOutputHelper output, QdrantVectorStoreFixture fixture) +{ + [Theory] + [InlineData("singleVectorHotels", true)] + [InlineData("nonexistentcollection", false)] + public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) + { + // Arrange. + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName); + + // Act. + var actual = await sut.CollectionExistsAsync(); + + // Assert. + Assert.Equal(expectedExists, actual); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task ItCanCreateACollectionUpsertAndGetAsync(bool hasNamedVectors, bool useRecordDefinition) + { + // Arrange + var collectionNamePostfix1 = useRecordDefinition ? "WithDefinition" : "WithType"; + var collectionNamePostfix2 = hasNamedVectors ? "HasNamedVectors" : "SingleUnnamedVector"; + var testCollectionName = $"createtest{collectionNamePostfix1}{collectionNamePostfix2}"; + + var options = new QdrantVectorStoreRecordCollectionOptions + { + HasNamedVectors = hasNamedVectors, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null + }; + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, testCollectionName, options); + + var record = this.CreateTestHotel(30); + + // Act + await sut.CreateCollectionAsync(); + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync(30, new GetRecordOptions { IncludeVectors = true }); + + // Assert + var collectionExistResult = await sut.CollectionExistsAsync(); + Assert.True(collectionExistResult); + await sut.DeleteCollectionAsync(); + + Assert.Equal(30ul, upsertResult); + Assert.Equal(record.HotelId, getResult?.HotelId); + Assert.Equal(record.HotelName, getResult?.HotelName); + Assert.Equal(record.HotelCode, getResult?.HotelCode); + Assert.Equal(record.HotelRating, getResult?.HotelRating); + Assert.Equal(record.ParkingIncluded, getResult?.ParkingIncluded); + Assert.Equal(record.Tags.ToArray(), getResult?.Tags.ToArray()); + Assert.Equal(record.Description, getResult?.Description); + + // Output + output.WriteLine(collectionExistResult.ToString()); + output.WriteLine(upsertResult.ToString(CultureInfo.InvariantCulture)); + output.WriteLine(getResult?.ToString()); + } + + [Fact] + public async Task ItCanDeleteCollectionAsync() + { + // Arrange + var tempCollectionName = "temp-test"; + await fixture.QdrantClient.CreateCollectionAsync( + tempCollectionName, + new VectorParams { Size = 4, Distance = Distance.Cosine }); + + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, tempCollectionName); + + // Act + await sut.DeleteCollectionAsync(); + + // Assert + Assert.False(await sut.CollectionExistsAsync()); + } + + [Theory] + [InlineData(true, "singleVectorHotels", false)] + [InlineData(false, "singleVectorHotels", false)] + [InlineData(true, "namedVectorsHotels", true)] + [InlineData(false, "namedVectorsHotels", true)] + public async Task ItCanUpsertDocumentToVectorStoreAsync(bool useRecordDefinition, string collectionName, bool hasNamedVectors) + { + // Arrange. + var options = new QdrantVectorStoreRecordCollectionOptions + { + HasNamedVectors = hasNamedVectors, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null + }; + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); + + var record = this.CreateTestHotel(20); + + // Act. + var upsertResult = await sut.UpsertAsync(record); + + // Assert. + var getResult = await sut.GetAsync(20, new GetRecordOptions { IncludeVectors = true }); + Assert.Equal(20ul, upsertResult); + Assert.Equal(record.HotelId, getResult?.HotelId); + Assert.Equal(record.HotelName, getResult?.HotelName); + Assert.Equal(record.HotelCode, getResult?.HotelCode); + Assert.Equal(record.HotelRating, getResult?.HotelRating); + Assert.Equal(record.ParkingIncluded, getResult?.ParkingIncluded); + Assert.Equal(record.Tags.ToArray(), getResult?.Tags.ToArray()); + Assert.Equal(record.Description, getResult?.Description); + + // TODO: figure out why original array is different from the one we get back. + //Assert.Equal(record.DescriptionEmbedding?.ToArray(), getResult?.DescriptionEmbedding?.ToArray()); + + // Output. + output.WriteLine(upsertResult.ToString(CultureInfo.InvariantCulture)); + output.WriteLine(getResult?.ToString()); + } + + [Fact] + public async Task ItCanUpsertAndRemoveDocumentWithGuidIdToVectorStoreAsync() + { + // Arrange. + var options = new QdrantVectorStoreRecordCollectionOptions { HasNamedVectors = false }; + IVectorStoreRecordCollection sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "singleVectorGuidIdHotels", options); + + var record = new HotelInfoWithGuidId + { + HotelId = Guid.Parse("55555555-5555-5555-5555-555555555555"), + HotelName = "My Hotel 5", + Description = "This is a great hotel.", + DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f }, + }; + + // Act. + var upsertResult = await sut.UpsertAsync(record); + + // Assert. + var getResult = await sut.GetAsync(Guid.Parse("55555555-5555-5555-5555-555555555555"), new GetRecordOptions { IncludeVectors = true }); + Assert.Equal(Guid.Parse("55555555-5555-5555-5555-555555555555"), upsertResult); + Assert.Equal(record.HotelId, getResult?.HotelId); + Assert.Equal(record.HotelName, getResult?.HotelName); + Assert.Equal(record.Description, getResult?.Description); + + // Act. + await sut.DeleteAsync(Guid.Parse("55555555-5555-5555-5555-555555555555")); + + // Assert. + Assert.Null(await sut.GetAsync(Guid.Parse("55555555-5555-5555-5555-555555555555"))); + + // Output. + output.WriteLine(upsertResult.ToString("D")); + output.WriteLine(getResult?.ToString()); + } + + [Theory] + [InlineData(true, true, "singleVectorHotels", false)] + [InlineData(true, false, "singleVectorHotels", false)] + [InlineData(false, true, "singleVectorHotels", false)] + [InlineData(false, false, "singleVectorHotels", false)] + [InlineData(true, true, "namedVectorsHotels", true)] + [InlineData(true, false, "namedVectorsHotels", true)] + [InlineData(false, true, "namedVectorsHotels", true)] + [InlineData(false, false, "namedVectorsHotels", true)] + public async Task ItCanGetDocumentFromVectorStoreAsync(bool useRecordDefinition, bool withEmbeddings, string collectionName, bool hasNamedVectors) + { + // Arrange. + var options = new QdrantVectorStoreRecordCollectionOptions + { + HasNamedVectors = hasNamedVectors, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null + }; + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); + + // Act. + var getResult = await sut.GetAsync(11, new GetRecordOptions { IncludeVectors = withEmbeddings }); + + // Assert. + Assert.Equal(11ul, getResult?.HotelId); + Assert.Equal("My Hotel 11", getResult?.HotelName); + Assert.Equal(11, getResult?.HotelCode); + Assert.True(getResult?.ParkingIncluded); + Assert.Equal(4.5f, getResult?.HotelRating); + Assert.Equal(2, getResult?.Tags.Count); + Assert.Equal("t1", getResult?.Tags[0]); + Assert.Equal("t2", getResult?.Tags[1]); + Assert.Equal("This is a great hotel.", getResult?.Description); + if (withEmbeddings) + { + Assert.NotNull(getResult?.DescriptionEmbedding); + } + else + { + Assert.Null(getResult?.DescriptionEmbedding); + } + + // Output. + output.WriteLine(getResult?.ToString()); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task ItCanGetDocumentWithGuidIdFromVectorStoreAsync(bool useRecordDefinition, bool withEmbeddings) + { + // Arrange. + var options = new QdrantVectorStoreRecordCollectionOptions + { + HasNamedVectors = false, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelWithGuidIdVectorStoreRecordDefinition : null + }; + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "singleVectorGuidIdHotels", options); + + // Act. + var getResult = await sut.GetAsync(Guid.Parse("11111111-1111-1111-1111-111111111111"), new GetRecordOptions { IncludeVectors = withEmbeddings }); + + // Assert. + Assert.Equal(Guid.Parse("11111111-1111-1111-1111-111111111111"), getResult?.HotelId); + Assert.Equal("My Hotel 11", getResult?.HotelName); + Assert.Equal("This is a great hotel.", getResult?.Description); + if (withEmbeddings) + { + Assert.NotNull(getResult?.DescriptionEmbedding); + } + else + { + Assert.Null(getResult?.DescriptionEmbedding); + } + + // Output. + output.WriteLine(getResult?.ToString()); + } + + [Fact] + public async Task ItCanGetManyDocumentsFromVectorStoreAsync() + { + // Arrange + var options = new QdrantVectorStoreRecordCollectionOptions { HasNamedVectors = true }; + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "namedVectorsHotels", options); + + // Act + // Also include one non-existing key to test that the operation does not fail for these and returns only the found ones. + var hotels = sut.GetBatchAsync([11, 15, 12], new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(hotels); + var hotelsList = await hotels.ToListAsync(); + Assert.Equal(2, hotelsList.Count); + + // Output + foreach (var hotel in hotelsList) + { + output.WriteLine(hotel?.ToString() ?? "Null"); + } + } + + [Theory] + [InlineData(true, "singleVectorHotels", false)] + [InlineData(false, "singleVectorHotels", false)] + [InlineData(true, "namedVectorsHotels", true)] + [InlineData(false, "namedVectorsHotels", true)] + public async Task ItCanRemoveDocumentFromVectorStoreAsync(bool useRecordDefinition, string collectionName, bool hasNamedVectors) + { + // Arrange. + var options = new QdrantVectorStoreRecordCollectionOptions + { + HasNamedVectors = hasNamedVectors, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null + }; + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); + + await sut.UpsertAsync(this.CreateTestHotel(20)); + + // Act. + await sut.DeleteAsync(20); + // Also delete a non-existing key to test that the operation does not fail for these. + await sut.DeleteAsync(21); + + // Assert. + Assert.Null(await sut.GetAsync(20)); + } + + [Theory] + [InlineData(true, "singleVectorHotels", false)] + [InlineData(false, "singleVectorHotels", false)] + [InlineData(true, "namedVectorsHotels", true)] + [InlineData(false, "namedVectorsHotels", true)] + public async Task ItCanRemoveManyDocumentsFromVectorStoreAsync(bool useRecordDefinition, string collectionName, bool hasNamedVectors) + { + // Arrange. + var options = new QdrantVectorStoreRecordCollectionOptions + { + HasNamedVectors = hasNamedVectors, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null + }; + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, collectionName, options); + + await sut.UpsertAsync(this.CreateTestHotel(20)); + + // Act. + // Also delete a non-existing key to test that the operation does not fail for these. + await sut.DeleteBatchAsync([20, 21]); + + // Assert. + Assert.Null(await sut.GetAsync(20)); + } + + [Fact] + public async Task ItReturnsNullWhenGettingNonExistentRecordAsync() + { + // Arrange + var options = new QdrantVectorStoreRecordCollectionOptions { HasNamedVectors = false }; + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "singleVectorHotels", options); + + // Act & Assert + Assert.Null(await sut.GetAsync(15, new GetRecordOptions { IncludeVectors = true })); + } + + [Fact] + public async Task ItThrowsMappingExceptionForFailedMapperAsync() + { + // Arrange + var options = new QdrantVectorStoreRecordCollectionOptions { PointStructCustomMapper = new FailingMapper() }; + var sut = new QdrantVectorStoreRecordCollection(fixture.QdrantClient, "singleVectorHotels", options); + + // Act & Assert + await Assert.ThrowsAsync(async () => await sut.GetAsync(11, new GetRecordOptions { IncludeVectors = true })); + } + + private HotelInfo CreateTestHotel(uint hotelId) + { + return new HotelInfo + { + HotelId = hotelId, + HotelName = $"My Hotel {hotelId}", + HotelCode = (int)hotelId, + HotelRating = 4.5f, + ParkingIncluded = true, + Tags = { "t1", "t2" }, + Description = "This is a great hotel.", + DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f }, + }; + } + + private sealed class FailingMapper : IVectorStoreRecordMapper + { + public PointStruct MapFromDataToStorageModel(HotelInfo dataModel) + { + throw new NotImplementedException(); + } + + public HotelInfo MapFromStorageToDataModel(PointStruct storageModel, StorageToDataModelMapperOptions options) + { + throw new NotImplementedException(); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreTests.cs new file mode 100644 index 000000000000..0da44530f5c0 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreTests.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Qdrant; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Qdrant; + +[Collection("QdrantVectorStoreCollection")] +public class QdrantVectorStoreTests(ITestOutputHelper output, QdrantVectorStoreFixture fixture) +{ + [Fact] + public async Task ItCanGetAListOfExistingCollectionNamesAsync() + { + // Arrange + var sut = new QdrantVectorStore(fixture.QdrantClient); + + // Act + var collectionNames = await sut.ListCollectionNamesAsync().ToListAsync(); + + // Assert + Assert.Equal(3, collectionNames.Count); + Assert.Contains("namedVectorsHotels", collectionNames); + Assert.Contains("singleVectorHotels", collectionNames); + Assert.Contains("singleVectorGuidIdHotels", collectionNames); + + // Output + output.WriteLine(string.Join(",", collectionNames)); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..b80d85551e6d --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs @@ -0,0 +1,340 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Redis; +using Microsoft.SemanticKernel.Data; +using NRedisStack.RedisStackCommands; +using NRedisStack.Search; +using StackExchange.Redis; +using Xunit; +using Xunit.Abstractions; +using static SemanticKernel.IntegrationTests.Connectors.Memory.Redis.RedisVectorStoreFixture; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; + +/// +/// Contains tests for the class. +/// +/// Used for logging. +/// Redis setup and teardown. +[Collection("RedisVectorStoreCollection")] +public sealed class RedisHashSetVectorStoreRecordCollectionTests(ITestOutputHelper output, RedisVectorStoreFixture fixture) +{ + private const string TestCollectionName = "hashhotels"; + + [Theory] + [InlineData(TestCollectionName, true)] + [InlineData("nonexistentcollection", false)] + public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) + { + // Arrange. + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, collectionName); + + // Act. + var actual = await sut.CollectionExistsAsync(); + + // Assert. + Assert.Equal(expectedExists, actual); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanCreateACollectionUpsertAndGetAsync(bool useRecordDefinition) + { + // Arrange + var record = CreateTestHotel("Upsert-1", 1); + var collectionNamePostfix = useRecordDefinition ? "WithDefinition" : "WithType"; + var testCollectionName = $"hashsetcreatetest{collectionNamePostfix}"; + + var options = new RedisHashSetVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.BasicVectorStoreRecordDefinition : null + }; + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, testCollectionName, options); + + // Act + await sut.CreateCollectionAsync(); + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync("Upsert-1", new GetRecordOptions { IncludeVectors = true }); + + // Assert + var collectionExistResult = await sut.CollectionExistsAsync(); + Assert.True(collectionExistResult); + await sut.DeleteCollectionAsync(); + + Assert.Equal("Upsert-1", upsertResult); + Assert.Equal(record.HotelId, getResult?.HotelId); + Assert.Equal(record.HotelName, getResult?.HotelName); + Assert.Equal(record.HotelCode, getResult?.HotelCode); + Assert.Equal(record.ParkingIncluded, getResult?.ParkingIncluded); + Assert.Equal(record.Rating, getResult?.Rating); + Assert.Equal(record.Description, getResult?.Description); + Assert.Equal(record.DescriptionEmbedding?.ToArray(), getResult?.DescriptionEmbedding?.ToArray()); + + // Output + output.WriteLine(collectionExistResult.ToString()); + output.WriteLine(upsertResult); + output.WriteLine(getResult?.ToString()); + } + + [Fact] + public async Task ItCanDeleteCollectionAsync() + { + // Arrange + var tempCollectionName = "temp-test"; + var schema = new Schema(); + schema.AddTextField("HotelName"); + var createParams = new FTCreateParams(); + createParams.AddPrefix(tempCollectionName); + await fixture.Database.FT().CreateAsync(tempCollectionName, createParams, schema); + + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, tempCollectionName); + + // Act + await sut.DeleteCollectionAsync(); + + // Assert + Assert.False(await sut.CollectionExistsAsync()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanUpsertDocumentToVectorStoreAsync(bool useRecordDefinition) + { + // Arrange. + var options = new RedisHashSetVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.BasicVectorStoreRecordDefinition : null + }; + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var record = CreateTestHotel("Upsert-2", 2); + + // Act. + var upsertResult = await sut.UpsertAsync(record); + + // Assert. + var getResult = await sut.GetAsync("Upsert-2", new GetRecordOptions { IncludeVectors = true }); + Assert.Equal("Upsert-2", upsertResult); + Assert.Equal(record.HotelId, getResult?.HotelId); + Assert.Equal(record.HotelName, getResult?.HotelName); + Assert.Equal(record.HotelCode, getResult?.HotelCode); + Assert.Equal(record.ParkingIncluded, getResult?.ParkingIncluded); + Assert.Equal(record.Rating, getResult?.Rating); + Assert.Equal(record.Description, getResult?.Description); + Assert.Equal(record.DescriptionEmbedding?.ToArray(), getResult?.DescriptionEmbedding?.ToArray()); + + // Output. + output.WriteLine(upsertResult); + output.WriteLine(getResult?.ToString()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanUpsertManyDocumentsToVectorStoreAsync(bool useRecordDefinition) + { + // Arrange. + var options = new RedisHashSetVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.BasicVectorStoreRecordDefinition : null + }; + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + + // Act. + var results = sut.UpsertBatchAsync( + [ + CreateTestHotel("UpsertMany-1", 1), + CreateTestHotel("UpsertMany-2", 2), + CreateTestHotel("UpsertMany-3", 3), + ]); + + // Assert. + Assert.NotNull(results); + var resultsList = await results.ToListAsync(); + + Assert.Equal(3, resultsList.Count); + Assert.Contains("UpsertMany-1", resultsList); + Assert.Contains("UpsertMany-2", resultsList); + Assert.Contains("UpsertMany-3", resultsList); + + // Output + foreach (var result in resultsList) + { + output.WriteLine(result); + } + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task ItCanGetDocumentFromVectorStoreAsync(bool includeVectors, bool useRecordDefinition) + { + // Arrange. + var options = new RedisHashSetVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.BasicVectorStoreRecordDefinition : null + }; + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + + // Act. + var getResult = await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = includeVectors }); + + // Assert. + Assert.Equal("BaseSet-1", getResult?.HotelId); + Assert.Equal("My Hotel 1", getResult?.HotelName); + Assert.Equal(1, getResult?.HotelCode); + Assert.True(getResult?.ParkingIncluded); + Assert.Equal(3.6, getResult?.Rating); + Assert.Equal("This is a great hotel.", getResult?.Description); + if (includeVectors) + { + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, getResult?.DescriptionEmbedding?.ToArray()); + } + else + { + Assert.Null(getResult?.DescriptionEmbedding); + } + + // Output. + output.WriteLine(getResult?.ToString()); + } + + [Fact] + public async Task ItCanGetManyDocumentsFromVectorStoreAsync() + { + // Arrange + var options = new RedisHashSetVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + + // Act + // Also include one non-existing key to test that the operation does not fail for these and returns only the found ones. + var hotels = sut.GetBatchAsync(["BaseSet-1", "BaseSet-5", "BaseSet-2"], new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(hotels); + var hotelsList = await hotels.ToListAsync(); + Assert.Equal(2, hotelsList.Count); + + // Output + foreach (var hotel in hotelsList) + { + output.WriteLine(hotel?.ToString() ?? "Null"); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanRemoveDocumentFromVectorStoreAsync(bool useRecordDefinition) + { + // Arrange. + var options = new RedisHashSetVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.BasicVectorStoreRecordDefinition : null + }; + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var record = new BasicHotel + { + HotelId = "Remove-1", + HotelName = "Remove Test Hotel", + HotelCode = 20, + Description = "This is a great hotel.", + DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } + }; + + await sut.UpsertAsync(record); + + // Act. + await sut.DeleteAsync("Remove-1"); + // Also delete a non-existing key to test that the operation does not fail for these. + await sut.DeleteAsync("Remove-2"); + + // Assert. + Assert.Null(await sut.GetAsync("Remove-1")); + } + + [Fact] + public async Task ItCanRemoveManyDocumentsFromVectorStoreAsync() + { + // Arrange + var options = new RedisHashSetVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + await sut.UpsertAsync(CreateTestHotel("RemoveMany-1", 1)); + await sut.UpsertAsync(CreateTestHotel("RemoveMany-2", 2)); + await sut.UpsertAsync(CreateTestHotel("RemoveMany-3", 3)); + + // Act + // Also include a non-existing key to test that the operation does not fail for these. + await sut.DeleteBatchAsync(["RemoveMany-1", "RemoveMany-2", "RemoveMany-3", "RemoveMany-4"]); + + // Assert + Assert.Null(await sut.GetAsync("RemoveMany-1", new GetRecordOptions { IncludeVectors = true })); + Assert.Null(await sut.GetAsync("RemoveMany-2", new GetRecordOptions { IncludeVectors = true })); + Assert.Null(await sut.GetAsync("RemoveMany-3", new GetRecordOptions { IncludeVectors = true })); + } + + [Fact] + public async Task ItReturnsNullWhenGettingNonExistentRecordAsync() + { + // Arrange + var options = new RedisHashSetVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + + // Act & Assert + Assert.Null(await sut.GetAsync("BaseSet-5", new GetRecordOptions { IncludeVectors = true })); + } + + [Fact] + public async Task ItThrowsMappingExceptionForFailedMapperAsync() + { + // Arrange + var options = new RedisHashSetVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + HashEntriesCustomMapper = new FailingMapper() + }; + var sut = new RedisHashSetVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + + // Act & Assert + await Assert.ThrowsAsync(async () => await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = true })); + } + + private static BasicHotel CreateTestHotel(string hotelId, int hotelCode) + { + var record = new BasicHotel + { + HotelId = hotelId, + HotelName = $"My Hotel {hotelCode}", + HotelCode = 1, + ParkingIncluded = true, + Rating = 3.6, + Description = "This is a great hotel.", + DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } + }; + return record; + } + + private sealed class FailingMapper : IVectorStoreRecordMapper + { + public (string Key, HashEntry[] HashEntries) MapFromDataToStorageModel(BasicHotel dataModel) + { + throw new NotImplementedException(); + } + + public BasicHotel MapFromStorageToDataModel((string Key, HashEntry[] HashEntries) storageModel, StorageToDataModelMapperOptions options) + { + throw new NotImplementedException(); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..9e7a2fde0561 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs @@ -0,0 +1,371 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Text.Json.Nodes; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Redis; +using Microsoft.SemanticKernel.Data; +using NRedisStack.RedisStackCommands; +using NRedisStack.Search; +using Xunit; +using Xunit.Abstractions; +using static SemanticKernel.IntegrationTests.Connectors.Memory.Redis.RedisVectorStoreFixture; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; + +/// +/// Contains tests for the class. +/// +/// Used for logging. +/// Redis setup and teardown. +[Collection("RedisVectorStoreCollection")] +public sealed class RedisJsonVectorStoreRecordCollectionTests(ITestOutputHelper output, RedisVectorStoreFixture fixture) +{ + private const string TestCollectionName = "jsonhotels"; + + [Theory] + [InlineData(TestCollectionName, true)] + [InlineData("nonexistentcollection", false)] + public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) + { + // Arrange. + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, collectionName); + + // Act. + var actual = await sut.CollectionExistsAsync(); + + // Assert. + Assert.Equal(expectedExists, actual); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanCreateACollectionUpsertAndGetAsync(bool useRecordDefinition) + { + // Arrange + var record = CreateTestHotel("Upsert-1", 1); + var collectionNamePostfix = useRecordDefinition ? "WithDefinition" : "WithType"; + var testCollectionName = $"jsoncreatetest{collectionNamePostfix}"; + + var options = new RedisJsonVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null + }; + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, testCollectionName, options); + + // Act + await sut.CreateCollectionAsync(); + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync("Upsert-1", new GetRecordOptions { IncludeVectors = true }); + + // Assert + var collectionExistResult = await sut.CollectionExistsAsync(); + Assert.True(collectionExistResult); + await sut.DeleteCollectionAsync(); + + Assert.Equal("Upsert-1", upsertResult); + Assert.Equal(record.HotelId, getResult?.HotelId); + Assert.Equal(record.HotelName, getResult?.HotelName); + Assert.Equal(record.HotelCode, getResult?.HotelCode); + Assert.Equal(record.Tags, getResult?.Tags); + Assert.Equal(record.FTSTags, getResult?.FTSTags); + Assert.Equal(record.ParkingIncluded, getResult?.ParkingIncluded); + Assert.Equal(record.LastRenovationDate, getResult?.LastRenovationDate); + Assert.Equal(record.Rating, getResult?.Rating); + Assert.Equal(record.Address.Country, getResult?.Address.Country); + Assert.Equal(record.Address.City, getResult?.Address.City); + Assert.Equal(record.Description, getResult?.Description); + Assert.Equal(record.DescriptionEmbedding?.ToArray(), getResult?.DescriptionEmbedding?.ToArray()); + + // Output + output.WriteLine(collectionExistResult.ToString()); + output.WriteLine(upsertResult); + output.WriteLine(getResult?.ToString()); + } + + [Fact] + public async Task ItCanDeleteCollectionAsync() + { + // Arrange + var tempCollectionName = "temp-test"; + var schema = new Schema(); + schema.AddTextField("HotelName"); + var createParams = new FTCreateParams(); + createParams.AddPrefix(tempCollectionName); + await fixture.Database.FT().CreateAsync(tempCollectionName, createParams, schema); + + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, tempCollectionName); + + // Act + await sut.DeleteCollectionAsync(); + + // Assert + Assert.False(await sut.CollectionExistsAsync()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanUpsertDocumentToVectorStoreAsync(bool useRecordDefinition) + { + // Arrange. + var options = new RedisJsonVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null + }; + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + Hotel record = CreateTestHotel("Upsert-2", 2); + + // Act. + var upsertResult = await sut.UpsertAsync(record); + + // Assert. + var getResult = await sut.GetAsync("Upsert-2", new GetRecordOptions { IncludeVectors = true }); + Assert.Equal("Upsert-2", upsertResult); + Assert.Equal(record.HotelId, getResult?.HotelId); + Assert.Equal(record.HotelName, getResult?.HotelName); + Assert.Equal(record.HotelCode, getResult?.HotelCode); + Assert.Equal(record.Tags, getResult?.Tags); + Assert.Equal(record.FTSTags, getResult?.FTSTags); + Assert.Equal(record.ParkingIncluded, getResult?.ParkingIncluded); + Assert.Equal(record.LastRenovationDate, getResult?.LastRenovationDate); + Assert.Equal(record.Rating, getResult?.Rating); + Assert.Equal(record.Address.Country, getResult?.Address.Country); + Assert.Equal(record.Address.City, getResult?.Address.City); + Assert.Equal(record.Description, getResult?.Description); + Assert.Equal(record.DescriptionEmbedding?.ToArray(), getResult?.DescriptionEmbedding?.ToArray()); + + // Output. + output.WriteLine(upsertResult); + output.WriteLine(getResult?.ToString()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanUpsertManyDocumentsToVectorStoreAsync(bool useRecordDefinition) + { + // Arrange. + var options = new RedisJsonVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null + }; + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + + // Act. + var results = sut.UpsertBatchAsync( + [ + CreateTestHotel("UpsertMany-1", 1), + CreateTestHotel("UpsertMany-2", 2), + CreateTestHotel("UpsertMany-3", 3), + ]); + + // Assert. + Assert.NotNull(results); + var resultsList = await results.ToListAsync(); + + Assert.Equal(3, resultsList.Count); + Assert.Contains("UpsertMany-1", resultsList); + Assert.Contains("UpsertMany-2", resultsList); + Assert.Contains("UpsertMany-3", resultsList); + + // Output + foreach (var result in resultsList) + { + output.WriteLine(result); + } + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task ItCanGetDocumentFromVectorStoreAsync(bool includeVectors, bool useRecordDefinition) + { + // Arrange. + var options = new RedisJsonVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null + }; + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + + // Act. + var getResult = await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = includeVectors }); + + // Assert. + Assert.Equal("BaseSet-1", getResult?.HotelId); + Assert.Equal("My Hotel 1", getResult?.HotelName); + Assert.Equal(1, getResult?.HotelCode); + Assert.Equal(new[] { "pool", "air conditioning", "concierge" }, getResult?.Tags); + Assert.Equal(new[] { "pool", "air conditioning", "concierge" }, getResult?.FTSTags); + Assert.True(getResult?.ParkingIncluded); + Assert.Equal(new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), getResult?.LastRenovationDate); + Assert.Equal(3.6, getResult?.Rating); + Assert.Equal("Seattle", getResult?.Address.City); + Assert.Equal("This is a great hotel.", getResult?.Description); + if (includeVectors) + { + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, getResult?.DescriptionEmbedding?.ToArray()); + } + else + { + Assert.Null(getResult?.DescriptionEmbedding); + } + + // Output. + output.WriteLine(getResult?.ToString()); + } + + [Fact] + public async Task ItCanGetManyDocumentsFromVectorStoreAsync() + { + // Arrange + var options = new RedisJsonVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + + // Act + // Also include one non-existing key to test that the operation does not fail for these and returns only the found ones. + var hotels = sut.GetBatchAsync(["BaseSet-1", "BaseSet-5", "BaseSet-2"], new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(hotels); + var hotelsList = await hotels.ToListAsync(); + Assert.Equal(2, hotelsList.Count); + + // Output + foreach (var hotel in hotelsList) + { + output.WriteLine(hotel?.ToString() ?? "Null"); + } + } + + [Fact] + public async Task ItFailsToGetDocumentsWithInvalidSchemaAsync() + { + // Arrange. + var options = new RedisJsonVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + + // Act & Assert. + await Assert.ThrowsAsync(async () => await sut.GetAsync("BaseSet-4-Invalid", new GetRecordOptions { IncludeVectors = true })); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanRemoveDocumentFromVectorStoreAsync(bool useRecordDefinition) + { + // Arrange. + var options = new RedisJsonVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + VectorStoreRecordDefinition = useRecordDefinition ? fixture.VectorStoreRecordDefinition : null + }; + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + var address = new HotelAddress { City = "Seattle", Country = "USA" }; + var record = new Hotel + { + HotelId = "Remove-1", + HotelName = "Remove Test Hotel", + HotelCode = 20, + Description = "This is a great hotel.", + DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } + }; + + await sut.UpsertAsync(record); + + // Act. + await sut.DeleteAsync("Remove-1"); + // Also delete a non-existing key to test that the operation does not fail for these. + await sut.DeleteAsync("Remove-2"); + + // Assert. + Assert.Null(await sut.GetAsync("Remove-1")); + } + + [Fact] + public async Task ItCanRemoveManyDocumentsFromVectorStoreAsync() + { + // Arrange + var options = new RedisJsonVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + await sut.UpsertAsync(CreateTestHotel("RemoveMany-1", 1)); + await sut.UpsertAsync(CreateTestHotel("RemoveMany-2", 2)); + await sut.UpsertAsync(CreateTestHotel("RemoveMany-3", 3)); + + // Act + // Also include a non-existing key to test that the operation does not fail for these. + await sut.DeleteBatchAsync(["RemoveMany-1", "RemoveMany-2", "RemoveMany-3", "RemoveMany-4"]); + + // Assert + Assert.Null(await sut.GetAsync("RemoveMany-1", new GetRecordOptions { IncludeVectors = true })); + Assert.Null(await sut.GetAsync("RemoveMany-2", new GetRecordOptions { IncludeVectors = true })); + Assert.Null(await sut.GetAsync("RemoveMany-3", new GetRecordOptions { IncludeVectors = true })); + } + + [Fact] + public async Task ItReturnsNullWhenGettingNonExistentRecordAsync() + { + // Arrange + var options = new RedisJsonVectorStoreRecordCollectionOptions { PrefixCollectionNameToKeyNames = true }; + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + + // Act & Assert + Assert.Null(await sut.GetAsync("BaseSet-5", new GetRecordOptions { IncludeVectors = true })); + } + + [Fact] + public async Task ItThrowsMappingExceptionForFailedMapperAsync() + { + // Arrange + var options = new RedisJsonVectorStoreRecordCollectionOptions + { + PrefixCollectionNameToKeyNames = true, + JsonNodeCustomMapper = new FailingMapper() + }; + var sut = new RedisJsonVectorStoreRecordCollection(fixture.Database, TestCollectionName, options); + + // Act & Assert + await Assert.ThrowsAsync(async () => await sut.GetAsync("BaseSet-1", new GetRecordOptions { IncludeVectors = true })); + } + + private static Hotel CreateTestHotel(string hotelId, int hotelCode) + { + var address = new HotelAddress { City = "Seattle", Country = "USA" }; + var record = new Hotel + { + HotelId = hotelId, + HotelName = $"My Hotel {hotelCode}", + HotelCode = 1, + Tags = ["pool", "air conditioning", "concierge"], + FTSTags = ["pool", "air conditioning", "concierge"], + ParkingIncluded = true, + LastRenovationDate = new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), + Rating = 3.6, + Address = address, + Description = "This is a great hotel.", + DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } + }; + return record; + } + + private sealed class FailingMapper : IVectorStoreRecordMapper + { + public (string Key, JsonNode Node) MapFromDataToStorageModel(Hotel dataModel) + { + throw new NotImplementedException(); + } + + public Hotel MapFromStorageToDataModel((string Key, JsonNode Node) storageModel, StorageToDataModelMapperOptions options) + { + throw new NotImplementedException(); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreCollectionFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreCollectionFixture.cs new file mode 100644 index 000000000000..1bebd51d8f5f --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreCollectionFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; + +[CollectionDefinition("RedisVectorStoreCollection")] +public class RedisVectorStoreCollectionFixture : ICollectionFixture +{ +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreFixture.cs new file mode 100644 index 000000000000..3256cae3e79e --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreFixture.cs @@ -0,0 +1,300 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text.Json.Serialization; +using System.Threading.Tasks; +using Docker.DotNet; +using Docker.DotNet.Models; +using Microsoft.SemanticKernel.Data; +using NRedisStack.RedisStackCommands; +using NRedisStack.Search; +using NRedisStack.Search.Literals.Enums; +using StackExchange.Redis; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. +/// +/// Does setup and teardown of redis docker container and associated test data. +/// +public class RedisVectorStoreFixture : IAsyncLifetime +{ + /// The docker client we are using to create a redis container with. + private readonly DockerClient _client; + + /// The id of the redis container that we are testing with. + private string? _containerId = null; + + /// + /// Initializes a new instance of the class. + /// + public RedisVectorStoreFixture() + { + using var dockerClientConfiguration = new DockerClientConfiguration(); + this._client = dockerClientConfiguration.CreateClient(); + this.VectorStoreRecordDefinition = new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("HotelId", typeof(string)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Description", typeof(string)) { IsFullTextSearchable = true }, + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4 }, + new VectorStoreRecordDataProperty("Tags", typeof(string[])) { IsFilterable = true }, + new VectorStoreRecordDataProperty("FTSTags", typeof(string[])) { IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("LastRenovationDate", typeof(DateTimeOffset)), + new VectorStoreRecordDataProperty("Rating", typeof(double)), + new VectorStoreRecordDataProperty("Address", typeof(HotelAddress)) + } + }; + this.BasicVectorStoreRecordDefinition = new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("HotelId", typeof(string)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Description", typeof(string)) { IsFullTextSearchable = true }, + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4 }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("Rating", typeof(double)), + } + }; + } + + /// Gets the redis database connection to use for tests. + public IDatabase Database { get; private set; } + + /// Gets the manually created vector store record definition for our test model. + public VectorStoreRecordDefinition VectorStoreRecordDefinition { get; private set; } + + /// Gets the manually created vector store record definition for our basic test model. + public VectorStoreRecordDefinition BasicVectorStoreRecordDefinition { get; private set; } + + /// + /// Create / Recreate redis docker container, create an index and add test data. + /// + /// An async task. + public async Task InitializeAsync() + { + this._containerId = await SetupRedisContainerAsync(this._client); + + // Connect to redis. + ConnectionMultiplexer redis = ConnectionMultiplexer.Connect("localhost:6379,connectTimeout=60000,connectRetry=5"); + this.Database = redis.GetDatabase(); + + // Create a schema for the vector store. + var schema = new Schema(); + schema.AddTextField(new FieldName("$.HotelName", "HotelName")); + schema.AddNumericField(new FieldName("$.HotelCode", "HotelCode")); + schema.AddTextField(new FieldName("$.Description", "Description")); + schema.AddVectorField(new FieldName("$.DescriptionEmbedding", "DescriptionEmbedding"), Schema.VectorField.VectorAlgo.HNSW, new Dictionary() + { + ["TYPE"] = "FLOAT32", + ["DIM"] = "4", + ["DISTANCE_METRIC"] = "L2" + }); + var jsonCreateParams = new FTCreateParams().AddPrefix("jsonhotels:").On(IndexDataType.JSON); + await this.Database.FT().CreateAsync("jsonhotels", jsonCreateParams, schema); + + // Create a hashset index. + var hashsetCreateParams = new FTCreateParams().AddPrefix("hashhotels:").On(IndexDataType.HASH); + await this.Database.FT().CreateAsync("hashhotels", hashsetCreateParams, schema); + + // Create some test data. + var address = new HotelAddress { City = "Seattle", Country = "USA" }; + var embedding = new[] { 30f, 31f, 32f, 33f }; + + // Add JSON test data. + await this.Database.JSON().SetAsync("jsonhotels:BaseSet-1", "$", new + { + HotelName = "My Hotel 1", + HotelCode = 1, + Description = "This is a great hotel.", + DescriptionEmbedding = embedding, + Tags = new[] { "pool", "air conditioning", "concierge" }, + FTSTags = new[] { "pool", "air conditioning", "concierge" }, + parking_is_included = true, + LastRenovationDate = new DateTimeOffset(1970, 1, 18, 0, 0, 0, TimeSpan.Zero), + Rating = 3.6, + Address = address + }); + await this.Database.JSON().SetAsync("jsonhotels:BaseSet-2", "$", new { HotelName = "My Hotel 2", HotelCode = 2, Description = "This is a great hotel.", DescriptionEmbedding = embedding, parking_is_included = false }); + await this.Database.JSON().SetAsync("jsonhotels:BaseSet-3", "$", new { HotelName = "My Hotel 3", HotelCode = 3, Description = "This is a great hotel.", DescriptionEmbedding = embedding, parking_is_included = false }); + await this.Database.JSON().SetAsync("jsonhotels:BaseSet-4-Invalid", "$", new { HotelId = "AnotherId", HotelName = "My Invalid Hotel", HotelCode = 4, Description = "This is an invalid hotel.", DescriptionEmbedding = embedding, parking_is_included = false }); + + // Add hashset test data. + await this.Database.HashSetAsync("hashhotels:BaseSet-1", new HashEntry[] + { + new("HotelName", "My Hotel 1"), + new("HotelCode", 1), + new("Description", "This is a great hotel."), + new("DescriptionEmbedding", MemoryMarshal.AsBytes(new ReadOnlySpan(embedding)).ToArray()), + new("parking_is_included", true), + new("Rating", 3.6) + }); + await this.Database.HashSetAsync("hashhotels:BaseSet-2", new HashEntry[] + { + new("HotelName", "My Hotel 2"), + new("HotelCode", 2), + new("Description", "This is a great hotel."), + new("DescriptionEmbedding", MemoryMarshal.AsBytes(new ReadOnlySpan(embedding)).ToArray()), + new("parking_is_included", false), + }); + await this.Database.HashSetAsync("hashhotels:BaseSet-3", new HashEntry[] + { + new("HotelName", "My Hotel 3"), + new("HotelCode", 3), + new("Description", "This is a great hotel."), + new("DescriptionEmbedding", MemoryMarshal.AsBytes(new ReadOnlySpan(embedding)).ToArray()), + new("parking_is_included", false), + }); + await this.Database.HashSetAsync("hashhotels:BaseSet-4-Invalid", new HashEntry[] + { + new("HotelId", "AnotherId"), + new("HotelName", "My Invalid Hotel"), + new("HotelCode", 4), + new("Description", "This is an invalid hotel."), + new("DescriptionEmbedding", MemoryMarshal.AsBytes(new ReadOnlySpan(embedding)).ToArray()), + new("parking_is_included", false), + }); + } + + /// + /// Delete the docker container after the test run. + /// + /// An async task. + public async Task DisposeAsync() + { + if (this._containerId != null) + { + await this._client.Containers.StopContainerAsync(this._containerId, new ContainerStopParameters()); + await this._client.Containers.RemoveContainerAsync(this._containerId, new ContainerRemoveParameters()); + } + } + + /// + /// Setup the redis container by pulling the image and running it. + /// + /// The docker client to create the container with. + /// The id of the container. + private static async Task SetupRedisContainerAsync(DockerClient client) + { + await client.Images.CreateImageAsync( + new ImagesCreateParameters + { + FromImage = "redis/redis-stack", + Tag = "latest", + }, + null, + new Progress()); + + var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() + { + Image = "redis/redis-stack", + HostConfig = new HostConfig() + { + PortBindings = new Dictionary> + { + {"6379", new List {new() {HostPort = "6379"}}} + }, + PublishAllPorts = true + }, + ExposedPorts = new Dictionary + { + { "6379", default } + }, + }); + + await client.Containers.StartContainerAsync( + container.ID, + new ContainerStartParameters()); + + return container.ID; + } + + /// + /// A test model for the vector store that has complex properties as supported by JSON redis mode. + /// + public class Hotel + { + [VectorStoreRecordKey] + public string HotelId { get; init; } + + [VectorStoreRecordData(IsFilterable = true)] + public string HotelName { get; init; } + + [VectorStoreRecordData(IsFilterable = true)] + public int HotelCode { get; init; } + + [VectorStoreRecordData(IsFullTextSearchable = true)] + public string Description { get; init; } + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? DescriptionEmbedding { get; init; } + +#pragma warning disable CA1819 // Properties should not return arrays + [VectorStoreRecordData(IsFilterable = true)] + public string[] Tags { get; init; } + + [VectorStoreRecordData(IsFullTextSearchable = true)] + public string[] FTSTags { get; init; } +#pragma warning restore CA1819 // Properties should not return arrays + + [JsonPropertyName("parking_is_included")] + [VectorStoreRecordData(StoragePropertyName = "parking_is_included")] + public bool ParkingIncluded { get; init; } + + [VectorStoreRecordData] + public DateTimeOffset LastRenovationDate { get; init; } + + [VectorStoreRecordData] + public double Rating { get; init; } + + [VectorStoreRecordData] + public HotelAddress Address { get; init; } + } + + /// + /// A test model for the vector store to simulate a complex type. + /// + public class HotelAddress + { + public string City { get; init; } + public string Country { get; init; } + } + + /// + /// A test model for the vector store that only uses basic types as supported by HashSets Redis mode. + /// + public class BasicHotel + { + [VectorStoreRecordKey] + public string HotelId { get; init; } + + [VectorStoreRecordData(IsFilterable = true)] + public string HotelName { get; init; } + + [VectorStoreRecordData(IsFilterable = true)] + public int HotelCode { get; init; } + + [VectorStoreRecordData(IsFullTextSearchable = true)] + public string Description { get; init; } + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? DescriptionEmbedding { get; init; } + + [JsonPropertyName("parking_is_included")] + [VectorStoreRecordData(StoragePropertyName = "parking_is_included")] + public bool ParkingIncluded { get; init; } + + [VectorStoreRecordData] + public double Rating { get; init; } + } +} +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreTests.cs new file mode 100644 index 000000000000..a6bda9559480 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisVectorStoreTests.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Redis; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; + +/// +/// Contains tests for the class. +/// +/// Used to write to the test output stream. +/// The test fixture. +[Collection("RedisVectorStoreCollection")] +public class RedisVectorStoreTests(ITestOutputHelper output, RedisVectorStoreFixture fixture) +{ + [Fact] + public async Task ItCanGetAListOfExistingCollectionNamesAsync() + { + // Arrange + var sut = new RedisVectorStore(fixture.Database); + + // Act + var collectionNames = await sut.ListCollectionNamesAsync().ToListAsync(); + + // Assert + Assert.Equal(2, collectionNames.Count); + Assert.Contains("jsonhotels", collectionNames); + Assert.Contains("hashhotels", collectionNames); + + // Output + output.WriteLine(string.Join(",", collectionNames)); + } +} diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index df5afa473ce7..55a6ac6d1006 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -44,6 +44,7 @@ + @@ -59,13 +60,17 @@ + + + + diff --git a/dotnet/src/IntegrationTests/TestSettings/Memory/AzureAISearchConfiguration.cs b/dotnet/src/IntegrationTests/TestSettings/Memory/AzureAISearchConfiguration.cs new file mode 100644 index 000000000000..fd4043ef9b83 --- /dev/null +++ b/dotnet/src/IntegrationTests/TestSettings/Memory/AzureAISearchConfiguration.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace SemanticKernel.IntegrationTests.TestSettings.Memory; + +[SuppressMessage("Design", "CA1054:URI-like parameters should not be strings", Justification = "This is just for test configuration")] +public sealed class AzureAISearchConfiguration(string serviceUrl, string apiKey) +{ + [SuppressMessage("Design", "CA1056:URI-like properties should not be strings", Justification = "This is just for test configuration")] + public string ServiceUrl { get; set; } = serviceUrl; + + public string ApiKey { get; set; } = apiKey; +} diff --git a/dotnet/src/IntegrationTests/TestSettings/Memory/AzureAISearchSetup.psm1 b/dotnet/src/IntegrationTests/TestSettings/Memory/AzureAISearchSetup.psm1 new file mode 100644 index 000000000000..64563abdeeb0 --- /dev/null +++ b/dotnet/src/IntegrationTests/TestSettings/Memory/AzureAISearchSetup.psm1 @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft. All rights reserved. + +# This module requires powershell 7 and the Az and Az.Search modules. You may need to import Az and install Az.Search. +# Import-Module -Name Az +# Install-Module -Name Az.Search + +# Before running any of the functions you will need to connect to your azure account and pick the appropriate subscription. +# Connect-AzAccount +# Select-AzSubscription -SubscriptionName "My Dev Subscription" + +$resourceGroup = "sk-integration-test-infra" +$aiSearchResourceName = "aisearch-integration-test-basic" + +<# +.SYNOPSIS + Setup the infra required for Azure AI Search Integration tests, + retrieve the connection information for it, and update the secrets + store with these settings. + +.Parameter OverrideResourceGroup + Optional override resource group name if the default doesn't work. + +.Parameter OverrideAISearchResourceName + Optional override ai search resource name if the default doesn't work. +#> +function New-AzureAISearchIntegrationInfra($overrideResourceGroup = $resourceGroup, $overrideAISearchResourceName = $aiSearchResourceName) { + # Create the resource group if it doesn't exist. + Get-AzResourceGroup -Name $overrideResourceGroup -ErrorVariable notPresent -ErrorAction SilentlyContinue + if ($notPresent) { + Write-Host "Resource Group does not exist, creating '$overrideResourceGroup' ..." + New-AzResourceGroup -Name $overrideResourceGroup -Location "North Europe" + } + + # Create the ai search service if it doesn't exist. + $service = Get-AzSearchService -ResourceGroupName $resourceGroup -Name $aiSearchResourceName + if (-not $service) { + Write-Host "Service does not exist, creating '$overrideAISearchResourceName' ..." + New-AzSearchService -ResourceGroupName $overrideResourceGroup -Name $overrideAISearchResourceName -Sku "Basic" -Location "North Europe" -PartitionCount 1 -ReplicaCount 1 -HostingMode Default + } + + # Set the required local secrets. + Set-AzureAISearchIntegrationInfraUserSecrets -OverrideResourceGroup $overrideResourceGroup -OverrideAISearchResourceName $overrideAISearchResourceName +} + +<# +.SYNOPSIS + Set the user secrets required to run the Azure AI Search integration tests. + +.Parameter OverrideResourceGroup + Optional override resource group name if the default doesn't work. + +.Parameter OverrideAISearchResourceName + Optional override ai search resource name if the default doesn't work. +#> +function Set-AzureAISearchIntegrationInfraUserSecrets($overrideResourceGroup = $resourceGroup, $overrideAISearchResourceName = $aiSearchResourceName) { + # Set the required local secrets. + $keys = Get-AzSearchAdminKeyPair -ResourceGroupName $overrideResourceGroup -ServiceName $overrideAISearchResourceName + dotnet user-secrets set "AzureAISearch:ServiceUrl" "https://$overrideAISearchResourceName.search.windows.net" --project ../../IntegrationTests.csproj + dotnet user-secrets set "AzureAISearch:ApiKey" $keys.Primary --project ../../IntegrationTests.csproj +} + +<# +.SYNOPSIS + Tear down the infra required for Azure AI Search Integration tests. + +.Parameter OverrideResourceGroup + Optional override resource group name if the default doesn't work. + +.Parameter OverrideAISearchResourceName + Optional override ai search resource name if the default doesn't work. +#> +function Remove-AzureAISearchIntegrationInfra($overrideResourceGroup = $resourceGroup, $overrideAISearchResourceName = $aiSearchResourceName) { + Remove-AzSearchService -ResourceGroupName $overrideResourceGroup -Name $overrideAISearchResourceName +} \ No newline at end of file diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreErrorHandler.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreErrorHandler.cs new file mode 100644 index 000000000000..1aa2e6f479ad --- /dev/null +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreErrorHandler.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Contains helpers for reading vector store model properties and their attributes. +/// +[ExcludeFromCodeCoverage] +internal static class VectorStoreErrorHandler +{ + /// + /// Run the given model conversion and wrap any exceptions with . + /// + /// The response type of the operation. + /// The name of the database system the operation is being run on. + /// The name of the collection the operation is being run on. + /// The type of database operation being run. + /// The operation to run. + /// The result of the operation. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static T RunModelConversion(string databaseSystemName, string collectionName, string operationName, Func operation) + { + try + { + return operation.Invoke(); + } + catch (Exception ex) + { + throw new VectorStoreRecordMappingException("Failed to convert vector store record.", ex) + { + VectorStoreType = databaseSystemName, + CollectionName = collectionName, + OperationName = operationName + }; + } + } +} diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs new file mode 100644 index 000000000000..d4f06071f66b --- /dev/null +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs @@ -0,0 +1,532 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Contains helpers for reading vector store model properties and their attributes. +/// +[ExcludeFromCodeCoverage] +internal static class VectorStoreRecordPropertyReader +{ + /// Cache of property enumerations so that we don't incur reflection costs with each invocation. + private static readonly ConcurrentDictionary dataProperties, List vectorProperties)> s_singleVectorPropertiesCache = new(); + + /// Cache of property enumerations so that we don't incur reflection costs with each invocation. + private static readonly ConcurrentDictionary dataProperties, List vectorProperties)> s_multipleVectorsPropertiesCache = new(); + + /// + /// Split the given into key, data and vector properties and verify that we have the expected numbers of each type. + /// + /// The name of the type that the definition relates to. + /// The to split. + /// A value indicating whether multiple vectors are supported. + /// A value indicating whether we need at least one vector. + /// The properties on the split into key, data and vector groupings. + /// Thrown if there are any validation failures with the provided . + public static (VectorStoreRecordKeyProperty KeyProperty, List DataProperties, List VectorProperties) SplitDefinitionAndVerify( + string typeName, + VectorStoreRecordDefinition definition, + bool supportsMultipleVectors, + bool requiresAtLeastOneVector) + { + var keyProperties = definition.Properties.OfType().ToList(); + + if (keyProperties.Count > 1) + { + throw new ArgumentException($"Multiple key properties found on type {typeName} or the provided {nameof(VectorStoreRecordDefinition)}."); + } + + var keyProperty = keyProperties.FirstOrDefault(); + var dataProperties = definition.Properties.OfType().ToList(); + var vectorProperties = definition.Properties.OfType().ToList(); + + if (keyProperty is null) + { + throw new ArgumentException($"No key property found on type {typeName} or the provided {nameof(VectorStoreRecordDefinition)}."); + } + + if (requiresAtLeastOneVector && vectorProperties.Count == 0) + { + throw new ArgumentException($"No vector property found on type {typeName} or the provided {nameof(VectorStoreRecordDefinition)}."); + } + + if (!supportsMultipleVectors && vectorProperties.Count > 1) + { + throw new ArgumentException($"Multiple vector properties found on type {typeName} or the provided {nameof(VectorStoreRecordDefinition)} while only one is supported."); + } + + return (keyProperty, dataProperties, vectorProperties); + } + + /// + /// Find the properties with , and attributes + /// and verify that they exist and that we have the expected numbers of each type. + /// Return those properties in separate categories. + /// + /// The data model to find the properties on. + /// A value indicating whether multiple vector properties are supported instead of just one. + /// The categorized properties. + public static (PropertyInfo KeyProperty, List DataProperties, List VectorProperties) FindProperties(Type type, bool supportsMultipleVectors) + { + var cache = supportsMultipleVectors ? s_multipleVectorsPropertiesCache : s_singleVectorPropertiesCache; + + // First check the cache. + if (cache.TryGetValue(type, out var cachedProperties)) + { + return cachedProperties; + } + + PropertyInfo? keyProperty = null; + List dataProperties = new(); + List vectorProperties = new(); + bool singleVectorPropertyFound = false; + + foreach (var property in type.GetProperties()) + { + // Get Key property. + if (property.GetCustomAttribute() is not null) + { + if (keyProperty is not null) + { + throw new ArgumentException($"Multiple key properties found on type {type.FullName}."); + } + + keyProperty = property; + } + + // Get data properties. + if (property.GetCustomAttribute() is not null) + { + dataProperties.Add(property); + } + + // Get Vector properties. + if (property.GetCustomAttribute() is not null) + { + // Add all vector properties if we support multiple vectors. + if (supportsMultipleVectors) + { + vectorProperties.Add(property); + } + // Add only one vector property if we don't support multiple vectors. + else if (!singleVectorPropertyFound) + { + vectorProperties.Add(property); + singleVectorPropertyFound = true; + } + else + { + throw new ArgumentException($"Multiple vector properties found on type {type.FullName} while only one is supported."); + } + } + } + + // Check that we have a key property. + if (keyProperty is null) + { + throw new ArgumentException($"No key property found on type {type.FullName}."); + } + + // Check that we have one vector property if we don't have named vectors. + if (!supportsMultipleVectors && !singleVectorPropertyFound) + { + throw new ArgumentException($"No vector property found on type {type.FullName}."); + } + + // Update the cache. + cache[type] = (keyProperty, dataProperties, vectorProperties); + + return (keyProperty, dataProperties, vectorProperties); + } + + /// + /// Find the properties listed in the on the and verify + /// that they exist and that we have the expected numbers of each type. + /// Return those properties in separate categories. + /// + /// The data model to find the properties on. + /// The property configuration. + /// A value indicating whether multiple vector properties are supported instead of just one. + /// The categorized properties. + public static (PropertyInfo KeyProperty, List DataProperties, List VectorProperties) FindProperties(Type type, VectorStoreRecordDefinition vectorStoreRecordDefinition, bool supportsMultipleVectors) + { + PropertyInfo? keyProperty = null; + List dataProperties = new(); + List vectorProperties = new(); + bool singleVectorPropertyFound = false; + + foreach (VectorStoreRecordProperty property in vectorStoreRecordDefinition.Properties) + { + // Key. + if (property is VectorStoreRecordKeyProperty keyPropertyInfo) + { + if (keyProperty is not null) + { + throw new ArgumentException($"Multiple key properties configured for type {type.FullName}."); + } + + keyProperty = type.GetProperty(keyPropertyInfo.DataModelPropertyName); + if (keyProperty == null) + { + throw new ArgumentException($"Key property '{keyPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}."); + } + } + // Data. + else if (property is VectorStoreRecordDataProperty dataPropertyInfo) + { + var dataProperty = type.GetProperty(dataPropertyInfo.DataModelPropertyName); + if (dataProperty == null) + { + throw new ArgumentException($"Data property '{dataPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}."); + } + + dataProperties.Add(dataProperty); + } + // Vector. + else if (property is VectorStoreRecordVectorProperty vectorPropertyInfo) + { + var vectorProperty = type.GetProperty(vectorPropertyInfo.DataModelPropertyName); + if (vectorProperty == null) + { + throw new ArgumentException($"Vector property '{vectorPropertyInfo.DataModelPropertyName}' not found on type {type.FullName}."); + } + + // Add all vector properties if we support multiple vectors. + if (supportsMultipleVectors) + { + vectorProperties.Add(vectorProperty); + } + // Add only one vector property if we don't support multiple vectors. + else if (!singleVectorPropertyFound) + { + vectorProperties.Add(vectorProperty); + singleVectorPropertyFound = true; + } + else + { + throw new ArgumentException($"Multiple vector properties configured for type {type.FullName} while only one is supported."); + } + } + else + { + throw new ArgumentException($"Unknown property type '{property.GetType().FullName}' in vector store record definition."); + } + } + + // Check that we have a key property. + if (keyProperty is null) + { + throw new ArgumentException($"No key property configured for type {type.FullName}."); + } + + // Check that we have one vector property if we don't have named vectors. + if (!supportsMultipleVectors && !singleVectorPropertyFound) + { + throw new ArgumentException($"No vector property configured for type {type.FullName}."); + } + + return (keyProperty!, dataProperties, vectorProperties); + } + + /// + /// Create a by reading the attributes on the properties of the given type. + /// + /// The type to create the definition for. + /// if the store supports multiple vectors, otherwise. + /// The based on the given type. + public static VectorStoreRecordDefinition CreateVectorStoreRecordDefinitionFromType(Type type, bool supportsMultipleVectors) + { + var properties = FindProperties(type, supportsMultipleVectors); + var definitionProperties = new List(); + + // Key property. + var keyAttribute = properties.KeyProperty.GetCustomAttribute(); + definitionProperties.Add(new VectorStoreRecordKeyProperty(properties.KeyProperty.Name, properties.KeyProperty.PropertyType) { StoragePropertyName = keyAttribute!.StoragePropertyName }); + + // Data properties. + foreach (var dataProperty in properties.DataProperties) + { + var dataAttribute = dataProperty.GetCustomAttribute(); + if (dataAttribute is not null) + { + definitionProperties.Add(new VectorStoreRecordDataProperty(dataProperty.Name, dataProperty.PropertyType) + { + IsFilterable = dataAttribute.IsFilterable, + IsFullTextSearchable = dataAttribute.IsFullTextSearchable, + StoragePropertyName = dataAttribute.StoragePropertyName + }); + } + } + + // Vector properties. + foreach (var vectorProperty in properties.VectorProperties) + { + var vectorAttribute = vectorProperty.GetCustomAttribute(); + if (vectorAttribute is not null) + { + definitionProperties.Add(new VectorStoreRecordVectorProperty(vectorProperty.Name, vectorProperty.PropertyType) + { + Dimensions = vectorAttribute.Dimensions, + IndexKind = vectorAttribute.IndexKind, + DistanceFunction = vectorAttribute.DistanceFunction, + StoragePropertyName = vectorAttribute.StoragePropertyName + }); + } + } + + return new VectorStoreRecordDefinition { Properties = definitionProperties }; + } + + /// + /// Verify that the given properties are of the supported types. + /// + /// The properties to check. + /// A set of supported types that the provided properties may have. + /// A description of the category of properties being checked. Used for error messaging. + /// A value indicating whether versions of all the types should also be supported. + /// Thrown if any of the properties are not in the given set of types. + public static void VerifyPropertyTypes(List properties, HashSet supportedTypes, string propertyCategoryDescription, bool? supportEnumerable = false) + { + var supportedEnumerableTypes = supportEnumerable == true + ? supportedTypes + : []; + + VerifyPropertyTypes(properties, supportedTypes, supportedEnumerableTypes, propertyCategoryDescription); + } + + /// + /// Verify that the given properties are of the supported types. + /// + /// The properties to check. + /// A set of supported types that the provided properties may have. + /// A set of supported types that the provided enumerable properties may use as their element type. + /// A description of the category of properties being checked. Used for error messaging. + /// Thrown if any of the properties are not in the given set of types. + public static void VerifyPropertyTypes(List properties, HashSet supportedTypes, HashSet supportedEnumerableTypes, string propertyCategoryDescription) + { + foreach (var property in properties) + { + VerifyPropertyType(property.Name, property.PropertyType, supportedTypes, supportedEnumerableTypes, propertyCategoryDescription); + } + } + + /// + /// Verify that the given properties are of the supported types. + /// + /// The properties to check. + /// A set of supported types that the provided properties may have. + /// A description of the category of properties being checked. Used for error messaging. + /// A value indicating whether versions of all the types should also be supported. + /// Thrown if any of the properties are not in the given set of types. + public static void VerifyPropertyTypes(IEnumerable properties, HashSet supportedTypes, string propertyCategoryDescription, bool? supportEnumerable = false) + { + var supportedEnumerableTypes = supportEnumerable == true + ? supportedTypes + : []; + + VerifyPropertyTypes(properties, supportedTypes, supportedEnumerableTypes, propertyCategoryDescription); + } + + /// + /// Verify that the given properties are of the supported types. + /// + /// The properties to check. + /// A set of supported types that the provided properties may have. + /// A set of supported types that the provided enumerable properties may use as their element type. + /// A description of the category of properties being checked. Used for error messaging. + /// Thrown if any of the properties are not in the given set of types. + public static void VerifyPropertyTypes(IEnumerable properties, HashSet supportedTypes, HashSet supportedEnumerableTypes, string propertyCategoryDescription) + { + foreach (var property in properties) + { + VerifyPropertyType(property.DataModelPropertyName, property.PropertyType, supportedTypes, supportedEnumerableTypes, propertyCategoryDescription); + } + } + + /// + /// Verify that the given property is of the supported types. + /// + /// The name of the property being checked. Used for error messaging. + /// The type of the property being checked. + /// A set of supported types that the provided property may have. + /// A set of supported types that the provided property may use as its element type if it's enumerable. + /// A description of the category of property being checked. Used for error messaging. + /// Thrown if the property is not in the given set of types. + public static void VerifyPropertyType(string propertyName, Type propertyType, HashSet supportedTypes, HashSet supportedEnumerableTypes, string propertyCategoryDescription) + { + // Add shortcut before testing all the more expensive scenarios. + if (supportedTypes.Contains(propertyType)) + { + return; + } + + // Check all collection scenarios and get stored type. + if (supportedEnumerableTypes.Count > 0 && typeof(IEnumerable).IsAssignableFrom(propertyType)) + { + var typeToCheck = propertyType switch + { + IEnumerable => typeof(object), + var enumerableType when enumerableType.IsGenericType && enumerableType.GetGenericTypeDefinition() == typeof(IEnumerable<>) => enumerableType.GetGenericArguments()[0], + var arrayType when arrayType.IsArray => arrayType.GetElementType()!, + var interfaceType when interfaceType.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)) is Type enumerableInterface => + enumerableInterface.GetGenericArguments()[0], + _ => propertyType + }; + + if (!supportedEnumerableTypes.Contains(typeToCheck)) + { + var supportedEnumerableElementTypesString = string.Join(", ", supportedEnumerableTypes!.Select(t => t.FullName)); + throw new ArgumentException($"Enumerable {propertyCategoryDescription} properties must have one of the supported element types: {supportedEnumerableElementTypesString}. Element type of the property '{propertyName}' is {typeToCheck.FullName}."); + } + } + else + { + // if we got here, we know the type is not supported + var supportedTypesString = string.Join(", ", supportedTypes.Select(t => t.FullName)); + throw new ArgumentException($"{propertyCategoryDescription} properties must be one of the supported types: {supportedTypesString}. Type of the property '{propertyName}' is {propertyType.FullName}."); + } + } + + /// + /// Get the JSON property name of a property by using the if available, otherwise + /// using the if available, otherwise falling back to the property name. + /// The provided may not actually contain the property, e.g. when the user has a data model that + /// doesn't resemble the stored data and where they are using a custom mapper. + /// + /// The property to retrieve a storage name for. + /// The data model type that the property belongs to. + /// The options used for JSON serialization. + /// The JSON storage property name. + public static string GetJsonPropertyName(VectorStoreRecordProperty property, Type dataModel, JsonSerializerOptions options) + { + var propertyInfo = dataModel.GetProperty(property.DataModelPropertyName); + + if (propertyInfo != null) + { + var jsonPropertyNameAttribute = propertyInfo.GetCustomAttribute(); + if (jsonPropertyNameAttribute is not null) + { + return jsonPropertyNameAttribute.Name; + } + } + + if (options.PropertyNamingPolicy is not null) + { + return options.PropertyNamingPolicy.ConvertName(property.DataModelPropertyName); + } + + return property.DataModelPropertyName; + } + + /// + /// Get the JSON property name of a property by using the if available, otherwise + /// using the if available, otherwise falling back to the property name. + /// + /// The options used for JSON serialization. + /// The property to retrieve a storage name for. + /// The JSON storage property name. + public static string GetJsonPropertyName(JsonSerializerOptions options, PropertyInfo property) + { + var jsonPropertyNameAttribute = property.GetCustomAttribute(); + if (jsonPropertyNameAttribute is not null) + { + return jsonPropertyNameAttribute.Name; + } + + if (options.PropertyNamingPolicy is not null) + { + return options.PropertyNamingPolicy.ConvertName(property.Name); + } + + return property.Name; + } + + /// + /// Build a map of property names to the names under which they should be saved in storage if using JSON serialization. + /// + /// The properties to build the map for. + /// The data model type that the property belongs to. + /// The options used for JSON serialization. + /// The map from property names to the names under which they should be saved in storage if using JSON serialization. + public static Dictionary BuildPropertyNameToJsonPropertyNameMap( + (VectorStoreRecordKeyProperty keyProperty, List dataProperties, List vectorProperties) properties, + Type dataModel, + JsonSerializerOptions options) + { + var jsonPropertyNameMap = new Dictionary(); + jsonPropertyNameMap.Add(properties.keyProperty.DataModelPropertyName, GetJsonPropertyName(properties.keyProperty, dataModel, options)); + + foreach (var dataProperty in properties.dataProperties) + { + jsonPropertyNameMap.Add(dataProperty.DataModelPropertyName, GetJsonPropertyName(dataProperty, dataModel, options)); + } + + foreach (var vectorProperty in properties.vectorProperties) + { + jsonPropertyNameMap.Add(vectorProperty.DataModelPropertyName, GetJsonPropertyName(vectorProperty, dataModel, options)); + } + + return jsonPropertyNameMap; + } + + /// + /// Build a map of property names to the names under which they should be saved in storage if using JSON serialization. + /// + /// The properties to build the map for. + /// The data model type that the property belongs to. + /// The options used for JSON serialization. + /// The map from property names to the names under which they should be saved in storage if using JSON serialization. + public static Dictionary BuildPropertyNameToJsonPropertyNameMap( + (PropertyInfo keyProperty, List dataProperties, List vectorProperties) properties, + Type dataModel, + JsonSerializerOptions options) + { + var jsonPropertyNameMap = new Dictionary(); + jsonPropertyNameMap.Add(properties.keyProperty.Name, GetJsonPropertyName(options, properties.keyProperty)); + + foreach (var dataProperty in properties.dataProperties) + { + jsonPropertyNameMap.Add(dataProperty.Name, GetJsonPropertyName(options, dataProperty)); + } + + foreach (var vectorProperty in properties.vectorProperties) + { + jsonPropertyNameMap.Add(vectorProperty.Name, GetJsonPropertyName(options, vectorProperty)); + } + + return jsonPropertyNameMap; + } + + /// + /// Build a map of property names to the names under which they should be saved in storage, for the given properties. + /// + /// The properties to build the map for. + /// The map from property names to the names under which they should be saved in storage. + public static Dictionary BuildPropertyNameToStorageNameMap((VectorStoreRecordKeyProperty keyProperty, List dataProperties, List vectorProperties) properties) + { + var storagePropertyNameMap = new Dictionary(); + storagePropertyNameMap.Add(properties.keyProperty.DataModelPropertyName, properties.keyProperty.StoragePropertyName ?? properties.keyProperty.DataModelPropertyName); + + foreach (var dataProperty in properties.dataProperties) + { + storagePropertyNameMap.Add(dataProperty.DataModelPropertyName, dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName); + } + + foreach (var vectorProperty in properties.vectorProperties) + { + storagePropertyNameMap.Add(vectorProperty.DataModelPropertyName, vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName); + } + + return storagePropertyNameMap; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/IVectorStore.cs b/dotnet/src/SemanticKernel.Abstractions/Data/IVectorStore.cs new file mode 100644 index 000000000000..31246a3138d6 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/IVectorStore.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Interface for accessing the list of collections in a vector store. +/// +/// +/// This interface can be used with collections of any schema type, but requires you to provide schema information when getting a collection. +/// +[Experimental("SKEXP0001")] +public interface IVectorStore +{ + /// + /// Get a collection from the vector store. + /// + /// The data type of the record key. + /// The record data model to use for adding, updating and retrieving data from the collection. + /// The name of the collection. + /// Defines the schema of the record type. + /// A new instance for managing the records in the collection. + /// + /// To successfully request a collection, either must be annotated with attributes that define the schema of + /// the record type, or must be provided. + /// + /// + /// + /// + IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + where TRecord : class; + + /// + /// Retrieve the names of all the collections in the vector store. + /// + /// The to monitor for cancellation requests. The default is . + /// The list of names of all the collections in the vector store. + IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/IVectorStoreRecordCollection.cs b/dotnet/src/SemanticKernel.Abstractions/Data/IVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..5071412014a8 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/IVectorStoreRecordCollection.cs @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// A schema aware interface for managing a named collection of records in a vector store and for creating or deleting the collection itself. +/// +/// The data type of the record key. +/// The record data model to use for adding, updating and retrieving data from the store. +[Experimental("SKEXP0001")] +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public interface IVectorStoreRecordCollection +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix + where TKey : notnull + where TRecord : class +{ + /// + /// Gets the name of the collection. + /// + public string CollectionName { get; } + + /// + /// Check if the collection exists in the vector store. + /// + /// The to monitor for cancellation requests. The default is . + /// if the collection exists, otherwise. + Task CollectionExistsAsync(CancellationToken cancellationToken = default); + + /// + /// Create this collection in the vector store. + /// + /// The to monitor for cancellation requests. The default is . + /// A that completes when the collection has been created. + Task CreateCollectionAsync(CancellationToken cancellationToken = default); + + /// + /// Create this collection in the vector store if it does not already exist. + /// + /// The to monitor for cancellation requests. The default is . + /// A that completes when the collection has been created. + Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default); + + /// + /// Delete the collection from the vector store. + /// + /// The to monitor for cancellation requests. The default is . + /// A that completes when the collection has been deleted. + Task DeleteCollectionAsync(CancellationToken cancellationToken = default); + + /// + /// Gets a record from the vector store. Does not guarantee that the collection exists. + /// Returns null if the record is not found. + /// + /// The unique id associated with the record to get. + /// Optional options for retrieving the record. + /// The to monitor for cancellation requests. The default is . + /// The record if found, otherwise null. + /// Throw when the command fails to execute for any reason. + /// Throw when mapping between the storage model and record data model fails. + Task GetAsync(TKey key, GetRecordOptions? options = default, CancellationToken cancellationToken = default); + + /// + /// Gets a batch of records from the vector store. Does not guarantee that the collection exists. + /// Gets will be made in a single request or in a single parallel batch depending on the available store functionality. + /// Only found records will be returned, so the resultset may be smaller than the requested keys. + /// Throws for any issues other than records not being found. + /// + /// The unique ids associated with the record to get. + /// Optional options for retrieving the records. + /// The to monitor for cancellation requests. The default is . + /// The records associated with the unique keys provided. + /// Throw when the command fails to execute for any reason. + /// Throw when mapping between the storage model and record data model fails. + IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = default, CancellationToken cancellationToken = default); + + /// + /// Deletes a record from the vector store. Does not guarantee that the collection exists. + /// + /// The unique id associated with the record to remove. + /// Optional options for removing the record. + /// The to monitor for cancellation requests. The default is . + /// The unique identifier for the record. + /// Throw when the command fails to execute for any reason other than that the record does not exit. + Task DeleteAsync(TKey key, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default); + + /// + /// Deletes a batch of records from the vector store. Does not guarantee that the collection exists. + /// Deletes will be made in a single request or in a single parallel batch depending on the available store functionality. + /// If a record is not found, it will be ignored and the batch will succeed. + /// If any record cannot be deleted for any other reason, the operation will throw. Some records may have already been deleted, while others may not, so the entire operation should be retried. + /// + /// The unique ids associated with the records to remove. + /// Optional options for removing the records. + /// The to monitor for cancellation requests. The default is . + /// A that completes when the records have been deleted. + /// Throw when the command fails to execute for any reason other than that a record does not exist. + Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = default, CancellationToken cancellationToken = default); + + /// + /// Upserts a record into the vector store. Does not guarantee that the collection exists. + /// If the record already exists, it will be updated. + /// If the record does not exist, it will be created. + /// + /// The record to upsert. + /// Optional options for upserting the record. + /// The to monitor for cancellation requests. The default is . + /// The unique identifier for the record. + /// Throw when the command fails to execute for any reason. + /// Throw when mapping between the storage model and record data model fails. + Task UpsertAsync(TRecord record, UpsertRecordOptions? options = default, CancellationToken cancellationToken = default); + + /// + /// Upserts a group of records into the vector store. Does not guarantee that the collection exists. + /// If the record already exists, it will be updated. + /// If the record does not exist, it will be created. + /// Upserts will be made in a single request or in a single parallel batch depending on the available store functionality. + /// + /// The records to upsert. + /// Optional options for upserting the records. + /// The to monitor for cancellation requests. The default is . + /// The unique identifiers for the records. + /// Throw when the command fails to execute for any reason. + /// Throw when mapping between the storage model and record data model fails. + IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = default, CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/IVectorStoreRecordMapper.cs b/dotnet/src/SemanticKernel.Abstractions/Data/IVectorStoreRecordMapper.cs new file mode 100644 index 000000000000..4125c4a1b3ad --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/IVectorStoreRecordMapper.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Interface for mapping between a storage model, and the consumer record data model. +/// +/// The consumer record data model to map to or from. +/// The storage model to map to or from. +[Experimental("SKEXP0001")] +public interface IVectorStoreRecordMapper + where TRecordDataModel : class +{ + /// + /// Map from the consumer record data model to the storage model. + /// + /// The consumer record data model record to map. + /// The mapped result. + TStorageModel MapFromDataToStorageModel(TRecordDataModel dataModel); + + /// + /// Map from the storage model to the consumer record data model. + /// + /// The storage data model record to map. + /// Options to control the mapping behavior. + /// The mapped result. + TRecordDataModel MapFromStorageToDataModel(TStorageModel storageModel, StorageToDataModelMapperOptions options); +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordAttributes/VectorStoreRecordDataAttribute.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordAttributes/VectorStoreRecordDataAttribute.cs new file mode 100644 index 000000000000..f31b5c38352e --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordAttributes/VectorStoreRecordDataAttribute.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Attribute to mark a property on a record class as 'data'. +/// +/// +/// Marking a property as 'data' means that the property is not a key, and not a vector, but optionally +/// this property may have an associated vector field containing an embedding for this data. +/// The characteristics defined here will influence how the property is treated by the vector store. +/// +[Experimental("SKEXP0001")] +[AttributeUsage(AttributeTargets.Property, AllowMultiple = false)] +public sealed class VectorStoreRecordDataAttribute : Attribute +{ + /// + /// Gets or sets a value indicating whether this data property is filterable. + /// + /// + /// Default is . + /// + public bool IsFilterable { get; init; } + + /// + /// Gets or sets a value indicating whether this data property is full text searchable. + /// + /// + /// Default is . + /// + public bool IsFullTextSearchable { get; init; } + + /// + /// Gets or sets an optional name to use for the property in storage, if different from the property name. + /// E.g. the property name might be "MyProperty" but the storage name might be "my_property". + /// + public string? StoragePropertyName { get; set; } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordAttributes/VectorStoreRecordKeyAttribute.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordAttributes/VectorStoreRecordKeyAttribute.cs new file mode 100644 index 000000000000..32376956b853 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordAttributes/VectorStoreRecordKeyAttribute.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Attribute to mark a property on a record class as the key under which the record is stored in a vector store. +/// +/// +/// The characteristics defined here will influence how the property is treated by the vector store. +/// +[Experimental("SKEXP0001")] +[AttributeUsage(AttributeTargets.Property, AllowMultiple = false)] +public sealed class VectorStoreRecordKeyAttribute : Attribute +{ + /// + /// Gets or sets an optional name to use for the property in storage, if different from the property name. + /// E.g. the property name might be "MyProperty" but the storage name might be "my_property". + /// + public string? StoragePropertyName { get; set; } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordAttributes/VectorStoreRecordVectorAttribute.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordAttributes/VectorStoreRecordVectorAttribute.cs new file mode 100644 index 000000000000..74a2a0796811 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordAttributes/VectorStoreRecordVectorAttribute.cs @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Attribute to mark a property on a record class as a vector. +/// +/// +/// The characteristics defined here will influence how the property is treated by the vector store. +/// +[Experimental("SKEXP0001")] +[AttributeUsage(AttributeTargets.Property, AllowMultiple = false)] +public sealed class VectorStoreRecordVectorAttribute : Attribute +{ + /// + /// Initializes a new instance of the class. + /// + public VectorStoreRecordVectorAttribute() + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The number of dimensions that the vector has. + public VectorStoreRecordVectorAttribute(int Dimensions) + { + this.Dimensions = Dimensions; + } + + /// + /// Initializes a new instance of the class. + /// + /// The number of dimensions that the vector has. + /// The kind of index to use. + /// The distance function to use when comparing vectors. + public VectorStoreRecordVectorAttribute(int Dimensions, string? IndexKind, string? DistanceFunction) + { + this.Dimensions = Dimensions; + this.IndexKind = IndexKind; + this.DistanceFunction = DistanceFunction; + } + + /// + /// Gets or sets the number of dimensions that the vector has. + /// + /// + /// This property is required when creating collections, but may be omitted if not using that functionality. + /// If not provided when trying to create a collection, create will fail. + /// + public int? Dimensions { get; private set; } + + /// + /// Gets the kind of index to use. + /// + /// + /// + /// Default varies by database type. See the documentation of your chosen database connector for more information. + /// + public string? IndexKind { get; private set; } + + /// + /// Gets the distance function to use when comparing vectors. + /// + /// + /// + /// Default varies by database type. See the documentation of your chosen database connector for more information. + /// + public string? DistanceFunction { get; private set; } + + /// + /// Gets or sets an optional name to use for the property in storage, if different from the property name. + /// E.g. the property name might be "MyProperty" but the storage name might be "my_property". + /// + public string? StoragePropertyName { get; set; } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/DistanceFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/DistanceFunction.cs new file mode 100644 index 000000000000..32601243966b --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/DistanceFunction.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Defines a list of well known distance functions that can be used to compare vectors. +/// +/// +/// Not all Vector Store connectors support all distance functions and some connectors may +/// support additional distance functions that are not defined here. See the documentation +/// for each connector for more information on what is supported. +/// +[Experimental("SKEXP0001")] +public static class DistanceFunction +{ + /// + /// The cosine (angular) similarity between two vectors. + /// + /// + /// Measures only the angle between the two vectors, without taking into account the length of the vectors. + /// ConsineSimilarity = 1 - CosineDistance. + /// -1 means vectors are opposite. + /// 0 means vectors are orthogonal. + /// 1 means vectors are identical. + /// + public const string CosineSimilarity = nameof(CosineSimilarity); + + /// + /// The cosine (angular) similarity between two vectors. + /// + /// + /// CosineDistance = 1 - CosineSimilarity. + /// 2 means vectors are opposite. + /// 1 means vectors are orthogonal. + /// 0 means vectors are identical. + /// + public const string CosineDistance = nameof(CosineDistance); + + /// + /// Measures both the length and angle between two vectors. + /// + /// + /// Same as cosine similarity if the vectors are the same length, but more performant. + /// + public const string DotProductSimilarity = nameof(DotProductSimilarity); + + /// + /// Measures the Euclidean distance between two vectors. + /// + /// + /// Also known as l2-norm. + /// + public const string EuclideanDistance = nameof(EuclideanDistance); + + /// + /// Measures the Manhattan distance between two vectors. + /// + public const string ManhattanDistance = nameof(ManhattanDistance); +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/IndexKind.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/IndexKind.cs new file mode 100644 index 000000000000..364baaa8e727 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/IndexKind.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Defines a list of well known index types that can be used to index vectors. +/// +/// +/// Not all Vector Store connectors support all index types and some connectors may +/// support additional index types that are not defined here. See the documentation +/// for each connector for more information on what is supported. +/// +[Experimental("SKEXP0001")] +public static class IndexKind +{ + /// + /// Hierarchical Navigable Small World, which performs an approximate nearest neighbour (ANN) search. + /// + /// + /// Lower accuracy than exhaustive k nearest neighbor, but faster and more efficient. + /// + public const string Hnsw = nameof(Hnsw); + + /// + /// Does a brute force search to find the nearest neighbors. + /// Calculates the distances between all pairs of data points, so has a linear time complexity, that grows directly proportional to the number of points. + /// Also referred to as exhaustive k nearest neighbor in some databases. + /// + /// + /// High recall accuracy, but slower and more expensive than HNSW. + /// Better with smaller datasets. + /// + public const string Flat = nameof(Flat); +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordDataProperty.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordDataProperty.cs new file mode 100644 index 000000000000..9dec25aa4ce1 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordDataProperty.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Defines a data property on a vector store record. +/// +/// +/// The characteristics defined here will influence how the property is treated by the vector store. +/// +[Experimental("SKEXP0001")] +public sealed class VectorStoreRecordDataProperty : VectorStoreRecordProperty +{ + /// + /// Initializes a new instance of the class. + /// + /// The name of the property. + /// The type of the property. + public VectorStoreRecordDataProperty(string propertyName, Type propertyType) + : base(propertyName, propertyType) + { + } + + /// + /// Initializes a new instance of the class by cloning the given source. + /// + /// The source to clone + public VectorStoreRecordDataProperty(VectorStoreRecordDataProperty source) + : base(source) + { + this.IsFilterable = source.IsFilterable; + this.IsFullTextSearchable = source.IsFullTextSearchable; + } + + /// + /// Gets or sets a value indicating whether this data property is filterable. + /// + /// + /// Default is . + /// + public bool IsFilterable { get; init; } + + /// + /// Gets or sets a value indicating whether this data property is full text searchable. + /// + /// + /// Default is . + /// + public bool IsFullTextSearchable { get; init; } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordDefinition.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordDefinition.cs new file mode 100644 index 000000000000..455bd5842c47 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordDefinition.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// A description of the properties of a record stored in a vector store. +/// +/// +/// Each property contains additional information about how the property will be treated by the vector store. +/// +[Experimental("SKEXP0001")] +public sealed class VectorStoreRecordDefinition +{ + /// Empty static list for initialization purposes. + private static readonly List s_emptyFields = new(); + + /// + /// The list of properties that are stored in the record. + /// + public IReadOnlyList Properties { get; init; } = s_emptyFields; +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordKeyProperty.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordKeyProperty.cs new file mode 100644 index 000000000000..6ba9725e2da4 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordKeyProperty.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Defines a key property on a vector store record. +/// +/// +/// The characteristics defined here will influence how the property is treated by the vector store. +/// +[Experimental("SKEXP0001")] +public sealed class VectorStoreRecordKeyProperty : VectorStoreRecordProperty +{ + /// + /// Initializes a new instance of the class. + /// + /// The name of the property. + /// The type of the property. + public VectorStoreRecordKeyProperty(string propertyName, Type propertyType) + : base(propertyName, propertyType) + { + } + + /// + /// Initializes a new instance of the class by cloning the given source. + /// + /// The source to clone + public VectorStoreRecordKeyProperty(VectorStoreRecordKeyProperty source) + : base(source) + { + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordProperty.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordProperty.cs new file mode 100644 index 000000000000..400ae7065355 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordProperty.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Defines a base property class for properties on a vector store record. +/// +/// +/// The characteristics defined here will influence how the property is treated by the vector store. +/// +[Experimental("SKEXP0001")] +public abstract class VectorStoreRecordProperty +{ + /// + /// Initializes a new instance of the class. + /// + /// The name of the property on the data model. + /// The type of the property. + private protected VectorStoreRecordProperty(string dataModelPropertyName, Type propertyType) + { + Verify.NotNullOrWhiteSpace(dataModelPropertyName); + Verify.NotNull(propertyType); + + this.DataModelPropertyName = dataModelPropertyName; + this.PropertyType = propertyType; + } + + private protected VectorStoreRecordProperty(VectorStoreRecordProperty source) + { + this.DataModelPropertyName = source.DataModelPropertyName; + this.StoragePropertyName = source.StoragePropertyName; + this.PropertyType = source.PropertyType; + } + + /// + /// Gets or sets the name of the property on the data model. + /// + public string DataModelPropertyName { get; private set; } + + /// + /// Gets or sets an optional name to use for the property in storage, if different from the property name. + /// E.g. the property name might be "MyProperty" but the storage name might be "my_property". + /// This property will only be respected by implementations that do not support a well known + /// serialization mechanism like JSON, in which case the attributes used by that seriallization system will + /// be used. + /// + public string? StoragePropertyName { get; init; } + + /// + /// Gets or sets the type of the property. + /// + public Type PropertyType { get; private set; } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordVectorProperty.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordVectorProperty.cs new file mode 100644 index 000000000000..4f4b3a1bce0a --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordDefinition/VectorStoreRecordVectorProperty.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Defines a vector property on a vector store record. +/// +/// +/// The characteristics defined here will influence how the property is treated by the vector store. +/// +[Experimental("SKEXP0001")] +public sealed class VectorStoreRecordVectorProperty : VectorStoreRecordProperty +{ + /// + /// Initializes a new instance of the class. + /// + /// The name of the property. + /// The type of the property. + public VectorStoreRecordVectorProperty(string propertyName, Type propertyType) + : base(propertyName, propertyType) + { + } + + /// + /// Initializes a new instance of the class by cloning the given source. + /// + /// The source to clone + public VectorStoreRecordVectorProperty(VectorStoreRecordVectorProperty source) + : base(source) + { + this.Dimensions = source.Dimensions; + this.IndexKind = source.IndexKind; + this.DistanceFunction = source.DistanceFunction; + } + + /// + /// Gets or sets the number of dimensions that the vector has. + /// + /// + /// This property is required when creating collections, but may be omitted if not using that functionality. + /// If not provided when trying to create a collection, create will fail. + /// + public int? Dimensions { get; init; } + + /// + /// Gets the kind of index to use. + /// + /// + /// + /// Default varies by database type. See the documentation of your chosen database connector for more information. + /// + public string? IndexKind { get; init; } + + /// + /// Gets the distance function to use when comparing vectors. + /// + /// + /// + /// Default varies by database type. See the documentation of your chosen database connector for more information. + /// + public string? DistanceFunction { get; init; } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordOptions/DeleteRecordOptions.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordOptions/DeleteRecordOptions.cs new file mode 100644 index 000000000000..4f034d125a6d --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordOptions/DeleteRecordOptions.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Options when calling . +/// +/// +/// This class does not currently include any options, but is added for future extensibility of the API. +/// +[Experimental("SKEXP0001")] +public class DeleteRecordOptions +{ + /// + /// Initializes a new instance of the class. + /// + public DeleteRecordOptions() + { + } + + /// + /// Initializes a new instance of the class by cloning the given options. + /// + /// The options to clone + public DeleteRecordOptions(DeleteRecordOptions source) + { + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordOptions/GetRecordOptions.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordOptions/GetRecordOptions.cs new file mode 100644 index 000000000000..5330e076acea --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordOptions/GetRecordOptions.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Options when calling . +/// +[Experimental("SKEXP0001")] +public class GetRecordOptions +{ + /// + /// Initializes a new instance of the class. + /// + public GetRecordOptions() + { + } + + /// + /// Initializes a new instance of the class by cloning the given options. + /// + /// The options to clone + public GetRecordOptions(GetRecordOptions source) + { + this.IncludeVectors = source.IncludeVectors; + } + + /// + /// Gets or sets a value indicating whether to include vectors in the retrieval result. + /// + public bool IncludeVectors { get; init; } = false; +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/RecordOptions/UpsertRecordOptions.cs b/dotnet/src/SemanticKernel.Abstractions/Data/RecordOptions/UpsertRecordOptions.cs new file mode 100644 index 000000000000..c1d9cba35b5d --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/RecordOptions/UpsertRecordOptions.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Options when calling . +/// Reserved for future use. +/// +/// +/// This class does not currently include any options, but is added for future extensibility of the API. +/// +[Experimental("SKEXP0001")] +public class UpsertRecordOptions +{ + /// + /// Initializes a new instance of the class. + /// + public UpsertRecordOptions() + { + } + + /// + /// Initializes a new instance of the class by cloning the given options. + /// + /// The options to clone + public UpsertRecordOptions(UpsertRecordOptions source) + { + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/StorageToDataModelMapperOptions.cs b/dotnet/src/SemanticKernel.Abstractions/Data/StorageToDataModelMapperOptions.cs new file mode 100644 index 000000000000..bdee284b0f14 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/StorageToDataModelMapperOptions.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Options to use with the method. +/// +[Experimental("SKEXP0001")] +public class StorageToDataModelMapperOptions +{ + /// + /// Get or sets a value indicating whether to include vectors in the retrieval result. + /// + public bool IncludeVectors { get; init; } = false; +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/VectorStoreException.cs b/dotnet/src/SemanticKernel.Abstractions/Data/VectorStoreException.cs new file mode 100644 index 000000000000..5a0183e85d83 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/VectorStoreException.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Base exception type thrown for any type of failure when using vector stores. +/// +[Experimental("SKEXP0001")] +public abstract class VectorStoreException : KernelException +{ + /// + /// Initializes a new instance of the class. + /// + protected VectorStoreException() + { + } + + /// + /// Initializes a new instance of the class with a specified error message. + /// + /// The error message that explains the reason for the exception. + protected VectorStoreException(string? message) : base(message) + { + } + + /// + /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. + protected VectorStoreException(string? message, Exception? innerException) : base(message, innerException) + { + } + + /// + /// Gets or sets the type of vector store that the failing operation was performed on. + /// + public string? VectorStoreType { get; init; } + + /// + /// Gets or sets the name of the vector store collection that the failing operation was performed on. + /// + public string? CollectionName { get; init; } + + /// + /// Gets or sets the name of the vector store operation that failed. + /// + public string? OperationName { get; init; } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/VectorStoreOperationException.cs b/dotnet/src/SemanticKernel.Abstractions/Data/VectorStoreOperationException.cs new file mode 100644 index 000000000000..2830c1b22646 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/VectorStoreOperationException.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Exception thrown when a vector store command fails, such as upserting a record or deleting a collection. +/// +[Experimental("SKEXP0001")] +public class VectorStoreOperationException : VectorStoreException +{ + /// + /// Initializes a new instance of the class. + /// + public VectorStoreOperationException() + { + } + + /// + /// Initializes a new instance of the class with a specified error message. + /// + /// The error message that explains the reason for the exception. + public VectorStoreOperationException(string? message) : base(message) + { + } + + /// + /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. + public VectorStoreOperationException(string? message, Exception? innerException) : base(message, innerException) + { + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Data/VectorStoreRecordMappingException.cs b/dotnet/src/SemanticKernel.Abstractions/Data/VectorStoreRecordMappingException.cs new file mode 100644 index 000000000000..6b912b233ceb --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Data/VectorStoreRecordMappingException.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Exception thrown when a failure occurs while trying to convert models for storage or retrieval. +/// +[Experimental("SKEXP0001")] +public class VectorStoreRecordMappingException : VectorStoreException +{ + /// + /// Initializes a new instance of the class. + /// + public VectorStoreRecordMappingException() + { + } + + /// + /// Initializes a new instance of the class with a specified error message. + /// + /// The error message that explains the reason for the exception. + public VectorStoreRecordMappingException(string? message) : base(message) + { + } + + /// + /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. + public VectorStoreRecordMappingException(string? message, Exception? innerException) : base(message, innerException) + { + } +} diff --git a/dotnet/src/SemanticKernel.Core/Data/KernelBuilderExtensions.cs b/dotnet/src/SemanticKernel.Core/Data/KernelBuilderExtensions.cs new file mode 100644 index 000000000000..251dee88a4f3 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Data/KernelBuilderExtensions.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Extension methods to register Data services on the . +/// +[Experimental("SKEXP0001")] +public static class KernelBuilderExtensions +{ + /// + /// Register a Volatile with the specified service ID. + /// + /// The builder to register the on. + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddVolatileVectorStore(this IKernelBuilder builder, string? serviceId = default) + { + builder.Services.AddVolatileVectorStore(serviceId); + return builder; + } +} diff --git a/dotnet/src/SemanticKernel.Core/Data/ServiceCollectionExtensions.cs b/dotnet/src/SemanticKernel.Core/Data/ServiceCollectionExtensions.cs new file mode 100644 index 000000000000..83aaf7b57af4 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Data/ServiceCollectionExtensions.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Extension methods to register Data services on an . +/// +[Experimental("SKEXP0001")] +public static class ServiceCollectionExtensions +{ + /// + /// Register a Volatile with the specified service ID. + /// + /// The to register the on. + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddVolatileVectorStore(this IServiceCollection services, string? serviceId = default) + { + services.AddKeyedSingleton(serviceId); + return services; + } +} diff --git a/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStore.cs b/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStore.cs new file mode 100644 index 000000000000..7175e2896978 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStore.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Threading; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Service for storing and retrieving vector records, and managing vector record collections, that uses an in memory dictionary as the underlying storage. +/// +[Experimental("SKEXP0001")] +public sealed class VolatileVectorStore : IVectorStore +{ + /// Internal storage for the record collection. + private readonly ConcurrentDictionary> _internalCollection; + + /// + /// Initializes a new instance of the class. + /// + public VolatileVectorStore() + { + this._internalCollection = new(); + } + + /// + /// Initializes a new instance of the class. + /// + /// Allows passing in the dictionary used for storage, for testing purposes. + internal VolatileVectorStore(ConcurrentDictionary> internalCollection) + { + this._internalCollection = internalCollection; + } + + /// + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + where TRecord : class + { + var collection = new VolatileVectorStoreRecordCollection(this._internalCollection, name, new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; + return collection!; + } + + /// + public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) + { + return this._internalCollection.Keys.ToAsyncEnumerable(); + } +} diff --git a/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollection.cs b/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..decfa8ef20ea --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollection.cs @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Service for storing and retrieving vector records, that uses an in memory dictionary as the underlying storage. +/// +/// The data type of the record key. +/// The data model to use for adding, updating and retrieving data from storage. +[Experimental("SKEXP0001")] +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public sealed class VolatileVectorStoreRecordCollection : IVectorStoreRecordCollection +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix + where TKey : notnull + where TRecord : class +{ + /// Internal storage for the record collection. + private readonly ConcurrentDictionary> _internalCollection; + + /// Optional configuration options for this class. + private readonly VolatileVectorStoreRecordCollectionOptions _options; + + /// The name of the collection that this will access. + private readonly string _collectionName; + + /// A property info object that points at the key property for the current model, allowing easy reading and writing of this property. + private readonly PropertyInfo _keyPropertyInfo; + + /// + /// Initializes a new instance of the class. + /// + /// The name of the collection that this will access. + /// Optional configuration options for this class. + public VolatileVectorStoreRecordCollection(string collectionName, VolatileVectorStoreRecordCollectionOptions? options = default) + { + // Verify. + Verify.NotNullOrWhiteSpace(collectionName); + + // Assign. + this._collectionName = collectionName; + this._internalCollection = new(); + this._options = options ?? new VolatileVectorStoreRecordCollectionOptions(); + var vectorStoreRecordDefinition = this._options.VectorStoreRecordDefinition ?? VectorStoreRecordPropertyReader.CreateVectorStoreRecordDefinitionFromType(typeof(TRecord), true); + + // Get the key property info. + var keyProperty = vectorStoreRecordDefinition.Properties.OfType().FirstOrDefault(); + if (keyProperty is null) + { + throw new ArgumentException($"No Key property found on {typeof(TRecord).Name} or provided via {nameof(VectorStoreRecordDefinition)}"); + } + + this._keyPropertyInfo = typeof(TRecord).GetProperty(keyProperty.DataModelPropertyName) ?? throw new ArgumentException($"Key property {keyProperty.DataModelPropertyName} not found on {typeof(TRecord).Name}"); + } + + /// + /// Initializes a new instance of the class. + /// + /// Allows passing in the dictionary used for storage, for testing purposes. + /// The name of the collection that this will access. + /// Optional configuration options for this class. + internal VolatileVectorStoreRecordCollection(ConcurrentDictionary> internalCollection, string collectionName, VolatileVectorStoreRecordCollectionOptions? options = default) + : this(collectionName, options) + { + this._internalCollection = internalCollection; + } + + /// + public string CollectionName => this._collectionName; + + /// + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + return this._internalCollection.ContainsKey(this._collectionName) ? Task.FromResult(true) : Task.FromResult(false); + } + + /// + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + this._internalCollection.TryAdd(this._collectionName, new ConcurrentDictionary()); + return Task.CompletedTask; + } + + /// + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) + { + await this.CreateCollectionAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + this._internalCollection.TryRemove(this._collectionName, out _); + return Task.CompletedTask; + } + + /// + public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + var collectionDictionary = this.GetCollectionDictionary(); + + if (collectionDictionary.TryGetValue(key, out var record)) + { + return Task.FromResult(record as TRecord); + } + + return Task.FromResult(null); + } + + /// + public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var key in keys) + { + var record = await this.GetAsync(key, options, cancellationToken).ConfigureAwait(false); + + if (record is not null) + { + yield return record; + } + } + } + + /// + public Task DeleteAsync(TKey key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + var collectionDictionary = this.GetCollectionDictionary(); + + collectionDictionary.TryRemove(key, out _); + return Task.CompletedTask; + } + + /// + public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + var collectionDictionary = this.GetCollectionDictionary(); + + foreach (var key in keys) + { + collectionDictionary.TryRemove(key, out _); + } + + return Task.CompletedTask; + } + + /// + public Task UpsertAsync(TRecord record, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) + { + var collectionDictionary = this.GetCollectionDictionary(); + + var key = (TKey)this._keyPropertyInfo.GetValue(record)!; + collectionDictionary.AddOrUpdate(key!, record, (key, currentValue) => record); + + return Task.FromResult(key!); + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var record in records) + { + yield return await this.UpsertAsync(record, options, cancellationToken).ConfigureAwait(false); + } + } + + /// + /// Get the collection dictionary from the internal storage, throws if it does not exist. + /// + /// The retrieved collection dictionary. + private ConcurrentDictionary GetCollectionDictionary() + { + if (!this._internalCollection.TryGetValue(this._collectionName, out var collectionDictionary)) + { + throw new VectorStoreOperationException($"Call to vector store failed. Collection '{this._collectionName}' does not exist."); + } + + return collectionDictionary; + } +} diff --git a/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollectionOptions.cs b/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..8732e7efa486 --- /dev/null +++ b/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SemanticKernel.Data; + +/// +/// Options when creating a . +/// +[Experimental("SKEXP0001")] +public sealed class VolatileVectorStoreRecordCollectionOptions +{ + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/KernelBuilderExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/KernelBuilderExtensionsTests.cs new file mode 100644 index 000000000000..2f1f3923c3c4 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Data/KernelBuilderExtensionsTests.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Data; +using Xunit; + +namespace SemanticKernel.UnitTests.Data; + +/// +/// Contains tests for . +/// +public class KernelBuilderExtensionsTests +{ + private readonly IKernelBuilder _kernelBuilder; + + public KernelBuilderExtensionsTests() + { + this._kernelBuilder = Kernel.CreateBuilder(); + } + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Act. + this._kernelBuilder.AddVolatileVectorStore(); + + // Assert. + var kernel = this._kernelBuilder.Build(); + var vectorStore = kernel.Services.GetRequiredService(); + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/ServiceCollectionExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/ServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..9b8e934c11ca --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Data/ServiceCollectionExtensionsTests.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Data; +using Xunit; + +namespace SemanticKernel.UnitTests.Data; + +/// +/// Contains tests for the class. +/// +public class ServiceCollectionExtensionsTests +{ + private readonly IServiceCollection _serviceCollection; + + public ServiceCollectionExtensionsTests() + { + this._serviceCollection = new ServiceCollection(); + } + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Act. + this._serviceCollection.AddVolatileVectorStore(); + + // Assert. + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + var vectorStore = serviceProvider.GetRequiredService(); + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordPropertyReaderTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordPropertyReaderTests.cs new file mode 100644 index 000000000000..cfddd8437425 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreRecordPropertyReaderTests.cs @@ -0,0 +1,468 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.SemanticKernel.Data; +using Xunit; + +namespace SemanticKernel.UnitTests.Data; + +public class VectorStoreRecordPropertyReaderTests +{ + [Fact] + public void SplitDefinitionsAndVerifyReturnsProperties() + { + // Act. + var properties = VectorStoreRecordPropertyReader.SplitDefinitionAndVerify("testType", this._multiPropsDefinition, true, true); + + // Assert. + Assert.Equal("Key", properties.KeyProperty.DataModelPropertyName); + Assert.Equal(2, properties.DataProperties.Count); + Assert.Equal(2, properties.VectorProperties.Count); + Assert.Equal("Data1", properties.DataProperties[0].DataModelPropertyName); + Assert.Equal("Data2", properties.DataProperties[1].DataModelPropertyName); + Assert.Equal("Vector1", properties.VectorProperties[0].DataModelPropertyName); + Assert.Equal("Vector2", properties.VectorProperties[1].DataModelPropertyName); + } + + [Theory] + [InlineData(false, true, "MultiProps")] + [InlineData(true, true, "NoKey")] + [InlineData(true, true, "MultiKeys")] + [InlineData(false, true, "NoVector")] + [InlineData(true, true, "NoVector")] + public void SplitDefinitionsAndVerifyThrowsForInvalidModel(bool supportsMultipleVectors, bool requiresAtLeastOneVector, string definitionName) + { + // Arrange. + var definition = definitionName switch + { + "MultiProps" => this._multiPropsDefinition, + "NoKey" => this._noKeyDefinition, + "MultiKeys" => this._multiKeysDefinition, + "NoVector" => this._noVectorDefinition, + _ => throw new ArgumentException("Invalid definition.") + }; + + // Act & Assert. + Assert.Throws(() => VectorStoreRecordPropertyReader.SplitDefinitionAndVerify("testType", definition, supportsMultipleVectors, requiresAtLeastOneVector)); + } + + [Theory] + [InlineData(true, false)] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + public void FindPropertiesCanFindAllPropertiesOnSinglePropsModel(bool supportsMultipleVectors, bool useConfig) + { + // Act. + var properties = useConfig ? + VectorStoreRecordPropertyReader.FindProperties(typeof(SinglePropsModel), this._singlePropsDefinition, supportsMultipleVectors) : + VectorStoreRecordPropertyReader.FindProperties(typeof(SinglePropsModel), supportsMultipleVectors); + + // Assert. + Assert.Equal("Key", properties.KeyProperty.Name); + Assert.Single(properties.DataProperties); + Assert.Single(properties.VectorProperties); + Assert.Equal("Data", properties.DataProperties[0].Name); + Assert.Equal("Vector", properties.VectorProperties[0].Name); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void FindPropertiesCanFindAllPropertiesOnMultiPropsModel(bool useConfig) + { + // Act. + var properties = useConfig ? + VectorStoreRecordPropertyReader.FindProperties(typeof(MultiPropsModel), this._multiPropsDefinition, true) : + VectorStoreRecordPropertyReader.FindProperties(typeof(MultiPropsModel), true); + + // Assert. + Assert.Equal("Key", properties.KeyProperty.Name); + Assert.Equal(2, properties.DataProperties.Count); + Assert.Equal(2, properties.VectorProperties.Count); + Assert.Equal("Data1", properties.DataProperties[0].Name); + Assert.Equal("Data2", properties.DataProperties[1].Name); + Assert.Equal("Vector1", properties.VectorProperties[0].Name); + Assert.Equal("Vector2", properties.VectorProperties[1].Name); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void FindPropertiesThrowsForMultipleVectorsWithSingleVectorSupport(bool useConfig) + { + // Act. + var ex = useConfig ? + Assert.Throws(() => VectorStoreRecordPropertyReader.FindProperties(typeof(MultiPropsModel), this._multiPropsDefinition, false)) : + Assert.Throws(() => VectorStoreRecordPropertyReader.FindProperties(typeof(MultiPropsModel), false)); + + // Assert. + var expectedMessage = useConfig ? + "Multiple vector properties configured for type SemanticKernel.UnitTests.Data.VectorStoreRecordPropertyReaderTests+MultiPropsModel while only one is supported." : + "Multiple vector properties found on type SemanticKernel.UnitTests.Data.VectorStoreRecordPropertyReaderTests+MultiPropsModel while only one is supported."; + Assert.Equal(expectedMessage, ex.Message); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void FindPropertiesThrowsOnMultipleKeyProperties(bool useConfig) + { + // Act. + var ex = useConfig ? + Assert.Throws(() => VectorStoreRecordPropertyReader.FindProperties(typeof(MultiKeysModel), this._multiKeysDefinition, true)) : + Assert.Throws(() => VectorStoreRecordPropertyReader.FindProperties(typeof(MultiKeysModel), true)); + + // Assert. + var expectedMessage = useConfig ? + "Multiple key properties configured for type SemanticKernel.UnitTests.Data.VectorStoreRecordPropertyReaderTests+MultiKeysModel." : + "Multiple key properties found on type SemanticKernel.UnitTests.Data.VectorStoreRecordPropertyReaderTests+MultiKeysModel."; + Assert.Equal(expectedMessage, ex.Message); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void FindPropertiesThrowsOnNoKeyProperty(bool useConfig) + { + // Act. + var ex = useConfig ? + Assert.Throws(() => VectorStoreRecordPropertyReader.FindProperties(typeof(NoKeyModel), this._noKeyDefinition, true)) : + Assert.Throws(() => VectorStoreRecordPropertyReader.FindProperties(typeof(NoKeyModel), true)); + + // Assert. + var expectedMessage = useConfig ? + "No key property configured for type SemanticKernel.UnitTests.Data.VectorStoreRecordPropertyReaderTests+NoKeyModel." : + "No key property found on type SemanticKernel.UnitTests.Data.VectorStoreRecordPropertyReaderTests+NoKeyModel."; + Assert.Equal(expectedMessage, ex.Message); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void FindPropertiesThrowsOnNoVectorPropertyWithSingleVectorSupport(bool useConfig) + { + // Act. + var ex = useConfig ? + Assert.Throws(() => VectorStoreRecordPropertyReader.FindProperties(typeof(NoVectorModel), this._noVectorDefinition, false)) : + Assert.Throws(() => VectorStoreRecordPropertyReader.FindProperties(typeof(NoVectorModel), false)); + + // Assert. + var expectedMessage = useConfig ? + "No vector property configured for type SemanticKernel.UnitTests.Data.VectorStoreRecordPropertyReaderTests+NoVectorModel." : + "No vector property found on type SemanticKernel.UnitTests.Data.VectorStoreRecordPropertyReaderTests+NoVectorModel."; + Assert.Equal(expectedMessage, ex.Message); + } + + [Theory] + [InlineData("Key", "MissingKey")] + [InlineData("Data", "MissingData")] + [InlineData("Vector", "MissingVector")] + public void FindPropertiesUsingConfigThrowsForNotFoundProperties(string propertyType, string propertyName) + { + var missingKeyDefinition = new VectorStoreRecordDefinition { Properties = [new VectorStoreRecordKeyProperty(propertyName, typeof(string))] }; + var missingDataDefinition = new VectorStoreRecordDefinition { Properties = [new VectorStoreRecordDataProperty(propertyName, typeof(string))] }; + var missingVectorDefinition = new VectorStoreRecordDefinition { Properties = [new VectorStoreRecordVectorProperty(propertyName, typeof(ReadOnlyMemory))] }; + + var definition = propertyType switch + { + "Key" => missingKeyDefinition, + "Data" => missingDataDefinition, + "Vector" => missingVectorDefinition, + _ => throw new ArgumentException("Invalid property type.") + }; + + Assert.Throws(() => VectorStoreRecordPropertyReader.FindProperties(typeof(NoKeyModel), definition, false)); + } + + [Fact] + public void CreateVectorStoreRecordDefinitionFromTypeConvertsAllProps() + { + // Act. + var definition = VectorStoreRecordPropertyReader.CreateVectorStoreRecordDefinitionFromType(typeof(MultiPropsModel), true); + + // Assert. + Assert.Equal(5, definition.Properties.Count); + Assert.Equal("Key", definition.Properties[0].DataModelPropertyName); + Assert.Equal("Data1", definition.Properties[1].DataModelPropertyName); + Assert.Equal("Data2", definition.Properties[2].DataModelPropertyName); + Assert.Equal("Vector1", definition.Properties[3].DataModelPropertyName); + Assert.Equal("Vector2", definition.Properties[4].DataModelPropertyName); + + Assert.IsType(definition.Properties[0]); + Assert.IsType(definition.Properties[1]); + Assert.IsType(definition.Properties[2]); + Assert.IsType(definition.Properties[3]); + Assert.IsType(definition.Properties[4]); + + var data1 = (VectorStoreRecordDataProperty)definition.Properties[1]; + var data2 = (VectorStoreRecordDataProperty)definition.Properties[2]; + + Assert.True(data1.IsFilterable); + Assert.False(data2.IsFilterable); + + Assert.True(data1.IsFullTextSearchable); + Assert.False(data2.IsFullTextSearchable); + + Assert.Equal(typeof(string), data1.PropertyType); + Assert.Equal(typeof(string), data2.PropertyType); + + var vector1 = (VectorStoreRecordVectorProperty)definition.Properties[3]; + + Assert.Equal(4, vector1.Dimensions); + } + + [Fact] + public void VerifyPropertyTypesPassForAllowedTypes() + { + // Arrange. + var properties = VectorStoreRecordPropertyReader.FindProperties(typeof(SinglePropsModel), true); + + // Act. + VectorStoreRecordPropertyReader.VerifyPropertyTypes(properties.DataProperties, [typeof(string)], "Data"); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(this._singlePropsDefinition.Properties.OfType(), [typeof(string)], "Data"); + } + + [Fact] + public void VerifyPropertyTypesPassForAllowedEnumerableTypes() + { + // Arrange. + var properties = VectorStoreRecordPropertyReader.FindProperties(typeof(EnumerablePropsModel), true); + + // Act. + VectorStoreRecordPropertyReader.VerifyPropertyTypes(properties.DataProperties, [typeof(string)], "Data", supportEnumerable: true); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(this._enumerablePropsDefinition.Properties.OfType(), [typeof(string)], "Data", supportEnumerable: true); + } + + [Fact] + public void VerifyPropertyTypesFailsForDisallowedTypes() + { + // Arrange. + var properties = VectorStoreRecordPropertyReader.FindProperties(typeof(SinglePropsModel), true); + + // Act. + var ex1 = Assert.Throws(() => VectorStoreRecordPropertyReader.VerifyPropertyTypes(properties.DataProperties, [typeof(int), typeof(float)], "Data")); + var ex2 = Assert.Throws(() => VectorStoreRecordPropertyReader.VerifyPropertyTypes(this._singlePropsDefinition.Properties.OfType(), [typeof(int), typeof(float)], "Data")); + + // Assert. + Assert.Equal("Data properties must be one of the supported types: System.Int32, System.Single. Type of the property 'Data' is System.String.", ex1.Message); + Assert.Equal("Data properties must be one of the supported types: System.Int32, System.Single. Type of the property 'Data' is System.String.", ex2.Message); + } + + [Fact] + public void VerifyStoragePropertyNameMapChecksStorageNameAndFallsBackToPropertyName() + { + // Arrange. + var properties = VectorStoreRecordPropertyReader.SplitDefinitionAndVerify("testType", this._multiPropsDefinition, true, true); + + // Act. + var storageNameMap = VectorStoreRecordPropertyReader.BuildPropertyNameToStorageNameMap(properties); + + // Assert. + Assert.Equal(5, storageNameMap.Count); + + // From Property Names. + Assert.Equal("Key", storageNameMap["Key"]); + Assert.Equal("Data1", storageNameMap["Data1"]); + Assert.Equal("Vector1", storageNameMap["Vector1"]); + Assert.Equal("Vector2", storageNameMap["Vector2"]); + + // From storage property name on vector store record data property. + Assert.Equal("data_2", storageNameMap["Data2"]); + } + + [Fact] + public void VerifyGetJsonPropertyNameChecksJsonOptionsAndJsonAttributesAndFallsBackToPropertyName() + { + // Arrange. + var options = new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseUpper }; + var properties = VectorStoreRecordPropertyReader.FindProperties(typeof(MultiPropsModel), true); + var allProperties = (new PropertyInfo[] { properties.KeyProperty }) + .Concat(properties.DataProperties) + .Concat(properties.VectorProperties); + + // Act. + var jsonNameMap = allProperties + .Select(p => new { PropertyName = p.Name, JsonName = VectorStoreRecordPropertyReader.GetJsonPropertyName(options, p) }) + .ToDictionary(p => p.PropertyName, p => p.JsonName); + + // Assert. + Assert.Equal(5, jsonNameMap.Count); + + // From JsonNamingPolicy. + Assert.Equal("KEY", jsonNameMap["Key"]); + Assert.Equal("DATA1", jsonNameMap["Data1"]); + Assert.Equal("DATA2", jsonNameMap["Data2"]); + Assert.Equal("VECTOR1", jsonNameMap["Vector1"]); + + // From JsonPropertyName attribute. + Assert.Equal("vector-2", jsonNameMap["Vector2"]); + } + + [Fact] + public void VerifyBuildPropertyNameToJsonPropertyNameMapChecksJsonAttributesAndJsonOptionsAndFallsbackToPropertyNames() + { + // Arrange. + var options = new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseUpper }; + var properties = VectorStoreRecordPropertyReader.SplitDefinitionAndVerify("testType", this._multiPropsDefinition, true, true); + var propertiesInfo = VectorStoreRecordPropertyReader.FindProperties(typeof(MultiPropsModel), true); + + // Act. + var jsonNameMap1 = VectorStoreRecordPropertyReader.BuildPropertyNameToJsonPropertyNameMap(properties, typeof(MultiPropsModel), options); + var jsonNameMap2 = VectorStoreRecordPropertyReader.BuildPropertyNameToJsonPropertyNameMap(propertiesInfo, typeof(MultiPropsModel), options); + + void assertJsonNameMap(Dictionary jsonNameMap) + { + Assert.Equal(5, jsonNameMap.Count); + + // From JsonNamingPolicy. + Assert.Equal("KEY", jsonNameMap["Key"]); + Assert.Equal("DATA1", jsonNameMap["Data1"]); + Assert.Equal("DATA2", jsonNameMap["Data2"]); + Assert.Equal("VECTOR1", jsonNameMap["Vector1"]); + + // From JsonPropertyName attribute. + Assert.Equal("vector-2", jsonNameMap["Vector2"]); + }; + + // Assert. + assertJsonNameMap(jsonNameMap1); + assertJsonNameMap(jsonNameMap2); + } + +#pragma warning disable CA1812 // Invalid unused classes error, since I am using these for testing purposes above. + + private sealed class NoKeyModel + { + } + + private readonly VectorStoreRecordDefinition _noKeyDefinition = new(); + + private sealed class NoVectorModel + { + [VectorStoreRecordKey] + public string Key { get; set; } = string.Empty; + } + + private readonly VectorStoreRecordDefinition _noVectorDefinition = new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)) + ] + }; + + private sealed class MultiKeysModel + { + [VectorStoreRecordKey] + public string Key1 { get; set; } = string.Empty; + + [VectorStoreRecordKey] + public string Key2 { get; set; } = string.Empty; + } + + private readonly VectorStoreRecordDefinition _multiKeysDefinition = new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key1", typeof(string)), + new VectorStoreRecordKeyProperty("Key2", typeof(string)) + ] + }; + + private sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public string Key { get; set; } = string.Empty; + + [VectorStoreRecordData] + public string Data { get; set; } = string.Empty; + + [VectorStoreRecordVector] + public ReadOnlyMemory Vector { get; set; } + + public string NotAnnotated { get; set; } = string.Empty; + } + + private readonly VectorStoreRecordDefinition _singlePropsDefinition = new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("Data", typeof(string)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) + ] + }; + + private sealed class MultiPropsModel + { + [VectorStoreRecordKey] + public string Key { get; set; } = string.Empty; + + [VectorStoreRecordData(IsFilterable = true, IsFullTextSearchable = true)] + public string Data1 { get; set; } = string.Empty; + + [VectorStoreRecordData] + public string Data2 { get; set; } = string.Empty; + + [VectorStoreRecordVector(4, IndexKind.Flat, DistanceFunction.DotProductSimilarity)] + public ReadOnlyMemory Vector1 { get; set; } + + [VectorStoreRecordVector] + [JsonPropertyName("vector-2")] + public ReadOnlyMemory Vector2 { get; set; } + + public string NotAnnotated { get; set; } = string.Empty; + } + + private readonly VectorStoreRecordDefinition _multiPropsDefinition = new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("Data1", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("Data2", typeof(string)) { StoragePropertyName = "data_2" }, + new VectorStoreRecordVectorProperty("Vector1", typeof(ReadOnlyMemory)) { Dimensions = 4, IndexKind = IndexKind.Flat, DistanceFunction = DistanceFunction.DotProductSimilarity }, + new VectorStoreRecordVectorProperty("Vector2", typeof(ReadOnlyMemory)) + ] + }; + + private sealed class EnumerablePropsModel + { + [VectorStoreRecordKey] + public string Key { get; set; } = string.Empty; + + [VectorStoreRecordData] + public IEnumerable EnumerableData { get; set; } = new List(); + + [VectorStoreRecordData] + public string[] ArrayData { get; set; } = Array.Empty(); + + [VectorStoreRecordData] + public List ListData { get; set; } = new List(); + + [VectorStoreRecordVector] + public ReadOnlyMemory Vector { get; set; } + + public string NotAnnotated { get; set; } = string.Empty; + } + + private readonly VectorStoreRecordDefinition _enumerablePropsDefinition = new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("EnumerableData", typeof(IEnumerable)), + new VectorStoreRecordDataProperty("ArrayData", typeof(string[])), + new VectorStoreRecordDataProperty("ListData", typeof(List)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) + ] + }; + +#pragma warning restore CA1812 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreRecordCollectionTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..c70382481fbc --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreRecordCollectionTests.cs @@ -0,0 +1,314 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Data; +using Xunit; + +namespace SemanticKernel.UnitTests.Data; + +/// +/// Contains tests for the class. +/// +public class VolatileVectorStoreRecordCollectionTests +{ + private const string TestCollectionName = "testcollection"; + private const string TestRecordKey1 = "testid1"; + private const string TestRecordKey2 = "testid2"; + private const int TestRecordIntKey1 = 1; + private const int TestRecordIntKey2 = 2; + + private readonly CancellationToken _testCancellationToken = new(false); + + private readonly ConcurrentDictionary> _collectionStore; + + public VolatileVectorStoreRecordCollectionTests() + { + this._collectionStore = new(); + } + + [Theory] + [InlineData(TestCollectionName, true)] + [InlineData("nonexistentcollection", false)] + public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) + { + // Arrange + var collection = new ConcurrentDictionary(); + this._collectionStore.TryAdd(TestCollectionName, collection); + + var sut = new VolatileVectorStoreRecordCollection>( + this._collectionStore, + collectionName); + + // Act + var actual = await sut.CollectionExistsAsync(this._testCancellationToken); + + // Assert + Assert.Equal(expectedExists, actual); + } + + [Fact] + public async Task CanCreateCollectionAsync() + { + // Arrange + var sut = this.CreateRecordCollection(false); + + // Act + await sut.CreateCollectionAsync(this._testCancellationToken); + + // Assert + Assert.True(this._collectionStore.ContainsKey(TestCollectionName)); + } + + [Fact] + public async Task DeleteCollectionRemovesCollectionFromDictionaryAsync() + { + // Arrange + var collection = new ConcurrentDictionary(); + this._collectionStore.TryAdd(TestCollectionName, collection); + + var sut = this.CreateRecordCollection(false); + + // Act + await sut.DeleteCollectionAsync(this._testCancellationToken); + + // Assert + Assert.Empty(this._collectionStore); + } + + [Theory] + [InlineData(true, TestRecordKey1)] + [InlineData(true, TestRecordIntKey1)] + [InlineData(false, TestRecordKey1)] + [InlineData(false, TestRecordIntKey1)] + public async Task CanGetRecordWithVectorsAsync(bool useDefinition, TKey testKey) + where TKey : notnull + { + // Arrange + var record = CreateModel(testKey, withVectors: true); + var collection = new ConcurrentDictionary(); + collection.TryAdd(testKey!, record); + this._collectionStore.TryAdd(TestCollectionName, collection); + + var sut = this.CreateRecordCollection(useDefinition); + + // Act + var actual = await sut.GetAsync( + testKey, + new() + { + IncludeVectors = true + }, + this._testCancellationToken); + + // Assert + var expectedArgs = new object[] { TestRecordKey1 }; + + Assert.NotNull(actual); + Assert.Equal(testKey, actual.Key); + Assert.Equal($"data {testKey}", actual.Data); + Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.Vector!.Value.ToArray()); + } + + [Theory] + [InlineData(true, TestRecordKey1, TestRecordKey2)] + [InlineData(true, TestRecordIntKey1, TestRecordIntKey2)] + [InlineData(false, TestRecordKey1, TestRecordKey2)] + [InlineData(false, TestRecordIntKey1, TestRecordIntKey2)] + public async Task CanGetManyRecordsWithVectorsAsync(bool useDefinition, TKey testKey1, TKey testKey2) + where TKey : notnull + { + // Arrange + var record1 = CreateModel(testKey1, withVectors: true); + var record2 = CreateModel(testKey2, withVectors: true); + var collection = new ConcurrentDictionary(); + collection.TryAdd(testKey1!, record1); + collection.TryAdd(testKey2!, record2); + this._collectionStore.TryAdd(TestCollectionName, collection); + + var sut = this.CreateRecordCollection(useDefinition); + + // Act + var actual = await sut.GetBatchAsync( + [testKey1, testKey2], + new() + { + IncludeVectors = true + }, + this._testCancellationToken).ToListAsync(); + + // Assert + Assert.NotNull(actual); + Assert.Equal(2, actual.Count); + Assert.Equal(testKey1, actual[0].Key); + Assert.Equal($"data {testKey1}", actual[0].Data); + Assert.Equal(testKey2, actual[1].Key); + Assert.Equal($"data {testKey2}", actual[1].Data); + } + + [Theory] + [InlineData(true, TestRecordKey1, TestRecordKey2)] + [InlineData(true, TestRecordIntKey1, TestRecordIntKey2)] + [InlineData(false, TestRecordKey1, TestRecordKey2)] + [InlineData(false, TestRecordIntKey1, TestRecordIntKey2)] + public async Task CanDeleteRecordAsync(bool useDefinition, TKey testKey1, TKey testKey2) + where TKey : notnull + { + // Arrange + var record1 = CreateModel(testKey1, withVectors: true); + var record2 = CreateModel(testKey2, withVectors: true); + var collection = new ConcurrentDictionary(); + collection.TryAdd(testKey1, record1); + collection.TryAdd(testKey2, record2); + this._collectionStore.TryAdd(TestCollectionName, collection); + + var sut = this.CreateRecordCollection(useDefinition); + + // Act + await sut.DeleteAsync( + testKey1, + cancellationToken: this._testCancellationToken); + + // Assert + Assert.False(collection.ContainsKey(testKey1)); + Assert.True(collection.ContainsKey(testKey2)); + } + + [Theory] + [InlineData(true, TestRecordKey1, TestRecordKey2)] + [InlineData(true, TestRecordIntKey1, TestRecordIntKey2)] + [InlineData(false, TestRecordKey1, TestRecordKey2)] + [InlineData(false, TestRecordIntKey1, TestRecordIntKey2)] + public async Task CanDeleteManyRecordsWithVectorsAsync(bool useDefinition, TKey testKey1, TKey testKey2) + where TKey : notnull + { + // Arrange + var record1 = CreateModel(testKey1, withVectors: true); + var record2 = CreateModel(testKey2, withVectors: true); + var collection = new ConcurrentDictionary(); + collection.TryAdd(testKey1, record1); + collection.TryAdd(testKey2, record2); + this._collectionStore.TryAdd(TestCollectionName, collection); + + var sut = this.CreateRecordCollection(useDefinition); + + // Act + await sut.DeleteBatchAsync( + [testKey1, testKey2], + cancellationToken: this._testCancellationToken); + + // Assert + Assert.False(collection.ContainsKey(testKey1)); + Assert.False(collection.ContainsKey(testKey2)); + } + + [Theory] + [InlineData(true, TestRecordKey1)] + [InlineData(true, TestRecordIntKey1)] + [InlineData(false, TestRecordKey1)] + [InlineData(false, TestRecordIntKey1)] + public async Task CanUpsertRecordAsync(bool useDefinition, TKey testKey1) + where TKey : notnull + { + // Arrange + var record1 = CreateModel(testKey1, withVectors: true); + var collection = new ConcurrentDictionary(); + this._collectionStore.TryAdd(TestCollectionName, collection); + + var sut = this.CreateRecordCollection(useDefinition); + + // Act + var upsertResult = await sut.UpsertAsync( + record1, + cancellationToken: this._testCancellationToken); + + // Assert + Assert.Equal(testKey1, upsertResult); + Assert.True(collection.ContainsKey(testKey1)); + Assert.IsType>(collection[testKey1]); + Assert.Equal($"data {testKey1}", (collection[testKey1] as SinglePropsModel)!.Data); + } + + [Theory] + [InlineData(true, TestRecordKey1, TestRecordKey2)] + [InlineData(true, TestRecordIntKey1, TestRecordIntKey2)] + [InlineData(false, TestRecordKey1, TestRecordKey2)] + [InlineData(false, TestRecordIntKey1, TestRecordIntKey2)] + public async Task CanUpsertManyRecordsAsync(bool useDefinition, TKey testKey1, TKey testKey2) + where TKey : notnull + { + // Arrange + var record1 = CreateModel(testKey1, withVectors: true); + var record2 = CreateModel(testKey2, withVectors: true); + + var collection = new ConcurrentDictionary(); + this._collectionStore.TryAdd(TestCollectionName, collection); + + var sut = this.CreateRecordCollection(useDefinition); + + // Act + var actual = await sut.UpsertBatchAsync( + [record1, record2], + cancellationToken: this._testCancellationToken).ToListAsync(); + + // Assert + Assert.NotNull(actual); + Assert.Equal(2, actual.Count); + Assert.Equal(testKey1, actual[0]); + Assert.Equal(testKey2, actual[1]); + + Assert.True(collection.ContainsKey(testKey1)); + Assert.IsType>(collection[testKey1]); + Assert.Equal($"data {testKey1}", (collection[testKey1] as SinglePropsModel)!.Data); + } + + private static SinglePropsModel CreateModel(TKey key, bool withVectors) + { + return new SinglePropsModel + { + Key = key, + Data = "data " + key, + Vector = withVectors ? new float[] { 1, 2, 3, 4 } : null, + NotAnnotated = null, + }; + } + + private VolatileVectorStoreRecordCollection> CreateRecordCollection(bool useDefinition) + where TKey : notnull + { + return new VolatileVectorStoreRecordCollection>( + this._collectionStore, + TestCollectionName, + new() + { + VectorStoreRecordDefinition = useDefinition ? this._singlePropsDefinition : null + }); + } + + private readonly VectorStoreRecordDefinition _singlePropsDefinition = new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("Data", typeof(string)), + new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory)) + ] + }; + + public sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } + + [VectorStoreRecordData] + public string Data { get; set; } = string.Empty; + + [VectorStoreRecordVector] + public ReadOnlyMemory? Vector { get; set; } + + public string? NotAnnotated { get; set; } + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreTests.cs new file mode 100644 index 000000000000..694d2239b224 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreTests.cs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Data; +using Xunit; + +namespace SemanticKernel.UnitTests.Data; + +/// +/// Contains tests for the class. +/// +public class VolatileVectorStoreTests +{ + private const string TestCollectionName = "testcollection"; + + [Fact] + public void GetCollectionReturnsCollection() + { + // Arrange. + var sut = new VolatileVectorStore(); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.NotNull(actual); + Assert.IsType>>(actual); + } + + [Fact] + public void GetCollectionReturnsCollectionWithNonStringKey() + { + // Arrange. + var sut = new VolatileVectorStore(); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.NotNull(actual); + Assert.IsType>>(actual); + } + + [Fact] + public async Task ListCollectionNamesReadsDictionaryAsync() + { + // Arrange. + var collectionStore = new ConcurrentDictionary>(); + collectionStore.TryAdd("collection1", new ConcurrentDictionary()); + collectionStore.TryAdd("collection2", new ConcurrentDictionary()); + var sut = new VolatileVectorStore(collectionStore); + + // Act. + var collectionNames = sut.ListCollectionNamesAsync(); + + // Assert. + var collectionNamesList = await collectionNames.ToListAsync(); + Assert.Equal(new[] { "collection1", "collection2" }, collectionNamesList); + } + + public sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public required TKey Key { get; set; } + + [VectorStoreRecordData] + public string Data { get; set; } = string.Empty; + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? Vector { get; set; } + + public string? NotAnnotated { get; set; } + } +}