Skip to content

Commit 307bc92

Browse files
razdoburdinDmitry Razdoburdin
andauthored
Native serialization for InvertedIndex (#299)
This PR introduce native serialization for Inverted index. Main changes are: New overload of svs::index::inverted::assemble_from_clustering accepting istream is introduced. Added related tests. Co-authored-by: Dmitry Razdoburdin <drazdobu@intel.com>
1 parent a717a63 commit 307bc92

4 files changed

Lines changed: 195 additions & 35 deletions

File tree

include/svs/index/inverted/clustering.h

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,36 @@ template <std::integral I> class Clustering {
571571
// Saving and Loading.
572572
static constexpr lib::Version save_version{0, 0, 0};
573573
static constexpr std::string_view serialization_schema = "clustering";
574+
575+
lib::SaveTable metadata() const {
576+
return lib::SaveTable(
577+
serialization_schema,
578+
save_version,
579+
{{"integer_type", lib::save(datatype_v<I>)},
580+
{"num_clusters", lib::save(size())}}
581+
);
582+
}
583+
584+
void save(std::ostream& os) const {
585+
for (const auto& [id, cluster] : *this) {
586+
cluster.serialize(os);
587+
}
588+
}
589+
590+
static Clustering<I>
591+
load(const lib::ContextFreeLoadTable& table, std::istream& stream) {
592+
auto saved_integer_type = lib::load_at<DataType>(table, "integer_type");
593+
if (saved_integer_type != datatype_v<I>) {
594+
throw ANNEXCEPTION("Clustering was saved using {} but we're trying to reload it using {}!", saved_integer_type, datatype_v<I>);
595+
}
596+
auto num_clusters = lib::load_at<size_t>(table, "num_clusters");
597+
auto clustering = Clustering<I>();
598+
for (size_t i = 0; i < num_clusters; ++i) {
599+
clustering.insert(Cluster<I>::deserialize(stream));
600+
}
601+
return clustering;
602+
}
603+
574604
lib::SaveTable save(const lib::SaveContext& ctx) const {
575605
// Serialize all clusters into an auxiliary file.
576606
auto fullpath = ctx.generate_name("clustering", "bin");
@@ -582,48 +612,28 @@ template <std::integral I> class Clustering {
582612
}
583613
}
584614

585-
return lib::SaveTable(
586-
serialization_schema,
587-
save_version,
588-
{{"filepath", lib::save(fullpath.filename())},
589-
SVS_LIST_SAVE(filesize),
590-
{"integer_type", lib::save(datatype_v<I>)},
591-
{"num_clusters", lib::save(size())}}
592-
);
615+
auto table = metadata();
616+
table.insert("filepath", lib::save(fullpath.filename()));
617+
table.insert("filesize", lib::save(filesize));
618+
return table;
619+
620+
return table;
593621
}
594622

595623
static Clustering<I> load(const lib::LoadTable& table) {
596-
// Ensure we have the correct integer type when decoding.
597-
auto saved_integer_type = lib::load_at<DataType>(table, "integer_type");
598-
if (saved_integer_type != datatype_v<I>) {
599-
auto type = datatype_v<I>;
624+
auto expected_filesize = lib::load_at<size_t>(table, "filesize");
625+
auto file = table.resolve_at("filepath");
626+
size_t actual_filesize = std::filesystem::file_size(file);
627+
if (actual_filesize != expected_filesize) {
600628
throw ANNEXCEPTION(
601-
"Clustering was saved using {} but we're trying to reload it using {}!",
602-
saved_integer_type,
603-
type
629+
"Expected cluster file size to be {}. Instead, it is {}!",
630+
actual_filesize,
631+
expected_filesize
604632
);
605633
}
606634

607-
auto num_clusters = lib::load_at<size_t>(table, "num_clusters");
608-
auto expected_filesize = lib::load_at<size_t>(table, "filesize");
609-
auto clustering = Clustering<I>();
610-
{
611-
auto file = table.resolve_at("filepath");
612-
size_t actual_filesize = std::filesystem::file_size(file);
613-
if (actual_filesize != expected_filesize) {
614-
throw ANNEXCEPTION(
615-
"Expected cluster file size to be {}. Instead, it is {}!",
616-
actual_filesize,
617-
expected_filesize
618-
);
619-
}
620-
621-
auto io = lib::open_read(file);
622-
for (size_t i = 0; i < num_clusters; ++i) {
623-
clustering.insert(Cluster<I>::deserialize(io));
624-
}
625-
}
626-
return clustering;
635+
auto io = lib::open_read(file);
636+
return load(table, io);
627637
}
628638

629639
private:

include/svs/index/inverted/memory_based.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,8 @@ template <typename Index, typename Cluster> class InvertedIndex {
497497
index_.save(index_config, graph, data);
498498
}
499499

500+
void save_primary_index(std::ostream& os) const { index_.save(os); }
501+
500502
///// Accessors
501503
/// @brief Getter method for logger
502504
svs::logging::logger_ptr get_logger() const { return logger_; }
@@ -655,4 +657,46 @@ auto assemble_from_clustering(
655657
);
656658
}
657659

660+
template <
661+
typename DataProto,
662+
typename Distance,
663+
StorageStrategy Strategy,
664+
typename ThreadPoolProto>
665+
auto assemble_from_clustering(
666+
std::istream& is,
667+
DataProto data_proto,
668+
Distance distance,
669+
Strategy strategy,
670+
ThreadPoolProto threadpool_proto,
671+
svs::logging::logger_ptr logger = svs::logging::get()
672+
) {
673+
auto threadpool = threads::as_threadpool(std::move(threadpool_proto));
674+
auto original = svs::detail::dispatch_load(std::move(data_proto), threadpool);
675+
auto clustering = lib::load_from_stream<Clustering<uint32_t>>(is);
676+
auto ids = clustering.sorted_centroids();
677+
678+
// skip magic
679+
svs::lib::detail::Deserializer::build(is);
680+
auto index = index::vamana::auto_assemble(
681+
is,
682+
lib::Lazy([&]() { return GraphLoader<uint32_t>::return_type::load(is); }),
683+
lib::Lazy([&]() {
684+
using T = typename std::decay_t<decltype(original)>::element_type;
685+
constexpr size_t Ext = std::decay_t<decltype(original)>::extent;
686+
return lib::load_from_stream<data::SimpleData<T, Ext>>(is);
687+
}),
688+
distance,
689+
1,
690+
logger
691+
);
692+
693+
return InvertedIndex(
694+
std::move(index),
695+
strategy(original, clustering, HugepageAllocator<std::byte>()),
696+
std::move(ids),
697+
std::move(threadpool),
698+
std::move(logger)
699+
);
700+
}
701+
658702
} // namespace svs::index::inverted

