Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/GetHNSW.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set ( HNSW_GITHUB "https://github.com/manticoresoftware/hnswlib/archive/6568d3b.zip" )
set ( HNSW_GITHUB "https://github.com/manticoresoftware/hnswlib/archive/3c50a66.zip" )
set ( HNSW_BUNDLEZIP "${LIBS_BUNDLE}/hnswlib-0.7.0.tar.gz" )

cmake_minimum_required ( VERSION 3.17 FATAL_ERROR )
Expand Down
73 changes: 45 additions & 28 deletions knn/knn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ void HNSWIndex_c::Search ( std::vector<DocDist_t> & dResults, const Span_T<float
const void * pData = dData.begin();
if ( m_pQuantizer )
{
m_pQuantizer->Encode ( 0, dData, dQuantized );
std::vector<uint8_t> dUnusedQuantizedForQuery;
m_pQuantizer->Encode ( 0, dData, dQuantized, dUnusedQuantizedForQuery );
pData = dQuantized.data();
}

Expand Down Expand Up @@ -536,7 +537,8 @@ class HNSWIndexBuilder_i
virtual ~HNSWIndexBuilder_i() = default;

virtual void Train ( const util::Span_T<float> & dData ) = 0;
virtual bool AddDoc ( uint32_t uRowID, const util::Span_T<float> & dData, std::string & sError ) = 0;
virtual bool FinalizeTraining ( std::string & sError ) = 0;
virtual bool AddDoc ( uint32_t uRowID, const util::Span_T<float> & dData, BuildContext_t & tBuildCtx, std::string & sError ) = 0;
virtual void Save ( FileWriter_c & tWriter ) = 0;
virtual const AttrWithSettings_t & GetAttr() const = 0;
virtual const QuantizationSettings_t & GetQuantizationSettings() const = 0;
Expand All @@ -549,16 +551,14 @@ class HNSWIndexBuilder_c : public HNSWIndexBuilder_i, public HNSWDist_c
HNSWIndexBuilder_c ( const AttrWithSettings_t & tAttr, int64_t iNumElements, ScalarQuantizer_i * pQuantizer );

void Train ( const util::Span_T<float> & dData ) override;
bool AddDoc ( uint32_t uRowID, const util::Span_T<float> & dData, std::string & sError ) override;
bool FinalizeTraining ( std::string & sError ) override;
bool AddDoc ( uint32_t uRowID, const util::Span_T<float> & dData, BuildContext_t & tBuildCtx, std::string & sError ) override;
void Save ( FileWriter_c & tWriter ) override;
const AttrWithSettings_t & GetAttr() const override { return m_tAttr; }
const QuantizationSettings_t & GetQuantizationSettings() const override { return m_pQuantizer->GetSettings(); }

private:
AttrWithSettings_t m_tAttr;
bool m_bFirstDoc = true;
SpanResizeable_T<float> m_dNormalized;
std::vector<uint8_t> m_dQuantized;
std::unique_ptr<ScalarQuantizer_i> m_pQuantizer;
std::unique_ptr<hnswlib::HierarchicalNSW<float>> m_pAlg;
};
Expand All @@ -570,7 +570,6 @@ HNSWIndexBuilder_c::HNSWIndexBuilder_c ( const AttrWithSettings_t & tAttr, int64
, m_pQuantizer ( pQuantizer )
{
m_pAlg = std::make_unique<hnswlib::HierarchicalNSW<float>>( m_pSpace.get(), iNumElements, m_tAttr.m_iHNSWM, m_tAttr.m_iHNSWEFConstruction );
m_dNormalized.resize ( tAttr.m_iDims );
}


Expand All @@ -581,36 +580,45 @@ void HNSWIndexBuilder_c::Train ( const util::Span_T<float> & dData )
}


bool HNSWIndexBuilder_c::AddDoc ( uint32_t uRowID, const util::Span_T<float> & dData, std::string & sError )
bool HNSWIndexBuilder_c::FinalizeTraining ( std::string & sError )
{
if ( dData.size()!=m_tAttr.m_iDims )
if ( !m_pQuantizer )
return true;

if ( m_pQuantizer->IsFinalized() )
return true;

if ( !m_pQuantizer->FinalizeTraining ( sError ) )
return false;

m_pSpace->SetQuantizationSettings ( *m_pQuantizer );
return true;
}


