Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/major-20251009203808375389.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "major",
"description": "Simplify internal args with stronger types and firmer boundaries."
}
35 changes: 17 additions & 18 deletions docs/examples_notebooks/index_migration_to_v1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -202,45 +202,44 @@
"metadata": {},
"outputs": [],
"source": [
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
"\n",
"from graphrag.cache.factory import CacheFactory\n",
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
"from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings\n",
"from graphrag.config.get_vector_store_settings import get_vector_store_settings\n",
"from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings\n",
"\n",
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
"# We'll construct the context and run this function flow directly to avoid everything else\n",
"\n",
"\n",
"embedded_fields = get_embedded_fields(config)\n",
"text_embed = get_embedding_settings(config)\n",
"vector_store_config = get_vector_store_settings(config)\n",
"model_config = config.get_language_model_config(config.embed_text.model_id)\n",
"callbacks = NoopWorkflowCallbacks()\n",
"cache_config = config.cache.model_dump() # type: ignore\n",
"cache = CacheFactory().create_cache(\n",
" cache_type=cache_config[\"type\"], # type: ignore\n",
" root_dir=PROJECT_DIRECTORY,\n",
" kwargs=cache_config,\n",
" **cache_config,\n",
")\n",
"\n",
"await generate_text_embeddings(\n",
" final_documents=None,\n",
" final_relationships=None,\n",
" final_text_units=final_text_units,\n",
" final_entities=final_entities,\n",
" final_community_reports=final_community_reports,\n",
" documents=None,\n",
" relationships=None,\n",
" text_units=final_text_units,\n",
" entities=final_entities,\n",
" community_reports=final_community_reports,\n",
" callbacks=callbacks,\n",
" cache=cache,\n",
" storage=storage,\n",
" text_embed_config=text_embed,\n",
" embedded_fields=embedded_fields,\n",
" snapshot_embeddings_enabled=False,\n",
" model_config=model_config,\n",
" batch_size=config.embed_text.batch_size,\n",
" batch_max_tokens=config.embed_text.batch_max_tokens,\n",
" vector_store_config=vector_store_config,\n",
" embedded_fields=config.embed_text.names,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
Expand All @@ -254,7 +253,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.10"
}
},
"nbformat": 4,
Expand Down
5 changes: 2 additions & 3 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
ENCODING_MODEL = "o200k_base"
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"

DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]

DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = {
"native": NativeRetry,
Expand Down Expand Up @@ -125,7 +126,6 @@ class CommunityReportDefaults:
text_prompt: None = None
max_length: int = 2000
max_input_length: int = 8000
strategy: None = None
model_id: str = DEFAULT_CHAT_MODEL_ID


Expand Down Expand Up @@ -162,10 +162,9 @@ class DriftSearchDefaults:
class EmbedTextDefaults:
"""Default values for embedding text."""

model: str = "text-embedding-3-small"
model_id: str = DEFAULT_EMBEDDING_MODEL_ID
batch_size: int = 16
batch_max_tokens: int = 8191
model_id: str = DEFAULT_EMBEDDING_MODEL_ID
names: list[str] = field(default_factory=lambda: default_embeddings)
strategy: None = None
vector_store_id: str = DEFAULT_VECTOR_STORE_ID
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A module containing get_embedding_settings."""
"""A module containing get_vector_store_settings."""

from graphrag.config.models.graph_rag_config import GraphRagConfig


def get_embedding_settings(
def get_vector_store_settings(
settings: GraphRagConfig,
vector_store_params: dict | None = None,
) -> dict:
"""Transform GraphRAG config into settings for workflows."""
embeddings_llm_settings = settings.get_language_model_config(
settings.embed_text.model_id
)
vector_store_settings = settings.get_vector_store_config(
settings.embed_text.vector_store_id
).model_dump()
Expand All @@ -23,16 +20,7 @@ def get_embedding_settings(
# settings.vector_store.base contains connection information, or may be undefined
# settings.vector_store.<vector_name> contains the specific settings for this embedding
#
strategy = settings.embed_text.resolved_strategy(
embeddings_llm_settings
) # get the default strategy
strategy.update({
"vector_store": {
**(vector_store_params or {}),
**(vector_store_settings),
}
}) # update the default strategy with the vector store settings
# This ensures the vector store config is part of the strategy and not the global config
return {
"strategy": strategy,
**(vector_store_params or {}),
**(vector_store_settings),
}
3 changes: 0 additions & 3 deletions graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
# api_version: 2024-05-01-preview
model_supports_json: true # recommended if this is available for your model.
concurrent_requests: {language_model_defaults.concurrent_requests}
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: {language_model_defaults.retry_strategy}
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: null
Expand All @@ -42,7 +41,6 @@
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview
concurrent_requests: {language_model_defaults.concurrent_requests}
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: {language_model_defaults.retry_strategy}
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: null
Expand Down Expand Up @@ -102,7 +100,6 @@
extract_graph_nlp:
text_analyzer:
extractor_type: {graphrag_config_defaults.extract_graph_nlp.text_analyzer.extractor_type.value} # [regex_english, syntactic_parser, cfg]
async_mode: {graphrag_config_defaults.extract_graph_nlp.async_mode.value} # or asyncio

cluster_graph:
max_cluster_size: {graphrag_config_defaults.cluster_graph.max_cluster_size}
Expand Down
46 changes: 21 additions & 25 deletions graphrag/config/models/community_reports_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,24 @@

"""Parameterization settings for the default configuration."""

from dataclasses import dataclass
from pathlib import Path

from pydantic import BaseModel, Field

from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT
from graphrag.prompts.index.community_report_text_units import (
COMMUNITY_REPORT_TEXT_PROMPT,
)


@dataclass
class CommunityReportPrompts:
"""Community report prompt templates."""

graph_prompt: str
text_prompt: str


class CommunityReportsConfig(BaseModel):
Expand All @@ -34,32 +46,16 @@ class CommunityReportsConfig(BaseModel):
description="The maximum input length in tokens to use when generating reports.",
default=graphrag_config_defaults.community_reports.max_input_length,
)
strategy: dict | None = Field(
description="The override strategy to use.",
default=graphrag_config_defaults.community_reports.strategy,
)

def resolved_strategy(
self, root_dir: str, model_config: LanguageModelConfig
) -> dict:
"""Get the resolved community report extraction strategy."""
from graphrag.index.operations.summarize_communities.typing import (
CreateCommunityReportsStrategyType,
)

return self.strategy or {
"type": CreateCommunityReportsStrategyType.graph_intelligence,
"llm": model_config.model_dump(),
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
def resolved_prompts(self, root_dir: str) -> CommunityReportPrompts:
"""Get the resolved community report extraction prompts."""
return CommunityReportPrompts(
graph_prompt=(Path(root_dir) / self.graph_prompt).read_text(
encoding="utf-8"
)
if self.graph_prompt
else None,
"text_prompt": (Path(root_dir) / self.text_prompt).read_text(
encoding="utf-8"
)
else COMMUNITY_REPORT_PROMPT,
text_prompt=(Path(root_dir) / self.text_prompt).read_text(encoding="utf-8")
if self.text_prompt
else None,
"max_report_length": self.max_length,
"max_input_length": self.max_input_length,
}
else COMMUNITY_REPORT_TEXT_PROMPT,
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from pydantic import BaseModel, Field

from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.models.language_model_config import LanguageModelConfig


class TextEmbeddingConfig(BaseModel):
class EmbedTextConfig(BaseModel):
"""Configuration section for text embeddings."""

model_id: str = Field(
Expand All @@ -32,21 +31,3 @@ class TextEmbeddingConfig(BaseModel):
description="The specific embeddings to perform.",
default=graphrag_config_defaults.embed_text.names,
)
strategy: dict | None = Field(
description="The override strategy to use.",
default=graphrag_config_defaults.embed_text.strategy,
)

def resolved_strategy(self, model_config: LanguageModelConfig) -> dict:
"""Get the resolved text embedding strategy."""
from graphrag.index.operations.embed_text.embed_text import (
TextEmbedStrategyType,
)

return self.strategy or {
"type": TextEmbedStrategyType.openai,
"llm": model_config.model_dump(),
"num_threads": model_config.concurrent_requests,
"batch_size": self.batch_size,
"batch_max_tokens": self.batch_max_tokens,
}
35 changes: 16 additions & 19 deletions graphrag/config/models/extract_claims_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,23 @@

"""Parameterization settings for the default configuration."""

from dataclasses import dataclass
from pathlib import Path

from pydantic import BaseModel, Field

from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.prompts.index.extract_claims import EXTRACT_CLAIMS_PROMPT


class ClaimExtractionConfig(BaseModel):
@dataclass
class ClaimExtractionPrompts:
"""Claim extraction prompt templates."""

extraction_prompt: str


class ExtractClaimsConfig(BaseModel):
"""Configuration section for claim extraction."""

enabled: bool = Field(
Expand All @@ -34,22 +42,11 @@ class ClaimExtractionConfig(BaseModel):
description="The maximum number of entity gleanings to use.",
default=graphrag_config_defaults.extract_claims.max_gleanings,
)
strategy: dict | None = Field(
description="The override strategy to use.",
default=graphrag_config_defaults.extract_claims.strategy,
)

def resolved_strategy(
self, root_dir: str, model_config: LanguageModelConfig
) -> dict:
"""Get the resolved claim extraction strategy."""
return self.strategy or {
"llm": model_config.model_dump(),
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
encoding="utf-8"
)
def resolved_prompts(self, root_dir: str) -> ClaimExtractionPrompts:
"""Get the resolved claim extraction prompts."""
return ClaimExtractionPrompts(
extraction_prompt=(Path(root_dir) / self.prompt).read_text(encoding="utf-8")
if self.prompt
else None,
"claim_description": self.description,
"max_gleanings": self.max_gleanings,
}
else EXTRACT_CLAIMS_PROMPT,
)
37 changes: 15 additions & 22 deletions graphrag/config/models/extract_graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@

"""Parameterization settings for the default configuration."""

from dataclasses import dataclass
from pathlib import Path

from pydantic import BaseModel, Field

from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT


@dataclass
class ExtractGraphPrompts:
"""Graph extraction prompt templates."""

extraction_prompt: str


class ExtractGraphConfig(BaseModel):
Expand All @@ -30,26 +38,11 @@ class ExtractGraphConfig(BaseModel):
description="The maximum number of entity gleanings to use.",
default=graphrag_config_defaults.extract_graph.max_gleanings,
)
strategy: dict | None = Field(
description="Override the default entity extraction strategy",
default=graphrag_config_defaults.extract_graph.strategy,
)

def resolved_strategy(
self, root_dir: str, model_config: LanguageModelConfig
) -> dict:
"""Get the resolved entity extraction strategy."""
from graphrag.index.operations.extract_graph.typing import (
ExtractEntityStrategyType,
)

return self.strategy or {
"type": ExtractEntityStrategyType.graph_intelligence,
"llm": model_config.model_dump(),
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
encoding="utf-8"
)
def resolved_prompts(self, root_dir: str) -> ExtractGraphPrompts:
"""Get the resolved graph extraction prompts."""
return ExtractGraphPrompts(
extraction_prompt=(Path(root_dir) / self.prompt).read_text(encoding="utf-8")
if self.prompt
else None,
"max_gleanings": self.max_gleanings,
}
else GRAPH_EXTRACTION_PROMPT,
)
Loading
Loading