Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions packages/graphrag/graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,16 @@ class SnapshotsDefaults:
raw_graph: bool = False


@dataclass
class EntityResolutionDefaults:
"""Default values for entity resolution."""

enabled: bool = False
prompt: None = None
completion_model_id: str = DEFAULT_COMPLETION_MODEL_ID
model_instance_name: str = "entity_resolution"


@dataclass
class SummarizeDescriptionsDefaults:
"""Default values for summarizing descriptions."""
Expand Down Expand Up @@ -359,6 +369,9 @@ class GraphRagConfigDefaults:
chunking: ChunkingDefaults = field(default_factory=ChunkingDefaults)
snapshots: SnapshotsDefaults = field(default_factory=SnapshotsDefaults)
extract_graph: ExtractGraphDefaults = field(default_factory=ExtractGraphDefaults)
entity_resolution: EntityResolutionDefaults = field(
default_factory=EntityResolutionDefaults
)
extract_graph_nlp: ExtractGraphNLPDefaults = field(
default_factory=ExtractGraphNLPDefaults
)
Expand Down
4 changes: 4 additions & 0 deletions packages/graphrag/graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@
entity_types: [{",".join(graphrag_config_defaults.extract_graph.entity_types)}]
max_gleanings: {graphrag_config_defaults.extract_graph.max_gleanings}

entity_resolution:
enabled: {graphrag_config_defaults.entity_resolution.enabled}
completion_model_id: {graphrag_config_defaults.entity_resolution.completion_model_id}

summarize_descriptions:
completion_model_id: {graphrag_config_defaults.summarize_descriptions.completion_model_id}
prompt: "prompts/summarize_descriptions.txt"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Parameterization settings for entity resolution."""

from dataclasses import dataclass
from pathlib import Path

from pydantic import BaseModel, Field

from graphrag.config.defaults import graphrag_config_defaults
from graphrag.prompts.index.entity_resolution import ENTITY_RESOLUTION_PROMPT


@dataclass
class EntityResolutionPrompts:
"""Entity resolution prompt templates."""

resolution_prompt: str


class EntityResolutionConfig(BaseModel):
"""Configuration section for entity resolution."""

enabled: bool = Field(
description="Whether to enable LLM-based entity resolution.",
default=graphrag_config_defaults.entity_resolution.enabled,
)
completion_model_id: str = Field(
description="The model ID to use for entity resolution.",
default=graphrag_config_defaults.entity_resolution.completion_model_id,
)
model_instance_name: str = Field(
description="The model singleton instance name. This primarily affects the cache storage partitioning.",
default=graphrag_config_defaults.entity_resolution.model_instance_name,
)
prompt: str | None = Field(
description="The entity resolution prompt to use.",
default=graphrag_config_defaults.entity_resolution.prompt,
)

def resolved_prompts(self) -> EntityResolutionPrompts:
"""Get the resolved entity resolution prompts."""
return EntityResolutionPrompts(
resolution_prompt=Path(self.prompt).read_text(encoding="utf-8")
if self.prompt
else ENTITY_RESOLUTION_PROMPT,
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from graphrag.config.models.community_reports_config import CommunityReportsConfig
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.embed_text_config import EmbedTextConfig
from graphrag.config.models.entity_resolution_config import EntityResolutionConfig
from graphrag.config.models.extract_claims_config import ExtractClaimsConfig
from graphrag.config.models.extract_graph_config import ExtractGraphConfig
from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig
Expand Down Expand Up @@ -186,6 +187,12 @@ def _validate_reporting_base_dir(self) -> None:
)
"""The entity extraction configuration to use."""

entity_resolution: EntityResolutionConfig = Field(
description="The entity resolution configuration to use.",
default=EntityResolutionConfig(),
)
"""The entity resolution configuration to use."""

summarize_descriptions: SummarizeDescriptionsConfig = Field(
description="The description summarization configuration to use.",
default=SummarizeDescriptionsConfig(),
Expand Down
2 changes: 2 additions & 0 deletions packages/graphrag/graphrag/data_model/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
NODE_DEGREE = "degree"
NODE_FREQUENCY = "frequency"
NODE_DETAILS = "node_details"
ALTERNATIVE_NAMES = "alternative_names"

# POST-PREP EDGE TABLE SCHEMA
EDGE_SOURCE = "source"
Expand Down Expand Up @@ -73,6 +74,7 @@
TITLE,
TYPE,
DESCRIPTION,
ALTERNATIVE_NAMES,
TEXT_UNIT_IDS,
NODE_FREQUENCY,
NODE_DEGREE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def finalize_entities(
final_entities["id"] = final_entities["human_readable_id"].apply(
lambda _x: str(uuid4())
)
# Ensure alternative_names column exists (empty when resolution is disabled)
if "alternative_names" not in final_entities.columns:
final_entities["alternative_names"] = [[] for _ in range(len(final_entities))]
return final_entities.loc[
:,
ENTITIES_FINAL_COLUMNS,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Entity resolution operation package."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""LLM-based entity resolution operation.

Identifies entities with different surface forms that refer to the same
real-world entity (e.g. "Ahab" and "Captain Ahab") and unifies their titles.
"""

