Skip to content

Commit be591d3

Browse files
azevaykinCopilot
andauthored
Knn python tests (#30328)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 9e88859 commit be591d3

File tree

5 files changed

+294
-0
lines changed

5 files changed

+294
-0
lines changed

ydb/tests/compatibility/test_vector_index.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,30 @@ def _write_data(self, name, vector_type, table_name):
8585
with ydb.QuerySessionPool(self.driver) as session_pool:
8686
session_pool.execute_with_retries(sql_upsert)
8787

88+
def _get_knn_queries(self):
89+
"""Get KNN search queries without vector index (brute force with pushdown)."""
90+
queries = []
91+
for prefixed in ['', '_pfx']:
92+
for vector_type in self.vector_types.keys():
93+
for distance in self.targets.keys():
94+
for distance_func in self.targets[distance].keys():
95+
table_name = f"{vector_type}_{distance}_{distance_func}{prefixed}"
96+
order = "ASC" if distance != "similarity" else "DESC"
97+
vector = self.get_vector(f"{vector_type}Vector", 1)
98+
where = ""
99+
if prefixed:
100+
where = "WHERE user=1"
101+
queries.append([
102+
True, f"""
103+
$Target = {self.vector_types[vector_type]}(Cast([{vector}] AS List<{vector_type}>));
104+
SELECT key, vec, {self.targets[distance][distance_func]}(vec, $Target) as target
105+
FROM `{table_name}`
106+
{where}
107+
ORDER BY {self.targets[distance][distance_func]}(vec, $Target) {order}
108+
LIMIT {self.rows_count};"""
109+
])
110+
return queries
111+
88112
def _get_queries(self):
89113
queries = []
90114
for prefixed in ['', '_pfx']:
@@ -159,6 +183,12 @@ def select_from_index_without_roll(self):
159183
queries = self._get_queries()
160184
self._do_queries(queries)
161185

186+
def knn_search(self):
187+
"""Perform KNN search without vector index during rolling upgrade/downgrade."""
188+
queries = self._get_knn_queries()
189+
for _ in self.roll():
190+
self._do_queries(queries)
191+
162192
def create_table(self, table_name):
163193
query = f"""
164194
CREATE TABLE {table_name} (
@@ -189,5 +219,6 @@ def test_vector_index(self):
189219
target=f"{distance}={distance_func}",
190220
prefixed=prefixed,
191221
)
222+
self.knn_search()
192223
self.wait_index_ready()
193224
self.select_from_index()
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from ydb.tests.datashard.lib.vector_base import VectorBase
2+
3+
4+
class TestKnn(VectorBase):
5+
table_name = "TestTable"
6+
target_table_name = "TargetVectors"
7+
# Target embedding: 0x67, 0x71 (103, 113)
8+
target_embedding_hex = "677102"
9+
10+
@classmethod
11+
def setup_class(cls):
12+
super().setup_class()
13+
cls._create_tables()
14+
15+
@classmethod
16+
def teardown_class(cls):
17+
cls._cleanup_tables()
18+
super().teardown_class()
19+
20+
@classmethod
21+
def _cleanup_tables(cls):
22+
try:
23+
cls.pool.execute_with_retries(f"DROP TABLE `{cls.table_name}`")
24+
except Exception:
25+
pass
26+
try:
27+
cls.pool.execute_with_retries(f"DROP TABLE `{cls.target_table_name}`")
28+
except Exception:
29+
pass
30+
31+
@classmethod
32+
def _create_tables(cls):
33+
cls._cleanup_tables()
34+
35+
cls.pool.execute_with_retries(f"""
36+
CREATE TABLE `{cls.table_name}` (
37+
pk Int64 NOT NULL,
38+
emb String NOT NULL,
39+
data String NOT NULL,
40+
PRIMARY KEY (pk)
41+
);
42+
""")
43+
44+
cls.pool.execute_with_retries(f"""
45+
UPSERT INTO `{cls.table_name}` (pk, emb, data) VALUES
46+
(0, Unwrap(String::HexDecode("033002")), "0"),
47+
(1, Unwrap(String::HexDecode("133102")), "1"),
48+
(2, Unwrap(String::HexDecode("233202")), "2"),
49+
(3, Unwrap(String::HexDecode("533302")), "3"),
50+
(4, Unwrap(String::HexDecode("433402")), "4"),
51+
(5, Unwrap(String::HexDecode("506002")), "5"),
52+
(6, Unwrap(String::HexDecode("611102")), "6"),
53+
(7, Unwrap(String::HexDecode("126202")), "7"),
54+
(8, Unwrap(String::HexDecode("757602")), "8"),
55+
(9, Unwrap(String::HexDecode("767602")), "9");
56+
""")
57+
58+
cls.pool.execute_with_retries(f"""
59+
CREATE TABLE `{cls.target_table_name}` (
60+
id Int64 NOT NULL,
61+
target_emb String,
62+
PRIMARY KEY (id)
63+
)
64+
""")
65+
66+
cls.pool.execute_with_retries(f"""
67+
UPSERT INTO `{cls.target_table_name}` (id, target_emb) VALUES
68+
(1, String::HexDecode("{cls.target_embedding_hex}"))
69+
""")
70+
71+
def test_knn_cosine_distance(self):
72+
"""Test KNN with CosineDistance."""
73+
result = self.query(f"""
74+
$TargetEmbedding = String::HexDecode("{self.target_embedding_hex}");
75+
SELECT pk FROM `{self.table_name}`
76+
ORDER BY Knn::CosineDistance(emb, $TargetEmbedding)
77+
LIMIT 3
78+
""")
79+
assert len(result) == 3
80+
81+
def test_knn_cosine_similarity(self):
82+
"""Test KNN with CosineSimilarity (DESC order)."""
83+
result = self.query(f"""
84+
$TargetEmbedding = String::HexDecode("{self.target_embedding_hex}");
85+
SELECT pk FROM `{self.table_name}`
86+
ORDER BY Knn::CosineSimilarity(emb, $TargetEmbedding) DESC
87+
LIMIT 3
88+
""")
89+
assert len(result) == 3
90+
91+
def test_knn_inner_product_similarity(self):
92+
"""Test KNN with InnerProductSimilarity (DESC order)."""
93+
result = self.query(f"""
94+
$TargetEmbedding = String::HexDecode("{self.target_embedding_hex}");
95+
SELECT pk FROM `{self.table_name}`
96+
ORDER BY Knn::InnerProductSimilarity(emb, $TargetEmbedding) DESC
97+
LIMIT 3
98+
""")
99+
assert len(result) == 3
100+
101+
def test_knn_manhattan_distance(self):
102+
"""Test KNN with ManhattanDistance."""
103+
result = self.query(f"""
104+
$TargetEmbedding = String::HexDecode("{self.target_embedding_hex}");
105+
SELECT pk FROM `{self.table_name}`
106+
ORDER BY Knn::ManhattanDistance(emb, $TargetEmbedding)
107+
LIMIT 3
108+
""")
109+
assert len(result) == 3
110+
111+
def test_knn_euclidean_distance(self):
112+
"""Test KNN with EuclideanDistance."""
113+
result = self.query(f"""
114+
$TargetEmbedding = String::HexDecode("{self.target_embedding_hex}");
115+
SELECT pk FROM `{self.table_name}`
116+
ORDER BY Knn::EuclideanDistance(emb, $TargetEmbedding)
117+
LIMIT 3
118+
""")
119+
assert len(result) == 3
120+
121+
def test_knn_verify_results(self):
122+
"""
123+
Verify actual results - check that top 3 PKs are correct.
124+
Target vector is 0x67, 0x71 (103, 113)
125+
Expected cosine distances:
126+
pk=8 (117, 118): 0.000882 - closest
127+
pk=5 (80, 96): 0.000985
128+
pk=9 (118, 118): 0.001070
129+
"""
130+
result = self.query(f"""
131+
$TargetEmbedding = String::HexDecode("{self.target_embedding_hex}");
132+
SELECT pk, Knn::CosineDistance(emb, $TargetEmbedding) AS distance
133+
FROM `{self.table_name}`
134+
ORDER BY distance
135+
LIMIT 3
136+
""")
137+
138+
assert len(result) == 3
139+
pks = [row['pk'] for row in result]
140+
assert pks == [8, 5, 9], f"Expected PKs [8, 5, 9], got {pks}"
141+
142+
def test_knn_two_stage_query(self):
143+
"""
144+
Test two-stage query: quantized search followed by full precision reranking.
145+
"""
146+
result = self.query(f"""
147+
$TargetEmbedding = String::HexDecode("{self.target_embedding_hex}");
148+
149+
$Pks = SELECT pk
150+
FROM `{self.table_name}`
151+
ORDER BY Knn::CosineDistance(emb, $TargetEmbedding)
152+
LIMIT 3;
153+
154+
SELECT pk, Knn::CosineDistance(emb, $TargetEmbedding) AS distance
155+
FROM `{self.table_name}`
156+
WHERE pk IN $Pks
157+
ORDER BY distance
158+
LIMIT 3;
159+
""")
160+
assert len(result) == 3
161+
162+
def test_knn_subquery_target(self):
163+
"""Test KNN with target vector from subquery."""
164+
result = self.query(f"""
165+
$TargetEmbedding = (SELECT target_emb FROM `{self.target_table_name}` WHERE id = 1);
166+
SELECT pk, Knn::CosineDistance(emb, $TargetEmbedding) AS distance
167+
FROM `{self.table_name}`
168+
ORDER BY distance
169+
LIMIT 3
170+
""")
171+
assert len(result) == 3

ydb/tests/datashard/knn/ya.make

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
PY3TEST()
2+
INCLUDE(${ARCADIA_ROOT}/ydb/tests/harness_dep.inc)
3+
4+
FORK_SUBTESTS()
5+
SPLIT_FACTOR(10)
6+
7+
IF (SANITIZER_TYPE)
8+
SIZE(LARGE)
9+
INCLUDE(${ARCADIA_ROOT}/ydb/tests/large.inc)
10+
ELSE()
11+
SIZE(MEDIUM)
12+
ENDIF()
13+
14+
TEST_SRCS(
15+
test_knn.py
16+
)
17+
18+
PEERDIR(
19+
ydb/tests/datashard/lib
20+
)
21+
22+
DEPENDS(
23+
ydb/apps/ydb
24+
)
25+
26+
END()
27+

ydb/tests/datashard/ya.make

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ RECURSE(
44
dml
55
dump_restore
66
copy_table
7+
knn
78
lib
89
partitioning
910
parametrized_queries

ydb/tests/stress/oltp_workload/workload/type/vector_index.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,61 @@ def _select(self, index_name, table_path, vector_type, vector_dimension, distanc
143143
"""
144144
return self.client.query(select_sql, False)
145145

146+
def _knn_search(self, table_path, vector_type, vector_dimension, distance, similarity, prefixed=False):
147+
"""KNN search without vector index (brute force with pushdown)."""
148+
if distance is not None:
149+
target = targets["distance"][distance]
150+
else:
151+
target = targets["similarity"][similarity]
152+
order = "ASC" if distance is not None else "DESC"
153+
vector = self._get_random_vector(vector_type, vector_dimension)
154+
converter = to_binary_string_converters[vector_type]
155+
name = converter.name
156+
data_type = converter.data_type
157+
if prefixed:
158+
where = "WHERE user=1"
159+
else:
160+
where = ""
161+
select_sql = f"""
162+
$Target = {name}(Cast([{vector}] AS List<{data_type}>));
163+
SELECT pk, embedding, {target}(embedding, $Target) as target
164+
FROM `{table_path}`
165+
{where}
166+
ORDER BY {target}(embedding, $Target) {order}
167+
LIMIT {self.limit};
168+
"""
169+
return self.client.query(select_sql, False)
170+
171+
def _knn_search_check(self, table_path, vector_type, vector_dimension, distance, similarity, prefixed=False):
172+
"""Perform KNN search and verify results are ordered correctly."""
173+
logger.info(f"KNN search: vector_type={vector_type}, distance={distance}, similarity={similarity}")
174+
result_set = self._knn_search(
175+
table_path=table_path,
176+
vector_type=vector_type,
177+
vector_dimension=vector_dimension,
178+
distance=distance,
179+
similarity=similarity,
180+
prefixed=prefixed,
181+
)
182+
if len(result_set) == 0:
183+
raise Exception(
184+
f"KNN query returned an empty set for vector_type={vector_type}, distance={distance}, similarity={similarity}, prefixed={prefixed}"
185+
)
186+
rows = result_set[0].rows
187+
logger.info(f"KNN search returned {len(rows)} rows")
188+
189+
if len(rows) > 1:
190+
prev = rows[0]['target']
191+
for row in rows[1:]:
192+
cur = row['target']
193+
condition = prev <= cur if distance is not None else prev >= cur
194+
if not condition:
195+
raise Exception(
196+
f"KNN results not properly ordered: prev={prev}, cur={cur}"
197+
)
198+
prev = cur
199+
logger.info("KNN search completed successfully")
200+
146201
def _select_top(self, index_name, table_path, vector_type, vector_dimension, distance, similarity, prefixed=False):
147202
logger.info("Select values from table")
148203
result_set = self._select(
@@ -196,6 +251,15 @@ def _wait_index_ready(self, index_name, table_path, vector_type, vector_dimensio
196251
raise Exception("Error getting index status")
197252

198253
def _check_loop(self, table_path, vector_type, vector_dimension, levels, clusters, distance=None, similarity=None, prefixed=False):
254+
self._knn_search_check(
255+
table_path=table_path,
256+
vector_type=vector_type,
257+
vector_dimension=vector_dimension,
258+
distance=distance,
259+
similarity=similarity,
260+
prefixed=prefixed,
261+
)
262+
199263
index_name = f"{self.index_name_prefix}_{vector_type}_{vector_dimension}_{levels}_{clusters}_{distance}_{similarity}_{str(prefixed)}"
200264
self._create_index(
201265
index_name=index_name,

0 commit comments

Comments
 (0)