|
| 1 | +from ydb.tests.datashard.lib.vector_base import VectorBase |
| 2 | + |
| 3 | + |
| 4 | +class TestKnn(VectorBase): |
| 5 | + table_name = "TestTable" |
| 6 | + target_table_name = "TargetVectors" |
| 7 | + # Target embedding: 0x67, 0x71 (103, 113) |
| 8 | + target_embedding_hex = "677102" |
| 9 | + |
| 10 | + @classmethod |
| 11 | + def setup_class(cls): |
| 12 | + super().setup_class() |
| 13 | + cls._create_tables() |
| 14 | + |
| 15 | + @classmethod |
| 16 | + def teardown_class(cls): |
| 17 | + cls._cleanup_tables() |
| 18 | + super().teardown_class() |
| 19 | + |
| 20 | + @classmethod |
| 21 | + def _cleanup_tables(cls): |
| 22 | + try: |
| 23 | + cls.pool.execute_with_retries(f"DROP TABLE `{cls.table_name}`") |
| 24 | + except Exception: |
| 25 | + pass |
| 26 | + try: |
| 27 | + cls.pool.execute_with_retries(f"DROP TABLE `{cls.target_table_name}`") |
| 28 | + except Exception: |
| 29 | + pass |
| 30 | + |
| 31 | + @classmethod |
| 32 | + def _create_tables(cls): |
| 33 | + cls._cleanup_tables() |
| 34 | + |
| 35 | + cls.pool.execute_with_retries(f""" |
| 36 | + CREATE TABLE `{cls.table_name}` ( |
| 37 | + pk Int64 NOT NULL, |
| 38 | + emb String NOT NULL, |
| 39 | + data String NOT NULL, |
| 40 | + PRIMARY KEY (pk) |
| 41 | + ); |
| 42 | + """) |
| 43 | + |
| 44 | + cls.pool.execute_with_retries(f""" |
| 45 | + UPSERT INTO `{cls.table_name}` (pk, emb, data) VALUES |
| 46 | + (0, Unwrap(String::HexDecode("033002")), "0"), |
| 47 | + (1, Unwrap(String::HexDecode("133102")), "1"), |
| 48 | + (2, Unwrap(String::HexDecode("233202")), "2"), |
| 49 | + (3, Unwrap(String::HexDecode("533302")), "3"), |
| 50 | + (4, Unwrap(String::HexDecode("433402")), "4"), |
| 51 | + (5, Unwrap(String::HexDecode("506002")), "5"), |
| 52 | + (6, Unwrap(String::HexDecode("611102")), "6"), |
| 53 | + (7, Unwrap(String::HexDecode("126202")), "7"), |
| 54 | + (8, Unwrap(String::HexDecode("757602")), "8"), |
| 55 | + (9, Unwrap(String::HexDecode("767602")), "9"); |
| 56 | + """) |
| 57 | + |
| 58 | + cls.pool.execute_with_retries(f""" |
| 59 | + CREATE TABLE `{cls.target_table_name}` ( |
| 60 | + id Int64 NOT NULL, |
| 61 | + target_emb String, |
| 62 | + PRIMARY KEY (id) |
| 63 | + ) |
| 64 | + """) |
| 65 | + |
| 66 | + cls.pool.execute_with_retries(f""" |
| 67 | + UPSERT INTO `{cls.target_table_name}` (id, target_emb) VALUES |
| 68 | + (1, String::HexDecode("{cls.target_embedding_hex}")) |
| 69 | + """) |
| 70 | + |
| 71 | + def test_knn_cosine_distance(self): |
| 72 | + """Test KNN with CosineDistance.""" |
| 73 | + result = self.query(f""" |
| 74 | + $TargetEmbedding = String::HexDecode("{self.target_embedding_hex}"); |
| 75 | + SELECT pk FROM `{self.table_name}` |
| 76 | + ORDER BY Knn::CosineDistance(emb, $TargetEmbedding) |
| 77 | + LIMIT 3 |
| 78 | + """) |
| 79 | + assert len(result) == 3 |
| 80 | + |
| 81 | + def test_knn_cosine_similarity(self): |
| 82 | + """Test KNN with CosineSimilarity (DESC order).""" |
| 83 | + result = self.query(f""" |
| 84 | + $TargetEmbedding = String::HexDecode("{self.target_embedding_hex}"); |
| 85 | + SELECT pk FROM `{self.table_name}` |
| 86 | + ORDER BY Knn::CosineSimilarity(emb, $TargetEmbedding) DESC |
| 87 | + LIMIT 3 |
| 88 | + """) |
| 89 | + assert len(result) == 3 |
| 90 | + |
| 91 | + def test_knn_inner_product_similarity(self): |
| 92 | + """Test KNN with InnerProductSimilarity (DESC order).""" |
| 93 | + result = self.query(f""" |
| 94 | + $TargetEmbedding = String::HexDecode("{self.target_embedding_hex}"); |
| 95 | + SELECT pk FROM `{self.table_name}` |
| 96 | + ORDER BY Knn::InnerProductSimilarity(emb, $TargetEmbedding) DESC |
| 97 | + LIMIT 3 |
| 98 | + """) |
| 99 | + assert len(result) == 3 |
| 100 | + |
| 101 | + def test_knn_manhattan_distance(self): |
| 102 | + """Test KNN with ManhattanDistance.""" |
| 103 | + result = self.query(f""" |
| 104 | + $TargetEmbedding = String::HexDecode("{self.target_embedding_hex}"); |
| 105 | + SELECT pk FROM `{self.table_name}` |
| 106 | + ORDER BY Knn::ManhattanDistance(emb, $TargetEmbedding) |
| 107 | + LIMIT 3 |
| 108 | + """) |
| 109 | + assert len(result) == 3 |
| 110 | + |
| 111 | + def test_knn_euclidean_distance(self): |
| 112 | + """Test KNN with EuclideanDistance.""" |
| 113 | + result = self.query(f""" |
| 114 | + $TargetEmbedding = String::HexDecode("{self.target_embedding_hex}"); |
| 115 | + SELECT pk FROM `{self.table_name}` |
| 116 | + ORDER BY Knn::EuclideanDistance(emb, $TargetEmbedding) |
| 117 | + LIMIT 3 |
| 118 | + """) |
| 119 | + assert len(result) == 3 |
| 120 | + |
| 121 | + def test_knn_verify_results(self): |
| 122 | + """ |
| 123 | + Verify actual results - check that top 3 PKs are correct. |
| 124 | + Target vector is 0x67, 0x71 (103, 113) |
| 125 | + Expected cosine distances: |
| 126 | + pk=8 (117, 118): 0.000882 - closest |
| 127 | + pk=5 (80, 96): 0.000985 |
| 128 | + pk=9 (118, 118): 0.001070 |
| 129 | + """ |
| 130 | + result = self.query(f""" |
| 131 | + $TargetEmbedding = String::HexDecode("{self.target_embedding_hex}"); |
| 132 | + SELECT pk, Knn::CosineDistance(emb, $TargetEmbedding) AS distance |
| 133 | + FROM `{self.table_name}` |
| 134 | + ORDER BY distance |
| 135 | + LIMIT 3 |
| 136 | + """) |
| 137 | + |
| 138 | + assert len(result) == 3 |
| 139 | + pks = [row['pk'] for row in result] |
| 140 | + assert pks == [8, 5, 9], f"Expected PKs [8, 5, 9], got {pks}" |
| 141 | + |
| 142 | + def test_knn_two_stage_query(self): |
| 143 | + """ |
| 144 | + Test two-stage query: quantized search followed by full precision reranking. |
| 145 | + """ |
| 146 | + result = self.query(f""" |
| 147 | + $TargetEmbedding = String::HexDecode("{self.target_embedding_hex}"); |
| 148 | +
|
| 149 | + $Pks = SELECT pk |
| 150 | + FROM `{self.table_name}` |
| 151 | + ORDER BY Knn::CosineDistance(emb, $TargetEmbedding) |
| 152 | + LIMIT 3; |
| 153 | +
|
| 154 | + SELECT pk, Knn::CosineDistance(emb, $TargetEmbedding) AS distance |
| 155 | + FROM `{self.table_name}` |
| 156 | + WHERE pk IN $Pks |
| 157 | + ORDER BY distance |
| 158 | + LIMIT 3; |
| 159 | + """) |
| 160 | + assert len(result) == 3 |
| 161 | + |
| 162 | + def test_knn_subquery_target(self): |
| 163 | + """Test KNN with target vector from subquery.""" |
| 164 | + result = self.query(f""" |
| 165 | + $TargetEmbedding = (SELECT target_emb FROM `{self.target_table_name}` WHERE id = 1); |
| 166 | + SELECT pk, Knn::CosineDistance(emb, $TargetEmbedding) AS distance |
| 167 | + FROM `{self.table_name}` |
| 168 | + ORDER BY distance |
| 169 | + LIMIT 3 |
| 170 | + """) |
| 171 | + assert len(result) == 3 |
0 commit comments