import logging
from typing import TYPE_CHECKING

import pandas as pd

from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks

if TYPE_CHECKING:
from graphrag_llm.completion import LLMCompletion

logger = logging.getLogger(__name__)


async def resolve_entities(
entities: pd.DataFrame,
relationships: pd.DataFrame,
callbacks: WorkflowCallbacks,
model: "LLMCompletion",
prompt: str,
num_threads: int,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Identify and merge duplicate entities with different surface forms.

Sends all unique entity titles to the LLM in a single call, parses the
response to build a rename mapping, then applies it to entity titles and
relationship source/target columns. Each canonical entity receives an
``alternative_names`` column listing all of its aliases.

Parameters
----------
entities : pd.DataFrame
Entity DataFrame with at least a ``title`` column.
relationships : pd.DataFrame
Relationship DataFrame with ``source`` and ``target`` columns.
callbacks : WorkflowCallbacks
Progress callbacks.
model : LLMCompletion
The LLM completion model to use.
prompt : str
The entity resolution prompt template (must contain ``{entity_list}``).
num_threads : int
Concurrency limit for LLM calls (reserved for future use).

Returns
-------
tuple[pd.DataFrame, pd.DataFrame]
Updated ``(entities, relationships)`` with unified titles and an
``alternative_names`` column on entities.
"""
if "title" not in entities.columns:
return entities, relationships

titles = entities["title"].dropna().unique().tolist()
if len(titles) < 2:
return entities, relationships

logger.info(
"Running LLM entity resolution on %d unique entity names...", len(titles)
)

# Build numbered entity list for the prompt
entity_list = "\n".join(f"{i+1}. {name}" for i, name in enumerate(titles))
formatted_prompt = prompt.format(entity_list=entity_list)

try:
response = await model.completion_async(messages=formatted_prompt)
raw = (response.content or "").strip()
except Exception as e:
logger.warning("Entity resolution LLM call failed, skipping resolution: %s", e, exc_info=True)
return entities, relationships

if "NO_DUPLICATES" in raw:
logger.info("Entity resolution: no duplicates found")
return entities, relationships

# Parse response and build rename mapping
rename_map: dict[str, str] = {} # alias → canonical
alternatives: dict[str, set[str]] = {} # canonical → {aliases}

for line in raw.splitlines():
line = line.strip()
if not line or line.startswith("#") or line.startswith("Where"):
continue
parts = [p.strip() for p in line.split(",")]
indices: list[int] = []
for p in parts:
digits = "".join(c for c in p if c.isdigit())
if digits:
idx = int(digits) - 1 # 1-indexed → 0-indexed
if 0 <= idx < len(titles):
indices.append(idx)
if len(indices) >= 2:
canonical = titles[indices[0]]
if canonical not in alternatives:
alternatives[canonical] = set()
for alias_idx in indices[1:]:
alias = titles[alias_idx]
rename_map[alias] = canonical
alternatives[canonical].add(alias)
logger.info(" Entity resolution: '%s' → '%s'", alias, canonical)

if not rename_map:
logger.info("Entity resolution complete: no duplicates found")
return entities, relationships

logger.info("Entity resolution: merging %d duplicate names", len(rename_map))

# Apply renames to entity titles
entities = entities.copy()
entities["title"] = entities["title"].map(lambda t: rename_map.get(t, t))

# Add alternative_names column
entities["alternative_names"] = entities["title"].map(
lambda t: sorted(alternatives.get(t, set()))
)

# Apply renames to relationship source/target
if not relationships.empty:
relationships = relationships.copy()
if "source" in relationships.columns:
relationships["source"] = relationships["source"].map(
lambda s: rename_map.get(s, s)
)
if "target" in relationships.columns:
relationships["target"] = relationships["target"].map(
lambda t: rename_map.get(t, t)
)

return entities, relationships
8 changes: 8 additions & 0 deletions packages/graphrag/graphrag/index/update/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def _group_and_resolve_entities(
delta_entities_df["human_readable_id"] = np.arange(
initial_id, initial_id + len(delta_entities_df)
)
# Ensure alternative_names column exists (may be absent in older indexes)
for df in [old_entities_df, delta_entities_df]:
if "alternative_names" not in df.columns:
df["alternative_names"] = [[] for _ in range(len(df))]

# Concat A and B
combined = pd.concat(
[old_entities_df, delta_entities_df], ignore_index=True, copy=False
Expand All @@ -60,6 +65,9 @@ def _group_and_resolve_entities(
"description": lambda x: list(x.astype(str)), # Ensure str
# Concatenate nd.array into a single list
"text_unit_ids": lambda x: list(itertools.chain(*x.tolist())),
"alternative_names": lambda x: sorted(
set(itertools.chain(*x.tolist()))
),
"degree": "first", # todo: we could probably re-compute this with the entire new graph
})
.reset_index()
Expand Down
42 changes: 41 additions & 1 deletion packages/graphrag/graphrag/index/workflows/extract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from graphrag.index.operations.extract_graph.extract_graph import (
extract_graph as extractor,
)
from graphrag.index.operations.resolve_entities.resolve_entities import (
resolve_entities,
)
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
summarize_descriptions,
)
Expand Down Expand Up @@ -58,6 +61,24 @@ async def run_workflow(
cache_key_creator=cache_key_creator,
)

# Entity resolution model (optional)
resolution_enabled = config.entity_resolution.enabled
resolution_model = None
resolution_prompt = ""
if resolution_enabled:
resolution_model_config = config.get_completion_model_config(
config.entity_resolution.completion_model_id
)
resolution_prompts = config.entity_resolution.resolved_prompts()
resolution_prompt = resolution_prompts.resolution_prompt
resolution_model = create_completion(
resolution_model_config,
cache=context.cache.child(
config.entity_resolution.model_instance_name
),
cache_key_creator=cache_key_creator,
)

entities, relationships, raw_entities, raw_relationships = await extract_graph(
text_units=text_units,
callbacks=context.callbacks,
Expand All @@ -72,6 +93,10 @@ async def run_workflow(
max_input_tokens=config.summarize_descriptions.max_input_tokens,
summarization_prompt=summarization_prompts.summarize_prompt,
summarization_num_threads=config.concurrent_requests,
resolution_enabled=resolution_enabled,
resolution_model=resolution_model,
resolution_prompt=resolution_prompt,
resolution_num_threads=config.concurrent_requests,
)

await context.output_table_provider.write_dataframe("entities", entities)
Expand Down Expand Up @@ -108,6 +133,10 @@ async def extract_graph(
max_input_tokens: int,
summarization_prompt: str,
summarization_num_threads: int,
resolution_enabled: bool = False,
resolution_model: "LLMCompletion | None" = None,
resolution_prompt: str = "",
resolution_num_threads: int = 1,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
Expand Down Expand Up @@ -136,10 +165,21 @@ async def extract_graph(
logger.error(error_msg)
raise ValueError(error_msg)

# copy these as is before any summarization
# copy these as is before any resolution or summarization
raw_entities = extracted_entities.copy()
raw_relationships = extracted_relationships.copy()

# Resolve duplicate entity names before grouping by title
if resolution_enabled and resolution_model is not None:
extracted_entities, extracted_relationships = await resolve_entities(
entities=extracted_entities,
relationships=extracted_relationships,
callbacks=callbacks,
model=resolution_model,
prompt=resolution_prompt,
num_threads=resolution_num_threads,
)

entities, relationships = await get_summarized_entities_relationships(
extracted_entities=extracted_entities,
extracted_relationships=extracted_relationships,
Expand Down
Loading