diff --git a/docs/examples_notebooks/index_migration_to_v1.ipynb b/docs/examples_notebooks/index_migration_to_v1.ipynb index 581f5cef64..4a89d95305 100644 --- a/docs/examples_notebooks/index_migration_to_v1.ipynb +++ b/docs/examples_notebooks/index_migration_to_v1.ipynb @@ -204,14 +204,13 @@ "source": [ "from graphrag.cache.factory import CacheFactory\n", "from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n", - "from graphrag.config.get_vector_store_settings import get_vector_store_settings\n", "from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings\n", + "from graphrag.language_model.manager import ModelManager\n", + "from graphrag.tokenizer.get_tokenizer import get_tokenizer\n", "\n", "# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n", "# We'll construct the context and run this function flow directly to avoid everything else\n", "\n", - "\n", - "vector_store_config = get_vector_store_settings(config)\n", "model_config = config.get_language_model_config(config.embed_text.model_id)\n", "callbacks = NoopWorkflowCallbacks()\n", "cache_config = config.cache.model_dump() # type: ignore\n", @@ -219,6 +218,15 @@ " cache_type=cache_config[\"type\"], # type: ignore\n", " **cache_config,\n", ")\n", + "model = ModelManager().get_or_create_embedding_model(\n", + " name=\"text_embedding\",\n", + " model_type=model_config.type,\n", + " config=model_config,\n", + " callbacks=callbacks,\n", + " cache=cache,\n", + ")\n", + "\n", + "tokenizer = get_tokenizer(model_config)\n", "\n", "await generate_text_embeddings(\n", " documents=None,\n", @@ -227,11 +235,12 @@ " entities=final_entities,\n", " community_reports=final_community_reports,\n", " callbacks=callbacks,\n", - " cache=cache,\n", - " model_config=model_config,\n", + " model=model,\n", + " tokenizer=tokenizer,\n", " batch_size=config.embed_text.batch_size,\n", " batch_max_tokens=config.embed_text.batch_max_tokens,\n", - " vector_store_config=vector_store_config,\n", + " num_threads=model_config.concurrent_requests,\n", + " vector_store_config=config.vector_store,\n", " embedded_fields=config.embed_text.names,\n", ")" ] diff --git a/docs/examples_notebooks/multi_index_search.ipynb b/docs/examples_notebooks/multi_index_search.ipynb deleted file mode 100644 index 2e70ed5086..0000000000 --- a/docs/examples_notebooks/multi_index_search.ipynb +++ /dev/null @@ -1,558 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright (c) 2024 Microsoft Corporation.\n", - "# Licensed under the MIT License." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Multi Index Search\n", - "This notebook demonstrates multi-index search using the GraphRAG API.\n", - "\n", - "Indexes created from Wikipedia state articles for Alaska, California, DC, Maryland, NY and Washington are used." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import asyncio\n", - "\n", - "import pandas as pd\n", - "\n", - "from graphrag.api.query import (\n", - " multi_index_basic_search,\n", - " multi_index_drift_search,\n", - " multi_index_global_search,\n", - " multi_index_local_search,\n", - ")\n", - "from graphrag.config.create_graphrag_config import create_graphrag_config\n", - "\n", - "indexes = [\"alaska\", \"california\", \"dc\", \"maryland\", \"ny\", \"washington\"]\n", - "indexes = sorted(indexes)\n", - "\n", - "print(indexes)\n", - "\n", - "vector_store_configs = {\n", - " index: {\n", - " \"type\": \"lancedb\",\n", - " \"db_uri\": f\"inputs/{index}/lancedb\",\n", - " \"container_name\": \"default\",\n", - " \"overwrite\": True,\n", - " \"index_name\": f\"{index}\",\n", - " }\n", - " for index in indexes\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config_data = {\n", - " \"models\": {\n", - " \"default_chat_model\": {\n", - " \"model_supports_json\": True,\n", - " \"parallelization_num_threads\": 50,\n", - " \"parallelization_stagger\": 0.3,\n", - " \"async_mode\": \"threaded\",\n", - " \"type\": \"azure_openai_chat\",\n", - " \"model\": \"gpt-4o\",\n", - " \"auth_type\": \"azure_managed_identity\",\n", - " \"api_base\": \"\",\n", - " \"api_version\": \"2024-02-15-preview\",\n", - " \"deployment_name\": \"gpt-4o\",\n", - " },\n", - " \"default_embedding_model\": {\n", - " \"parallelization_num_threads\": 50,\n", - " \"parallelization_stagger\": 0.3,\n", - " \"async_mode\": \"threaded\",\n", - " \"type\": \"azure_openai_embedding\",\n", - " \"model\": \"text-embedding-3-large\",\n", - " \"auth_type\": \"azure_managed_identity\",\n", - " \"api_base\": \"\",\n", - " \"api_version\": \"2024-02-15-preview\",\n", - " \"deployment_name\": \"text-embedding-3-large\",\n", - " },\n", - " },\n", - " \"vector_store\": vector_store_configs,\n", - " \"local_search\": {\n", - " \"prompt\": \"prompts/local_search_system_prompt.txt\",\n", - " \"llm_max_tokens\": 12000,\n", - " },\n", - " \"global_search\": {\n", - " \"map_prompt\": \"prompts/global_search_map_system_prompt.txt\",\n", - " \"reduce_prompt\": \"prompts/global_search_reduce_system_prompt.txt\",\n", - " \"knowledge_prompt\": \"prompts/global_search_knowledge_system_prompt.txt\",\n", - " },\n", - " \"drift_search\": {\n", - " \"prompt\": \"prompts/drift_search_system_prompt.txt\",\n", - " \"reduce_prompt\": \"prompts/drift_search_reduce_prompt.txt\",\n", - " },\n", - " \"basic_search\": {\"prompt\": \"prompts/basic_search_system_prompt.txt\"},\n", - "}\n", - "parameters = create_graphrag_config(config_data, \".\")\n", - "loop = asyncio.get_event_loop()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Multi-index Global Search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n", - "communities = [\n", - " pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n", - "]\n", - "community_reports = [\n", - " pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n", - "]\n", - "\n", - "task = loop.create_task(\n", - " multi_index_global_search(\n", - " parameters,\n", - " entities,\n", - " communities,\n", - " community_reports,\n", - " indexes,\n", - " 1,\n", - " False,\n", - " \"Multiple Paragraphs\",\n", - " False,\n", - " \"Describe this dataset.\",\n", - " )\n", - ")\n", - "results = await task" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Print report" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(results[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Show context links back to original index" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for report_id in [120, 129, 40, 16, 204, 143, 85, 122, 83]:\n", - " index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(report_id, index_name, index_id)\n", - " index_reports = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_community_reports.parquet\"\n", - " )\n", - " print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n", - " print(\n", - " index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n", - " 0\n", - " ]\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Multi-index Local Search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n", - "communities = [\n", - " pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n", - "]\n", - "community_reports = [\n", - " pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n", - "]\n", - "covariates = [\n", - " pd.read_parquet(f\"inputs/{index}/covariates.parquet\") for index in indexes\n", - "]\n", - "text_units = [\n", - " pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n", - "]\n", - "relationships = [\n", - " pd.read_parquet(f\"inputs/{index}/relationships.parquet\") for index in indexes\n", - "]\n", - "\n", - "task = loop.create_task(\n", - " multi_index_local_search(\n", - " parameters,\n", - " entities,\n", - " communities,\n", - " community_reports,\n", - " text_units,\n", - " relationships,\n", - " covariates,\n", - " indexes,\n", - " 1,\n", - " \"Multiple Paragraphs\",\n", - " False,\n", - " \"weather\",\n", - " )\n", - ")\n", - "results = await task" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Print report" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(results[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Show context links back to original index" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for report_id in [47, 213]:\n", - " index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(report_id, index_name, index_id)\n", - " index_reports = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_community_reports.parquet\"\n", - " )\n", - " print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n", - " print(\n", - " index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n", - " 0\n", - " ]\n", - " )\n", - "for entity_id in [500, 502, 506, 1960, 1961, 1962]:\n", - " index_name = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(entity_id, index_name, index_id)\n", - " index_entities = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_entities.parquet\"\n", - " )\n", - " print(\n", - " [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n", - " \"description\"\n", - " ][:100]\n", - " )\n", - " print(\n", - " index_entities[index_entities[\"human_readable_id\"] == int(index_id)][\n", - " \"description\"\n", - " ].to_numpy()[0][:100]\n", - " )\n", - "for relationship_id in [1805, 1806]:\n", - " index_name = [ # noqa: RUF015\n", - " i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n", - " ][0][\"index_name\"]\n", - " index_id = [ # noqa: RUF015\n", - " i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n", - " ][0][\"index_id\"]\n", - " print(relationship_id, index_name, index_id)\n", - " index_relationships = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_relationships.parquet\"\n", - " )\n", - " print(\n", - " [i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)][0][ # noqa: RUF015\n", - " \"description\"\n", - " ]\n", - " )\n", - " print(\n", - " index_relationships[index_relationships[\"human_readable_id\"] == int(index_id)][\n", - " \"description\"\n", - " ].to_numpy()[0]\n", - " )\n", - "for claim_id in [100]:\n", - " index_name = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(relationship_id, index_name, index_id)\n", - " index_claims = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_covariates.parquet\"\n", - " )\n", - " print(\n", - " [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][\"description\"] # noqa: RUF015\n", - " )\n", - " print(\n", - " index_claims[index_claims[\"human_readable_id\"] == int(index_id)][\n", - " \"description\"\n", - " ].to_numpy()[0]\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Multi-index Drift Search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n", - "communities = [\n", - " pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n", - "]\n", - "community_reports = [\n", - " pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n", - "]\n", - "text_units = [\n", - " pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n", - "]\n", - "relationships = [\n", - " pd.read_parquet(f\"inputs/{index}/relationships.parquet\") for index in indexes\n", - "]\n", - "\n", - "task = loop.create_task(\n", - " multi_index_drift_search(\n", - " parameters,\n", - " entities,\n", - " communities,\n", - " community_reports,\n", - " text_units,\n", - " relationships,\n", - " indexes,\n", - " 1,\n", - " \"Multiple Paragraphs\",\n", - " False,\n", - " \"agriculture\",\n", - " )\n", - ")\n", - "results = await task" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Print report" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(results[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Show context links back to original index" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for report_id in [47, 236]:\n", - " for question in results[1]:\n", - " resq = results[1][question]\n", - " if len(resq[\"reports\"]) == 0:\n", - " continue\n", - " if len([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)]) == 0:\n", - " continue\n", - " index_name = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(question, report_id, index_name, index_id)\n", - " index_reports = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_community_reports.parquet\"\n", - " )\n", - " print([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n", - " print(\n", - " index_reports[index_reports[\"community\"] == int(index_id)][\n", - " \"title\"\n", - " ].to_numpy()[0]\n", - " )\n", - " break\n", - "for source_id in [10, 16, 19, 20, 21, 22, 24, 29, 93, 95]:\n", - " for question in results[1]:\n", - " resq = results[1][question]\n", - " if len(resq[\"sources\"]) == 0:\n", - " continue\n", - " if len([i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)]) == 0:\n", - " continue\n", - " index_name = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(question, source_id, index_name, index_id)\n", - " index_sources = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_text_units.parquet\"\n", - " )\n", - " print(\n", - " [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][\"text\"][:250] # noqa: RUF015\n", - " )\n", - " print(index_sources.loc[int(index_id)][\"text\"][:250])\n", - " break" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Multi-index Basic Search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "text_units = [\n", - " pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n", - "]\n", - "\n", - "task = loop.create_task(\n", - " multi_index_basic_search(\n", - " parameters, text_units, indexes, False, \"industry in maryland\"\n", - " )\n", - ")\n", - "results = await task" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Print report" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(results[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Show context links back to original text\n", - "\n", - "Note that original index name is not saved in context data for basic search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for source_id in [0, 1]:\n", - " print(results[1][\"sources\"][source_id][\"text\"][:250])" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs/query/multi_index_search.md b/docs/query/multi_index_search.md deleted file mode 100644 index 6b6ff2b41a..0000000000 --- a/docs/query/multi_index_search.md +++ /dev/null @@ -1,20 +0,0 @@ -# Multi Index Search ๐Ÿ”Ž - -## Multi Dataset Reasoning - -GraphRAG takes in unstructured data contained in text documents and uses large languages models to โ€œreadโ€ the documents in a targeted fashion and create a knowledge graph. This knowledge graph, or index, contains information about specific entities in the data, how the entities relate to one another, and high-level reports about communities and topics found in the data. Indexes can be searched by users to get meaningful information about the underlying data, including reports with citations that point back to the original unstructured text. - -Multi-index search is a new capability that has been added to the GraphRAG python library to query multiple knowledge stores at once. Multi-index search allows for many new search scenarios, including: - -- Combining knowledge from different domains โ€“ Many documents contain similar types of entities: person, place, thing. But GraphRAG can be tuned for highly specialized domains, such as science and engineering. With the recent updates to search, GraphRAG can now simultaneously query multiple datasets with completely different schemas and entity definitions. - -- Combining knowledge with different access levels โ€“ Not all datasets are accessible to all people, even within an organization. Some datasets are publicly available. Some datasets, such as internal financial information or intellectual property, may only be accessible by a small number of employees at a company. Multi-index search allows multiple sources with different access controls to be queried at the same time, creating more nuanced and informative reports. Internal R&D findings can be seamlessly combined with open-source scientific publications. - -- Combining knowledge in different locations โ€“ With multi-index search, indexes do not need to be in the same location or type of storage to be queried. Indexes in the cloud in Azure Storage can be queried at the same time as indexes stored on a personal computer. Multi-index search makes these types of data joins easy and accessible. - -To search across multiple datasets, the underlying contexts from each index, based on the user query, are combined in-memory at query time, saving on computation and allowing the joint querying of indexes that canโ€™t be joined inherently, either do access controls or differing schemas. Multi-index search automatically keeps track of provenance information, so that any references can be traced back to the correct indexes and correct original documents. - - -## How to Use - -An example of a global search scenario can be found in the following [notebook](../examples_notebooks/multi_index_search.ipynb). \ No newline at end of file diff --git a/graphrag/api/__init__.py b/graphrag/api/__init__.py index a3a06033bc..05692c4182 100644 --- a/graphrag/api/__init__.py +++ b/graphrag/api/__init__.py @@ -18,10 +18,6 @@ global_search_streaming, local_search, local_search_streaming, - multi_index_basic_search, - multi_index_drift_search, - multi_index_global_search, - multi_index_local_search, ) from graphrag.prompt_tune.types import DocSelectionType @@ -37,10 +33,6 @@ "drift_search_streaming", "basic_search", "basic_search_streaming", - "multi_index_basic_search", - "multi_index_drift_search", - "multi_index_global_search", - "multi_index_local_search", # prompt tuning API "DocSelectionType", "generate_indexing_prompts", diff --git a/graphrag/api/query.py b/graphrag/api/query.py index b92cd8145b..e49a0976df 100644 --- a/graphrag/api/query.py +++ b/graphrag/api/query.py @@ -52,7 +52,6 @@ get_embedding_store, load_search_prompt, truncate, - update_context_data, ) from graphrag.utils.cli import redact @@ -192,152 +191,6 @@ def global_search_streaming( return search_engine.stream_search(query=query) -@validate_call(config={"arbitrary_types_allowed": True}) -async def multi_index_global_search( - config: GraphRagConfig, - entities_list: list[pd.DataFrame], - communities_list: list[pd.DataFrame], - community_reports_list: list[pd.DataFrame], - index_names: list[str], - community_level: int | None, - dynamic_community_selection: bool, - response_type: str, - streaming: bool, - query: str, - callbacks: list[QueryCallbacks] | None = None, - verbose: bool = False, -) -> tuple[ - str | dict[str, Any] | list[dict[str, Any]], - str | list[pd.DataFrame] | dict[str, pd.DataFrame], -]: - """Perform a global search across multiple indexes and return the context data and response. - - Parameters - ---------- - - config (GraphRagConfig): A graphrag configuration (from settings.yaml) - - entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet) - - communities_list (list[pd.DataFrame]): A list of DataFrames containing the final communities (from communities.parquet) - - community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet) - - index_names (list[str]): A list of index names. - - community_level (int): The community level to search at. - - dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search. - - response_type (str): The type of response to return. - - streaming (bool): Whether to stream the results or not. - - query (str): The user query to search for. - - Returns - ------- - TODO: Document the search response type and format. - """ - init_loggers(config=config, verbose=verbose, filename="query.log") - - logger.warning( - "Multi-index search is deprecated and will be removed in GraphRAG v3." - ) - - # Streaming not supported yet - if streaming: - message = "Streaming not yet implemented for multi_global_search" - raise NotImplementedError(message) - - links = { - "communities": {}, - "community_reports": {}, - "entities": {}, - } - max_vals = { - "communities": -1, - "community_reports": -1, - "entities": -1, - } - - communities_dfs = [] - community_reports_dfs = [] - entities_dfs = [] - - for idx, index_name in enumerate(index_names): - # Prepare each index's community reports dataframe for merging - community_reports_df = community_reports_list[idx] - community_reports_df["community"] = community_reports_df["community"].astype( - int - ) - for i in community_reports_df["community"]: - links["community_reports"][i + max_vals["community_reports"] + 1] = { - "index_name": index_name, - "id": str(i), - } - community_reports_df["community"] += max_vals["community_reports"] + 1 - community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1 - max_vals["community_reports"] = int(community_reports_df["community"].max()) - community_reports_dfs.append(community_reports_df) - - # Prepare each index's communities dataframe for merging - communities_df = communities_list[idx] - communities_df["community"] = communities_df["community"].astype(int) - communities_df["parent"] = communities_df["parent"].astype(int) - for i in communities_df["community"]: - links["communities"][i + max_vals["communities"] + 1] = { - "index_name": index_name, - "id": str(i), - } - communities_df["community"] += max_vals["communities"] + 1 - communities_df["parent"] = communities_df["parent"].apply( - lambda x: x if x == -1 else x + max_vals["communities"] + 1 - ) - communities_df["human_readable_id"] += max_vals["communities"] + 1 - # concat the index name to the entity_ids, since this is used for joining later - communities_df["entity_ids"] = communities_df["entity_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["communities"] = int(communities_df["community"].max()) - communities_dfs.append(communities_df) - - # Prepare each index's entities dataframe for merging - entities_df = entities_list[idx] - for i in entities_df["human_readable_id"]: - links["entities"][i + max_vals["entities"] + 1] = { - "index_name": index_name, - "id": i, - } - entities_df["human_readable_id"] += max_vals["entities"] + 1 - entities_df["title"] = entities_df["title"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["entities"] = int(entities_df["human_readable_id"].max()) - entities_dfs.append(entities_df) - - # Merge the dataframes - community_reports_combined = pd.concat( - community_reports_dfs, axis=0, ignore_index=True, sort=False - ) - entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False) - communities_combined = pd.concat( - communities_dfs, axis=0, ignore_index=True, sort=False - ) - - logger.debug("Executing multi-index global search query: %s", query) - result = await global_search( - config, - entities=entities_combined, - communities=communities_combined, - community_reports=community_reports_combined, - community_level=community_level, - dynamic_community_selection=dynamic_community_selection, - response_type=response_type, - query=query, - callbacks=callbacks, - ) - - # Update the context data by linking index names and community ids - context = update_context_data(result[1], links) - - logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore - return (result[0], context) - - @validate_call(config={"arbitrary_types_allowed": True}) async def local_search( config: GraphRagConfig, @@ -441,14 +294,11 @@ def local_search_streaming( """ init_loggers(config=config, verbose=verbose, filename="query.log") - vector_store_args = {} - for index, store in config.vector_store.items(): - vector_store_args[index] = store.model_dump() - msg = f"Vector Store Args: {redact(vector_store_args)}" + msg = f"Vector Store Args: {redact(config.vector_store.model_dump())}" logger.debug(msg) description_embedding_store = get_embedding_store( - config_args=vector_store_args, + store=config.vector_store.model_dump(), embedding_name=entity_description_embedding, ) @@ -472,238 +322,6 @@ def local_search_streaming( return search_engine.stream_search(query=query) -@validate_call(config={"arbitrary_types_allowed": True}) -async def multi_index_local_search( - config: GraphRagConfig, - entities_list: list[pd.DataFrame], - communities_list: list[pd.DataFrame], - community_reports_list: list[pd.DataFrame], - text_units_list: list[pd.DataFrame], - relationships_list: list[pd.DataFrame], - covariates_list: list[pd.DataFrame] | None, - index_names: list[str], - community_level: int, - response_type: str, - streaming: bool, - query: str, - callbacks: list[QueryCallbacks] | None = None, - verbose: bool = False, -) -> tuple[ - str | dict[str, Any] | list[dict[str, Any]], - str | list[pd.DataFrame] | dict[str, pd.DataFrame], -]: - """Perform a local search across multiple indexes and return the context data and response. - - Parameters - ---------- - - config (GraphRagConfig): A graphrag configuration (from settings.yaml) - - entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet) - - community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet) - - text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet) - - relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet) - - covariates_list (list[pd.DataFrame]): [Optional] A list of DataFrames containing the final covariates (from covariates.parquet) - - index_names (list[str]): A list of index names. - - community_level (int): The community level to search at. - - response_type (str): The response type to return. - - streaming (bool): Whether to stream the results or not. - - query (str): The user query to search for. - - Returns - ------- - TODO: Document the search response type and format. - """ - init_loggers(config=config, verbose=verbose, filename="query.log") - - logger.warning( - "Multi-index search is deprecated and will be removed in GraphRAG v3." - ) - # Streaming not supported yet - if streaming: - message = "Streaming not yet implemented for multi_index_local_search" - raise NotImplementedError(message) - - links = { - "community_reports": {}, - "communities": {}, - "entities": {}, - "text_units": {}, - "relationships": {}, - "covariates": {}, - } - max_vals = { - "community_reports": -1, - "communities": -1, - "entities": -1, - "text_units": 0, - "relationships": -1, - "covariates": 0, - } - community_reports_dfs = [] - communities_dfs = [] - entities_dfs = [] - relationships_dfs = [] - text_units_dfs = [] - covariates_dfs = [] - - for idx, index_name in enumerate(index_names): - # Prepare each index's communities dataframe for merging - communities_df = communities_list[idx] - communities_df["community"] = communities_df["community"].astype(int) - for i in communities_df["community"]: - links["communities"][i + max_vals["communities"] + 1] = { - "index_name": index_name, - "id": str(i), - } - communities_df["community"] += max_vals["communities"] + 1 - communities_df["human_readable_id"] += max_vals["communities"] + 1 - # concat the index name to the entity_ids, since this is used for joining later - communities_df["entity_ids"] = communities_df["entity_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["communities"] = int(communities_df["community"].max()) - communities_dfs.append(communities_df) - - # Prepare each index's community reports dataframe for merging - community_reports_df = community_reports_list[idx] - community_reports_df["community"] = community_reports_df["community"].astype( - int - ) - for i in community_reports_df["community"]: - links["community_reports"][i + max_vals["community_reports"] + 1] = { - "index_name": index_name, - "id": str(i), - } - community_reports_df["community"] += max_vals["community_reports"] + 1 - community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1 - max_vals["community_reports"] = int(community_reports_df["community"].max()) - community_reports_dfs.append(community_reports_df) - - # Prepare each index's entities dataframe for merging - entities_df = entities_list[idx] - for i in entities_df["human_readable_id"]: - links["entities"][i + max_vals["entities"] + 1] = { - "index_name": index_name, - "id": i, - } - entities_df["human_readable_id"] += max_vals["entities"] + 1 - entities_df["title"] = entities_df["title"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["id"] = entities_df["id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["entities"] = int(entities_df["human_readable_id"].max()) - entities_dfs.append(entities_df) - - # Prepare each index's relationships dataframe for merging - relationships_df = relationships_list[idx] - for i in relationships_df["human_readable_id"].astype(int): - links["relationships"][i + max_vals["relationships"] + 1] = { - "index_name": index_name, - "id": i, - } - if max_vals["relationships"] != -1: - col = ( - relationships_df["human_readable_id"].astype(int) - + max_vals["relationships"] - + 1 - ) - relationships_df["human_readable_id"] = col.astype(str) - relationships_df["source"] = relationships_df["source"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - relationships_df["target"] = relationships_df["target"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["relationships"] = int(relationships_df["human_readable_id"].max()) - relationships_dfs.append(relationships_df) - - # Prepare each index's text units dataframe for merging - text_units_df = text_units_list[idx] - for i in range(text_units_df.shape[0]): - links["text_units"][i + max_vals["text_units"]] = { - "index_name": index_name, - "id": i, - } - text_units_df["id"] = text_units_df["id"].apply( - lambda x, index_name=index_name: f"{x}-{index_name}" - ) - text_units_df["human_readable_id"] = ( - text_units_df["human_readable_id"] + max_vals["text_units"] - ) - max_vals["text_units"] += text_units_df.shape[0] - text_units_dfs.append(text_units_df) - - # If presents, prepare each index's covariates dataframe for merging - if covariates_list is not None: - covariates_df = covariates_list[idx] - for i in covariates_df["human_readable_id"].astype(int): - links["covariates"][i + max_vals["covariates"]] = { - "index_name": index_name, - "id": i, - } - covariates_df["id"] = covariates_df["id"].apply( - lambda x, index_name=index_name: f"{x}-{index_name}" - ) - covariates_df["human_readable_id"] = ( - covariates_df["human_readable_id"] + max_vals["covariates"] - ) - covariates_df["text_unit_id"] = covariates_df["text_unit_id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - covariates_df["subject_id"] = covariates_df["subject_id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - max_vals["covariates"] += covariates_df.shape[0] - covariates_dfs.append(covariates_df) - - # Merge the dataframes - communities_combined = pd.concat( - communities_dfs, axis=0, ignore_index=True, sort=False - ) - community_reports_combined = pd.concat( - community_reports_dfs, axis=0, ignore_index=True, sort=False - ) - entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False) - relationships_combined = pd.concat( - relationships_dfs, axis=0, ignore_index=True, sort=False - ) - text_units_combined = pd.concat( - text_units_dfs, axis=0, ignore_index=True, sort=False - ) - covariates_combined = None - if len(covariates_dfs) > 0: - covariates_combined = pd.concat( - covariates_dfs, axis=0, ignore_index=True, sort=False - ) - logger.debug("Executing multi-index local search query: %s", query) - result = await local_search( - config, - entities=entities_combined, - communities=communities_combined, - community_reports=community_reports_combined, - text_units=text_units_combined, - relationships=relationships_combined, - covariates=covariates_combined, - community_level=community_level, - response_type=response_type, - query=query, - callbacks=callbacks, - ) - - # Update the context data by linking index names and community ids - context = update_context_data(result[1], links) - - logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore - return (result[0], context) - - @validate_call(config={"arbitrary_types_allowed": True}) async def drift_search( config: GraphRagConfig, @@ -801,19 +419,16 @@ def drift_search_streaming( """ init_loggers(config=config, verbose=verbose, filename="query.log") - vector_store_args = {} - for index, store in config.vector_store.items(): - vector_store_args[index] = store.model_dump() - msg = f"Vector Store Args: {redact(vector_store_args)}" + msg = f"Vector Store Args: {redact(config.vector_store.model_dump())}" logger.debug(msg) description_embedding_store = get_embedding_store( - config_args=vector_store_args, + store=config.vector_store.model_dump(), embedding_name=entity_description_embedding, ) full_content_embedding_store = get_embedding_store( - config_args=vector_store_args, + store=config.vector_store.model_dump(), embedding_name=community_full_content_embedding, ) @@ -841,222 +456,11 @@ def drift_search_streaming( return search_engine.stream_search(query=query) -@validate_call(config={"arbitrary_types_allowed": True}) -async def multi_index_drift_search( - config: GraphRagConfig, - entities_list: list[pd.DataFrame], - communities_list: list[pd.DataFrame], - community_reports_list: list[pd.DataFrame], - text_units_list: list[pd.DataFrame], - relationships_list: list[pd.DataFrame], - index_names: list[str], - community_level: int, - response_type: str, - streaming: bool, - query: str, - callbacks: list[QueryCallbacks] | None = None, - verbose: bool = False, -) -> tuple[ - str | dict[str, Any] | list[dict[str, Any]], - str | list[pd.DataFrame] | dict[str, pd.DataFrame], -]: - """Perform a DRIFT search across multiple indexes and return the context data and response. - - Parameters - ---------- - - config (GraphRagConfig): A graphrag configuration (from settings.yaml) - - entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet) - - community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet) - - text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet) - - relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet) - - index_names (list[str]): A list of index names. - - community_level (int): The community level to search at. - - response_type (str): The response type to return. - - streaming (bool): Whether to stream the results or not. - - query (str): The user query to search for. - - Returns - ------- - TODO: Document the search response type and format. - """ - init_loggers(config=config, verbose=verbose, filename="query.log") - - logger.warning( - "Multi-index search is deprecated and will be removed in GraphRAG v3." - ) - - # Streaming not supported yet - if streaming: - message = "Streaming not yet implemented for multi_drift_search" - raise NotImplementedError(message) - - links = { - "community_reports": {}, - "communities": {}, - "entities": {}, - "text_units": {}, - "relationships": {}, - } - max_vals = { - "community_reports": -1, - "communities": -1, - "entities": -1, - "text_units": 0, - "relationships": -1, - } - - communities_dfs = [] - community_reports_dfs = [] - entities_dfs = [] - relationships_dfs = [] - text_units_dfs = [] - - for idx, index_name in enumerate(index_names): - # Prepare each index's communities dataframe for merging - communities_df = communities_list[idx] - communities_df["community"] = communities_df["community"].astype(int) - for i in communities_df["community"]: - links["communities"][i + max_vals["communities"] + 1] = { - "index_name": index_name, - "id": str(i), - } - communities_df["community"] += max_vals["communities"] + 1 - communities_df["human_readable_id"] += max_vals["communities"] + 1 - # concat the index name to the entity_ids, since this is used for joining later - communities_df["entity_ids"] = communities_df["entity_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["communities"] = int(communities_df["community"].max()) - communities_dfs.append(communities_df) - - # Prepare each index's community reports dataframe for merging - community_reports_df = community_reports_list[idx] - community_reports_df["community"] = community_reports_df["community"].astype( - int - ) - for i in community_reports_df["community"]: - links["community_reports"][i + max_vals["community_reports"] + 1] = { - "index_name": index_name, - "id": str(i), - } - community_reports_df["community"] += max_vals["community_reports"] + 1 - community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1 - community_reports_df["id"] = community_reports_df["id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - max_vals["community_reports"] = int(community_reports_df["community"].max()) - community_reports_dfs.append(community_reports_df) - - # Prepare each index's entities dataframe for merging - entities_df = entities_list[idx] - for i in entities_df["human_readable_id"]: - links["entities"][i + max_vals["entities"] + 1] = { - "index_name": index_name, - "id": i, - } - entities_df["human_readable_id"] += max_vals["entities"] + 1 - entities_df["title"] = entities_df["title"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["id"] = entities_df["id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["entities"] = int(entities_df["human_readable_id"].max()) - entities_dfs.append(entities_df) - - # Prepare each index's relationships dataframe for merging - relationships_df = relationships_list[idx] - for i in relationships_df["human_readable_id"].astype(int): - links["relationships"][i + max_vals["relationships"] + 1] = { - "index_name": index_name, - "id": i, - } - if max_vals["relationships"] != -1: - col = ( - relationships_df["human_readable_id"].astype(int) - + max_vals["relationships"] - + 1 - ) - relationships_df["human_readable_id"] = col.astype(str) - relationships_df["source"] = relationships_df["source"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - relationships_df["target"] = relationships_df["target"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["relationships"] = int( - relationships_df["human_readable_id"].astype(int).max() - ) - - relationships_dfs.append(relationships_df) - - # Prepare each index's text units dataframe for merging - text_units_df = text_units_list[idx] - for i in range(text_units_df.shape[0]): - links["text_units"][i + max_vals["text_units"]] = { - "index_name": index_name, - "id": i, - } - text_units_df["id"] = text_units_df["id"].apply( - lambda x, index_name=index_name: f"{x}-{index_name}" - ) - text_units_df["human_readable_id"] = ( - text_units_df["human_readable_id"] + max_vals["text_units"] - ) - max_vals["text_units"] += text_units_df.shape[0] - text_units_dfs.append(text_units_df) - - # Merge the dataframes - communities_combined = pd.concat( - communities_dfs, axis=0, ignore_index=True, sort=False - ) - community_reports_combined = pd.concat( - community_reports_dfs, axis=0, ignore_index=True, sort=False - ) - entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False) - relationships_combined = pd.concat( - relationships_dfs, axis=0, ignore_index=True, sort=False - ) - text_units_combined = pd.concat( - text_units_dfs, axis=0, ignore_index=True, sort=False - ) - - logger.debug("Executing multi-index drift search query: %s", query) - result = await drift_search( - config, - entities=entities_combined, - communities=communities_combined, - community_reports=community_reports_combined, - text_units=text_units_combined, - relationships=relationships_combined, - community_level=community_level, - response_type=response_type, - query=query, - callbacks=callbacks, - ) - - # Update the context data by linking index names and community ids - context = {} - if type(result[1]) is dict: - for key in result[1]: - context[key] = update_context_data(result[1][key], links) - else: - context = result[1] - - logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore - return (result[0], context) - - @validate_call(config={"arbitrary_types_allowed": True}) async def basic_search( config: GraphRagConfig, text_units: pd.DataFrame, + response_type: str, query: str, callbacks: list[QueryCallbacks] | None = None, verbose: bool = False, @@ -1094,6 +498,7 @@ def on_context(context: Any) -> None: async for chunk in basic_search_streaming( config=config, text_units=text_units, + response_type=response_type, query=query, callbacks=callbacks, ): @@ -1106,6 +511,7 @@ def on_context(context: Any) -> None: def basic_search_streaming( config: GraphRagConfig, text_units: pd.DataFrame, + response_type: str, query: str, callbacks: list[QueryCallbacks] | None = None, verbose: bool = False, @@ -1124,14 +530,11 @@ def basic_search_streaming( """ init_loggers(config=config, verbose=verbose, filename="query.log") - vector_store_args = {} - for index, store in config.vector_store.items(): - vector_store_args[index] = store.model_dump() - msg = f"Vector Store Args: {redact(vector_store_args)}" + msg = f"Vector Store Args: {redact(config.vector_store.model_dump())}" logger.debug(msg) embedding_store = get_embedding_store( - config_args=vector_store_args, + store=config.vector_store.model_dump(), embedding_name=text_unit_text_embedding, ) @@ -1142,85 +545,8 @@ def basic_search_streaming( config=config, text_units=read_indexer_text_units(text_units), text_unit_embeddings=embedding_store, + response_type=response_type, system_prompt=prompt, callbacks=callbacks, ) return search_engine.stream_search(query=query) - - -@validate_call(config={"arbitrary_types_allowed": True}) -async def multi_index_basic_search( - config: GraphRagConfig, - text_units_list: list[pd.DataFrame], - index_names: list[str], - streaming: bool, - query: str, - callbacks: list[QueryCallbacks] | None = None, - verbose: bool = False, -) -> tuple[ - str | dict[str, Any] | list[dict[str, Any]], - str | list[pd.DataFrame] | dict[str, pd.DataFrame], -]: - """Perform a basic search across multiple indexes and return the context data and response. - - Parameters - ---------- - - config (GraphRagConfig): A graphrag configuration (from settings.yaml) - - text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet) - - index_names (list[str]): A list of index names. - - streaming (bool): Whether to stream the results or not. - - query (str): The user query to search for. - - Returns - ------- - TODO: Document the search response type and format. - """ - init_loggers(config=config, verbose=verbose, filename="query.log") - - logger.warning( - "Multi-index search is deprecated and will be removed in GraphRAG v3." - ) - - # Streaming not supported yet - if streaming: - message = "Streaming not yet implemented for multi_basic_search" - raise NotImplementedError(message) - - links = { - "text_units": {}, - } - max_vals = { - "text_units": 0, - } - - text_units_dfs = [] - - for idx, index_name in enumerate(index_names): - # Prepare each index's text units dataframe for merging - text_units_df = text_units_list[idx] - for i in range(text_units_df.shape[0]): - links["text_units"][i + max_vals["text_units"]] = { - "index_name": index_name, - "id": i, - } - text_units_df["id"] = text_units_df["id"].apply( - lambda x, index_name=index_name: f"{x}-{index_name}" - ) - text_units_df["human_readable_id"] = ( - text_units_df["human_readable_id"] + max_vals["text_units"] - ) - max_vals["text_units"] += text_units_df.shape[0] - text_units_dfs.append(text_units_df) - - # Merge the dataframes - text_units_combined = pd.concat( - text_units_dfs, axis=0, ignore_index=True, sort=False - ) - - logger.debug("Executing multi-index basic search query: %s", query) - return await basic_search( - config, - text_units=text_units_combined, - query=query, - callbacks=callbacks, - ) diff --git a/graphrag/cli/main.py b/graphrag/cli/main.py index bc0e9f39ac..7b4d3bb0cd 100644 --- a/graphrag/cli/main.py +++ b/graphrag/cli/main.py @@ -537,6 +537,7 @@ def _query_cli( config_filepath=config, data_dir=data, root_dir=root, + response_type=response_type, streaming=streaming, query=query, verbose=verbose, diff --git a/graphrag/cli/query.py b/graphrag/cli/query.py index 0075734859..914196b865 100644 --- a/graphrag/cli/query.py +++ b/graphrag/cli/query.py @@ -52,35 +52,9 @@ def run_global_search( optional_list=[], ) - # Call the Multi-Index Global Search API - if dataframe_dict["multi-index"]: - final_entities_list = dataframe_dict["entities"] - final_communities_list = dataframe_dict["communities"] - final_community_reports_list = dataframe_dict["community_reports"] - index_names = dataframe_dict["index_names"] - - response, context_data = asyncio.run( - api.multi_index_global_search( - config=config, - entities_list=final_entities_list, - communities_list=final_communities_list, - community_reports_list=final_community_reports_list, - index_names=index_names, - community_level=community_level, - dynamic_community_selection=dynamic_community_selection, - response_type=response_type, - streaming=streaming, - query=query, - verbose=verbose, - ) - ) - print(response) - return response, context_data - - # Otherwise, call the Single-Index Global Search API - final_entities: pd.DataFrame = dataframe_dict["entities"] - final_communities: pd.DataFrame = dataframe_dict["communities"] - final_community_reports: pd.DataFrame = dataframe_dict["community_reports"] + entities: pd.DataFrame = dataframe_dict["entities"] + communities: pd.DataFrame = dataframe_dict["communities"] + community_reports: pd.DataFrame = dataframe_dict["community_reports"] if streaming: @@ -97,9 +71,9 @@ def on_context(context: Any) -> None: async for stream_chunk in api.global_search_streaming( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, + entities=entities, + communities=communities, + community_reports=community_reports, community_level=community_level, dynamic_community_selection=dynamic_community_selection, response_type=response_type, @@ -118,9 +92,9 @@ def on_context(context: Any) -> None: response, context_data = asyncio.run( api.global_search( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, + entities=entities, + communities=communities, + community_reports=community_reports, community_level=community_level, dynamic_community_selection=dynamic_community_selection, response_type=response_type, @@ -166,49 +140,13 @@ def run_local_search( "covariates", ], ) - # Call the Multi-Index Local Search API - if dataframe_dict["multi-index"]: - final_entities_list = dataframe_dict["entities"] - final_communities_list = dataframe_dict["communities"] - final_community_reports_list = dataframe_dict["community_reports"] - final_text_units_list = dataframe_dict["text_units"] - final_relationships_list = dataframe_dict["relationships"] - index_names = dataframe_dict["index_names"] - - # If any covariates tables are missing from any index, set the covariates list to None - if len(dataframe_dict["covariates"]) != dataframe_dict["num_indexes"]: - final_covariates_list = None - else: - final_covariates_list = dataframe_dict["covariates"] - - response, context_data = asyncio.run( - api.multi_index_local_search( - config=config, - entities_list=final_entities_list, - communities_list=final_communities_list, - community_reports_list=final_community_reports_list, - text_units_list=final_text_units_list, - relationships_list=final_relationships_list, - covariates_list=final_covariates_list, - index_names=index_names, - community_level=community_level, - response_type=response_type, - streaming=streaming, - query=query, - verbose=verbose, - ) - ) - print(response) - return response, context_data - - # Otherwise, call the Single-Index Local Search API - final_communities: pd.DataFrame = dataframe_dict["communities"] - final_community_reports: pd.DataFrame = dataframe_dict["community_reports"] - final_text_units: pd.DataFrame = dataframe_dict["text_units"] - final_relationships: pd.DataFrame = dataframe_dict["relationships"] - final_entities: pd.DataFrame = dataframe_dict["entities"] - final_covariates: pd.DataFrame | None = dataframe_dict["covariates"] + communities: pd.DataFrame = dataframe_dict["communities"] + community_reports: pd.DataFrame = dataframe_dict["community_reports"] + text_units: pd.DataFrame = dataframe_dict["text_units"] + relationships: pd.DataFrame = dataframe_dict["relationships"] + entities: pd.DataFrame = dataframe_dict["entities"] + covariates: pd.DataFrame | None = dataframe_dict["covariates"] if streaming: @@ -225,12 +163,12 @@ def on_context(context: Any) -> None: async for stream_chunk in api.local_search_streaming( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, - text_units=final_text_units, - relationships=final_relationships, - covariates=final_covariates, + entities=entities, + communities=communities, + community_reports=community_reports, + text_units=text_units, + relationships=relationships, + covariates=covariates, community_level=community_level, response_type=response_type, query=query, @@ -248,12 +186,12 @@ def on_context(context: Any) -> None: response, context_data = asyncio.run( api.local_search( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, - text_units=final_text_units, - relationships=final_relationships, - covariates=final_covariates, + entities=entities, + communities=communities, + community_reports=community_reports, + text_units=text_units, + relationships=relationships, + covariates=covariates, community_level=community_level, response_type=response_type, query=query, @@ -296,41 +234,11 @@ def run_drift_search( ], ) - # Call the Multi-Index Drift Search API - if dataframe_dict["multi-index"]: - final_entities_list = dataframe_dict["entities"] - final_communities_list = dataframe_dict["communities"] - final_community_reports_list = dataframe_dict["community_reports"] - final_text_units_list = dataframe_dict["text_units"] - final_relationships_list = dataframe_dict["relationships"] - index_names = dataframe_dict["index_names"] - - response, context_data = asyncio.run( - api.multi_index_drift_search( - config=config, - entities_list=final_entities_list, - communities_list=final_communities_list, - community_reports_list=final_community_reports_list, - text_units_list=final_text_units_list, - relationships_list=final_relationships_list, - index_names=index_names, - community_level=community_level, - response_type=response_type, - streaming=streaming, - query=query, - verbose=verbose, - ) - ) - print(response) - - return response, context_data - - # Otherwise, call the Single-Index Drift Search API - final_communities: pd.DataFrame = dataframe_dict["communities"] - final_community_reports: pd.DataFrame = dataframe_dict["community_reports"] - final_text_units: pd.DataFrame = dataframe_dict["text_units"] - final_relationships: pd.DataFrame = dataframe_dict["relationships"] - final_entities: pd.DataFrame = dataframe_dict["entities"] + communities: pd.DataFrame = dataframe_dict["communities"] + community_reports: pd.DataFrame = dataframe_dict["community_reports"] + text_units: pd.DataFrame = dataframe_dict["text_units"] + relationships: pd.DataFrame = dataframe_dict["relationships"] + entities: pd.DataFrame = dataframe_dict["entities"] if streaming: @@ -347,11 +255,11 @@ def on_context(context: Any) -> None: async for stream_chunk in api.drift_search_streaming( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, - text_units=final_text_units, - relationships=final_relationships, + entities=entities, + communities=communities, + community_reports=community_reports, + text_units=text_units, + relationships=relationships, community_level=community_level, response_type=response_type, query=query, @@ -370,11 +278,11 @@ def on_context(context: Any) -> None: response, context_data = asyncio.run( api.drift_search( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, - text_units=final_text_units, - relationships=final_relationships, + entities=entities, + communities=communities, + community_reports=community_reports, + text_units=text_units, + relationships=relationships, community_level=community_level, response_type=response_type, query=query, @@ -390,6 +298,7 @@ def run_basic_search( config_filepath: Path | None, data_dir: Path | None, root_dir: Path, + response_type: str, streaming: bool, query: str, verbose: bool, @@ -411,27 +320,7 @@ def run_basic_search( ], ) - # Call the Multi-Index Basic Search API - if dataframe_dict["multi-index"]: - final_text_units_list = dataframe_dict["text_units"] - index_names = dataframe_dict["index_names"] - - response, context_data = asyncio.run( - api.multi_index_basic_search( - config=config, - text_units_list=final_text_units_list, - index_names=index_names, - streaming=streaming, - query=query, - verbose=verbose, - ) - ) - print(response) - - return response, context_data - - # Otherwise, call the Single-Index Basic Search API - final_text_units: pd.DataFrame = dataframe_dict["text_units"] + text_units: pd.DataFrame = dataframe_dict["text_units"] if streaming: @@ -448,7 +337,8 @@ def on_context(context: Any) -> None: async for stream_chunk in api.basic_search_streaming( config=config, - text_units=final_text_units, + text_units=text_units, + response_type=response_type, query=query, callbacks=[callbacks], verbose=verbose, @@ -464,7 +354,8 @@ def on_context(context: Any) -> None: response, context_data = asyncio.run( api.basic_search( config=config, - text_units=final_text_units, + text_units=text_units, + response_type=response_type, query=query, verbose=verbose, ) @@ -481,40 +372,6 @@ def _resolve_output_files( ) -> dict[str, Any]: """Read indexing output files to a dataframe dict.""" dataframe_dict = {} - - # Loading output files for multi-index search - if config.outputs: - dataframe_dict["multi-index"] = True - dataframe_dict["num_indexes"] = len(config.outputs) - dataframe_dict["index_names"] = config.outputs.keys() - for output in config.outputs.values(): - storage_obj = create_storage_from_config(output) - for name in output_list: - if name not in dataframe_dict: - dataframe_dict[name] = [] - df_value = asyncio.run( - load_table_from_storage(name=name, storage=storage_obj) - ) - dataframe_dict[name].append(df_value) - - # for optional output files, do not append if the dataframe does not exist - if optional_list: - for optional_file in optional_list: - if optional_file not in dataframe_dict: - dataframe_dict[optional_file] = [] - file_exists = asyncio.run( - storage_has_table(optional_file, storage_obj) - ) - if file_exists: - df_value = asyncio.run( - load_table_from_storage( - name=optional_file, storage=storage_obj - ) - ) - dataframe_dict[optional_file].append(df_value) - return dataframe_dict - # Loading output files for single-index search - dataframe_dict["multi-index"] = False storage_obj = create_storage_from_config(config.output) for name in output_list: df_value = asyncio.run(load_table_from_storage(name=name, storage=storage_obj)) diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index eadafd9860..e81c5ac31c 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -55,8 +55,6 @@ DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large" DEFAULT_MODEL_PROVIDER = "openai" -DEFAULT_VECTOR_STORE_ID = "default_vector_store" - ENCODING_MODEL = "o200k_base" COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default" @@ -167,7 +165,6 @@ class EmbedTextDefaults: batch_max_tokens: int = 8191 names: list[str] = field(default_factory=lambda: default_embeddings) strategy: None = None - vector_store_id: str = DEFAULT_VECTOR_STORE_ID @dataclass @@ -418,7 +415,6 @@ class GraphRagConfigDefaults: reporting: ReportingDefaults = field(default_factory=ReportingDefaults) storage: StorageDefaults = field(default_factory=StorageDefaults) output: OutputDefaults = field(default_factory=OutputDefaults) - outputs: None = None update_index_output: UpdateIndexOutputDefaults = field( default_factory=UpdateIndexOutputDefaults ) @@ -444,8 +440,8 @@ class GraphRagConfigDefaults: global_search: GlobalSearchDefaults = field(default_factory=GlobalSearchDefaults) drift_search: DriftSearchDefaults = field(default_factory=DriftSearchDefaults) basic_search: BasicSearchDefaults = field(default_factory=BasicSearchDefaults) - vector_store: dict[str, VectorStoreDefaults] = field( - default_factory=lambda: {DEFAULT_VECTOR_STORE_ID: VectorStoreDefaults()} + vector_store: VectorStoreDefaults = field( + default_factory=lambda: VectorStoreDefaults() ) workflows: None = None diff --git a/graphrag/config/get_vector_store_settings.py b/graphrag/config/get_vector_store_settings.py deleted file mode 100644 index 3771d65ff3..0000000000 --- a/graphrag/config/get_vector_store_settings.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing get_vector_store_settings.""" - -from graphrag.config.models.graph_rag_config import GraphRagConfig - - -def get_vector_store_settings( - settings: GraphRagConfig, - vector_store_params: dict | None = None, -) -> dict: - """Transform GraphRAG config into settings for workflows.""" - vector_store_settings = settings.get_vector_store_config( - settings.embed_text.vector_store_id - ).model_dump() - - # - # If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding. - # settings.vector_store.base contains connection information, or may be undefined - # settings.vector_store. contains the specific settings for this embedding - # - return { - **(vector_store_params or {}), - **(vector_store_settings), - } diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index aadacf8f38..69b05015fb 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -75,16 +75,14 @@ base_dir: "{graphrag_config_defaults.reporting.base_dir}" vector_store: - {defs.DEFAULT_VECTOR_STORE_ID}: - type: {vector_store_defaults.type} - db_uri: {vector_store_defaults.db_uri} - container_name: {vector_store_defaults.container_name} + type: {vector_store_defaults.type} + db_uri: {vector_store_defaults.db_uri} + container_name: {vector_store_defaults.container_name} ### Workflow settings ### embed_text: model_id: {graphrag_config_defaults.embed_text.model_id} - vector_store_id: {graphrag_config_defaults.embed_text.vector_store_id} extract_graph: model_id: {graphrag_config_defaults.extract_graph.model_id} diff --git a/graphrag/config/models/embed_text_config.py b/graphrag/config/models/embed_text_config.py index f785bf6eed..5e5963811c 100644 --- a/graphrag/config/models/embed_text_config.py +++ b/graphrag/config/models/embed_text_config.py @@ -15,10 +15,6 @@ class EmbedTextConfig(BaseModel): description="The model ID to use for text embeddings.", default=graphrag_config_defaults.embed_text.model_id, ) - vector_store_id: str = Field( - description="The vector store ID to use for text embeddings.", - default=graphrag_config_defaults.embed_text.vector_store_id, - ) batch_size: int = Field( description="The batch size to use.", default=graphrag_config_defaults.embed_text.batch_size, diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index 2b5961321a..c8cdca819c 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -3,7 +3,6 @@ """Parameterization settings for the default configuration.""" -from dataclasses import asdict from pathlib import Path from devtools import pformat @@ -179,23 +178,6 @@ def _validate_output_base_dir(self) -> None: (Path(self.root_dir) / self.output.base_dir).resolve() ) - outputs: dict[str, StorageConfig] | None = Field( - description="A list of output configurations used for multi-index query.", - default=graphrag_config_defaults.outputs, - ) - - def _validate_multi_output_base_dirs(self) -> None: - """Validate the outputs dict base directories.""" - if self.outputs: - for output in self.outputs.values(): - if output.type == defs.StorageType.file: - if output.base_dir.strip() == "": - msg = "Output base directory is required for file output. Please rerun `graphrag init` and set the output configuration." - raise ValueError(msg) - output.base_dir = str( - (Path(self.root_dir) / output.base_dir).resolve() - ) - update_index_output: StorageConfig = Field( description="The output configuration for the updated index.", default=StorageConfig( @@ -234,12 +216,8 @@ def _validate_reporting_base_dir(self) -> None: (Path(self.root_dir) / self.reporting.base_dir).resolve() ) - vector_store: dict[str, VectorStoreConfig] = Field( - description="The vector store configuration.", - default_factory=lambda: { - k: VectorStoreConfig(**asdict(v)) - for k, v in graphrag_config_defaults.vector_store.items() - }, + vector_store: VectorStoreConfig = Field( + description="The vector store configuration.", default=VectorStoreConfig() ) """The vector store configuration.""" @@ -327,12 +305,12 @@ def _validate_reporting_base_dir(self) -> None: def _validate_vector_store_db_uri(self) -> None: """Validate the vector store configuration.""" - for store in self.vector_store.values(): - if store.type == VectorStoreType.LanceDB: - if not store.db_uri or store.db_uri.strip == "": - msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration." - raise ValueError(msg) - store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve()) + store = self.vector_store + if store.type == VectorStoreType.LanceDB: + if not store.db_uri or store.db_uri.strip == "": + msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration." + raise ValueError(msg) + store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve()) def _validate_factories(self) -> None: """Validate the factories used in the configuration.""" @@ -363,30 +341,6 @@ def get_language_model_config(self, model_id: str) -> LanguageModelConfig: return self.models[model_id] - def get_vector_store_config(self, vector_store_id: str) -> VectorStoreConfig: - """Get a vector store configuration by ID. - - Parameters - ---------- - vector_store_id : str - The ID of the vector store to get. Should match an ID in the vector_store list. - - Returns - ------- - VectorStoreConfig - The vector store configuration if found. - - Raises - ------ - ValueError - If the vector store ID is not found in the configuration. - """ - if vector_store_id not in self.vector_store: - err_msg = f"Vector Store ID {vector_store_id} not found in configuration. Please rerun `graphrag init` and set the vector store configuration." - raise ValueError(err_msg) - - return self.vector_store[vector_store_id] - @model_validator(mode="after") def _validate_model(self): """Validate the model configuration.""" @@ -396,7 +350,6 @@ def _validate_model(self): self._validate_input_base_dir() self._validate_reporting_base_dir() self._validate_output_base_dir() - self._validate_multi_output_base_dirs() self._validate_update_index_output_base_dir() self._validate_vector_store_db_uri() self._validate_factories() diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index 04099f8109..f01682604f 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -9,13 +9,10 @@ import pandas as pd from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.config.embeddings import create_index_name -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.index.operations.embed_text.run_embed_text import run_embed_text from graphrag.language_model.protocol.base import EmbeddingModel from graphrag.tokenizer.tokenizer import Tokenizer from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument -from graphrag.vector_stores.factory import VectorStoreFactory logger = logging.getLogger(__name__) @@ -26,86 +23,14 @@ async def embed_text( model: EmbeddingModel, tokenizer: Tokenizer, embed_column: str, - embedding_name: str, batch_size: int, batch_max_tokens: int, num_threads: int, - vector_store_config: dict, + vector_store: BaseVectorStore, id_column: str = "id", title_column: str | None = None, ): """Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector.""" - if vector_store_config: - index_name = _get_index_name(vector_store_config, embedding_name) - vector_store: BaseVectorStore = _create_vector_store( - vector_store_config, index_name, embedding_name - ) - vector_store_workflow_config = vector_store_config.get( - embedding_name, vector_store_config - ) - return await _text_embed_with_vector_store( - input=input, - callbacks=callbacks, - model=model, - tokenizer=tokenizer, - embed_column=embed_column, - vector_store=vector_store, - vector_store_config=vector_store_workflow_config, - batch_size=batch_size, - batch_max_tokens=batch_max_tokens, - num_threads=num_threads, - id_column=id_column, - title_column=title_column, - ) - - return await _text_embed_in_memory( - input=input, - callbacks=callbacks, - model=model, - tokenizer=tokenizer, - embed_column=embed_column, - batch_size=batch_size, - batch_max_tokens=batch_max_tokens, - num_threads=num_threads, - ) - - -async def _text_embed_in_memory( - input: pd.DataFrame, - callbacks: WorkflowCallbacks, - model: EmbeddingModel, - tokenizer: Tokenizer, - embed_column: str, - batch_size: int, - batch_max_tokens: int, - num_threads: int, -): - texts: list[str] = input[embed_column].tolist() - result = await run_embed_text( - texts, callbacks, model, tokenizer, batch_size, batch_max_tokens, num_threads - ) - - return result.embeddings - - -async def _text_embed_with_vector_store( - input: pd.DataFrame, - callbacks: WorkflowCallbacks, - model: EmbeddingModel, - tokenizer: Tokenizer, - embed_column: str, - vector_store: BaseVectorStore, - vector_store_config: dict, - batch_size: int, - batch_max_tokens: int, - num_threads: int, - id_column: str, - title_column: str | None = None, -): - # Get vector-storage configuration - - overwrite: bool = vector_store_config.get("overwrite", True) - if embed_column not in input.columns: msg = f"Column {embed_column} not found in input dataframe with columns {input.columns}" raise ValueError(msg) @@ -168,51 +93,8 @@ async def _text_embed_with_vector_store( ) documents.append(document) - vector_store.load_documents(documents, overwrite and i == 0) + vector_store.load_documents(documents, True) starting_index += len(documents) i += 1 return all_results - - -def _create_vector_store( - vector_store_config: dict, index_name: str, embedding_name: str | None = None -) -> BaseVectorStore: - vector_store_type: str = str(vector_store_config.get("type")) - - embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get( - "embeddings_schema", {} - ) - single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() - - if ( - embeddings_schema is not None - and embedding_name is not None - and embedding_name in embeddings_schema - ): - raw_config = embeddings_schema[embedding_name] - if isinstance(raw_config, dict): - single_embedding_config = VectorStoreSchemaConfig(**raw_config) - else: - single_embedding_config = raw_config - - if single_embedding_config.index_name is None: - single_embedding_config.index_name = index_name - - vector_store = VectorStoreFactory().create_vector_store( - vector_store_schema_config=single_embedding_config, - vector_store_type=vector_store_type, - **vector_store_config, - ) - - vector_store.connect(**vector_store_config) - return vector_store - - -def _get_index_name(vector_store_config: dict, embedding_name: str) -> str: - container_name = vector_store_config.get("container_name", "default") - index_name = create_index_name(container_name, embedding_name) - - msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}" - logger.info(msg) - return index_name diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py index 3a0ae41259..848b57d292 100644 --- a/graphrag/index/workflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -12,14 +12,16 @@ community_full_content_embedding, community_summary_embedding, community_title_embedding, + create_index_name, document_text_embedding, entity_description_embedding, entity_title_embedding, relationship_description_embedding, text_unit_text_embedding, ) -from graphrag.config.get_vector_store_settings import get_vector_store_settings from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.models.vector_store_config import VectorStoreConfig +from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.index.operations.embed_text.embed_text import embed_text from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -31,6 +33,8 @@ load_table_from_storage, write_table_to_storage, ) +from graphrag.vector_stores.base import BaseVectorStore +from graphrag.vector_stores.factory import VectorStoreFactory logger = logging.getLogger(__name__) @@ -70,8 +74,6 @@ async def run_workflow( "community_reports", context.output_storage ) - vector_store_config = get_vector_store_settings(config) - model_config = config.get_language_model_config(config.embed_text.model_id) model = ModelManager().get_or_create_embedding_model( @@ -96,7 +98,7 @@ async def run_workflow( batch_size=config.embed_text.batch_size, batch_max_tokens=config.embed_text.batch_max_tokens, num_threads=model_config.concurrent_requests, - vector_store_config=vector_store_config, + vector_store_config=config.vector_store, embedded_fields=embedded_fields, ) @@ -124,7 +126,7 @@ async def generate_text_embeddings( batch_size: int, batch_max_tokens: int, num_threads: int, - vector_store_config: dict, + vector_store_config: VectorStoreConfig, embedded_fields: list[str], ) -> dict[str, pd.DataFrame]: """All the steps to generate all embeddings.""" @@ -208,20 +210,69 @@ async def _run_embeddings( batch_size: int, batch_max_tokens: int, num_threads: int, - vector_store_config: dict, + vector_store_config: VectorStoreConfig, ) -> pd.DataFrame: """All the steps to generate single embedding.""" + index_name = _get_index_name(vector_store_config, name) + vector_store = _create_vector_store(vector_store_config, index_name, name) + data["embedding"] = await embed_text( input=data, callbacks=callbacks, model=model, tokenizer=tokenizer, embed_column=embed_column, - embedding_name=name, batch_size=batch_size, batch_max_tokens=batch_max_tokens, num_threads=num_threads, - vector_store_config=vector_store_config, + vector_store=vector_store, ) return data.loc[:, ["id", "embedding"]] + + +def _create_vector_store( + vector_store_config: VectorStoreConfig, + index_name: str, + embedding_name: str | None = None, +) -> BaseVectorStore: + vector_store_type: str = str(vector_store_config.type) + + embeddings_schema: dict[str, VectorStoreSchemaConfig] = ( + vector_store_config.embeddings_schema + ) + + single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() + + if ( + embeddings_schema is not None + and embedding_name is not None + and embedding_name in embeddings_schema + ): + raw_config = embeddings_schema[embedding_name] + if isinstance(raw_config, dict): + single_embedding_config = VectorStoreSchemaConfig(**raw_config) + else: + single_embedding_config = raw_config + + if single_embedding_config.index_name is None: + single_embedding_config.index_name = index_name + + args = vector_store_config.model_dump() + vector_store = VectorStoreFactory().create_vector_store( + vector_store_schema_config=single_embedding_config, + vector_store_type=vector_store_type, + **args, + ) + + vector_store.connect(**args) + return vector_store + + +def _get_index_name(vector_store_config: VectorStoreConfig, embedding_name: str) -> str: + container_name = vector_store_config.container_name + index_name = create_index_name(container_name, embedding_name) + + msg = f"using vector store {vector_store_config.type} with container_name {container_name} for embedding {embedding_name}: {index_name}" + logger.info(msg) + return index_name diff --git a/graphrag/index/workflows/update_text_embeddings.py b/graphrag/index/workflows/update_text_embeddings.py index 4b349f5809..0fc5f1fdeb 100644 --- a/graphrag/index/workflows/update_text_embeddings.py +++ b/graphrag/index/workflows/update_text_embeddings.py @@ -5,7 +5,6 @@ import logging -from graphrag.config.get_vector_store_settings import get_vector_store_settings from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext @@ -37,7 +36,6 @@ async def run_workflow( ] embedded_fields = config.embed_text.names - vector_store_config = get_vector_store_settings(config) model_config = config.get_language_model_config(config.embed_text.model_id) @@ -63,7 +61,7 @@ async def run_workflow( batch_size=config.embed_text.batch_size, batch_max_tokens=config.embed_text.batch_max_tokens, num_threads=model_config.concurrent_requests, - vector_store_config=vector_store_config, + vector_store_config=config.vector_store, embedded_fields=embedded_fields, ) if config.snapshots.embeddings: diff --git a/graphrag/prompts/query/basic_search_system_prompt.py b/graphrag/prompts/query/basic_search_system_prompt.py index a20fb6ad10..bc37c70d98 100644 --- a/graphrag/prompts/query/basic_search_system_prompt.py +++ b/graphrag/prompts/query/basic_search_system_prompt.py @@ -27,7 +27,7 @@ "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]" -where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables. +where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "id" column in the provided tables. Do not include information where the supporting evidence for it is not provided. @@ -60,7 +60,7 @@ "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]" -where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables. +where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "id" column in the provided tables. Do not include information where the supporting evidence for it is not provided. diff --git a/graphrag/query/factory.py b/graphrag/query/factory.py index 10dfd3ba1f..0c4a2ef0b7 100644 --- a/graphrag/query/factory.py +++ b/graphrag/query/factory.py @@ -251,8 +251,8 @@ def get_basic_search_engine( text_units: list[TextUnit], text_unit_embeddings: BaseVectorStore, config: GraphRagConfig, + response_type: str, system_prompt: str | None = None, - response_type: str = "multiple paragraphs", callbacks: list[QueryCallbacks] | None = None, ) -> BasicSearch: """Create a basic search engine based on data + configuration.""" diff --git a/graphrag/query/structured_search/basic_search/basic_context.py b/graphrag/query/structured_search/basic_search/basic_context.py index 215b26a31e..b7390017fc 100644 --- a/graphrag/query/structured_search/basic_search/basic_context.py +++ b/graphrag/query/structured_search/basic_search/basic_context.py @@ -38,7 +38,6 @@ def __init__( self.text_units = text_units self.text_unit_embeddings = text_unit_embeddings self.embedding_vectorstore_key = embedding_vectorstore_key - self.text_id_map = self._map_ids() def build_context( self, @@ -48,7 +47,7 @@ def build_context( max_context_tokens: int = 12_000, context_name: str = "Sources", column_delimiter: str = "|", - text_id_col: str = "source_id", + text_id_col: str = "id", text_col: str = "text", **kwargs, ) -> ContextBuilderResult: @@ -63,7 +62,7 @@ def build_context( text_unit_ids = {t.document.id for t in related_texts} text_units_filtered = [] text_units_filtered = [ - {text_id_col: t.id, text_col: t.text} + {text_id_col: t.short_id, text_col: t.text} for t in self.text_units or [] if t.id in text_unit_ids ] @@ -102,13 +101,5 @@ def build_context( return ContextBuilderResult( context_chunks=final_text, - context_records={context_name: final_text_df}, + context_records={context_name.lower(): final_text_df}, ) - - def _map_ids(self) -> dict[str, str]: - """Map id to short id in the text units.""" - id_map = {} - text_units = self.text_units or [] - for unit in text_units: - id_map[unit.id] = unit.short_id - return id_map diff --git a/graphrag/utils/api.py b/graphrag/utils/api.py index db3d94790d..16a5f9ed52 100644 --- a/graphrag/utils/api.py +++ b/graphrag/utils/api.py @@ -12,133 +12,51 @@ from graphrag.config.models.cache_config import CacheConfig from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig -from graphrag.data_model.types import TextEmbedder from graphrag.storage.factory import StorageFactory from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.vector_stores.base import ( BaseVectorStore, - VectorStoreDocument, - VectorStoreSearchResult, ) from graphrag.vector_stores.factory import VectorStoreFactory -class MultiVectorStore(BaseVectorStore): - """Multi Vector Store wrapper implementation.""" - - def __init__( - self, - embedding_stores: list[BaseVectorStore], - index_names: list[str], - ): - self.embedding_stores = embedding_stores - self.index_names = index_names - - def load_documents( - self, documents: list[VectorStoreDocument], overwrite: bool = True - ) -> None: - """Load documents into the vector store.""" - msg = "load_documents method not implemented" - raise NotImplementedError(msg) - - def connect(self, **kwargs: Any) -> Any: - """Connect to vector storage.""" - msg = "connect method not implemented" - raise NotImplementedError(msg) - - def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: - """Build a query filter to filter documents by id.""" - msg = "filter_by_id method not implemented" - raise NotImplementedError(msg) - - def search_by_id(self, id: str) -> VectorStoreDocument: - """Search for a document by id.""" - search_index_id = id.split("-")[0] - search_index_name = id.split("-")[1] - for index_name, embedding_store in zip( - self.index_names, self.embedding_stores, strict=False - ): - if index_name == search_index_name: - return embedding_store.search_by_id(search_index_id) - else: - message = f"Index {search_index_name} not found." - raise ValueError(message) - - def similarity_search_by_vector( - self, query_embedding: list[float], k: int = 10, **kwargs: Any - ) -> list[VectorStoreSearchResult]: - """Perform a vector-based similarity search.""" - all_results = [] - for index_name, embedding_store in zip( - self.index_names, self.embedding_stores, strict=False - ): - results = embedding_store.similarity_search_by_vector( - query_embedding=query_embedding, k=k - ) - mod_results = [] - for r in results: - r.document.id = str(r.document.id) + f"-{index_name}" - mod_results += [r] - all_results += mod_results - return sorted(all_results, key=lambda x: x.score, reverse=True)[:k] - - def similarity_search_by_text( - self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any - ) -> list[VectorStoreSearchResult]: - """Perform a text-based similarity search.""" - query_embedding = text_embedder(text) - if query_embedding: - return self.similarity_search_by_vector( - query_embedding=query_embedding, k=k - ) - return [] - - def get_embedding_store( - config_args: dict[str, dict], + store: dict[str, Any], embedding_name: str, ) -> BaseVectorStore: """Get the embedding description store.""" - num_indexes = len(config_args) - embedding_stores = [] - index_names = [] - for index, store in config_args.items(): - vector_store_type = store["type"] - index_name = create_index_name( - store.get("container_name", "default"), embedding_name - ) + vector_store_type = store["type"] + index_name = create_index_name( + store.get("container_name", "default"), embedding_name + ) - embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get( - "embeddings_schema", {} - ) - single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() + embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get( + "embeddings_schema", {} + ) + embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() - if ( - embeddings_schema is not None - and embedding_name is not None - and embedding_name in embeddings_schema - ): - raw_config = embeddings_schema[embedding_name] - if isinstance(raw_config, dict): - single_embedding_config = VectorStoreSchemaConfig(**raw_config) - else: - single_embedding_config = raw_config + if ( + embeddings_schema is not None + and embedding_name is not None + and embedding_name in embeddings_schema + ): + raw_config = embeddings_schema[embedding_name] + if isinstance(raw_config, dict): + embedding_config = VectorStoreSchemaConfig(**raw_config) + else: + embedding_config = raw_config - if single_embedding_config.index_name is None: - single_embedding_config.index_name = index_name + if embedding_config.index_name is None: + embedding_config.index_name = index_name - embedding_store = VectorStoreFactory().create_vector_store( - vector_store_type=vector_store_type, - vector_store_schema_config=single_embedding_config, - **store, - ) - embedding_store.connect(**store) - # If there is only a single index, return the embedding store directly - if num_indexes == 1: - return embedding_store - embedding_stores.append(embedding_store) - index_names.append(index) - return MultiVectorStore(embedding_stores, index_names) + embedding_store = VectorStoreFactory().create_vector_store( + vector_store_type=vector_store_type, + vector_store_schema_config=embedding_config, + **store, + ) + embedding_store.connect(**store) + + return embedding_store def reformat_context_data(context_data: dict) -> dict: @@ -172,81 +90,6 @@ def reformat_context_data(context_data: dict) -> dict: return final_format -def update_context_data( - context_data: Any, - links: dict[str, Any], -) -> Any: - """ - Update context data with the links dict so that it contains both the index name and community id. - - Parameters - ---------- - - context_data (str | list[pd.DataFrame] | dict[str, pd.DataFrame]): The context data to update. - - links (dict[str, Any]): A dictionary of links to the original dataframes. - - Returns - ------- - str | list[pd.DataFrame] | dict[str, pd.DataFrame]: The updated context data. - """ - updated_context_data = {} - for key in context_data: - entries = context_data[key].to_dict(orient="records") - updated_entry = [] - if key == "reports": - updated_entry = [ - dict( - entry, - index_name=links["community_reports"][int(entry["id"])][ - "index_name" - ], - index_id=links["community_reports"][int(entry["id"])]["id"], - ) - for entry in entries - ] - if key == "entities": - updated_entry = [ - dict( - entry, - entity=entry["entity"].split("-")[0], - index_name=links["entities"][int(entry["id"])]["index_name"], - index_id=links["entities"][int(entry["id"])]["id"], - ) - for entry in entries - ] - if key == "relationships": - updated_entry = [ - dict( - entry, - source=entry["source"].split("-")[0], - target=entry["target"].split("-")[0], - index_name=links["relationships"][int(entry["id"])]["index_name"], - index_id=links["relationships"][int(entry["id"])]["id"], - ) - for entry in entries - ] - if key == "claims": - updated_entry = [ - dict( - entry, - entity=entry["entity"].split("-")[0], - index_name=links["covariates"][int(entry["id"])]["index_name"], - index_id=links["covariates"][int(entry["id"])]["id"], - ) - for entry in entries - ] - if key == "sources": - updated_entry = [ - dict( - entry, - index_name=links["text_units"][int(entry["id"])]["index_name"], - index_id=links["text_units"][int(entry["id"])]["id"], - ) - for entry in entries - ] - updated_context_data[key] = updated_entry - return updated_context_data - - def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None: """ Load the search prompt from disk if configured. diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index f230692560..7588335a87 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -105,23 +105,18 @@ def assert_language_model_configs( def assert_vector_store_configs( - actual: dict[str, VectorStoreConfig], - expected: dict[str, VectorStoreConfig], + actual: VectorStoreConfig, + expected: VectorStoreConfig, ): assert type(actual) is type(expected) - assert len(actual) == len(expected) - for (index_a, store_a), (index_e, store_e) in zip( - actual.items(), expected.items(), strict=True - ): - assert index_a == index_e - assert store_a.type == store_e.type - assert store_a.db_uri == store_e.db_uri - assert store_a.url == store_e.url - assert store_a.api_key == store_e.api_key - assert store_a.audience == store_e.audience - assert store_a.container_name == store_e.container_name - assert store_a.overwrite == store_e.overwrite - assert store_a.database_name == store_e.database_name + assert actual.type == expected.type + assert actual.db_uri == expected.db_uri + assert actual.url == expected.url + assert actual.api_key == expected.api_key + assert actual.audience == expected.audience + assert actual.container_name == expected.container_name + assert actual.overwrite == expected.overwrite + assert actual.database_name == expected.database_name def assert_reporting_configs( @@ -187,7 +182,6 @@ def assert_text_embedding_configs( assert actual.batch_max_tokens == expected.batch_max_tokens assert actual.names == expected.names assert actual.model_id == expected.model_id - assert actual.vector_store_id == expected.vector_store_id def assert_chunking_configs(actual: ChunkingConfig, expected: ChunkingConfig) -> None: @@ -371,14 +365,6 @@ def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) -> assert_reporting_configs(actual.reporting, expected.reporting) assert_output_configs(actual.output, expected.output) - if expected.outputs is not None: - assert actual.outputs is not None - assert len(actual.outputs) == len(expected.outputs) - for a, e in zip(actual.outputs.keys(), expected.outputs.keys(), strict=True): - assert_output_configs(actual.outputs[a], expected.outputs[e]) - else: - assert actual.outputs is None - assert_update_output_configs( actual.update_index_output, expected.update_index_output ) diff --git a/tests/unit/query/context_builder/test_entity_extraction.py b/tests/unit/query/context_builder/test_entity_extraction.py index aa2ff29b09..a0b7b77c52 100644 --- a/tests/unit/query/context_builder/test_entity_extraction.py +++ b/tests/unit/query/context_builder/test_entity_extraction.py @@ -55,9 +55,6 @@ def similarity_search_by_text( key=lambda x: x.score, )[:k] - def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: - return [document for document in self.documents if document.id in include_ids] - def search_by_id(self, id: str) -> VectorStoreDocument: result = self.documents[0] result.id = id @@ -92,27 +89,6 @@ def test_map_query_to_entities(): ), ] - assert map_query_to_entities( - query="t22", - text_embedding_vectorstore=MockBaseVectorStore([ - VectorStoreDocument(id=entity.id, vector=None) for entity in entities - ]), - text_embedder=ModelManager().get_or_create_embedding_model( - model_type="mock_embedding", name="mock" - ), - all_entities_dict={entity.id: entity for entity in entities}, - embedding_vectorstore_key=EntityVectorStoreKey.ID, - k=1, - oversample_scaler=1, - ) == [ - Entity( - id="c4f93564-4507-4ee4-b102-98add401a965", - short_id="sid2", - title="t22", - rank=4, - ) - ] - assert map_query_to_entities( query="t22", text_embedding_vectorstore=MockBaseVectorStore([ @@ -134,32 +110,6 @@ def test_map_query_to_entities(): ) ] - assert map_query_to_entities( - query="", - text_embedding_vectorstore=MockBaseVectorStore([ - VectorStoreDocument(id=entity.id, vector=None) for entity in entities - ]), - text_embedder=ModelManager().get_or_create_embedding_model( - model_type="mock_embedding", name="mock" - ), - all_entities_dict={entity.id: entity for entity in entities}, - embedding_vectorstore_key=EntityVectorStoreKey.ID, - k=2, - ) == [ - Entity( - id="c4f93564-4507-4ee4-b102-98add401a965", - short_id="sid2", - title="t22", - rank=4, - ), - Entity( - id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83", - short_id="sid4", - title="t4444", - rank=3, - ), - ] - assert map_query_to_entities( query="", text_embedding_vectorstore=MockBaseVectorStore([