include/svs/orchestrators/inverted.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class InvertedInterface {
3434
const std::filesystem::path& primary_data,
3535
const std::filesystem::path& primary_graph
3636
) = 0;
37+
38+
///// Saving
39+
virtual void save_primary_index(std::ostream& os) = 0;
3740
};
3841

3942
template <lib::TypeList QueryTypes, typename Impl, typename IFace = InvertedInterface>
@@ -72,6 +75,8 @@ class InvertedImpl : public manager::ManagerImpl<QueryTypes, Impl, IFace> {
7275
) override {
7376
impl().save_primary_index(primary_config, primary_data, primary_graph);
7477
}
78+
79+
void save_primary_index(std::ostream& os) override { impl().save_primary_index(os); }
7580
};
7681

7782
/////
@@ -106,6 +111,8 @@ class Inverted : public manager::IndexManager<InvertedInterface> {
106111
impl_->save_primary_index(primary_config, primary_data, primary_graph);
107112
}
108113

114+
void save_primary_index(std::ostream& os) { impl_->save_primary_index(os); }
115+
109116
///// Building
110117
template <
111118
manager::QueryTypeDefinition QueryTypes,
@@ -168,6 +175,30 @@ class Inverted : public manager::IndexManager<InvertedInterface> {
168175
std::move(threadpool_proto)
169176
)};
170177
}
178+
template <
179+
manager::QueryTypeDefinition QueryTypes,
180+
typename DataProto,
181+
typename Distance,
182+
typename ThreadPoolProto,
183+
typename StorageStrategy = index::inverted::SparseStrategy>
184+
static Inverted assemble_from_clustering(
185+
std::istream& is,
186+
DataProto data_proto,
187+
Distance distance,
188+
ThreadPoolProto threadpool_proto,
189+
StorageStrategy strategy = {}
190+
) {
191+
return Inverted{
192+
std::in_place,
193+
manager::as_typelist<QueryTypes>{},
194+
index::inverted::assemble_from_clustering(
195+
is,
196+
std::move(data_proto),
197+
std::move(distance),
198+
std::move(strategy),
199+
std::move(threadpool_proto)
200+
)};
201+
}
171202
};
172203

