Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 48 additions & 26 deletions application/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,57 +567,76 @@ def gap_analysis(self, name_1, name_2):
denylist = ["Cross-cutting concerns"]
from datetime import datetime

t1 = datetime.now()

path_records_all, _ = db.cypher_query(
# Tier 1: Strong Links (LINKED_TO, SAME, AUTOMATICALLY_LINKED_TO)
path_records, _ = db.cypher_query(
"""
MATCH (BaseStandard:NeoStandard {name: $name1})
MATCH (CompareStandard:NeoStandard {name: $name2})
MATCH p = allShortestPaths((BaseStandard)-[*..20]-(CompareStandard))
MATCH p = allShortestPaths((BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|SAME)*..20]-(CompareStandard))
WITH p
WHERE length(p) > 1 AND ALL (n in NODES(p) where (n:NeoCRE or n = BaseStandard or n = CompareStandard) AND NOT n.name in $denylist)
WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE (n:NeoCRE or n = BaseStandard or n = CompareStandard) AND NOT n.name in $denylist)
RETURN p
""",
# """
# OPTIONAL MATCH (BaseStandard:NeoStandard {name: $name1})
# OPTIONAL MATCH (CompareStandard:NeoStandard {name: $name2})
# OPTIONAL MATCH p = allShortestPaths((BaseStandard)-[*..20]-(CompareStandard))
# WITH p
# WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE (n:NeoCRE or n = BaseStandard or n = CompareStandard) AND NOT n.name in $denylist)
# RETURN p
# """,
{"name1": name_1, "name2": name_2, "denylist": denylist},
resolve_objects=True,
)
t2 = datetime.now()

# If strict strong links found, return early (Pruning)
if path_records and len(path_records) > 0:
logger.info(
f"Gap Analysis: Tier 1 (Strong) found {len(path_records)} paths. Pruning remainder."
)
# Helper to format and return
return self._format_gap_analysis_response(base_standard, path_records)

# Tier 2: Medium Links (Add CONTAINS to the mix)
path_records, _ = db.cypher_query(
"""
MATCH (BaseStandard:NeoStandard {name: $name1})
MATCH (CompareStandard:NeoStandard {name: $name2})
MATCH p = allShortestPaths((BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|CONTAINS)*..20]-(CompareStandard))
MATCH p = allShortestPaths((BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|SAME|CONTAINS)*..20]-(CompareStandard))
WITH p
WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE (n:NeoCRE or n = BaseStandard or n = CompareStandard) AND NOT n.name in $denylist)
RETURN p
""",
# """
# OPTIONAL MATCH (BaseStandard:NeoStandard {name: $name1})
# OPTIONAL MATCH (CompareStandard:NeoStandard {name: $name2})
# OPTIONAL MATCH p = allShortestPaths((BaseStandard)-[:(LINKED_TO|AUTOMATICALLY_LINKED_TO|CONTAINS)*..20]-(CompareStandard))
# WITH p
# WHERE length(p) > 1 AND ALL(n in NODES(p) WHERE (n:NeoCRE or n = BaseStandard or n = CompareStandard) AND NOT n.name in $denylist)
# RETURN p
# """,
{"name1": name_1, "name2": name_2, "denylist": denylist},
resolve_objects=True,
)
t3 = datetime.now()

if path_records and len(path_records) > 0:
logger.info(
f"Gap Analysis: Tier 2 (Medium) found {len(path_records)} paths. Pruning remainder."
)
return self._format_gap_analysis_response(base_standard, path_records)

# Tier 3: Weak/All Links (Wildcard - The original expensive query)
logger.info(
"Gap Analysis: Tiers 1 & 2 empty. Executing Tier 3 (Wildcard search)."
)
path_records_all, _ = db.cypher_query(
"""
MATCH (BaseStandard:NeoStandard {name: $name1})
MATCH (CompareStandard:NeoStandard {name: $name2})
MATCH p = allShortestPaths((BaseStandard)-[*..20]-(CompareStandard))
WITH p
WHERE length(p) > 1 AND ALL (n in NODES(p) where (n:NeoCRE or n = BaseStandard or n = CompareStandard) AND NOT n.name in $denylist)
RETURN p
""",
{"name1": name_1, "name2": name_2, "denylist": denylist},
resolve_objects=True,
)

return self._format_gap_analysis_response(base_standard, path_records_all)

@classmethod
def _format_gap_analysis_response(self, base_standard, path_records):
def format_segment(seg: StructuredRel, nodes):
relation_map = {
RelatedRel: "RELATED",
ContainsRel: "CONTAINS",
LinkedToRel: "LINKED_TO",
AutoLinkedToRel: "AUTOMATICALLY_LINKED_TO",
SameRel: "SAME",
}
start_node = [
node for node in nodes if node.element_id == seg._start_node_element_id
Expand All @@ -626,10 +645,13 @@ def format_segment(seg: StructuredRel, nodes):
node for node in nodes if node.element_id == seg._end_node_element_id
][0]

# Default to RELATED if relation unknown (though mostly governed by class type)
rtype = relation_map.get(type(seg), "RELATED")

return {
"start": NEO_DB.parse_node_no_links(start_node),
"end": NEO_DB.parse_node_no_links(end_node),
"relationship": relation_map[type(seg)],
"relationship": rtype,
}

def format_path_record(rec):
Expand All @@ -640,7 +662,7 @@ def format_path_record(rec):
}

