Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions packages/graphrag/graphrag/query/context_builder/local_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ def build_covariates_context(
current_tokens = tokenizer.num_tokens(current_context_text)

all_context_records = [header]
# Build index dict for O(1) lookups instead of scanning all covariates per entity
cov_by_subject: dict[str, list[Covariate]] = defaultdict(list)
for cov in covariates:
cov_by_subject[cov.subject_id].append(cov)
for entity in selected_entities:
selected_covariates.extend([
cov for cov in covariates if cov.subject_id == entity.title
])
selected_covariates.extend(cov_by_subject.get(entity.title, []))

for covariate in selected_covariates:
new_context = [
Expand Down Expand Up @@ -255,7 +257,7 @@ def _filter_relationships(

# within out-of-network relationships, prioritize mutual relationships
# (i.e. relationships with out-network entities that are shared with multiple selected entities)
selected_entity_names = [entity.title for entity in selected_entities]
selected_entity_names = {entity.title for entity in selected_entities}
out_network_source_names = [
relationship.source
for relationship in out_network_relationships
Expand All @@ -269,19 +271,19 @@ def _filter_relationships(
out_network_entity_names = list(
set(out_network_source_names + out_network_target_names)
)

# Build index dicts for O(1) lookups instead of scanning all relationships per entity
by_source = defaultdict(list)
by_target = defaultdict(list)
for rel in out_network_relationships:
by_source[rel.source].append(rel.target)
by_target[rel.target].append(rel.source)

out_network_entity_links = defaultdict(int)
for entity_name in out_network_entity_names:
targets = [
relationship.target
for relationship in out_network_relationships
if relationship.source == entity_name
]
sources = [
relationship.source
for relationship in out_network_relationships
if relationship.target == entity_name
]
out_network_entity_links[entity_name] = len(set(targets + sources))
out_network_entity_links[entity_name] = len(
set(by_source.get(entity_name, []) + by_target.get(entity_name, []))
)

# sort out-network relationships by number of links and rank_attributes
for rel in out_network_relationships:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ def get_candidate_communities(
selected_community_ids = [
entity.community_ids for entity in selected_entities if entity.community_ids
]
selected_community_ids = [
selected_community_ids_set = {
item for sublist in selected_community_ids for item in sublist
]
}
selected_reports = [
community
for community in community_reports
if community.id in selected_community_ids
if community.id in selected_community_ids_set
]
return to_community_report_dataframe(
reports=selected_reports,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_candidate_covariates(
covariates: list[Covariate],
) -> list[Covariate]:
"""Get all covariates that are related to selected entities."""
selected_entity_names = [entity.title for entity in selected_entities]
selected_entity_names = {entity.title for entity in selected_entities}
return [
covariate
for covariate in covariates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_in_network_relationships(
ranking_attribute: str = "rank",
) -> list[Relationship]:
"""Get all directed relationships between selected entities, sorted by ranking_attribute."""
selected_entity_names = [entity.title for entity in selected_entities]
selected_entity_names = {entity.title for entity in selected_entities}
selected_relationships = [
relationship
for relationship in relationships
Expand All @@ -37,7 +37,7 @@ def get_out_network_relationships(
ranking_attribute: str = "rank",
) -> list[Relationship]:
"""Get relationships from selected entities to other entities that are not within the selected entities, sorted by ranking_attribute."""
selected_entity_names = [entity.title for entity in selected_entities]
selected_entity_names = {entity.title for entity in selected_entities}
source_relationships = [
relationship
for relationship in relationships
Expand All @@ -59,7 +59,7 @@ def get_candidate_relationships(
relationships: list[Relationship],
) -> list[Relationship]:
"""Get all relationships that are associated with the selected entities."""
selected_entity_names = [entity.title for entity in selected_entities]
selected_entity_names = {entity.title for entity in selected_entities}
return [
relationship
for relationship in relationships
Expand All @@ -72,9 +72,9 @@ def get_entities_from_relationships(
relationships: list[Relationship], entities: list[Entity]
) -> list[Entity]:
"""Get all entities that are associated with the selected relationships."""
selected_entity_names = [relationship.source for relationship in relationships] + [
selected_entity_names = {relationship.source for relationship in relationships} | {
relationship.target for relationship in relationships
]
}
return [entity for entity in entities if entity.title in selected_entity_names]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def get_candidate_text_units(
selected_text_ids = [
entity.text_unit_ids for entity in selected_entities if entity.text_unit_ids
]
selected_text_ids = [item for sublist in selected_text_ids for item in sublist]
selected_text_units = [unit for unit in text_units if unit.id in selected_text_ids]
selected_text_ids_set = {item for sublist in selected_text_ids for item in sublist}
selected_text_units = [unit for unit in text_units if unit.id in selected_text_ids_set]
return to_text_unit_dataframe(selected_text_units)


Expand Down