Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,20 +700,38 @@ def get_neo4j_config(user_id: str | None = None) -> dict[str, Any]:
@staticmethod
def get_noshared_neo4j_config(user_id) -> dict[str, Any]:
"""Get Neo4j configuration."""
return {
config = {
"uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"),
"user": os.getenv("NEO4J_USER", "neo4j"),
"db_name": f"memos{user_id.replace('-', '')}",
"db_name": f"{user_id.replace('_', '-')}",
"password": os.getenv("NEO4J_PASSWORD", "12345678"),
"auto_create": True,
"use_multi_db": True,
Comment on lines 705 to 709
"embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
}

enable_qdrant_sync = os.getenv("MOS_NEO4J_ENABLE_QDRANT_SYNC", "true").lower() == "true"
if enable_qdrant_sync:
config["vec_config"] = {
"backend": "qdrant",
"config": {
"collection_name": "neo4j_vec_db",
"vector_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
"distance_metric": "cosine",
"host": os.getenv("QDRANT_HOST", "localhost"),
"port": int(os.getenv("QDRANT_PORT", "6333")),
"path": os.getenv("QDRANT_PATH"),
"url": os.getenv("QDRANT_URL"),
"api_key": os.getenv("QDRANT_API_KEY"),
},
}

return config

@staticmethod
def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]:
"""Get Neo4j configuration."""
return {
config = {
"uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"),
"user": os.getenv("NEO4J_USER", "neo4j"),
"db_name": os.getenv("NEO4J_DB_NAME", "shared-tree-textual-memory"),
Expand All @@ -724,6 +742,24 @@ def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]:
"embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
}

enable_qdrant_sync = os.getenv("MOS_NEO4J_ENABLE_QDRANT_SYNC", "true").lower() == "true"
if enable_qdrant_sync:
config["vec_config"] = {
"backend": "qdrant",
"config": {
"collection_name": "neo4j_vec_db",
"vector_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
"distance_metric": "cosine",
"host": os.getenv("QDRANT_HOST", "localhost"),
"port": int(os.getenv("QDRANT_PORT", "6333")),
"path": os.getenv("QDRANT_PATH"),
"url": os.getenv("QDRANT_URL"),
"api_key": os.getenv("QDRANT_API_KEY"),
},
}

return config

@staticmethod
def get_nebular_config(user_id: str | None = None) -> dict[str, Any]:
"""Get Nebular configuration."""
Expand Down
47 changes: 43 additions & 4 deletions src/memos/api/handlers/add_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
using dependency injection for better modularity and testability.
"""

import os
import threading

from pydantic import validate_call

from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
from memos.api.handlers.component_init import create_per_db_components
from memos.api.product_models import APIADDRequest, APIFeedbackRequest, MemoryResponse
from memos.memories.textual.item import (
list_all_fields,
Expand Down Expand Up @@ -36,6 +40,8 @@ def __init__(self, dependencies: HandlerDependencies):
self._validate_dependencies(
"naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server"
)
self._per_user_cube_cache: dict[str, dict] = {}
self._cache_lock = threading.Lock()

def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse:
"""
Expand Down Expand Up @@ -113,6 +119,31 @@ def _check_messages(messages: MessageList) -> None:
data=results,
)

@property
def _is_neo4j_multidb(self) -> bool:
"""Return True when using Neo4j enterprise with one-database-per-user mode."""
backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "")).lower()
shared_db = os.getenv("MOS_NEO4J_SHARED_DB", "false").lower() == "true"
return backend == "neo4j" and not shared_db

def _get_per_user_components(self, user_id: str) -> dict:
"""Return (creating on first access) per-user graph/mem components.

Uses double-checked locking so the expensive component creation happens
only once per user even under concurrent requests.
"""
if user_id not in self._per_user_cube_cache:
with self._cache_lock:
if user_id not in self._per_user_cube_cache:
self.logger.info(
f"[AddHandler] Creating per-user components for user_id={user_id!r}"
)
self._per_user_cube_cache[user_id] = create_per_db_components(
db_name=user_id,
base_components=vars(self.deps),
)
return self._per_user_cube_cache[user_id]
Comment on lines +135 to +145