return [NEO_DB.parse_node_no_links(rec) for rec in base_standard], [
format_path_record(rec[0]) for rec in (path_records + path_records_all)
format_path_record(rec[0]) for rec in path_records
]

@classmethod
Expand Down
67 changes: 67 additions & 0 deletions application/tests/gap_analysis_db_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest
from unittest.mock import MagicMock, patch
from application.database import db
from application.defs import cre_defs as defs


class TestGapAnalysisPruning(unittest.TestCase):
def setUp(self):
# Patch the entire Class to avoid descriptor issues with .nodes
self.mock_NeoStandard = patch("application.database.db.NeoStandard").start()
self.mock_cypher = patch("application.database.db.db.cypher_query").start()
self.addCleanup(patch.stopall)

def test_tiered_execution_optimization(self):
"""
Verify that if Tier 1 (Strong) returns results, we DO NOT execute Tier 3 (Broad).
"""
strong_path_mock = [MagicMock()]
empty_result = []

# Configure the class mock
# NeoStandard.nodes.filter(...) should return a list
self.mock_NeoStandard.nodes.filter.return_value = []

# We will use a side_effect to return different results based on the query content
def cypher_side_effect(query, params=None, resolve_objects=True):
# Crude way to detect query type by checking for unique relationship strings
if "LINKED_TO|AUTOMATICALLY_LINKED_TO|SAME" in query: # Tier 1 (Strong)
return strong_path_mock, None
if "CONTAINS" in query: # Tier 2 (Medium)
return empty_result, None
if "[*..20]" in query: # Tier 3 (Broad/Weak)
return empty_result, None
return empty_result, None

self.mock_cypher.side_effect = cypher_side_effect

# Call the function
db.NEO_DB.gap_analysis("StandardA", "StandardB")

# ASSERTION:
# We expect cypher_query to be called.
# BUT, we expect it to be called ONLY for Tier 1 (and maybe Tier 2 setups),
# but DEFINITELY NOT for the broad Tier 3 query if Tier 1 found something.

# Let's inspect all calls to cypher_query
calls = self.mock_cypher.call_args_list

tier_1_called = False
tier_3_called = False

for call in calls:
query_str = call[0][0]
if "LINKED_TO|AUTOMATICALLY_LINKED_TO" in query_str:
tier_1_called = True
if "[*..20]" in query_str:
tier_3_called = True

self.assertTrue(tier_1_called, "Tier 1 query should have been executed")
self.assertFalse(
tier_3_called,
"Tier 3 (Wildcard) query should NOT have been executed because Tier 1 found paths",
)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion application/utils/gap_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from rq import Queue, job, exceptions
from typing import List, Dict
from application.utils import redis
from application.database import db
from flask import json as flask_json
import json
from application.defs import cre_defs as defs
Expand Down Expand Up @@ -62,6 +61,8 @@ def get_next_id(step, previous_id):

# database is of type Node_collection, cannot annotate due to circular import
def schedule(standards: List[str], database):
from application.database import db

standards_hash = make_resources_key(standards)
if database.gap_analysis_exists(
standards_hash
Expand Down