From 8574a19c0bbbf482df8e2e3f47e18f072fd3fd74 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 10 Oct 2025 13:25:13 -0700 Subject: [PATCH 01/10] Remove multi-search from CLI --- graphrag/cli/query.py | 236 ++++++++---------------------------------- 1 file changed, 45 insertions(+), 191 deletions(-) diff --git a/graphrag/cli/query.py b/graphrag/cli/query.py index 007573485..b7c12d1e8 100644 --- a/graphrag/cli/query.py +++ b/graphrag/cli/query.py @@ -52,35 +52,9 @@ def run_global_search( optional_list=[], ) - # Call the Multi-Index Global Search API - if dataframe_dict["multi-index"]: - final_entities_list = dataframe_dict["entities"] - final_communities_list = dataframe_dict["communities"] - final_community_reports_list = dataframe_dict["community_reports"] - index_names = dataframe_dict["index_names"] - - response, context_data = asyncio.run( - api.multi_index_global_search( - config=config, - entities_list=final_entities_list, - communities_list=final_communities_list, - community_reports_list=final_community_reports_list, - index_names=index_names, - community_level=community_level, - dynamic_community_selection=dynamic_community_selection, - response_type=response_type, - streaming=streaming, - query=query, - verbose=verbose, - ) - ) - print(response) - return response, context_data - - # Otherwise, call the Single-Index Global Search API - final_entities: pd.DataFrame = dataframe_dict["entities"] - final_communities: pd.DataFrame = dataframe_dict["communities"] - final_community_reports: pd.DataFrame = dataframe_dict["community_reports"] + entities: pd.DataFrame = dataframe_dict["entities"] + communities: pd.DataFrame = dataframe_dict["communities"] + community_reports: pd.DataFrame = dataframe_dict["community_reports"] if streaming: @@ -97,9 +71,9 @@ def on_context(context: Any) -> None: async for stream_chunk in api.global_search_streaming( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, + entities=entities, + communities=communities, + community_reports=community_reports, community_level=community_level, dynamic_community_selection=dynamic_community_selection, response_type=response_type, @@ -118,9 +92,9 @@ def on_context(context: Any) -> None: response, context_data = asyncio.run( api.global_search( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, + entities=entities, + communities=communities, + community_reports=community_reports, community_level=community_level, dynamic_community_selection=dynamic_community_selection, response_type=response_type, @@ -166,49 +140,13 @@ def run_local_search( "covariates", ], ) - # Call the Multi-Index Local Search API - if dataframe_dict["multi-index"]: - final_entities_list = dataframe_dict["entities"] - final_communities_list = dataframe_dict["communities"] - final_community_reports_list = dataframe_dict["community_reports"] - final_text_units_list = dataframe_dict["text_units"] - final_relationships_list = dataframe_dict["relationships"] - index_names = dataframe_dict["index_names"] - - # If any covariates tables are missing from any index, set the covariates list to None - if len(dataframe_dict["covariates"]) != dataframe_dict["num_indexes"]: - final_covariates_list = None - else: - final_covariates_list = dataframe_dict["covariates"] - - response, context_data = asyncio.run( - api.multi_index_local_search( - config=config, - entities_list=final_entities_list, - communities_list=final_communities_list, - community_reports_list=final_community_reports_list, - text_units_list=final_text_units_list, - relationships_list=final_relationships_list, - covariates_list=final_covariates_list, - index_names=index_names, - community_level=community_level, - response_type=response_type, - streaming=streaming, - query=query, - verbose=verbose, - ) - ) - print(response) - - return response, context_data - # Otherwise, call the Single-Index Local Search API - final_communities: pd.DataFrame = dataframe_dict["communities"] - final_community_reports: pd.DataFrame = dataframe_dict["community_reports"] - final_text_units: pd.DataFrame = dataframe_dict["text_units"] - final_relationships: pd.DataFrame = dataframe_dict["relationships"] - final_entities: pd.DataFrame = dataframe_dict["entities"] - final_covariates: pd.DataFrame | None = dataframe_dict["covariates"] + communities: pd.DataFrame = dataframe_dict["communities"] + community_reports: pd.DataFrame = dataframe_dict["community_reports"] + text_units: pd.DataFrame = dataframe_dict["text_units"] + relationships: pd.DataFrame = dataframe_dict["relationships"] + entities: pd.DataFrame = dataframe_dict["entities"] + covariates: pd.DataFrame | None = dataframe_dict["covariates"] if streaming: @@ -225,12 +163,12 @@ def on_context(context: Any) -> None: async for stream_chunk in api.local_search_streaming( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, - text_units=final_text_units, - relationships=final_relationships, - covariates=final_covariates, + entities=entities, + communities=communities, + community_reports=community_reports, + text_units=text_units, + relationships=relationships, + covariates=covariates, community_level=community_level, response_type=response_type, query=query, @@ -248,12 +186,12 @@ def on_context(context: Any) -> None: response, context_data = asyncio.run( api.local_search( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, - text_units=final_text_units, - relationships=final_relationships, - covariates=final_covariates, + entities=entities, + communities=communities, + community_reports=community_reports, + text_units=text_units, + relationships=relationships, + covariates=covariates, community_level=community_level, response_type=response_type, query=query, @@ -296,41 +234,11 @@ def run_drift_search( ], ) - # Call the Multi-Index Drift Search API - if dataframe_dict["multi-index"]: - final_entities_list = dataframe_dict["entities"] - final_communities_list = dataframe_dict["communities"] - final_community_reports_list = dataframe_dict["community_reports"] - final_text_units_list = dataframe_dict["text_units"] - final_relationships_list = dataframe_dict["relationships"] - index_names = dataframe_dict["index_names"] - - response, context_data = asyncio.run( - api.multi_index_drift_search( - config=config, - entities_list=final_entities_list, - communities_list=final_communities_list, - community_reports_list=final_community_reports_list, - text_units_list=final_text_units_list, - relationships_list=final_relationships_list, - index_names=index_names, - community_level=community_level, - response_type=response_type, - streaming=streaming, - query=query, - verbose=verbose, - ) - ) - print(response) - - return response, context_data - - # Otherwise, call the Single-Index Drift Search API - final_communities: pd.DataFrame = dataframe_dict["communities"] - final_community_reports: pd.DataFrame = dataframe_dict["community_reports"] - final_text_units: pd.DataFrame = dataframe_dict["text_units"] - final_relationships: pd.DataFrame = dataframe_dict["relationships"] - final_entities: pd.DataFrame = dataframe_dict["entities"] + communities: pd.DataFrame = dataframe_dict["communities"] + community_reports: pd.DataFrame = dataframe_dict["community_reports"] + text_units: pd.DataFrame = dataframe_dict["text_units"] + relationships: pd.DataFrame = dataframe_dict["relationships"] + entities: pd.DataFrame = dataframe_dict["entities"] if streaming: @@ -347,11 +255,11 @@ def on_context(context: Any) -> None: async for stream_chunk in api.drift_search_streaming( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, - text_units=final_text_units, - relationships=final_relationships, + entities=entities, + communities=communities, + community_reports=community_reports, + text_units=text_units, + relationships=relationships, community_level=community_level, response_type=response_type, query=query, @@ -370,11 +278,11 @@ def on_context(context: Any) -> None: response, context_data = asyncio.run( api.drift_search( config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, - text_units=final_text_units, - relationships=final_relationships, + entities=entities, + communities=communities, + community_reports=community_reports, + text_units=text_units, + relationships=relationships, community_level=community_level, response_type=response_type, query=query, @@ -411,27 +319,7 @@ def run_basic_search( ], ) - # Call the Multi-Index Basic Search API - if dataframe_dict["multi-index"]: - final_text_units_list = dataframe_dict["text_units"] - index_names = dataframe_dict["index_names"] - - response, context_data = asyncio.run( - api.multi_index_basic_search( - config=config, - text_units_list=final_text_units_list, - index_names=index_names, - streaming=streaming, - query=query, - verbose=verbose, - ) - ) - print(response) - - return response, context_data - - # Otherwise, call the Single-Index Basic Search API - final_text_units: pd.DataFrame = dataframe_dict["text_units"] + text_units: pd.DataFrame = dataframe_dict["text_units"] if streaming: @@ -448,7 +336,7 @@ def on_context(context: Any) -> None: async for stream_chunk in api.basic_search_streaming( config=config, - text_units=final_text_units, + text_units=text_units, query=query, callbacks=[callbacks], verbose=verbose, @@ -464,7 +352,7 @@ def on_context(context: Any) -> None: response, context_data = asyncio.run( api.basic_search( config=config, - text_units=final_text_units, + text_units=text_units, query=query, verbose=verbose, ) @@ -481,40 +369,6 @@ def _resolve_output_files( ) -> dict[str, Any]: """Read indexing output files to a dataframe dict.""" dataframe_dict = {} - - # Loading output files for multi-index search - if config.outputs: - dataframe_dict["multi-index"] = True - dataframe_dict["num_indexes"] = len(config.outputs) - dataframe_dict["index_names"] = config.outputs.keys() - for output in config.outputs.values(): - storage_obj = create_storage_from_config(output) - for name in output_list: - if name not in dataframe_dict: - dataframe_dict[name] = [] - df_value = asyncio.run( - load_table_from_storage(name=name, storage=storage_obj) - ) - dataframe_dict[name].append(df_value) - - # for optional output files, do not append if the dataframe does not exist - if optional_list: - for optional_file in optional_list: - if optional_file not in dataframe_dict: - dataframe_dict[optional_file] = [] - file_exists = asyncio.run( - storage_has_table(optional_file, storage_obj) - ) - if file_exists: - df_value = asyncio.run( - load_table_from_storage( - name=optional_file, storage=storage_obj - ) - ) - dataframe_dict[optional_file].append(df_value) - return dataframe_dict - # Loading output files for single-index search - dataframe_dict["multi-index"] = False storage_obj = create_storage_from_config(config.output) for name in output_list: df_value = asyncio.run(load_table_from_storage(name=name, storage=storage_obj)) From 9f98fca08704ae9363007e1780df44cbb6390e8e Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 10 Oct 2025 13:42:18 -0700 Subject: [PATCH 02/10] Remove multi-search from API --- graphrag/api/__init__.py | 8 - graphrag/api/query.py | 669 --------------------------------------- graphrag/utils/api.py | 215 ++----------- 3 files changed, 29 insertions(+), 863 deletions(-) diff --git a/graphrag/api/__init__.py b/graphrag/api/__init__.py index a3a06033b..05692c418 100644 --- a/graphrag/api/__init__.py +++ b/graphrag/api/__init__.py @@ -18,10 +18,6 @@ global_search_streaming, local_search, local_search_streaming, - multi_index_basic_search, - multi_index_drift_search, - multi_index_global_search, - multi_index_local_search, ) from graphrag.prompt_tune.types import DocSelectionType @@ -37,10 +33,6 @@ "drift_search_streaming", "basic_search", "basic_search_streaming", - "multi_index_basic_search", - "multi_index_drift_search", - "multi_index_global_search", - "multi_index_local_search", # prompt tuning API "DocSelectionType", "generate_indexing_prompts", diff --git a/graphrag/api/query.py b/graphrag/api/query.py index b92cd8145..b9613a3b2 100644 --- a/graphrag/api/query.py +++ b/graphrag/api/query.py @@ -52,7 +52,6 @@ get_embedding_store, load_search_prompt, truncate, - update_context_data, ) from graphrag.utils.cli import redact @@ -192,152 +191,6 @@ def global_search_streaming( return search_engine.stream_search(query=query) -@validate_call(config={"arbitrary_types_allowed": True}) -async def multi_index_global_search( - config: GraphRagConfig, - entities_list: list[pd.DataFrame], - communities_list: list[pd.DataFrame], - community_reports_list: list[pd.DataFrame], - index_names: list[str], - community_level: int | None, - dynamic_community_selection: bool, - response_type: str, - streaming: bool, - query: str, - callbacks: list[QueryCallbacks] | None = None, - verbose: bool = False, -) -> tuple[ - str | dict[str, Any] | list[dict[str, Any]], - str | list[pd.DataFrame] | dict[str, pd.DataFrame], -]: - """Perform a global search across multiple indexes and return the context data and response. - - Parameters - ---------- - - config (GraphRagConfig): A graphrag configuration (from settings.yaml) - - entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet) - - communities_list (list[pd.DataFrame]): A list of DataFrames containing the final communities (from communities.parquet) - - community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet) - - index_names (list[str]): A list of index names. - - community_level (int): The community level to search at. - - dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search. - - response_type (str): The type of response to return. - - streaming (bool): Whether to stream the results or not. - - query (str): The user query to search for. - - Returns - ------- - TODO: Document the search response type and format. - """ - init_loggers(config=config, verbose=verbose, filename="query.log") - - logger.warning( - "Multi-index search is deprecated and will be removed in GraphRAG v3." - ) - - # Streaming not supported yet - if streaming: - message = "Streaming not yet implemented for multi_global_search" - raise NotImplementedError(message) - - links = { - "communities": {}, - "community_reports": {}, - "entities": {}, - } - max_vals = { - "communities": -1, - "community_reports": -1, - "entities": -1, - } - - communities_dfs = [] - community_reports_dfs = [] - entities_dfs = [] - - for idx, index_name in enumerate(index_names): - # Prepare each index's community reports dataframe for merging - community_reports_df = community_reports_list[idx] - community_reports_df["community"] = community_reports_df["community"].astype( - int - ) - for i in community_reports_df["community"]: - links["community_reports"][i + max_vals["community_reports"] + 1] = { - "index_name": index_name, - "id": str(i), - } - community_reports_df["community"] += max_vals["community_reports"] + 1 - community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1 - max_vals["community_reports"] = int(community_reports_df["community"].max()) - community_reports_dfs.append(community_reports_df) - - # Prepare each index's communities dataframe for merging - communities_df = communities_list[idx] - communities_df["community"] = communities_df["community"].astype(int) - communities_df["parent"] = communities_df["parent"].astype(int) - for i in communities_df["community"]: - links["communities"][i + max_vals["communities"] + 1] = { - "index_name": index_name, - "id": str(i), - } - communities_df["community"] += max_vals["communities"] + 1 - communities_df["parent"] = communities_df["parent"].apply( - lambda x: x if x == -1 else x + max_vals["communities"] + 1 - ) - communities_df["human_readable_id"] += max_vals["communities"] + 1 - # concat the index name to the entity_ids, since this is used for joining later - communities_df["entity_ids"] = communities_df["entity_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["communities"] = int(communities_df["community"].max()) - communities_dfs.append(communities_df) - - # Prepare each index's entities dataframe for merging - entities_df = entities_list[idx] - for i in entities_df["human_readable_id"]: - links["entities"][i + max_vals["entities"] + 1] = { - "index_name": index_name, - "id": i, - } - entities_df["human_readable_id"] += max_vals["entities"] + 1 - entities_df["title"] = entities_df["title"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["entities"] = int(entities_df["human_readable_id"].max()) - entities_dfs.append(entities_df) - - # Merge the dataframes - community_reports_combined = pd.concat( - community_reports_dfs, axis=0, ignore_index=True, sort=False - ) - entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False) - communities_combined = pd.concat( - communities_dfs, axis=0, ignore_index=True, sort=False - ) - - logger.debug("Executing multi-index global search query: %s", query) - result = await global_search( - config, - entities=entities_combined, - communities=communities_combined, - community_reports=community_reports_combined, - community_level=community_level, - dynamic_community_selection=dynamic_community_selection, - response_type=response_type, - query=query, - callbacks=callbacks, - ) - - # Update the context data by linking index names and community ids - context = update_context_data(result[1], links) - - logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore - return (result[0], context) - - @validate_call(config={"arbitrary_types_allowed": True}) async def local_search( config: GraphRagConfig, @@ -472,238 +325,6 @@ def local_search_streaming( return search_engine.stream_search(query=query) -@validate_call(config={"arbitrary_types_allowed": True}) -async def multi_index_local_search( - config: GraphRagConfig, - entities_list: list[pd.DataFrame], - communities_list: list[pd.DataFrame], - community_reports_list: list[pd.DataFrame], - text_units_list: list[pd.DataFrame], - relationships_list: list[pd.DataFrame], - covariates_list: list[pd.DataFrame] | None, - index_names: list[str], - community_level: int, - response_type: str, - streaming: bool, - query: str, - callbacks: list[QueryCallbacks] | None = None, - verbose: bool = False, -) -> tuple[ - str | dict[str, Any] | list[dict[str, Any]], - str | list[pd.DataFrame] | dict[str, pd.DataFrame], -]: - """Perform a local search across multiple indexes and return the context data and response. - - Parameters - ---------- - - config (GraphRagConfig): A graphrag configuration (from settings.yaml) - - entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet) - - community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet) - - text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet) - - relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet) - - covariates_list (list[pd.DataFrame]): [Optional] A list of DataFrames containing the final covariates (from covariates.parquet) - - index_names (list[str]): A list of index names. - - community_level (int): The community level to search at. - - response_type (str): The response type to return. - - streaming (bool): Whether to stream the results or not. - - query (str): The user query to search for. - - Returns - ------- - TODO: Document the search response type and format. - """ - init_loggers(config=config, verbose=verbose, filename="query.log") - - logger.warning( - "Multi-index search is deprecated and will be removed in GraphRAG v3." - ) - # Streaming not supported yet - if streaming: - message = "Streaming not yet implemented for multi_index_local_search" - raise NotImplementedError(message) - - links = { - "community_reports": {}, - "communities": {}, - "entities": {}, - "text_units": {}, - "relationships": {}, - "covariates": {}, - } - max_vals = { - "community_reports": -1, - "communities": -1, - "entities": -1, - "text_units": 0, - "relationships": -1, - "covariates": 0, - } - community_reports_dfs = [] - communities_dfs = [] - entities_dfs = [] - relationships_dfs = [] - text_units_dfs = [] - covariates_dfs = [] - - for idx, index_name in enumerate(index_names): - # Prepare each index's communities dataframe for merging - communities_df = communities_list[idx] - communities_df["community"] = communities_df["community"].astype(int) - for i in communities_df["community"]: - links["communities"][i + max_vals["communities"] + 1] = { - "index_name": index_name, - "id": str(i), - } - communities_df["community"] += max_vals["communities"] + 1 - communities_df["human_readable_id"] += max_vals["communities"] + 1 - # concat the index name to the entity_ids, since this is used for joining later - communities_df["entity_ids"] = communities_df["entity_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["communities"] = int(communities_df["community"].max()) - communities_dfs.append(communities_df) - - # Prepare each index's community reports dataframe for merging - community_reports_df = community_reports_list[idx] - community_reports_df["community"] = community_reports_df["community"].astype( - int - ) - for i in community_reports_df["community"]: - links["community_reports"][i + max_vals["community_reports"] + 1] = { - "index_name": index_name, - "id": str(i), - } - community_reports_df["community"] += max_vals["community_reports"] + 1 - community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1 - max_vals["community_reports"] = int(community_reports_df["community"].max()) - community_reports_dfs.append(community_reports_df) - - # Prepare each index's entities dataframe for merging - entities_df = entities_list[idx] - for i in entities_df["human_readable_id"]: - links["entities"][i + max_vals["entities"] + 1] = { - "index_name": index_name, - "id": i, - } - entities_df["human_readable_id"] += max_vals["entities"] + 1 - entities_df["title"] = entities_df["title"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["id"] = entities_df["id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["entities"] = int(entities_df["human_readable_id"].max()) - entities_dfs.append(entities_df) - - # Prepare each index's relationships dataframe for merging - relationships_df = relationships_list[idx] - for i in relationships_df["human_readable_id"].astype(int): - links["relationships"][i + max_vals["relationships"] + 1] = { - "index_name": index_name, - "id": i, - } - if max_vals["relationships"] != -1: - col = ( - relationships_df["human_readable_id"].astype(int) - + max_vals["relationships"] - + 1 - ) - relationships_df["human_readable_id"] = col.astype(str) - relationships_df["source"] = relationships_df["source"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - relationships_df["target"] = relationships_df["target"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["relationships"] = int(relationships_df["human_readable_id"].max()) - relationships_dfs.append(relationships_df) - - # Prepare each index's text units dataframe for merging - text_units_df = text_units_list[idx] - for i in range(text_units_df.shape[0]): - links["text_units"][i + max_vals["text_units"]] = { - "index_name": index_name, - "id": i, - } - text_units_df["id"] = text_units_df["id"].apply( - lambda x, index_name=index_name: f"{x}-{index_name}" - ) - text_units_df["human_readable_id"] = ( - text_units_df["human_readable_id"] + max_vals["text_units"] - ) - max_vals["text_units"] += text_units_df.shape[0] - text_units_dfs.append(text_units_df) - - # If presents, prepare each index's covariates dataframe for merging - if covariates_list is not None: - covariates_df = covariates_list[idx] - for i in covariates_df["human_readable_id"].astype(int): - links["covariates"][i + max_vals["covariates"]] = { - "index_name": index_name, - "id": i, - } - covariates_df["id"] = covariates_df["id"].apply( - lambda x, index_name=index_name: f"{x}-{index_name}" - ) - covariates_df["human_readable_id"] = ( - covariates_df["human_readable_id"] + max_vals["covariates"] - ) - covariates_df["text_unit_id"] = covariates_df["text_unit_id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - covariates_df["subject_id"] = covariates_df["subject_id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - max_vals["covariates"] += covariates_df.shape[0] - covariates_dfs.append(covariates_df) - - # Merge the dataframes - communities_combined = pd.concat( - communities_dfs, axis=0, ignore_index=True, sort=False - ) - community_reports_combined = pd.concat( - community_reports_dfs, axis=0, ignore_index=True, sort=False - ) - entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False) - relationships_combined = pd.concat( - relationships_dfs, axis=0, ignore_index=True, sort=False - ) - text_units_combined = pd.concat( - text_units_dfs, axis=0, ignore_index=True, sort=False - ) - covariates_combined = None - if len(covariates_dfs) > 0: - covariates_combined = pd.concat( - covariates_dfs, axis=0, ignore_index=True, sort=False - ) - logger.debug("Executing multi-index local search query: %s", query) - result = await local_search( - config, - entities=entities_combined, - communities=communities_combined, - community_reports=community_reports_combined, - text_units=text_units_combined, - relationships=relationships_combined, - covariates=covariates_combined, - community_level=community_level, - response_type=response_type, - query=query, - callbacks=callbacks, - ) - - # Update the context data by linking index names and community ids - context = update_context_data(result[1], links) - - logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore - return (result[0], context) - - @validate_call(config={"arbitrary_types_allowed": True}) async def drift_search( config: GraphRagConfig, @@ -841,218 +462,6 @@ def drift_search_streaming( return search_engine.stream_search(query=query) -@validate_call(config={"arbitrary_types_allowed": True}) -async def multi_index_drift_search( - config: GraphRagConfig, - entities_list: list[pd.DataFrame], - communities_list: list[pd.DataFrame], - community_reports_list: list[pd.DataFrame], - text_units_list: list[pd.DataFrame], - relationships_list: list[pd.DataFrame], - index_names: list[str], - community_level: int, - response_type: str, - streaming: bool, - query: str, - callbacks: list[QueryCallbacks] | None = None, - verbose: bool = False, -) -> tuple[ - str | dict[str, Any] | list[dict[str, Any]], - str | list[pd.DataFrame] | dict[str, pd.DataFrame], -]: - """Perform a DRIFT search across multiple indexes and return the context data and response. - - Parameters - ---------- - - config (GraphRagConfig): A graphrag configuration (from settings.yaml) - - entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet) - - community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet) - - text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet) - - relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet) - - index_names (list[str]): A list of index names. - - community_level (int): The community level to search at. - - response_type (str): The response type to return. - - streaming (bool): Whether to stream the results or not. - - query (str): The user query to search for. - - Returns - ------- - TODO: Document the search response type and format. - """ - init_loggers(config=config, verbose=verbose, filename="query.log") - - logger.warning( - "Multi-index search is deprecated and will be removed in GraphRAG v3." - ) - - # Streaming not supported yet - if streaming: - message = "Streaming not yet implemented for multi_drift_search" - raise NotImplementedError(message) - - links = { - "community_reports": {}, - "communities": {}, - "entities": {}, - "text_units": {}, - "relationships": {}, - } - max_vals = { - "community_reports": -1, - "communities": -1, - "entities": -1, - "text_units": 0, - "relationships": -1, - } - - communities_dfs = [] - community_reports_dfs = [] - entities_dfs = [] - relationships_dfs = [] - text_units_dfs = [] - - for idx, index_name in enumerate(index_names): - # Prepare each index's communities dataframe for merging - communities_df = communities_list[idx] - communities_df["community"] = communities_df["community"].astype(int) - for i in communities_df["community"]: - links["communities"][i + max_vals["communities"] + 1] = { - "index_name": index_name, - "id": str(i), - } - communities_df["community"] += max_vals["communities"] + 1 - communities_df["human_readable_id"] += max_vals["communities"] + 1 - # concat the index name to the entity_ids, since this is used for joining later - communities_df["entity_ids"] = communities_df["entity_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["communities"] = int(communities_df["community"].max()) - communities_dfs.append(communities_df) - - # Prepare each index's community reports dataframe for merging - community_reports_df = community_reports_list[idx] - community_reports_df["community"] = community_reports_df["community"].astype( - int - ) - for i in community_reports_df["community"]: - links["community_reports"][i + max_vals["community_reports"] + 1] = { - "index_name": index_name, - "id": str(i), - } - community_reports_df["community"] += max_vals["community_reports"] + 1 - community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1 - community_reports_df["id"] = community_reports_df["id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - max_vals["community_reports"] = int(community_reports_df["community"].max()) - community_reports_dfs.append(community_reports_df) - - # Prepare each index's entities dataframe for merging - entities_df = entities_list[idx] - for i in entities_df["human_readable_id"]: - links["entities"][i + max_vals["entities"] + 1] = { - "index_name": index_name, - "id": i, - } - entities_df["human_readable_id"] += max_vals["entities"] + 1 - entities_df["title"] = entities_df["title"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["id"] = entities_df["id"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["entities"] = int(entities_df["human_readable_id"].max()) - entities_dfs.append(entities_df) - - # Prepare each index's relationships dataframe for merging - relationships_df = relationships_list[idx] - for i in relationships_df["human_readable_id"].astype(int): - links["relationships"][i + max_vals["relationships"] + 1] = { - "index_name": index_name, - "id": i, - } - if max_vals["relationships"] != -1: - col = ( - relationships_df["human_readable_id"].astype(int) - + max_vals["relationships"] - + 1 - ) - relationships_df["human_readable_id"] = col.astype(str) - relationships_df["source"] = relationships_df["source"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - relationships_df["target"] = relationships_df["target"].apply( - lambda x, index_name=index_name: x + f"-{index_name}" - ) - relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply( - lambda x, index_name=index_name: [i + f"-{index_name}" for i in x] - ) - max_vals["relationships"] = int( - relationships_df["human_readable_id"].astype(int).max() - ) - - relationships_dfs.append(relationships_df) - - # Prepare each index's text units dataframe for merging - text_units_df = text_units_list[idx] - for i in range(text_units_df.shape[0]): - links["text_units"][i + max_vals["text_units"]] = { - "index_name": index_name, - "id": i, - } - text_units_df["id"] = text_units_df["id"].apply( - lambda x, index_name=index_name: f"{x}-{index_name}" - ) - text_units_df["human_readable_id"] = ( - text_units_df["human_readable_id"] + max_vals["text_units"] - ) - max_vals["text_units"] += text_units_df.shape[0] - text_units_dfs.append(text_units_df) - - # Merge the dataframes - communities_combined = pd.concat( - communities_dfs, axis=0, ignore_index=True, sort=False - ) - community_reports_combined = pd.concat( - community_reports_dfs, axis=0, ignore_index=True, sort=False - ) - entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False) - relationships_combined = pd.concat( - relationships_dfs, axis=0, ignore_index=True, sort=False - ) - text_units_combined = pd.concat( - text_units_dfs, axis=0, ignore_index=True, sort=False - ) - - logger.debug("Executing multi-index drift search query: %s", query) - result = await drift_search( - config, - entities=entities_combined, - communities=communities_combined, - community_reports=community_reports_combined, - text_units=text_units_combined, - relationships=relationships_combined, - community_level=community_level, - response_type=response_type, - query=query, - callbacks=callbacks, - ) - - # Update the context data by linking index names and community ids - context = {} - if type(result[1]) is dict: - for key in result[1]: - context[key] = update_context_data(result[1][key], links) - else: - context = result[1] - - logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore - return (result[0], context) - - @validate_call(config={"arbitrary_types_allowed": True}) async def basic_search( config: GraphRagConfig, @@ -1146,81 +555,3 @@ def basic_search_streaming( callbacks=callbacks, ) return search_engine.stream_search(query=query) - - -@validate_call(config={"arbitrary_types_allowed": True}) -async def multi_index_basic_search( - config: GraphRagConfig, - text_units_list: list[pd.DataFrame], - index_names: list[str], - streaming: bool, - query: str, - callbacks: list[QueryCallbacks] | None = None, - verbose: bool = False, -) -> tuple[ - str | dict[str, Any] | list[dict[str, Any]], - str | list[pd.DataFrame] | dict[str, pd.DataFrame], -]: - """Perform a basic search across multiple indexes and return the context data and response. - - Parameters - ---------- - - config (GraphRagConfig): A graphrag configuration (from settings.yaml) - - text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet) - - index_names (list[str]): A list of index names. - - streaming (bool): Whether to stream the results or not. - - query (str): The user query to search for. - - Returns - ------- - TODO: Document the search response type and format. - """ - init_loggers(config=config, verbose=verbose, filename="query.log") - - logger.warning( - "Multi-index search is deprecated and will be removed in GraphRAG v3." - ) - - # Streaming not supported yet - if streaming: - message = "Streaming not yet implemented for multi_basic_search" - raise NotImplementedError(message) - - links = { - "text_units": {}, - } - max_vals = { - "text_units": 0, - } - - text_units_dfs = [] - - for idx, index_name in enumerate(index_names): - # Prepare each index's text units dataframe for merging - text_units_df = text_units_list[idx] - for i in range(text_units_df.shape[0]): - links["text_units"][i + max_vals["text_units"]] = { - "index_name": index_name, - "id": i, - } - text_units_df["id"] = text_units_df["id"].apply( - lambda x, index_name=index_name: f"{x}-{index_name}" - ) - text_units_df["human_readable_id"] = ( - text_units_df["human_readable_id"] + max_vals["text_units"] - ) - max_vals["text_units"] += text_units_df.shape[0] - text_units_dfs.append(text_units_df) - - # Merge the dataframes - text_units_combined = pd.concat( - text_units_dfs, axis=0, ignore_index=True, sort=False - ) - - logger.debug("Executing multi-index basic search query: %s", query) - return await basic_search( - config, - text_units=text_units_combined, - query=query, - callbacks=callbacks, - ) diff --git a/graphrag/utils/api.py b/graphrag/utils/api.py index db3d94790..ca7b4b4b2 100644 --- a/graphrag/utils/api.py +++ b/graphrag/utils/api.py @@ -4,7 +4,6 @@ """API functions for the GraphRAG module.""" from pathlib import Path -from typing import Any from graphrag.cache.factory import CacheFactory from graphrag.cache.pipeline_cache import PipelineCache @@ -12,133 +11,52 @@ from graphrag.config.models.cache_config import CacheConfig from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig -from graphrag.data_model.types import TextEmbedder from graphrag.storage.factory import StorageFactory from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.vector_stores.base import ( BaseVectorStore, - VectorStoreDocument, - VectorStoreSearchResult, ) from graphrag.vector_stores.factory import VectorStoreFactory -class MultiVectorStore(BaseVectorStore): - """Multi Vector Store wrapper implementation.""" - - def __init__( - self, - embedding_stores: list[BaseVectorStore], - index_names: list[str], - ): - self.embedding_stores = embedding_stores - self.index_names = index_names - - def load_documents( - self, documents: list[VectorStoreDocument], overwrite: bool = True - ) -> None: - """Load documents into the vector store.""" - msg = "load_documents method not implemented" - raise NotImplementedError(msg) - - def connect(self, **kwargs: Any) -> Any: - """Connect to vector storage.""" - msg = "connect method not implemented" - raise NotImplementedError(msg) - - def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: - """Build a query filter to filter documents by id.""" - msg = "filter_by_id method not implemented" - raise NotImplementedError(msg) - - def search_by_id(self, id: str) -> VectorStoreDocument: - """Search for a document by id.""" - search_index_id = id.split("-")[0] - search_index_name = id.split("-")[1] - for index_name, embedding_store in zip( - self.index_names, self.embedding_stores, strict=False - ): - if index_name == search_index_name: - return embedding_store.search_by_id(search_index_id) - else: - message = f"Index {search_index_name} not found." - raise ValueError(message) - - def similarity_search_by_vector( - self, query_embedding: list[float], k: int = 10, **kwargs: Any - ) -> list[VectorStoreSearchResult]: - """Perform a vector-based similarity search.""" - all_results = [] - for index_name, embedding_store in zip( - self.index_names, self.embedding_stores, strict=False - ): - results = embedding_store.similarity_search_by_vector( - query_embedding=query_embedding, k=k - ) - mod_results = [] - for r in results: - r.document.id = str(r.document.id) + f"-{index_name}" - mod_results += [r] - all_results += mod_results - return sorted(all_results, key=lambda x: x.score, reverse=True)[:k] - - def similarity_search_by_text( - self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any - ) -> list[VectorStoreSearchResult]: - """Perform a text-based similarity search.""" - query_embedding = text_embedder(text) - if query_embedding: - return self.similarity_search_by_vector( - query_embedding=query_embedding, k=k - ) - return [] - - def get_embedding_store( config_args: dict[str, dict], embedding_name: str, ) -> BaseVectorStore: """Get the embedding description store.""" - num_indexes = len(config_args) - embedding_stores = [] - index_names = [] - for index, store in config_args.items(): - vector_store_type = store["type"] - index_name = create_index_name( - store.get("container_name", "default"), embedding_name - ) + store = next(iter(config_args.values())) + vector_store_type = store["type"] + index_name = create_index_name( + store.get("container_name", "default"), embedding_name + ) - embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get( - "embeddings_schema", {} - ) - single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() + embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get( + "embeddings_schema", {} + ) + embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() - if ( - embeddings_schema is not None - and embedding_name is not None - and embedding_name in embeddings_schema - ): - raw_config = embeddings_schema[embedding_name] - if isinstance(raw_config, dict): - single_embedding_config = VectorStoreSchemaConfig(**raw_config) - else: - single_embedding_config = raw_config + if ( + embeddings_schema is not None + and embedding_name is not None + and embedding_name in embeddings_schema + ): + raw_config = embeddings_schema[embedding_name] + if isinstance(raw_config, dict): + embedding_config = VectorStoreSchemaConfig(**raw_config) + else: + embedding_config = raw_config - if single_embedding_config.index_name is None: - single_embedding_config.index_name = index_name + if embedding_config.index_name is None: + embedding_config.index_name = index_name - embedding_store = VectorStoreFactory().create_vector_store( - vector_store_type=vector_store_type, - vector_store_schema_config=single_embedding_config, - **store, - ) - embedding_store.connect(**store) - # If there is only a single index, return the embedding store directly - if num_indexes == 1: - return embedding_store - embedding_stores.append(embedding_store) - index_names.append(index) - return MultiVectorStore(embedding_stores, index_names) + embedding_store = VectorStoreFactory().create_vector_store( + vector_store_type=vector_store_type, + vector_store_schema_config=embedding_config, + **store, + ) + embedding_store.connect(**store) + + return embedding_store def reformat_context_data(context_data: dict) -> dict: @@ -172,81 +90,6 @@ def reformat_context_data(context_data: dict) -> dict: return final_format -def update_context_data( - context_data: Any, - links: dict[str, Any], -) -> Any: - """ - Update context data with the links dict so that it contains both the index name and community id. - - Parameters - ---------- - - context_data (str | list[pd.DataFrame] | dict[str, pd.DataFrame]): The context data to update. - - links (dict[str, Any]): A dictionary of links to the original dataframes. - - Returns - ------- - str | list[pd.DataFrame] | dict[str, pd.DataFrame]: The updated context data. - """ - updated_context_data = {} - for key in context_data: - entries = context_data[key].to_dict(orient="records") - updated_entry = [] - if key == "reports": - updated_entry = [ - dict( - entry, - index_name=links["community_reports"][int(entry["id"])][ - "index_name" - ], - index_id=links["community_reports"][int(entry["id"])]["id"], - ) - for entry in entries - ] - if key == "entities": - updated_entry = [ - dict( - entry, - entity=entry["entity"].split("-")[0], - index_name=links["entities"][int(entry["id"])]["index_name"], - index_id=links["entities"][int(entry["id"])]["id"], - ) - for entry in entries - ] - if key == "relationships": - updated_entry = [ - dict( - entry, - source=entry["source"].split("-")[0], - target=entry["target"].split("-")[0], - index_name=links["relationships"][int(entry["id"])]["index_name"], - index_id=links["relationships"][int(entry["id"])]["id"], - ) - for entry in entries - ] - if key == "claims": - updated_entry = [ - dict( - entry, - entity=entry["entity"].split("-")[0], - index_name=links["covariates"][int(entry["id"])]["index_name"], - index_id=links["covariates"][int(entry["id"])]["id"], - ) - for entry in entries - ] - if key == "sources": - updated_entry = [ - dict( - entry, - index_name=links["text_units"][int(entry["id"])]["index_name"], - index_id=links["text_units"][int(entry["id"])]["id"], - ) - for entry in entries - ] - updated_context_data[key] = updated_entry - return updated_context_data - - def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None: """ Load the search prompt from disk if configured. From 9ca2b4742d931310da38828c64902a87fe5796e4 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 10 Oct 2025 13:57:23 -0700 Subject: [PATCH 03/10] Flatten vector_store config --- graphrag/api/query.py | 23 +++------- graphrag/config/defaults.py | 7 +-- graphrag/config/get_vector_store_settings.py | 4 +- graphrag/config/init_content.py | 8 ++-- graphrag/config/models/embed_text_config.py | 4 -- graphrag/config/models/graph_rag_config.py | 45 ++++---------------- graphrag/utils/api.py | 4 +- tests/unit/config/utils.py | 26 +++++------ 8 files changed, 33 insertions(+), 88 deletions(-) diff --git a/graphrag/api/query.py b/graphrag/api/query.py index b9613a3b2..f9923675c 100644 --- a/graphrag/api/query.py +++ b/graphrag/api/query.py @@ -294,14 +294,11 @@ def local_search_streaming( """ init_loggers(config=config, verbose=verbose, filename="query.log") - vector_store_args = {} - for index, store in config.vector_store.items(): - vector_store_args[index] = store.model_dump() - msg = f"Vector Store Args: {redact(vector_store_args)}" + msg = f"Vector Store Args: {redact(config.vector_store.model_dump())}" logger.debug(msg) description_embedding_store = get_embedding_store( - config_args=vector_store_args, + store=config.vector_store.model_dump(), embedding_name=entity_description_embedding, ) @@ -422,19 +419,16 @@ def drift_search_streaming( """ init_loggers(config=config, verbose=verbose, filename="query.log") - vector_store_args = {} - for index, store in config.vector_store.items(): - vector_store_args[index] = store.model_dump() - msg = f"Vector Store Args: {redact(vector_store_args)}" + msg = f"Vector Store Args: {redact(config.vector_store.model_dump())}" logger.debug(msg) description_embedding_store = get_embedding_store( - config_args=vector_store_args, + store=config.vector_store.model_dump(), embedding_name=entity_description_embedding, ) full_content_embedding_store = get_embedding_store( - config_args=vector_store_args, + store=config.vector_store.model_dump(), embedding_name=community_full_content_embedding, ) @@ -533,14 +527,11 @@ def basic_search_streaming( """ init_loggers(config=config, verbose=verbose, filename="query.log") - vector_store_args = {} - for index, store in config.vector_store.items(): - vector_store_args[index] = store.model_dump() - msg = f"Vector Store Args: {redact(vector_store_args)}" + msg = f"Vector Store Args: {redact(config.vector_store.model_dump())}" logger.debug(msg) embedding_store = get_embedding_store( - config_args=vector_store_args, + store=config.vector_store.model_dump(), embedding_name=text_unit_text_embedding, ) diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index eadafd986..048459056 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -55,8 +55,6 @@ DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large" DEFAULT_MODEL_PROVIDER = "openai" -DEFAULT_VECTOR_STORE_ID = "default_vector_store" - ENCODING_MODEL = "o200k_base" COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default" @@ -167,7 +165,6 @@ class EmbedTextDefaults: batch_max_tokens: int = 8191 names: list[str] = field(default_factory=lambda: default_embeddings) strategy: None = None - vector_store_id: str = DEFAULT_VECTOR_STORE_ID @dataclass @@ -444,8 +441,8 @@ class GraphRagConfigDefaults: global_search: GlobalSearchDefaults = field(default_factory=GlobalSearchDefaults) drift_search: DriftSearchDefaults = field(default_factory=DriftSearchDefaults) basic_search: BasicSearchDefaults = field(default_factory=BasicSearchDefaults) - vector_store: dict[str, VectorStoreDefaults] = field( - default_factory=lambda: {DEFAULT_VECTOR_STORE_ID: VectorStoreDefaults()} + vector_store: VectorStoreDefaults = field( + default_factory=lambda: VectorStoreDefaults() ) workflows: None = None diff --git a/graphrag/config/get_vector_store_settings.py b/graphrag/config/get_vector_store_settings.py index 3771d65ff..779f8a7f1 100644 --- a/graphrag/config/get_vector_store_settings.py +++ b/graphrag/config/get_vector_store_settings.py @@ -11,9 +11,7 @@ def get_vector_store_settings( vector_store_params: dict | None = None, ) -> dict: """Transform GraphRAG config into settings for workflows.""" - vector_store_settings = settings.get_vector_store_config( - settings.embed_text.vector_store_id - ).model_dump() + vector_store_settings = settings.vector_store.model_dump() # # If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding. diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index aadacf8f3..69b05015f 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -75,16 +75,14 @@ base_dir: "{graphrag_config_defaults.reporting.base_dir}" vector_store: - {defs.DEFAULT_VECTOR_STORE_ID}: - type: {vector_store_defaults.type} - db_uri: {vector_store_defaults.db_uri} - container_name: {vector_store_defaults.container_name} + type: {vector_store_defaults.type} + db_uri: {vector_store_defaults.db_uri} + container_name: {vector_store_defaults.container_name} ### Workflow settings ### embed_text: model_id: {graphrag_config_defaults.embed_text.model_id} - vector_store_id: {graphrag_config_defaults.embed_text.vector_store_id} extract_graph: model_id: {graphrag_config_defaults.extract_graph.model_id} diff --git a/graphrag/config/models/embed_text_config.py b/graphrag/config/models/embed_text_config.py index f785bf6ee..5e5963811 100644 --- a/graphrag/config/models/embed_text_config.py +++ b/graphrag/config/models/embed_text_config.py @@ -15,10 +15,6 @@ class EmbedTextConfig(BaseModel): description="The model ID to use for text embeddings.", default=graphrag_config_defaults.embed_text.model_id, ) - vector_store_id: str = Field( - description="The vector store ID to use for text embeddings.", - default=graphrag_config_defaults.embed_text.vector_store_id, - ) batch_size: int = Field( description="The batch size to use.", default=graphrag_config_defaults.embed_text.batch_size, diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index 2b5961321..e430c2fa2 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -3,7 +3,6 @@ """Parameterization settings for the default configuration.""" -from dataclasses import asdict from pathlib import Path from devtools import pformat @@ -234,12 +233,8 @@ def _validate_reporting_base_dir(self) -> None: (Path(self.root_dir) / self.reporting.base_dir).resolve() ) - vector_store: dict[str, VectorStoreConfig] = Field( - description="The vector store configuration.", - default_factory=lambda: { - k: VectorStoreConfig(**asdict(v)) - for k, v in graphrag_config_defaults.vector_store.items() - }, + vector_store: VectorStoreConfig = Field( + description="The vector store configuration.", default=VectorStoreConfig() ) """The vector store configuration.""" @@ -327,12 +322,12 @@ def _validate_reporting_base_dir(self) -> None: def _validate_vector_store_db_uri(self) -> None: """Validate the vector store configuration.""" - for store in self.vector_store.values(): - if store.type == VectorStoreType.LanceDB: - if not store.db_uri or store.db_uri.strip == "": - msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration." - raise ValueError(msg) - store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve()) + store = self.vector_store + if store.type == VectorStoreType.LanceDB: + if not store.db_uri or store.db_uri.strip == "": + msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration." + raise ValueError(msg) + store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve()) def _validate_factories(self) -> None: """Validate the factories used in the configuration.""" @@ -363,30 +358,6 @@ def get_language_model_config(self, model_id: str) -> LanguageModelConfig: return self.models[model_id] - def get_vector_store_config(self, vector_store_id: str) -> VectorStoreConfig: - """Get a vector store configuration by ID. - - Parameters - ---------- - vector_store_id : str - The ID of the vector store to get. Should match an ID in the vector_store list. - - Returns - ------- - VectorStoreConfig - The vector store configuration if found. - - Raises - ------ - ValueError - If the vector store ID is not found in the configuration. - """ - if vector_store_id not in self.vector_store: - err_msg = f"Vector Store ID {vector_store_id} not found in configuration. Please rerun `graphrag init` and set the vector store configuration." - raise ValueError(err_msg) - - return self.vector_store[vector_store_id] - @model_validator(mode="after") def _validate_model(self): """Validate the model configuration.""" diff --git a/graphrag/utils/api.py b/graphrag/utils/api.py index ca7b4b4b2..16a5f9ed5 100644 --- a/graphrag/utils/api.py +++ b/graphrag/utils/api.py @@ -4,6 +4,7 @@ """API functions for the GraphRAG module.""" from pathlib import Path +from typing import Any from graphrag.cache.factory import CacheFactory from graphrag.cache.pipeline_cache import PipelineCache @@ -20,11 +21,10 @@ def get_embedding_store( - config_args: dict[str, dict], + store: dict[str, Any], embedding_name: str, ) -> BaseVectorStore: """Get the embedding description store.""" - store = next(iter(config_args.values())) vector_store_type = store["type"] index_name = create_index_name( store.get("container_name", "default"), embedding_name diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index f23069256..7e15ef1bc 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -105,23 +105,18 @@ def assert_language_model_configs( def assert_vector_store_configs( - actual: dict[str, VectorStoreConfig], - expected: dict[str, VectorStoreConfig], + actual: VectorStoreConfig, + expected: VectorStoreConfig, ): assert type(actual) is type(expected) - assert len(actual) == len(expected) - for (index_a, store_a), (index_e, store_e) in zip( - actual.items(), expected.items(), strict=True - ): - assert index_a == index_e - assert store_a.type == store_e.type - assert store_a.db_uri == store_e.db_uri - assert store_a.url == store_e.url - assert store_a.api_key == store_e.api_key - assert store_a.audience == store_e.audience - assert store_a.container_name == store_e.container_name - assert store_a.overwrite == store_e.overwrite - assert store_a.database_name == store_e.database_name + assert actual.type == expected.type + assert actual.db_uri == expected.db_uri + assert actual.url == expected.url + assert actual.api_key == expected.api_key + assert actual.audience == expected.audience + assert actual.container_name == expected.container_name + assert actual.overwrite == expected.overwrite + assert actual.database_name == expected.database_name def assert_reporting_configs( @@ -187,7 +182,6 @@ def assert_text_embedding_configs( assert actual.batch_max_tokens == expected.batch_max_tokens assert actual.names == expected.names assert actual.model_id == expected.model_id - assert actual.vector_store_id == expected.vector_store_id def assert_chunking_configs(actual: ChunkingConfig, expected: ChunkingConfig) -> None: From c64fb2f9837684d92ffd0d7d366c441e42cd9b21 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 10 Oct 2025 14:24:07 -0700 Subject: [PATCH 04/10] Push hydrated vector store down to embed_text --- graphrag/config/get_vector_store_settings.py | 24 ---- .../index/operations/embed_text/embed_text.py | 122 +----------------- .../workflows/generate_text_embeddings.py | 67 ++++++++-- .../index/workflows/update_text_embeddings.py | 4 +- 4 files changed, 62 insertions(+), 155 deletions(-) delete mode 100644 graphrag/config/get_vector_store_settings.py diff --git a/graphrag/config/get_vector_store_settings.py b/graphrag/config/get_vector_store_settings.py deleted file mode 100644 index 779f8a7f1..000000000 --- a/graphrag/config/get_vector_store_settings.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing get_vector_store_settings.""" - -from graphrag.config.models.graph_rag_config import GraphRagConfig - - -def get_vector_store_settings( - settings: GraphRagConfig, - vector_store_params: dict | None = None, -) -> dict: - """Transform GraphRAG config into settings for workflows.""" - vector_store_settings = settings.vector_store.model_dump() - - # - # If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding. - # settings.vector_store.base contains connection information, or may be undefined - # settings.vector_store. contains the specific settings for this embedding - # - return { - **(vector_store_params or {}), - **(vector_store_settings), - } diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index 04099f810..f01682604 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -9,13 +9,10 @@ import pandas as pd from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.config.embeddings import create_index_name -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.index.operations.embed_text.run_embed_text import run_embed_text from graphrag.language_model.protocol.base import EmbeddingModel from graphrag.tokenizer.tokenizer import Tokenizer from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument -from graphrag.vector_stores.factory import VectorStoreFactory logger = logging.getLogger(__name__) @@ -26,86 +23,14 @@ async def embed_text( model: EmbeddingModel, tokenizer: Tokenizer, embed_column: str, - embedding_name: str, batch_size: int, batch_max_tokens: int, num_threads: int, - vector_store_config: dict, + vector_store: BaseVectorStore, id_column: str = "id", title_column: str | None = None, ): """Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector.""" - if vector_store_config: - index_name = _get_index_name(vector_store_config, embedding_name) - vector_store: BaseVectorStore = _create_vector_store( - vector_store_config, index_name, embedding_name - ) - vector_store_workflow_config = vector_store_config.get( - embedding_name, vector_store_config - ) - return await _text_embed_with_vector_store( - input=input, - callbacks=callbacks, - model=model, - tokenizer=tokenizer, - embed_column=embed_column, - vector_store=vector_store, - vector_store_config=vector_store_workflow_config, - batch_size=batch_size, - batch_max_tokens=batch_max_tokens, - num_threads=num_threads, - id_column=id_column, - title_column=title_column, - ) - - return await _text_embed_in_memory( - input=input, - callbacks=callbacks, - model=model, - tokenizer=tokenizer, - embed_column=embed_column, - batch_size=batch_size, - batch_max_tokens=batch_max_tokens, - num_threads=num_threads, - ) - - -async def _text_embed_in_memory( - input: pd.DataFrame, - callbacks: WorkflowCallbacks, - model: EmbeddingModel, - tokenizer: Tokenizer, - embed_column: str, - batch_size: int, - batch_max_tokens: int, - num_threads: int, -): - texts: list[str] = input[embed_column].tolist() - result = await run_embed_text( - texts, callbacks, model, tokenizer, batch_size, batch_max_tokens, num_threads - ) - - return result.embeddings - - -async def _text_embed_with_vector_store( - input: pd.DataFrame, - callbacks: WorkflowCallbacks, - model: EmbeddingModel, - tokenizer: Tokenizer, - embed_column: str, - vector_store: BaseVectorStore, - vector_store_config: dict, - batch_size: int, - batch_max_tokens: int, - num_threads: int, - id_column: str, - title_column: str | None = None, -): - # Get vector-storage configuration - - overwrite: bool = vector_store_config.get("overwrite", True) - if embed_column not in input.columns: msg = f"Column {embed_column} not found in input dataframe with columns {input.columns}" raise ValueError(msg) @@ -168,51 +93,8 @@ async def _text_embed_with_vector_store( ) documents.append(document) - vector_store.load_documents(documents, overwrite and i == 0) + vector_store.load_documents(documents, True) starting_index += len(documents) i += 1 return all_results - - -def _create_vector_store( - vector_store_config: dict, index_name: str, embedding_name: str | None = None -) -> BaseVectorStore: - vector_store_type: str = str(vector_store_config.get("type")) - - embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get( - "embeddings_schema", {} - ) - single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() - - if ( - embeddings_schema is not None - and embedding_name is not None - and embedding_name in embeddings_schema - ): - raw_config = embeddings_schema[embedding_name] - if isinstance(raw_config, dict): - single_embedding_config = VectorStoreSchemaConfig(**raw_config) - else: - single_embedding_config = raw_config - - if single_embedding_config.index_name is None: - single_embedding_config.index_name = index_name - - vector_store = VectorStoreFactory().create_vector_store( - vector_store_schema_config=single_embedding_config, - vector_store_type=vector_store_type, - **vector_store_config, - ) - - vector_store.connect(**vector_store_config) - return vector_store - - -def _get_index_name(vector_store_config: dict, embedding_name: str) -> str: - container_name = vector_store_config.get("container_name", "default") - index_name = create_index_name(container_name, embedding_name) - - msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}" - logger.info(msg) - return index_name diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py index 3a0ae4125..848b57d29 100644 --- a/graphrag/index/workflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -12,14 +12,16 @@ community_full_content_embedding, community_summary_embedding, community_title_embedding, + create_index_name, document_text_embedding, entity_description_embedding, entity_title_embedding, relationship_description_embedding, text_unit_text_embedding, ) -from graphrag.config.get_vector_store_settings import get_vector_store_settings from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.models.vector_store_config import VectorStoreConfig +from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.index.operations.embed_text.embed_text import embed_text from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -31,6 +33,8 @@ load_table_from_storage, write_table_to_storage, ) +from graphrag.vector_stores.base import BaseVectorStore +from graphrag.vector_stores.factory import VectorStoreFactory logger = logging.getLogger(__name__) @@ -70,8 +74,6 @@ async def run_workflow( "community_reports", context.output_storage ) - vector_store_config = get_vector_store_settings(config) - model_config = config.get_language_model_config(config.embed_text.model_id) model = ModelManager().get_or_create_embedding_model( @@ -96,7 +98,7 @@ async def run_workflow( batch_size=config.embed_text.batch_size, batch_max_tokens=config.embed_text.batch_max_tokens, num_threads=model_config.concurrent_requests, - vector_store_config=vector_store_config, + vector_store_config=config.vector_store, embedded_fields=embedded_fields, ) @@ -124,7 +126,7 @@ async def generate_text_embeddings( batch_size: int, batch_max_tokens: int, num_threads: int, - vector_store_config: dict, + vector_store_config: VectorStoreConfig, embedded_fields: list[str], ) -> dict[str, pd.DataFrame]: """All the steps to generate all embeddings.""" @@ -208,20 +210,69 @@ async def _run_embeddings( batch_size: int, batch_max_tokens: int, num_threads: int, - vector_store_config: dict, + vector_store_config: VectorStoreConfig, ) -> pd.DataFrame: """All the steps to generate single embedding.""" + index_name = _get_index_name(vector_store_config, name) + vector_store = _create_vector_store(vector_store_config, index_name, name) + data["embedding"] = await embed_text( input=data, callbacks=callbacks, model=model, tokenizer=tokenizer, embed_column=embed_column, - embedding_name=name, batch_size=batch_size, batch_max_tokens=batch_max_tokens, num_threads=num_threads, - vector_store_config=vector_store_config, + vector_store=vector_store, ) return data.loc[:, ["id", "embedding"]] + + +def _create_vector_store( + vector_store_config: VectorStoreConfig, + index_name: str, + embedding_name: str | None = None, +) -> BaseVectorStore: + vector_store_type: str = str(vector_store_config.type) + + embeddings_schema: dict[str, VectorStoreSchemaConfig] = ( + vector_store_config.embeddings_schema + ) + + single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() + + if ( + embeddings_schema is not None + and embedding_name is not None + and embedding_name in embeddings_schema + ): + raw_config = embeddings_schema[embedding_name] + if isinstance(raw_config, dict): + single_embedding_config = VectorStoreSchemaConfig(**raw_config) + else: + single_embedding_config = raw_config + + if single_embedding_config.index_name is None: + single_embedding_config.index_name = index_name + + args = vector_store_config.model_dump() + vector_store = VectorStoreFactory().create_vector_store( + vector_store_schema_config=single_embedding_config, + vector_store_type=vector_store_type, + **args, + ) + + vector_store.connect(**args) + return vector_store + + +def _get_index_name(vector_store_config: VectorStoreConfig, embedding_name: str) -> str: + container_name = vector_store_config.container_name + index_name = create_index_name(container_name, embedding_name) + + msg = f"using vector store {vector_store_config.type} with container_name {container_name} for embedding {embedding_name}: {index_name}" + logger.info(msg) + return index_name diff --git a/graphrag/index/workflows/update_text_embeddings.py b/graphrag/index/workflows/update_text_embeddings.py index 4b349f580..0fc5f1fde 100644 --- a/graphrag/index/workflows/update_text_embeddings.py +++ b/graphrag/index/workflows/update_text_embeddings.py @@ -5,7 +5,6 @@ import logging -from graphrag.config.get_vector_store_settings import get_vector_store_settings from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext @@ -37,7 +36,6 @@ async def run_workflow( ] embedded_fields = config.embed_text.names - vector_store_config = get_vector_store_settings(config) model_config = config.get_language_model_config(config.embed_text.model_id) @@ -63,7 +61,7 @@ async def run_workflow( batch_size=config.embed_text.batch_size, batch_max_tokens=config.embed_text.batch_max_tokens, num_threads=model_config.concurrent_requests, - vector_store_config=vector_store_config, + vector_store_config=config.vector_store, embedded_fields=embedded_fields, ) if config.snapshots.embeddings: From 94bbeb6eac42c8314b7d6e05db2e1e1b75bc099b Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 10 Oct 2025 14:26:50 -0700 Subject: [PATCH 05/10] Remove outputs from config --- graphrag/config/defaults.py | 1 - graphrag/config/models/graph_rag_config.py | 18 ------------------ tests/unit/config/utils.py | 8 -------- 3 files changed, 27 deletions(-) diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 048459056..e81c5ac31 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -415,7 +415,6 @@ class GraphRagConfigDefaults: reporting: ReportingDefaults = field(default_factory=ReportingDefaults) storage: StorageDefaults = field(default_factory=StorageDefaults) output: OutputDefaults = field(default_factory=OutputDefaults) - outputs: None = None update_index_output: UpdateIndexOutputDefaults = field( default_factory=UpdateIndexOutputDefaults ) diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index e430c2fa2..c8cdca819 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -178,23 +178,6 @@ def _validate_output_base_dir(self) -> None: (Path(self.root_dir) / self.output.base_dir).resolve() ) - outputs: dict[str, StorageConfig] | None = Field( - description="A list of output configurations used for multi-index query.", - default=graphrag_config_defaults.outputs, - ) - - def _validate_multi_output_base_dirs(self) -> None: - """Validate the outputs dict base directories.""" - if self.outputs: - for output in self.outputs.values(): - if output.type == defs.StorageType.file: - if output.base_dir.strip() == "": - msg = "Output base directory is required for file output. Please rerun `graphrag init` and set the output configuration." - raise ValueError(msg) - output.base_dir = str( - (Path(self.root_dir) / output.base_dir).resolve() - ) - update_index_output: StorageConfig = Field( description="The output configuration for the updated index.", default=StorageConfig( @@ -367,7 +350,6 @@ def _validate_model(self): self._validate_input_base_dir() self._validate_reporting_base_dir() self._validate_output_base_dir() - self._validate_multi_output_base_dirs() self._validate_update_index_output_base_dir() self._validate_vector_store_db_uri() self._validate_factories() diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index 7e15ef1bc..7588335a8 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -365,14 +365,6 @@ def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) -> assert_reporting_configs(actual.reporting, expected.reporting) assert_output_configs(actual.output, expected.output) - if expected.outputs is not None: - assert actual.outputs is not None - assert len(actual.outputs) == len(expected.outputs) - for a, e in zip(actual.outputs.keys(), expected.outputs.keys(), strict=True): - assert_output_configs(actual.outputs[a], expected.outputs[e]) - else: - assert actual.outputs is None - assert_update_output_configs( actual.update_index_output, expected.update_index_output ) From 2d9ea5a70fcca9ef5497f35477fa07d1f80fd1be Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 10 Oct 2025 14:29:44 -0700 Subject: [PATCH 06/10] Remove multi-search notebook/docs --- .../multi_index_search.ipynb | 558 ------------------ docs/query/multi_index_search.md | 20 - 2 files changed, 578 deletions(-) delete mode 100644 docs/examples_notebooks/multi_index_search.ipynb delete mode 100644 docs/query/multi_index_search.md diff --git a/docs/examples_notebooks/multi_index_search.ipynb b/docs/examples_notebooks/multi_index_search.ipynb deleted file mode 100644 index 2e70ed508..000000000 --- a/docs/examples_notebooks/multi_index_search.ipynb +++ /dev/null @@ -1,558 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright (c) 2024 Microsoft Corporation.\n", - "# Licensed under the MIT License." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Multi Index Search\n", - "This notebook demonstrates multi-index search using the GraphRAG API.\n", - "\n", - "Indexes created from Wikipedia state articles for Alaska, California, DC, Maryland, NY and Washington are used." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import asyncio\n", - "\n", - "import pandas as pd\n", - "\n", - "from graphrag.api.query import (\n", - " multi_index_basic_search,\n", - " multi_index_drift_search,\n", - " multi_index_global_search,\n", - " multi_index_local_search,\n", - ")\n", - "from graphrag.config.create_graphrag_config import create_graphrag_config\n", - "\n", - "indexes = [\"alaska\", \"california\", \"dc\", \"maryland\", \"ny\", \"washington\"]\n", - "indexes = sorted(indexes)\n", - "\n", - "print(indexes)\n", - "\n", - "vector_store_configs = {\n", - " index: {\n", - " \"type\": \"lancedb\",\n", - " \"db_uri\": f\"inputs/{index}/lancedb\",\n", - " \"container_name\": \"default\",\n", - " \"overwrite\": True,\n", - " \"index_name\": f\"{index}\",\n", - " }\n", - " for index in indexes\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config_data = {\n", - " \"models\": {\n", - " \"default_chat_model\": {\n", - " \"model_supports_json\": True,\n", - " \"parallelization_num_threads\": 50,\n", - " \"parallelization_stagger\": 0.3,\n", - " \"async_mode\": \"threaded\",\n", - " \"type\": \"azure_openai_chat\",\n", - " \"model\": \"gpt-4o\",\n", - " \"auth_type\": \"azure_managed_identity\",\n", - " \"api_base\": \"\",\n", - " \"api_version\": \"2024-02-15-preview\",\n", - " \"deployment_name\": \"gpt-4o\",\n", - " },\n", - " \"default_embedding_model\": {\n", - " \"parallelization_num_threads\": 50,\n", - " \"parallelization_stagger\": 0.3,\n", - " \"async_mode\": \"threaded\",\n", - " \"type\": \"azure_openai_embedding\",\n", - " \"model\": \"text-embedding-3-large\",\n", - " \"auth_type\": \"azure_managed_identity\",\n", - " \"api_base\": \"\",\n", - " \"api_version\": \"2024-02-15-preview\",\n", - " \"deployment_name\": \"text-embedding-3-large\",\n", - " },\n", - " },\n", - " \"vector_store\": vector_store_configs,\n", - " \"local_search\": {\n", - " \"prompt\": \"prompts/local_search_system_prompt.txt\",\n", - " \"llm_max_tokens\": 12000,\n", - " },\n", - " \"global_search\": {\n", - " \"map_prompt\": \"prompts/global_search_map_system_prompt.txt\",\n", - " \"reduce_prompt\": \"prompts/global_search_reduce_system_prompt.txt\",\n", - " \"knowledge_prompt\": \"prompts/global_search_knowledge_system_prompt.txt\",\n", - " },\n", - " \"drift_search\": {\n", - " \"prompt\": \"prompts/drift_search_system_prompt.txt\",\n", - " \"reduce_prompt\": \"prompts/drift_search_reduce_prompt.txt\",\n", - " },\n", - " \"basic_search\": {\"prompt\": \"prompts/basic_search_system_prompt.txt\"},\n", - "}\n", - "parameters = create_graphrag_config(config_data, \".\")\n", - "loop = asyncio.get_event_loop()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Multi-index Global Search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n", - "communities = [\n", - " pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n", - "]\n", - "community_reports = [\n", - " pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n", - "]\n", - "\n", - "task = loop.create_task(\n", - " multi_index_global_search(\n", - " parameters,\n", - " entities,\n", - " communities,\n", - " community_reports,\n", - " indexes,\n", - " 1,\n", - " False,\n", - " \"Multiple Paragraphs\",\n", - " False,\n", - " \"Describe this dataset.\",\n", - " )\n", - ")\n", - "results = await task" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Print report" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(results[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Show context links back to original index" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for report_id in [120, 129, 40, 16, 204, 143, 85, 122, 83]:\n", - " index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(report_id, index_name, index_id)\n", - " index_reports = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_community_reports.parquet\"\n", - " )\n", - " print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n", - " print(\n", - " index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n", - " 0\n", - " ]\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Multi-index Local Search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n", - "communities = [\n", - " pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n", - "]\n", - "community_reports = [\n", - " pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n", - "]\n", - "covariates = [\n", - " pd.read_parquet(f\"inputs/{index}/covariates.parquet\") for index in indexes\n", - "]\n", - "text_units = [\n", - " pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n", - "]\n", - "relationships = [\n", - " pd.read_parquet(f\"inputs/{index}/relationships.parquet\") for index in indexes\n", - "]\n", - "\n", - "task = loop.create_task(\n", - " multi_index_local_search(\n", - " parameters,\n", - " entities,\n", - " communities,\n", - " community_reports,\n", - " text_units,\n", - " relationships,\n", - " covariates,\n", - " indexes,\n", - " 1,\n", - " \"Multiple Paragraphs\",\n", - " False,\n", - " \"weather\",\n", - " )\n", - ")\n", - "results = await task" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Print report" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(results[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Show context links back to original index" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for report_id in [47, 213]:\n", - " index_name = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(report_id, index_name, index_id)\n", - " index_reports = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_community_reports.parquet\"\n", - " )\n", - " print([i for i in results[1][\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n", - " print(\n", - " index_reports[index_reports[\"community\"] == int(index_id)][\"title\"].to_numpy()[\n", - " 0\n", - " ]\n", - " )\n", - "for entity_id in [500, 502, 506, 1960, 1961, 1962]:\n", - " index_name = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(entity_id, index_name, index_id)\n", - " index_entities = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_entities.parquet\"\n", - " )\n", - " print(\n", - " [i for i in results[1][\"entities\"] if i[\"id\"] == str(entity_id)][0][ # noqa: RUF015\n", - " \"description\"\n", - " ][:100]\n", - " )\n", - " print(\n", - " index_entities[index_entities[\"human_readable_id\"] == int(index_id)][\n", - " \"description\"\n", - " ].to_numpy()[0][:100]\n", - " )\n", - "for relationship_id in [1805, 1806]:\n", - " index_name = [ # noqa: RUF015\n", - " i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n", - " ][0][\"index_name\"]\n", - " index_id = [ # noqa: RUF015\n", - " i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)\n", - " ][0][\"index_id\"]\n", - " print(relationship_id, index_name, index_id)\n", - " index_relationships = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_relationships.parquet\"\n", - " )\n", - " print(\n", - " [i for i in results[1][\"relationships\"] if i[\"id\"] == str(relationship_id)][0][ # noqa: RUF015\n", - " \"description\"\n", - " ]\n", - " )\n", - " print(\n", - " index_relationships[index_relationships[\"human_readable_id\"] == int(index_id)][\n", - " \"description\"\n", - " ].to_numpy()[0]\n", - " )\n", - "for claim_id in [100]:\n", - " index_name = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(relationship_id, index_name, index_id)\n", - " index_claims = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_covariates.parquet\"\n", - " )\n", - " print(\n", - " [i for i in results[1][\"claims\"] if i[\"id\"] == str(claim_id)][0][\"description\"] # noqa: RUF015\n", - " )\n", - " print(\n", - " index_claims[index_claims[\"human_readable_id\"] == int(index_id)][\n", - " \"description\"\n", - " ].to_numpy()[0]\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Multi-index Drift Search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "entities = [pd.read_parquet(f\"inputs/{index}/entities.parquet\") for index in indexes]\n", - "communities = [\n", - " pd.read_parquet(f\"inputs/{index}/communities.parquet\") for index in indexes\n", - "]\n", - "community_reports = [\n", - " pd.read_parquet(f\"inputs/{index}/community_reports.parquet\") for index in indexes\n", - "]\n", - "text_units = [\n", - " pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n", - "]\n", - "relationships = [\n", - " pd.read_parquet(f\"inputs/{index}/relationships.parquet\") for index in indexes\n", - "]\n", - "\n", - "task = loop.create_task(\n", - " multi_index_drift_search(\n", - " parameters,\n", - " entities,\n", - " communities,\n", - " community_reports,\n", - " text_units,\n", - " relationships,\n", - " indexes,\n", - " 1,\n", - " \"Multiple Paragraphs\",\n", - " False,\n", - " \"agriculture\",\n", - " )\n", - ")\n", - "results = await task" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Print report" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(results[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Show context links back to original index" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for report_id in [47, 236]:\n", - " for question in results[1]:\n", - " resq = results[1][question]\n", - " if len(resq[\"reports\"]) == 0:\n", - " continue\n", - " if len([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)]) == 0:\n", - " continue\n", - " index_name = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(question, report_id, index_name, index_id)\n", - " index_reports = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_community_reports.parquet\"\n", - " )\n", - " print([i for i in resq[\"reports\"] if i[\"id\"] == str(report_id)][0][\"title\"]) # noqa: RUF015\n", - " print(\n", - " index_reports[index_reports[\"community\"] == int(index_id)][\n", - " \"title\"\n", - " ].to_numpy()[0]\n", - " )\n", - " break\n", - "for source_id in [10, 16, 19, 20, 21, 22, 24, 29, 93, 95]:\n", - " for question in results[1]:\n", - " resq = results[1][question]\n", - " if len(resq[\"sources\"]) == 0:\n", - " continue\n", - " if len([i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)]) == 0:\n", - " continue\n", - " index_name = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n", - " \"index_name\"\n", - " ]\n", - " index_id = [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][ # noqa: RUF015\n", - " \"index_id\"\n", - " ]\n", - " print(question, source_id, index_name, index_id)\n", - " index_sources = pd.read_parquet(\n", - " f\"inputs/{index_name}/create_final_text_units.parquet\"\n", - " )\n", - " print(\n", - " [i for i in resq[\"sources\"] if i[\"id\"] == str(source_id)][0][\"text\"][:250] # noqa: RUF015\n", - " )\n", - " print(index_sources.loc[int(index_id)][\"text\"][:250])\n", - " break" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Multi-index Basic Search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "text_units = [\n", - " pd.read_parquet(f\"inputs/{index}/text_units.parquet\") for index in indexes\n", - "]\n", - "\n", - "task = loop.create_task(\n", - " multi_index_basic_search(\n", - " parameters, text_units, indexes, False, \"industry in maryland\"\n", - " )\n", - ")\n", - "results = await task" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Print report" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(results[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Show context links back to original text\n", - "\n", - "Note that original index name is not saved in context data for basic search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for source_id in [0, 1]:\n", - " print(results[1][\"sources\"][source_id][\"text\"][:250])" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs/query/multi_index_search.md b/docs/query/multi_index_search.md deleted file mode 100644 index 6b6ff2b41..000000000 --- a/docs/query/multi_index_search.md +++ /dev/null @@ -1,20 +0,0 @@ -# Multi Index Search ๐Ÿ”Ž - -## Multi Dataset Reasoning - -GraphRAG takes in unstructured data contained in text documents and uses large languages models to โ€œreadโ€ the documents in a targeted fashion and create a knowledge graph. This knowledge graph, or index, contains information about specific entities in the data, how the entities relate to one another, and high-level reports about communities and topics found in the data. Indexes can be searched by users to get meaningful information about the underlying data, including reports with citations that point back to the original unstructured text. - -Multi-index search is a new capability that has been added to the GraphRAG python library to query multiple knowledge stores at once. Multi-index search allows for many new search scenarios, including: - -- Combining knowledge from different domains โ€“ Many documents contain similar types of entities: person, place, thing. But GraphRAG can be tuned for highly specialized domains, such as science and engineering. With the recent updates to search, GraphRAG can now simultaneously query multiple datasets with completely different schemas and entity definitions. - -- Combining knowledge with different access levels โ€“ Not all datasets are accessible to all people, even within an organization. Some datasets are publicly available. Some datasets, such as internal financial information or intellectual property, may only be accessible by a small number of employees at a company. Multi-index search allows multiple sources with different access controls to be queried at the same time, creating more nuanced and informative reports. Internal R&D findings can be seamlessly combined with open-source scientific publications. - -- Combining knowledge in different locations โ€“ With multi-index search, indexes do not need to be in the same location or type of storage to be queried. Indexes in the cloud in Azure Storage can be queried at the same time as indexes stored on a personal computer. Multi-index search makes these types of data joins easy and accessible. - -To search across multiple datasets, the underlying contexts from each index, based on the user query, are combined in-memory at query time, saving on computation and allowing the joint querying of indexes that canโ€™t be joined inherently, either do access controls or differing schemas. Multi-index search automatically keeps track of provenance information, so that any references can be traced back to the correct indexes and correct original documents. - - -## How to Use - -An example of a global search scenario can be found in the following [notebook](../examples_notebooks/multi_index_search.ipynb). \ No newline at end of file From 4b202f2fe997cbf42febccc4feae2d08a83b2656 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 10 Oct 2025 14:50:30 -0700 Subject: [PATCH 07/10] Add missing response_type in basic search API --- graphrag/api/query.py | 4 ++++ graphrag/cli/main.py | 1 + graphrag/cli/query.py | 3 +++ graphrag/query/factory.py | 2 +- 4 files changed, 9 insertions(+), 1 deletion(-) diff --git a/graphrag/api/query.py b/graphrag/api/query.py index f9923675c..e49a0976d 100644 --- a/graphrag/api/query.py +++ b/graphrag/api/query.py @@ -460,6 +460,7 @@ def drift_search_streaming( async def basic_search( config: GraphRagConfig, text_units: pd.DataFrame, + response_type: str, query: str, callbacks: list[QueryCallbacks] | None = None, verbose: bool = False, @@ -497,6 +498,7 @@ def on_context(context: Any) -> None: async for chunk in basic_search_streaming( config=config, text_units=text_units, + response_type=response_type, query=query, callbacks=callbacks, ): @@ -509,6 +511,7 @@ def on_context(context: Any) -> None: def basic_search_streaming( config: GraphRagConfig, text_units: pd.DataFrame, + response_type: str, query: str, callbacks: list[QueryCallbacks] | None = None, verbose: bool = False, @@ -542,6 +545,7 @@ def basic_search_streaming( config=config, text_units=read_indexer_text_units(text_units), text_unit_embeddings=embedding_store, + response_type=response_type, system_prompt=prompt, callbacks=callbacks, ) diff --git a/graphrag/cli/main.py b/graphrag/cli/main.py index bc0e9f39a..7b4d3bb0c 100644 --- a/graphrag/cli/main.py +++ b/graphrag/cli/main.py @@ -537,6 +537,7 @@ def _query_cli( config_filepath=config, data_dir=data, root_dir=root, + response_type=response_type, streaming=streaming, query=query, verbose=verbose, diff --git a/graphrag/cli/query.py b/graphrag/cli/query.py index b7c12d1e8..914196b86 100644 --- a/graphrag/cli/query.py +++ b/graphrag/cli/query.py @@ -298,6 +298,7 @@ def run_basic_search( config_filepath: Path | None, data_dir: Path | None, root_dir: Path, + response_type: str, streaming: bool, query: str, verbose: bool, @@ -337,6 +338,7 @@ def on_context(context: Any) -> None: async for stream_chunk in api.basic_search_streaming( config=config, text_units=text_units, + response_type=response_type, query=query, callbacks=[callbacks], verbose=verbose, @@ -353,6 +355,7 @@ def on_context(context: Any) -> None: api.basic_search( config=config, text_units=text_units, + response_type=response_type, query=query, verbose=verbose, ) diff --git a/graphrag/query/factory.py b/graphrag/query/factory.py index 10dfd3ba1..0c4a2ef0b 100644 --- a/graphrag/query/factory.py +++ b/graphrag/query/factory.py @@ -251,8 +251,8 @@ def get_basic_search_engine( text_units: list[TextUnit], text_unit_embeddings: BaseVectorStore, config: GraphRagConfig, + response_type: str, system_prompt: str | None = None, - response_type: str = "multiple paragraphs", callbacks: list[QueryCallbacks] | None = None, ) -> BasicSearch: """Create a basic search engine based on data + configuration.""" From 063c018c04d0e2ec97719da150fdc12518c78eee Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 10 Oct 2025 15:12:48 -0700 Subject: [PATCH 08/10] Fix basic search context and id mapping --- .../prompts/query/basic_search_system_prompt.py | 4 ++-- .../basic_search/basic_context.py | 15 +++------------ 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/graphrag/prompts/query/basic_search_system_prompt.py b/graphrag/prompts/query/basic_search_system_prompt.py index a20fb6ad1..bc37c70d9 100644 --- a/graphrag/prompts/query/basic_search_system_prompt.py +++ b/graphrag/prompts/query/basic_search_system_prompt.py @@ -27,7 +27,7 @@ "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]" -where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables. +where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "id" column in the provided tables. Do not include information where the supporting evidence for it is not provided. @@ -60,7 +60,7 @@ "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]" -where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables. +where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "id" column in the provided tables. Do not include information where the supporting evidence for it is not provided. diff --git a/graphrag/query/structured_search/basic_search/basic_context.py b/graphrag/query/structured_search/basic_search/basic_context.py index 215b26a31..b7390017f 100644 --- a/graphrag/query/structured_search/basic_search/basic_context.py +++ b/graphrag/query/structured_search/basic_search/basic_context.py @@ -38,7 +38,6 @@ def __init__( self.text_units = text_units self.text_unit_embeddings = text_unit_embeddings self.embedding_vectorstore_key = embedding_vectorstore_key - self.text_id_map = self._map_ids() def build_context( self, @@ -48,7 +47,7 @@ def build_context( max_context_tokens: int = 12_000, context_name: str = "Sources", column_delimiter: str = "|", - text_id_col: str = "source_id", + text_id_col: str = "id", text_col: str = "text", **kwargs, ) -> ContextBuilderResult: @@ -63,7 +62,7 @@ def build_context( text_unit_ids = {t.document.id for t in related_texts} text_units_filtered = [] text_units_filtered = [ - {text_id_col: t.id, text_col: t.text} + {text_id_col: t.short_id, text_col: t.text} for t in self.text_units or [] if t.id in text_unit_ids ] @@ -102,13 +101,5 @@ def build_context( return ContextBuilderResult( context_chunks=final_text, - context_records={context_name: final_text_df}, + context_records={context_name.lower(): final_text_df}, ) - - def _map_ids(self) -> dict[str, str]: - """Map id to short id in the text units.""" - id_map = {} - text_units = self.text_units or [] - for unit in text_units: - id_map[unit.id] = unit.short_id - return id_map From 2ffdc600db135e52a0afb3851638c5125f3ea3f0 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 10 Oct 2025 15:48:47 -0700 Subject: [PATCH 09/10] Fix v1 migration notebook --- .../index_migration_to_v1.ipynb | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/docs/examples_notebooks/index_migration_to_v1.ipynb b/docs/examples_notebooks/index_migration_to_v1.ipynb index 581f5cef6..4a89d9530 100644 --- a/docs/examples_notebooks/index_migration_to_v1.ipynb +++ b/docs/examples_notebooks/index_migration_to_v1.ipynb @@ -204,14 +204,13 @@ "source": [ "from graphrag.cache.factory import CacheFactory\n", "from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n", - "from graphrag.config.get_vector_store_settings import get_vector_store_settings\n", "from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings\n", + "from graphrag.language_model.manager import ModelManager\n", + "from graphrag.tokenizer.get_tokenizer import get_tokenizer\n", "\n", "# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n", "# We'll construct the context and run this function flow directly to avoid everything else\n", "\n", - "\n", - "vector_store_config = get_vector_store_settings(config)\n", "model_config = config.get_language_model_config(config.embed_text.model_id)\n", "callbacks = NoopWorkflowCallbacks()\n", "cache_config = config.cache.model_dump() # type: ignore\n", @@ -219,6 +218,15 @@ " cache_type=cache_config[\"type\"], # type: ignore\n", " **cache_config,\n", ")\n", + "model = ModelManager().get_or_create_embedding_model(\n", + " name=\"text_embedding\",\n", + " model_type=model_config.type,\n", + " config=model_config,\n", + " callbacks=callbacks,\n", + " cache=cache,\n", + ")\n", + "\n", + "tokenizer = get_tokenizer(model_config)\n", "\n", "await generate_text_embeddings(\n", " documents=None,\n", @@ -227,11 +235,12 @@ " entities=final_entities,\n", " community_reports=final_community_reports,\n", " callbacks=callbacks,\n", - " cache=cache,\n", - " model_config=model_config,\n", + " model=model,\n", + " tokenizer=tokenizer,\n", " batch_size=config.embed_text.batch_size,\n", " batch_max_tokens=config.embed_text.batch_max_tokens,\n", - " vector_store_config=vector_store_config,\n", + " num_threads=model_config.concurrent_requests,\n", + " vector_store_config=config.vector_store,\n", " embedded_fields=config.embed_text.names,\n", ")" ] From eb8e63fe9d0e42323da7392e2e3738737235bb9d Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 10 Oct 2025 16:02:44 -0700 Subject: [PATCH 10/10] Fix query entity search tests --- .../context_builder/test_entity_extraction.py | 50 ------------------- 1 file changed, 50 deletions(-) diff --git a/tests/unit/query/context_builder/test_entity_extraction.py b/tests/unit/query/context_builder/test_entity_extraction.py index aa2ff29b0..a0b7b77c5 100644 --- a/tests/unit/query/context_builder/test_entity_extraction.py +++ b/tests/unit/query/context_builder/test_entity_extraction.py @@ -55,9 +55,6 @@ def similarity_search_by_text( key=lambda x: x.score, )[:k] - def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: - return [document for document in self.documents if document.id in include_ids] - def search_by_id(self, id: str) -> VectorStoreDocument: result = self.documents[0] result.id = id @@ -92,27 +89,6 @@ def test_map_query_to_entities(): ), ] - assert map_query_to_entities( - query="t22", - text_embedding_vectorstore=MockBaseVectorStore([ - VectorStoreDocument(id=entity.id, vector=None) for entity in entities - ]), - text_embedder=ModelManager().get_or_create_embedding_model( - model_type="mock_embedding", name="mock" - ), - all_entities_dict={entity.id: entity for entity in entities}, - embedding_vectorstore_key=EntityVectorStoreKey.ID, - k=1, - oversample_scaler=1, - ) == [ - Entity( - id="c4f93564-4507-4ee4-b102-98add401a965", - short_id="sid2", - title="t22", - rank=4, - ) - ] - assert map_query_to_entities( query="t22", text_embedding_vectorstore=MockBaseVectorStore([ @@ -134,32 +110,6 @@ def test_map_query_to_entities(): ) ] - assert map_query_to_entities( - query="", - text_embedding_vectorstore=MockBaseVectorStore([ - VectorStoreDocument(id=entity.id, vector=None) for entity in entities - ]), - text_embedder=ModelManager().get_or_create_embedding_model( - model_type="mock_embedding", name="mock" - ), - all_entities_dict={entity.id: entity for entity in entities}, - embedding_vectorstore_key=EntityVectorStoreKey.ID, - k=2, - ) == [ - Entity( - id="c4f93564-4507-4ee4-b102-98add401a965", - short_id="sid2", - title="t22", - rank=4, - ), - Entity( - id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83", - short_id="sid4", - title="t4444", - rank=3, - ), - ] - assert map_query_to_entities( query="", text_embedding_vectorstore=MockBaseVectorStore([