def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]:
"""
Normalize target cube ids from add_req.
Expand All @@ -128,12 +159,20 @@ def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]:
def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView:
cube_ids = self._resolve_cube_ids(add_req)

if self._is_neo4j_multidb:
per_user = self._get_per_user_components(add_req.user_id)
naive_mem_cube = per_user["naive_mem_cube"]
mem_reader = per_user["mem_reader"]
else:
naive_mem_cube = self.naive_mem_cube
mem_reader = self.mem_reader

if len(cube_ids) == 1:
cube_id = cube_ids[0]
return SingleCubeView(
cube_id=cube_id,
naive_mem_cube=self.naive_mem_cube,
mem_reader=self.mem_reader,
naive_mem_cube=naive_mem_cube,
mem_reader=mem_reader,
mem_scheduler=self.mem_scheduler,
logger=self.logger,
feedback_server=self.feedback_server,
Expand All @@ -143,8 +182,8 @@ def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView:
single_views = [
SingleCubeView(
cube_id=cube_id,
naive_mem_cube=self.naive_mem_cube,
mem_reader=self.mem_reader,
naive_mem_cube=naive_mem_cube,
mem_reader=mem_reader,
mem_scheduler=self.mem_scheduler,
logger=self.logger,
feedback_server=self.feedback_server,
Expand Down
81 changes: 81 additions & 0 deletions src/memos/api/handlers/component_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,84 @@ def init_server() -> dict[str, Any]:
"nli_client": nli_client,
"memory_history_manager": memory_history_manager,
}


def create_per_db_components(db_name: str, base_components: dict[str, Any]) -> dict[str, Any]:
"""Create a set of per-database components for multi-db isolation.

Reuses expensive shared singletons (LLM, embedder, reranker, etc.) but builds
a fresh graph_db, MemoryManager, SimpleTreeTextMemory, NaiveMemCube, and
searcher for the specified Neo4j database name.

The returned ``mem_reader`` is a shallow copy of the shared one whose
``searcher`` is overridden to point at the new database, so deduplication
during add operates against the correct graph.

Args:
db_name: Target Neo4j database name (auto-created when ``auto_create=True``).
base_components: Shared component dict returned by :func:`init_server`.

Returns:
Dict with keys: ``graph_db``, ``memory_manager``, ``text_mem``,
``naive_mem_cube``, ``searcher``, ``mem_reader``.
"""
import copy

from memos.api.config import APIConfig
from memos.configs.graph_db import GraphDBConfigFactory
from memos.graph_dbs.factory import GraphStoreFactory

graph_db_backend = os.getenv(
"GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j")
).lower()
neo4j_cfg = APIConfig.get_neo4j_config(user_id=db_name)
new_graph_db = GraphStoreFactory.from_config(
GraphDBConfigFactory.model_validate(
{"backend": graph_db_backend, "config": neo4j_cfg}
)
)

default_cube_config = base_components["default_cube_config"]
new_memory_manager = MemoryManager(
new_graph_db,
base_components["embedder"],
base_components["llm"],
memory_size=_get_default_memory_size(default_cube_config),
is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False),
)

new_text_mem = SimpleTreeTextMemory(
llm=base_components["llm"],
embedder=base_components["embedder"],
mem_reader=base_components["mem_reader"],
graph_db=new_graph_db,
reranker=base_components["reranker"],
memory_manager=new_memory_manager,
config=default_cube_config.text_mem.config,
internet_retriever=base_components["internet_retriever"],
tokenizer=FastTokenizer(),
include_embedding=bool(os.getenv("INCLUDE_EMBEDDING", "false") == "true"),
)

new_naive_mem_cube = NaiveMemCube(text_mem=new_text_mem, act_mem=None, para_mem=None)

new_searcher = new_text_mem.get_searcher(
manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
moscube=False,
process_llm=base_components["mem_reader"].llm,
)

# Shallow-copy the shared mem_reader and point its searcher at the new database
# so deduplication reads target the correct graph store.
new_mem_reader = copy.copy(base_components["mem_reader"])
Comment on lines +386 to +388
new_mem_reader.set_searcher(new_searcher)

logger.info(f"[create_per_db_components] Created components for db_name={db_name!r}")
return {
"graph_db": new_graph_db,
"memory_manager": new_memory_manager,
"text_mem": new_text_mem,
"naive_mem_cube": new_naive_mem_cube,
"searcher": new_searcher,
"mem_reader": new_mem_reader,
}
83 changes: 68 additions & 15 deletions src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@

import copy
import math
import os
import threading

from typing import Any

from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
from memos.api.handlers.component_init import create_per_db_components
from memos.api.handlers.formatters_handler import rerank_knowledge_mem
from memos.api.product_models import APISearchRequest, SearchResponse
from memos.log import get_logger
from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
cosine_similarity_matrix,
)
Expand Down Expand Up @@ -43,6 +47,35 @@ def __init__(self, dependencies: HandlerDependencies):
self._validate_dependencies(
"naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent"
)
# Cache per-database components in Neo4j multi-db mode.
self._per_db_cube_cache: dict[str, dict[str, Any]] = {}
self._cache_lock = threading.Lock()

@property
def _is_neo4j_multidb(self) -> bool:
"""Return True when using Neo4j enterprise with one-database-per-cube mode."""
backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "")).lower()
shared_db = os.getenv("MOS_NEO4J_SHARED_DB", "false").lower() == "true"
return backend == "neo4j" and not shared_db

def _get_per_db_components(self, db_name: str) -> dict[str, Any]:
"""Return cached per-db components, creating them on first access."""
if db_name not in self._per_db_cube_cache:
with self._cache_lock:
if db_name not in self._per_db_cube_cache:
self.logger.info(
f"[SearchHandler] Creating per-db components for db_name={db_name!r}"
)
per_db = create_per_db_components(
db_name=db_name,
base_components=vars(self.deps),
)
per_db["deepsearch_agent"] = DeepSearchMemAgent(
llm=self.llm,
memory_retriever=per_db["text_mem"],
)
self._per_db_cube_cache[db_name] = per_db
return self._per_db_cube_cache[db_name]
Comment on lines +61 to +78

def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse:
"""
Expand Down Expand Up @@ -801,8 +834,28 @@ def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]:

