Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/memos/api/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Krolik middleware extensions for MemOS."""

from .auth import verify_api_key, require_scope, require_admin, require_read, require_write
from .auth import require_admin, require_read, require_scope, require_write, verify_api_key
from .rate_limit import RateLimitMiddleware


__all__ = [
"verify_api_key",
"require_scope",
"RateLimitMiddleware",
"require_admin",
"require_read",
"require_scope",
"require_write",
"RateLimitMiddleware",
"verify_api_key",
]
2 changes: 1 addition & 1 deletion src/memos/api/utils/api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"""

import hashlib
import os
import secrets

from dataclasses import dataclass
from datetime import datetime, timedelta

Expand Down
32 changes: 31 additions & 1 deletion src/memos/graph_dbs/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
import re

from abc import ABC, abstractmethod
from typing import Any, Literal


# Pattern for valid field names: alphanumeric and underscores, must start with letter or underscore
_VALID_FIELD_NAME_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")


class BaseGraphDB(ABC):
"""
Abstract base class for a graph database interface used in a memory-augmented RAG system.
"""

@staticmethod
def _validate_return_fields(return_fields: list[str] | None) -> list[str]:
"""Validate and sanitize return_fields to prevent query injection.

Only allows alphanumeric characters and underscores in field names.
Silently drops invalid field names.

Args:
return_fields: List of field names to validate.

Returns:
List of valid field names.
"""
if not return_fields:
return []
return [f for f in return_fields if _VALID_FIELD_NAME_RE.match(f)]

# Node (Memory) Management
@abstractmethod
def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
Expand Down Expand Up @@ -144,16 +167,23 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:

# Search / recall operations
@abstractmethod
def search_by_embedding(self, vector: list[float], top_k: int = 5, **kwargs) -> list[dict]:
def search_by_embedding(
self, vector: list[float], top_k: int = 5, return_fields: list[str] | None = None, **kwargs
) -> list[dict]:
"""
Retrieve node IDs based on vector similarity.

Args:
vector (list[float]): The embedding vector representing query semantics.
top_k (int): Number of top similar nodes to retrieve.
return_fields (list[str], optional): Additional node fields to include in results
(e.g., ["memory", "status", "tags"]). When provided, each result dict will
contain these fields in addition to 'id' and 'score'.
Defaults to None (only 'id' and 'score' are returned).

Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
If return_fields is specified, each dict also includes the requested fields.

Notes:
- This method may internally call a VecDB (e.g., Qdrant) or store embeddings in the graph DB itself.
Expand Down
27 changes: 25 additions & 2 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def search_by_embedding(
user_name: str | None = None,
filter: dict | None = None,
knowledgebase_ids: list[str] | None = None,
return_fields: list[str] | None = None,
**kwargs,
) -> list[dict]:
"""
Expand All @@ -832,9 +833,14 @@ def search_by_embedding(
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
search_filter (dict, optional): Additional metadata filters for search results.
Keys should match node properties, values are the expected values.
return_fields (list[str], optional): Additional node fields to include in results
(e.g., ["memory", "status", "tags"]). When provided, each result
dict will contain these fields in addition to 'id' and 'score'.
Defaults to None (only 'id' and 'score' are returned).

Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
If return_fields is specified, each dict also includes the requested fields.

Notes:
- This method uses Neo4j native vector indexing to search for similar nodes.
Expand Down Expand Up @@ -886,11 +892,20 @@ def search_by_embedding(
if where_clauses:
where_clause = "WHERE " + " AND ".join(where_clauses)

return_clause = "RETURN node.id AS id, score"
if return_fields:
validated_fields = self._validate_return_fields(return_fields)
extra_fields = ", ".join(
f"node.{field} AS {field}" for field in validated_fields if field != "id"
)
if extra_fields:
return_clause = f"RETURN node.id AS id, score, {extra_fields}"

query = f"""
CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding)
YIELD node, score
{where_clause}
RETURN node.id AS id, score
{return_clause}
"""

parameters = {"embedding": vector, "k": top_k}
Expand Down Expand Up @@ -920,7 +935,15 @@ def search_by_embedding(
print(f"[search_by_embedding] query: {query},parameters: {parameters}")
with self.driver.session(database=self.db_name) as session:
result = session.run(query, parameters)
records = [{"id": record["id"], "score": record["score"]} for record in result]
records = []
for record in result:
item = {"id": record["id"], "score": record["score"]}
if return_fields:
record_keys = record.keys()
for field in return_fields:
if field != "id" and field in record_keys:
item[field] = record[field]
records.append(item)

# Threshold filtering after retrieval
if threshold is not None:
Expand Down
87 changes: 80 additions & 7 deletions src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,39 @@ def get_children_with_embeddings(

return child_nodes

def _fetch_return_fields(
self,
ids: list[str],
score_map: dict[str, float],
return_fields: list[str],
) -> list[dict]:
"""Fetch additional fields from Neo4j for given node IDs."""
validated_fields = self._validate_return_fields(return_fields)
extra_fields = ", ".join(
f"n.{field} AS {field}" for field in validated_fields if field != "id"
)
return_clause = "RETURN n.id AS id"
if extra_fields:
return_clause = f"RETURN n.id AS id, {extra_fields}"

query = f"""
MATCH (n:Memory)
WHERE n.id IN $ids
{return_clause}
"""
with self.driver.session(database=self.db_name) as session:
neo4j_results = session.run(query, {"ids": ids})
results = []
for record in neo4j_results:
node_id = record["id"]
item = {"id": node_id, "score": score_map.get(node_id)}
record_keys = record.keys()
for field in return_fields:
if field != "id" and field in record_keys:
item[field] = record[field]
results.append(item)
return results

# Search / recall operations
def search_by_embedding(
self,
Expand All @@ -258,6 +291,7 @@ def search_by_embedding(
user_name: str | None = None,
filter: dict | None = None,
knowledgebase_ids: list[str] | None = None,
return_fields: list[str] | None = None,
**kwargs,
) -> list[dict]:
"""
Expand All @@ -273,9 +307,14 @@ def search_by_embedding(
filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results.
Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]}
knowledgebase_ids (list[str], optional): List of knowledgebase IDs to filter by.
return_fields (list[str], optional): Additional node fields to include in results
(e.g., ["memory", "status", "tags"]). When provided, each result dict will
contain these fields in addition to 'id' and 'score'.
Defaults to None (only 'id' and 'score' are returned).

Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
If return_fields is specified, each dict also includes the requested fields.

Notes:
- This method uses an external vector database (not Neo4j) to perform the search.
Expand Down Expand Up @@ -320,7 +359,14 @@ def search_by_embedding(

# If no filter or knowledgebase_ids provided, return vector search results directly
if not filter and not knowledgebase_ids:
return [{"id": r.id, "score": r.score} for r in vec_results]
if not return_fields:
return [{"id": r.id, "score": r.score} for r in vec_results]
# Need to fetch additional fields from Neo4j
vec_ids = [r.id for r in vec_results]
if not vec_ids:
return []
score_map = {r.id: r.score for r in vec_results}
return self._fetch_return_fields(vec_ids, score_map, return_fields)

# Extract IDs from vector search results
vec_ids = [r.id for r in vec_results]
Expand Down Expand Up @@ -363,22 +409,49 @@ def search_by_embedding(
if filter_params:
params.update(filter_params)

# Build RETURN clause with optional extra fields
return_clause = "RETURN n.id AS id"
if return_fields:
validated_fields = self._validate_return_fields(return_fields)
extra_fields = ", ".join(
f"n.{field} AS {field}" for field in validated_fields if field != "id"
)
if extra_fields:
return_clause = f"RETURN n.id AS id, {extra_fields}"

# Query Neo4j to filter results
query = f"""
MATCH (n:Memory)
{where_clause}
RETURN n.id AS id
{return_clause}
"""
logger.info(f"[search_by_embedding] query: {query}, params: {params}")

with self.driver.session(database=self.db_name) as session:
neo4j_results = session.run(query, params)
filtered_ids = {record["id"] for record in neo4j_results}
if return_fields:
# Build a map of id -> extra fields from Neo4j results
neo4j_data = {}
for record in neo4j_results:
node_id = record["id"]
record_keys = record.keys()
neo4j_data[node_id] = {
field: record[field]
for field in return_fields
if field != "id" and field in record_keys
}
filtered_ids = set(neo4j_data.keys())
else:
filtered_ids = {record["id"] for record in neo4j_results}

# Filter vector results by Neo4j filtered IDs and return with scores
filtered_results = [
{"id": r.id, "score": r.score} for r in vec_results if r.id in filtered_ids
]
filtered_results = []
for r in vec_results:
if r.id in filtered_ids:
item = {"id": r.id, "score": r.score}
if return_fields and r.id in neo4j_data:
item.update(neo4j_data[r.id])
filtered_results.append(item)

return filtered_results

Expand Down Expand Up @@ -1102,7 +1175,7 @@ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]]
# Merge embeddings into parsed nodes
for parsed_node in parsed_nodes:
node_id = parsed_node["id"]
parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id, None)
parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id)

return parsed_nodes

Expand Down
Loading