bool HNSWIndexBuilder_c::AddDoc ( uint32_t uRowID, const util::Span_T<float> & dData, BuildContext_t & tBuildCtx, std::string & sError )
{
if ( dData.size()!=(size_t)m_tAttr.m_iDims )
{
sError = FormatStr ( "HNSW error: data has %llu values, index '%s' needs %d values", dData.size(), m_tAttr.m_sName.c_str(), m_tAttr.m_iDims );
return false;
}

assert ( !m_pQuantizer || m_pQuantizer->IsFinalized() );

Span_T<float> dToAdd = dData;
if ( m_tAttr.m_eHNSWSimilarity==HNSWSimilarity_e::COSINE )
{
memcpy ( m_dNormalized.data(), dData.data(), dData.size()*sizeof(dData[0] ) );
VecNormalize(m_dNormalized);
dToAdd = m_dNormalized;
tBuildCtx.m_dNormalized.resize ( dData.size() );
memcpy ( tBuildCtx.m_dNormalized.data(), dData.data(), dData.size()*sizeof(dData[0] ) );
VecNormalize ( tBuildCtx.m_dNormalized );
dToAdd = tBuildCtx.m_dNormalized;
}

if ( m_pQuantizer )
{
if ( m_bFirstDoc )
{
m_bFirstDoc = false;

if ( !m_pQuantizer->FinalizeTraining(sError) )
return false;

m_pSpace->SetQuantizationSettings ( *m_pQuantizer );
}

m_pQuantizer->Encode ( uRowID, dToAdd, m_dQuantized );
m_pAlg->addPoint ( (void*)m_dQuantized.data(), (size_t)uRowID );
m_pQuantizer->Encode ( uRowID, dToAdd, tBuildCtx.m_dQuantized, tBuildCtx.m_dQuantizedForQuery );
m_pAlg->addPoint ( (void*)tBuildCtx.m_dQuantized.data(), (size_t)uRowID );
}
else
m_pAlg->addPoint ( (void*)dToAdd.data(), (size_t)uRowID );
Expand All @@ -634,14 +642,13 @@ class HNSWBuilder_c : public Builder_i
public:
HNSWBuilder_c ( const Schema_t & tSchema, int64_t iNumElements, const std::string & sTmpFilename );

void Train ( int iAttr, uint32_t uRowID, const util::Span_T<float> & dData ) override { m_dIndexes[iAttr]->Train(dData); }
bool SetAttr ( int iAttr, uint32_t uRowID, const util::Span_T<float> & dData ) override { return m_dIndexes[iAttr]->AddDoc ( uRowID, dData, m_sError ); }
void Train ( int iAttr, uint32_t uRowID, const util::Span_T<float> & dData ) override { m_dIndexes[iAttr]->Train(dData); }
bool SetAttr ( int iAttr, uint32_t uRowID, const util::Span_T<float> & dData, BuildContext_t & tBuildCtx ) override { return m_dIndexes[iAttr]->AddDoc ( uRowID, dData, tBuildCtx, tBuildCtx.m_sError ); }
bool FinalizeTraining ( std::string & sError ) override;
bool Save ( const std::string & sFilename, size_t tBufferSize, std::string & sError ) override;
const std::string & GetError() const override { return m_sError; }

private:
std::vector<std::unique_ptr<HNSWIndexBuilder_i>> m_dIndexes;
std::string m_sError;
};


Expand All @@ -653,6 +660,16 @@ HNSWBuilder_c::HNSWBuilder_c ( const Schema_t & tSchema, int64_t iNumElements, c
}


bool HNSWBuilder_c::FinalizeTraining ( std::string & sError )
{
for ( auto & i : m_dIndexes )
if ( !i->FinalizeTraining(sError) )
return false;

return true;
}