def _build_cube_view(self, search_req: APISearchRequest, searcher=None) -> MemCubeView:
cube_ids = self._resolve_cube_ids(search_req)
searcher_to_use = searcher if searcher is not None else self.searcher
if self._is_neo4j_multidb:
single_views = []
for cube_id in cube_ids:
per_db = self._get_per_db_components(cube_id)
searcher_to_use = searcher if searcher is not None else per_db["searcher"]
single_views.append(
SingleCubeView(
cube_id=cube_id,
naive_mem_cube=per_db["naive_mem_cube"],
mem_reader=per_db["mem_reader"],
mem_scheduler=self.mem_scheduler,
logger=self.logger,
searcher=searcher_to_use,
deepsearch_agent=per_db["deepsearch_agent"],
)
)

if len(single_views) == 1:
return single_views[0]
return CompositeCubeView(cube_views=single_views, logger=self.logger)

searcher_to_use = searcher if searcher is not None else self.searcher
if len(cube_ids) == 1:
cube_id = cube_ids[0]
return SingleCubeView(
Expand All @@ -814,17 +867,17 @@ def _build_cube_view(self, search_req: APISearchRequest, searcher=None) -> MemCu
searcher=searcher_to_use,
deepsearch_agent=self.deepsearch_agent,
)
else:
single_views = [
SingleCubeView(
cube_id=cube_id,
naive_mem_cube=self.naive_mem_cube,
mem_reader=self.mem_reader,
mem_scheduler=self.mem_scheduler,
logger=self.logger,
searcher=searcher_to_use,
deepsearch_agent=self.deepsearch_agent,
)
for cube_id in cube_ids
]
return CompositeCubeView(cube_views=single_views, logger=self.logger)

single_views = [
SingleCubeView(
cube_id=cube_id,
naive_mem_cube=self.naive_mem_cube,
mem_reader=self.mem_reader,
mem_scheduler=self.mem_scheduler,
logger=self.logger,
searcher=searcher_to_use,
deepsearch_agent=self.deepsearch_agent,
)
for cube_id in cube_ids
]
return CompositeCubeView(cube_views=single_views, logger=self.logger)
7 changes: 7 additions & 0 deletions src/memos/configs/graph_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ class Neo4jGraphDBConfig(BaseGraphDBConfig):
)

embedding_dimension: int = Field(default=768, description="Dimension of vector embedding")
vec_config: VectorDBConfigFactory | None = Field(
default=None,
description=(
"Optional external vector DB config for syncing embeddings (e.g., Qdrant). "
"When provided, graph writes can also sync to vector storage."
),
)

@model_validator(mode="after")
def validate_config(self):
Expand Down
Loading