173204
} // namespace svs

tests/svs/index/inverted/memory_based.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
#include "spdlog/sinks/callback_sink.h"
2020
#include "svs-benchmark/datasets.h"
2121
#include "svs/lib/timing.h"
22+
#include "svs/orchestrators/inverted.h"
2223
#include "tests/utils/inverted_reference.h"
2324
#include "tests/utils/test_dataset.h"
2425
#include <filesystem>
26+
#include <sstream>
2527

2628
CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") {
2729
// Vector to store captured log messages
@@ -73,3 +75,76 @@ CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") {
7375
CATCH_REQUIRE(captured_logs[0].find("Vamana Build Parameters:") != std::string::npos);
7476
CATCH_REQUIRE(captured_logs[1].find("Number of syncs") != std::string::npos);
7577
}
78+
79+
namespace {
80+
constexpr size_t NUM_NEIGHBORS = 10;
81+
82+
template <typename Strategy> void test_stream_save_load(Strategy strategy) {
83+
auto distance = svs::DistanceL2();
84+
constexpr auto distance_type = svs::distance_type_v<svs::DistanceL2>;
85+
auto expected_results = test_dataset::inverted::expected_build_results(
86+
distance_type, svsbenchmark::Uncompressed(svs::DataType::float32)
87+
);
88+
auto build_parameters = expected_results.build_parameters_.value();
89+
90+
// Capture the clustering during build.
91+
svs::index::inverted::Clustering<uint32_t> clustering;
92+
auto clustering_op = [&](const auto& c) { clustering = c; };
93+
94+
svs::Inverted index = svs::Inverted::build<float>(
95+
build_parameters,
96+
svs::data::SimpleData<float>::load(test_dataset::data_svs_file()),
97+
distance,
98+
2,
99+
strategy,
100+
svs::index::inverted::PickRandomly{},
101+
clustering_op
102+
);
103+
104+
auto queries = svs::data::SimpleData<float>::load(test_dataset::query_file());
105+
auto parameters = index.get_search_parameters();
106+
auto results = index.search(queries, NUM_NEIGHBORS);
107+
108+
// Serialize to stream.
109+
std::stringstream ss;
110+
svs::lib::save_to_stream(clustering, ss);
111+
index.save_primary_index(ss);
112+
113+
// Load from stream.
114+
svs::Inverted loaded = svs::Inverted::assemble_from_clustering<svs::lib::Types<float>>(
115+
ss,
116+
svs::data::SimpleData<float>::load(test_dataset::data_svs_file()),
117+
distance,
118+
2,
119+
strategy
120+
);
121+
loaded.set_search_parameters(parameters);
122+
123+
// Compare basic properties.
124+
CATCH_REQUIRE(loaded.size() == index.size());
125+
CATCH_REQUIRE(loaded.dimensions() == index.dimensions());
126+
127+
// Compare search results element-wise.
128+
auto loaded_results = loaded.search(queries, NUM_NEIGHBORS);
129+
CATCH_REQUIRE(loaded_results.n_queries() == results.n_queries());
130+
CATCH_REQUIRE(loaded_results.n_neighbors() == results.n_neighbors());
131+
for (size_t q = 0; q < results.n_queries(); ++q) {
132+
for (size_t i = 0; i < NUM_NEIGHBORS; ++i) {
133+
CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i));
134+
CATCH_REQUIRE(
135+
loaded_results.distance(q, i) ==
136+
Catch::Approx(results.distance(q, i)).epsilon(1e-5)
137+
);
138+
}
139+
}
140+
}
141+
} // namespace
142+
143+
CATCH_TEST_CASE("InvertedIndex Save and Load", "[saveload][inverted][index]") {
144+
CATCH_SECTION("SparseStrategy") {
145+
test_stream_save_load(svs::index::inverted::SparseStrategy());
146+
}
147+
CATCH_SECTION("DenseStrategy") {
148+
test_stream_save_load(svs::index::inverted::DenseStrategy());
149+
}
150+
}

0 commit comments

Comments
 (0)