bool HNSWBuilder_c::Save ( const std::string & sFilename, size_t tBufferSize, std::string & sError )
{
FileWriter_c tWriter;
Expand Down
15 changes: 12 additions & 3 deletions knn/knn.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
namespace knn
{

static const int LIB_VERSION = 13;
static const int LIB_VERSION = 14;
static const uint32_t STORAGE_VERSION = 3;

enum class HNSWSimilarity_e
Expand Down Expand Up @@ -122,15 +122,24 @@ class KNN_i
virtual bool ShouldUseFullscan ( const std::string & sName, int64_t iResults, int iEf, int64_t iFilterCount ) = 0;
};

// passed via SetAttr so the builder itself holds no per-row mutable state
struct BuildContext_t
{
util::SpanResizeable_T<float> m_dNormalized;
std::vector<uint8_t> m_dQuantized;
std::vector<uint8_t> m_dQuantizedForQuery; // 4-bit transposed representation, produced only by the BIT1 binary quantizer during BUILD mode
std::string m_sError;
};

class Builder_i
{
public:
virtual ~Builder_i() = default;

virtual void Train ( int iAttr, uint32_t uRowID, const util::Span_T<float> & dData ) = 0;
virtual bool SetAttr ( int iAttr, uint32_t uRowID, const util::Span_T<float> & dData ) = 0;
virtual bool SetAttr ( int iAttr, uint32_t uRowID, const util::Span_T<float> & dData, BuildContext_t & tBuildCtx ) = 0;
virtual bool FinalizeTraining ( std::string & sError ) = 0;
virtual bool Save ( const std::string & sFilename, size_t tBufferSize, std::string & sError ) = 0;
virtual const std::string & GetError() const = 0;
};

class TextToEmbeddings_i
Expand Down
56 changes: 36 additions & 20 deletions knn/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class ScalarQuantizer8Bit_c : public ScalarQuantizer_i

void Train ( const Span_T<float> & dPoint ) override;
bool FinalizeTraining ( std::string & sError ) override;
void Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized ) override;
bool IsFinalized() const override { return m_bFinalized; }
void Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized, std::vector<uint8_t> & dQuantizedForQuery ) override;
void FinalizeEncoding() override {}
const QuantizationSettings_t & GetSettings() override;
std::function<const uint8_t * (uint32_t)> GetPoolFetcher() const override { return nullptr; }
Expand Down Expand Up @@ -138,7 +139,7 @@ bool ScalarQuantizer8Bit_c::FinalizeTraining ( std::string & sError )
}


void ScalarQuantizer8Bit_c::Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized )
void ScalarQuantizer8Bit_c::Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized, std::vector<uint8_t> & /*dQuantizedForQuery*/ )
{
assert(m_bFinalized);

Expand Down Expand Up @@ -183,11 +184,11 @@ class ScalarQuantizer1Bit_c : public ScalarQuantizer8Bit_c
using ScalarQuantizer8Bit_c::ScalarQuantizer8Bit_c;

public:
void Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized ) override;
void Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized, std::vector<uint8_t> & dQuantizedForQuery ) override;
};


void ScalarQuantizer1Bit_c::Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized )
void ScalarQuantizer1Bit_c::Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized, std::vector<uint8_t> & /*dQuantizedForQuery*/ )
{
assert(m_bFinalized);

Expand Down Expand Up @@ -647,7 +648,8 @@ class ScalarQuantizerBinary_T : public ScalarQuantizer_i

void Train ( const Span_T<float> & dPoint ) override;
bool FinalizeTraining ( std::string & sError ) override;
void Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized ) override;
bool IsFinalized() const override { return m_bFinalized; }
void Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized, std::vector<uint8_t> & dQuantizedForQuery ) override;
void FinalizeEncoding() override;
const QuantizationSettings_t & GetSettings() override;

Expand All @@ -659,7 +661,6 @@ class ScalarQuantizerBinary_T : public ScalarQuantizer_i
HNSWSimilarity_e m_eSimilarity = HNSWSimilarity_e::COSINE;
std::string m_sTmpFilename;
std::vector<double> m_dCentroid64;
std::vector<uint8_t> m_dQuantizedForQuery;
MappedBuffer_T<uint8_t> m_tBuffer4Bit;
size_t m_uDim = 0;
bool m_bFinalized = false;
Expand Down Expand Up @@ -708,16 +709,17 @@ void ScalarQuantizerBinary_T<BUILD>::Train ( const Span_T<float> & dPoint )
}

