Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import random
import re
import textwrap
import threading
import time
Expand All @@ -20,6 +21,34 @@
logger = get_logger(__name__)


def _sanitize_tsquery_words(query_words: list[str]) -> list[str]:
"""Sanitize query words for safe use with PostgreSQL to_tsquery().

Strips tsquery operator characters and other special symbols that can
cause parsing errors when mixed content (e.g. message IDs with
underscores, Chinese text) is passed to ``to_tsquery``. Each word is
reduced to its alphanumeric/CJK core so that the jieba text-search
configuration can tokenize it correctly.

Returns a de-duplicated list of non-empty sanitized words.
"""
# Keep word characters (letters, digits, underscore) and CJK unified ideographs.
valid_chars_re = re.compile(
r"[^\w\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff]",
)
Comment on lines +35 to +38
Comment on lines +25 to +38
sanitized: list[str] = []
seen: set[str] = set()
for w in query_words:
# Strip surrounding single quotes that callers may have added for tsquery
w = w.strip().strip("'")
# Remove characters that are not word-characters or CJK
cleaned = valid_chars_re.sub("", w)
if cleaned and cleaned not in seen:
seen.add(cleaned)
sanitized.append(cleaned)
return sanitized


def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
node_id = item["id"]
memory = item["memory"]
Expand Down Expand Up @@ -1653,8 +1682,11 @@ def search_by_keywords_tfidf(
filter_conditions = self._build_filter_conditions_sql(filter)
where_clauses.extend(filter_conditions)
# Add fulltext search condition
# Convert query_text to OR query format: "word1 | word2 | word3"
tsquery_string = " | ".join(query_words)
# Sanitize and convert query_text to OR query format: "word1 | word2 | word3"
safe_words = _sanitize_tsquery_words(query_words)
Comment on lines +1685 to +1686
if not safe_words:
return []
tsquery_string = " | ".join(safe_words)

where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)")

Expand Down Expand Up @@ -1768,7 +1800,10 @@ def search_by_fulltext(
filter_conditions = self._build_filter_conditions_sql(filter)

where_clauses.extend(filter_conditions)
tsquery_string = " | ".join(query_words)
safe_words = _sanitize_tsquery_words(query_words)
if not safe_words:
return []
tsquery_string = " | ".join(safe_words)

where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)")

Expand Down
86 changes: 86 additions & 0 deletions tests/graph_dbs/test_sanitize_tsquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Tests for _sanitize_tsquery_words — standalone, no heavy imports."""

import re


# ---------------------------------------------------------------------------
# Inline the function under test to avoid pulling in the full memos import
# chain (which requires a running logging backend). The canonical copy lives
# in ``memos.graph_dbs.polardb._sanitize_tsquery_words``.
# ---------------------------------------------------------------------------


def _sanitize_tsquery_words(query_words: list[str]) -> list[str]:
valid_chars_re = re.compile(
r"[^\w\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff]",
)
sanitized: list[str] = []
seen: set[str] = set()
for w in query_words:
w = w.strip().strip("'")
cleaned = valid_chars_re.sub("", w)
if cleaned and cleaned not in seen:
seen.add(cleaned)
sanitized.append(cleaned)
return sanitized


Comment on lines +1 to +27
class TestSanitizeTsqueryWords:
"""Unit tests for FTS query word sanitization."""

def test_plain_english_words(self):
assert _sanitize_tsquery_words(["hello", "world"]) == ["hello", "world"]

def test_chinese_text(self):
result = _sanitize_tsquery_words(["我要", "测试"])
assert result == ["我要", "测试"]

def test_mixed_content_message_id_and_chinese(self):
"""Reproduce the original bug: mixed IDs + Chinese text."""
words = ["message_id", "om_x100b544a390604b8c3e1b7d8641f08e", "我要测试"]
result = _sanitize_tsquery_words(words)
assert len(result) == 3
assert "message_id" in result
assert "om_x100b544a390604b8c3e1b7d8641f08e" in result
assert "我要测试" in result

def test_single_quoted_words_are_stripped(self):
words = ["'hello'", "'world'"]
result = _sanitize_tsquery_words(words)
assert result == ["hello", "world"]

def test_special_characters_removed(self):
words = ["hello!", "world@#$"]
result = _sanitize_tsquery_words(words)
assert result == ["hello", "world"]

def test_empty_words_filtered(self):
words = ["", " ", "hello", ""]
result = _sanitize_tsquery_words(words)
assert result == ["hello"]

def test_deduplication(self):
words = ["hello", "hello", "world"]
result = _sanitize_tsquery_words(words)
assert result == ["hello", "world"]

def test_empty_input(self):
assert _sanitize_tsquery_words([]) == []

def test_all_special_chars_returns_empty(self):
words = ["!@#", "$%^"]
result = _sanitize_tsquery_words(words)
assert result == []

def test_underscores_preserved(self):
words = ["message_id", "user_name"]
result = _sanitize_tsquery_words(words)
assert result == ["message_id", "user_name"]

def test_tsquery_operators_stripped(self):
"""Tsquery operators like & | ! should be stripped from within words."""
words = ["hello & world", "foo | bar"]
result = _sanitize_tsquery_words(words)
# Spaces and operators removed; alphanumeric parts merge
assert "helloworld" in result
assert "foobar" in result
Loading