template <bool BUILD>
void ScalarQuantizerBinary_T<BUILD>::Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized )
void ScalarQuantizerBinary_T<BUILD>::Encode ( uint32_t uRowID, const Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized, std::vector<uint8_t> & dQuantizedForQuery )
{
assert(m_bFinalized);

m_pQuantizer->Quantize4Bit ( dPoint, m_tSettings.m_dCentroid, BUILD ? m_dQuantizedForQuery : dQuantized );
m_pQuantizer->Quantize4Bit ( dPoint, m_tSettings.m_dCentroid, BUILD ? dQuantizedForQuery : dQuantized );
if constexpr ( !BUILD )
return;

int64_t iOffset = (int64_t)uRowID * m_dQuantizedForQuery.size();
memcpy ( m_tBuffer4Bit.data() + iOffset, m_dQuantizedForQuery.data(), m_dQuantizedForQuery.size() );
assert ( dQuantizedForQuery.size() == m_uQuantized4BitEntrySize );
int64_t iOffset = (int64_t)uRowID * dQuantizedForQuery.size();
memcpy ( m_tBuffer4Bit.data() + iOffset, dQuantizedForQuery.data(), dQuantizedForQuery.size() );

m_pQuantizer->Quantize1Bit ( dPoint, m_tSettings.m_dCentroid, dQuantized );
}
Expand Down Expand Up @@ -766,20 +768,30 @@ bool ScalarQuantizerBinary_T<BUILD>::FinalizeTraining ( std::string & sError )
if ( m_bFinalized )
return true;

m_bFinalized = true;

if ( !m_uTrainedVecs )
{
m_bFinalized = true;
return true;
}

for ( auto & i : m_dCentroid64 )
m_tSettings.m_dCentroid.push_back ( i/m_uTrainedVecs );
if ( m_tSettings.m_dCentroid.empty() )
{
m_tSettings.m_dCentroid.reserve ( m_dCentroid64.size() );
for ( auto & i : m_dCentroid64 )
m_tSettings.m_dCentroid.push_back ( i/m_uTrainedVecs );
}

m_pQuantizer = std::make_unique<BinaryQuantizer_c> ( m_uDim, m_eSimilarity );
if ( !m_pQuantizer )
m_pQuantizer = std::make_unique<BinaryQuantizer_c> ( m_uDim, m_eSimilarity );

// quantize a fake vector to get quantized size
std::vector<float> dTmp ( m_uDim, 0.0f );
m_pQuantizer->Quantize4Bit ( dTmp, m_tSettings.m_dCentroid, m_dQuantizedForQuery );
m_uQuantized4BitEntrySize = m_dQuantizedForQuery.size();
if ( !m_uQuantized4BitEntrySize )
{
// quantize a fake vector to get quantized size
std::vector<float> dTmp ( m_uDim, 0.0f );
std::vector<uint8_t> dSizeProbe;
m_pQuantizer->Quantize4Bit ( dTmp, m_tSettings.m_dCentroid, dSizeProbe );
m_uQuantized4BitEntrySize = dSizeProbe.size();
}

FILE * pFile = fopen ( m_sTmpFilename.c_str(), "wb" );
if ( !pFile )
Expand All @@ -793,7 +805,11 @@ bool ScalarQuantizerBinary_T<BUILD>::FinalizeTraining ( std::string & sError )
fwrite ( "", 1, 1, pFile );
fclose ( pFile );

return m_tBuffer4Bit.Open ( m_sTmpFilename.c_str(), true, sError );
if ( !m_tBuffer4Bit.Open ( m_sTmpFilename.c_str(), true, sError ) )
return false;

m_bFinalized = true;
return true;
}

///////////////////////////////////////////////////////////////////////////////
Expand Down
3 changes: 2 additions & 1 deletion knn/quantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class ScalarQuantizer_i

virtual void Train ( const util::Span_T<float> & dPoint ) = 0;
virtual bool FinalizeTraining ( std::string & sError ) = 0;
virtual void Encode ( uint32_t uRowID, const util::Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized ) = 0;
virtual bool IsFinalized () const = 0;
virtual void Encode ( uint32_t uRowID, const util::Span_T<float> & dPoint, std::vector<uint8_t> & dQuantized, std::vector<uint8_t> & dQuantizedForQuery ) = 0;
virtual void FinalizeEncoding() = 0;
virtual const QuantizationSettings_t & GetSettings() = 0;

Expand Down
Loading