Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion dotnet/src/VectorData/AzureAISearch/AzureAISearchCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ public override IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord, bool
VectorSearch = new(),
Size = top,
Skip = options.Skip,
Filter = new AzureAISearchFilterTranslator().Translate(filter, this._model),
Filter = new AzureAISearchFilterTranslator().Translate(filter, this._model)
};

// Filter out vector fields if requested.
Expand Down Expand Up @@ -405,6 +405,15 @@ floatVector is null

await foreach (var record in this.SearchAndMapToDataModelAsync(null, searchOptions, options.IncludeVectors, cancellationToken).ConfigureAwait(false))
{
// Azure AI Search threshold filtering is in preview:
// https://learn.microsoft.com/azure/search/vector-search-how-to-query#set-thresholds-to-exclude-low-scoring-results-preview
// See https://github.com/microsoft/semantic-kernel/issues/13500.
// For now, perform post-filtering on the client-side.
if (options.ScoreThreshold.HasValue && record.Score < options.ScoreThreshold.Value)
{
continue;
}

yield return record;
}
}
Expand Down Expand Up @@ -450,6 +459,12 @@ floatVector is null

await foreach (var record in this.SearchAndMapToDataModelAsync(keywordsCombined, searchOptions, options.IncludeVectors, cancellationToken).ConfigureAwait(false))
{
// Azure AI Search returns scores where higher values indicate more relevant results.
if (options.ScoreThreshold.HasValue && record.Score < options.ScoreThreshold.Value)
{
continue;
}

yield return record;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,13 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embeddin
ScorePropertyName,
DocumentPropertyName);

BsonDocument[] pipeline = [searchQuery, projectionQuery];
List<BsonDocument> pipeline = [searchQuery, projectionQuery];

// Add score threshold filter as a $match stage if specified
if (options.ScoreThreshold.HasValue)
{
pipeline.Add(CosmosMongoCollectionSearchMapping.GetScoreThresholdMatchQuery(ScorePropertyName, options.ScoreThreshold.Value));
}

const string OperationName = "Aggregate";
var cursor = await this.RunOperationAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,20 @@ public static BsonDocument GetProjectionQuery(string scorePropertyName, string d
}
};
}

/// <summary>Returns a $match stage to filter results by score threshold.</summary>
/// <remarks>
/// Cosmos MongoDB returns a similarity score where higher values mean more similar,
/// so we filter with $gte to keep results at or above the threshold.
/// </remarks>
public static BsonDocument GetScoreThresholdMatchQuery(string scorePropertyName, double scoreThreshold)
=> new()
{
{
"$match", new BsonDocument
{
{ scorePropertyName, new BsonDocument { { "$gte", scoreThreshold } } }
}
}
};
}
4 changes: 4 additions & 0 deletions dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,12 @@ public override async IAsyncEnumerable<VectorSearchResult<TRecord>> SearchAsync<
null,
this._model,
vectorProperty.StorageName,
vectorProperty.DistanceFunction,
null,
ScorePropertyName,
options.OldFilter,
options.Filter,
options.ScoreThreshold,
top,
options.Skip,
options.IncludeVectors);
Expand Down Expand Up @@ -630,10 +632,12 @@ public async IAsyncEnumerable<VectorSearchResult<TRecord>> HybridSearchAsync<TIn
keywords,
this._model,
vectorProperty.StorageName,
vectorProperty.DistanceFunction,
textProperty.StorageName,
ScorePropertyName,
options.OldFilter,
options.Filter,
options.ScoreThreshold,
top,
options.Skip,
options.IncludeVectors);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ public static QueryDefinition BuildSearchQuery<TRecord>(
ICollection<string>? keywords,
CollectionModel model,
string vectorPropertyName,
string? distanceFunction,
string? textPropertyName,
string scorePropertyName,
#pragma warning disable CS0618 // Type or member is obsolete
VectorSearchFilter? oldFilter,
#pragma warning restore CS0618 // Type or member is obsolete
Expression<Func<TRecord, bool>>? filter,
double? scoreThreshold,
int top,
int skip,
bool includeVectors)
Expand Down Expand Up @@ -68,7 +70,7 @@ public static QueryDefinition BuildSearchQuery<TRecord>(

#pragma warning disable CS0618 // VectorSearchFilter is obsolete
// Build filter object.
var (whereClause, filterParameters) = (OldFilter: oldFilter, Filter: filter) switch
var (filterClause, filterParameters) = (OldFilter: oldFilter, Filter: filter) switch
{
{ OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"),
{ OldFilter: VectorSearchFilter legacyFilter } => BuildSearchFilter(legacyFilter, model),
Expand All @@ -82,6 +84,24 @@ public static QueryDefinition BuildSearchQuery<TRecord>(
[VectorVariableName] = vector
};

// Add score threshold filter if specified.
// For similarity functions (CosineSimilarity, DotProductSimilarity), higher scores are better, so filter with >=.
// For distance functions (EuclideanDistance), lower scores are better, so filter with <=.
const string ScoreThresholdVariableName = "@scoreThreshold";
string? scoreThresholdClause = null;
if (scoreThreshold.HasValue)
{
var comparisonOperator = distanceFunction switch
{
Microsoft.Extensions.VectorData.DistanceFunction.CosineSimilarity => ">=",
Microsoft.Extensions.VectorData.DistanceFunction.DotProductSimilarity => ">=",
Microsoft.Extensions.VectorData.DistanceFunction.EuclideanDistance => "<=",
_ => throw new NotSupportedException($"Score threshold is not supported for distance function '{distanceFunction}'.")
};
scoreThresholdClause = $"{vectorDistanceArgument} {comparisonOperator} {ScoreThresholdVariableName}";
queryParameters[ScoreThresholdVariableName] = scoreThreshold.Value;
}

// If Offset is not configured, use Top parameter instead of Limit/Offset
// since it's more optimized. Hybrid search doesn't allow top to be passed as a parameter
// so directly add it to the query here.
Expand All @@ -92,9 +112,25 @@ public static QueryDefinition BuildSearchQuery<TRecord>(
builder.AppendLine($"SELECT {topArgument}{selectClauseArguments}");
builder.AppendLine($"FROM {tableVariableName}");

if (whereClause is not null)
if (filterClause is not null || scoreThresholdClause is not null)
{
builder.Append("WHERE ").AppendLine(whereClause);
builder.Append("WHERE ");

if (filterClause is not null)
{
builder.Append(filterClause);
if (scoreThresholdClause is not null)
{
builder.Append(AndConditionDelimiter);
}
}

if (scoreThresholdClause is not null)
{
builder.Append(scoreThresholdClause);
}

builder.AppendLine();
}

builder.AppendLine($"ORDER BY {rankingArgument}");
Expand Down
8 changes: 8 additions & 0 deletions dotnet/src/VectorData/InMemory/InMemoryCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,14 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embeddin
// Get the non-null results since any record with a null vector results in a null result.
var nonNullResults = results.Where(x => x.HasValue).Select(x => x!.Value);

// Filter by score threshold if specified.
if (options.ScoreThreshold is double scoreThreshold)
{
nonNullResults = InMemoryCollectionSearchMapping.ShouldSortDescending(vectorProperty.DistanceFunction)
? nonNullResults.Where(x => x.score >= scoreThreshold)
: nonNullResults.Where(x => x.score <= scoreThreshold);
}

// Sort the results appropriately for the selected distance function and get the right page of results .
var sortedScoredResults = InMemoryCollectionSearchMapping.ShouldSortDescending(vectorProperty.DistanceFunction) ?
nonNullResults.OrderByDescending(x => x.score) :
Expand Down
18 changes: 15 additions & 3 deletions dotnet/src/VectorData/MongoDB/MongoCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,13 @@ public override async IAsyncEnumerable<VectorSearchResult<TRecord>> SearchAsync<
ScorePropertyName,
DocumentPropertyName);

BsonDocument[] pipeline = [searchQuery, projectionQuery];
List<BsonDocument> pipeline = [searchQuery, projectionQuery];

// Add score threshold filter as a $match stage if specified
if (options.ScoreThreshold.HasValue)
{
pipeline.Add(MongoCollectionSearchMapping.GetScoreThresholdMatchQuery(ScorePropertyName, options.ScoreThreshold.Value));
}

const string OperationName = "Aggregate";
using var cursor = await this.RunOperationWithRetryAsync(
Expand Down Expand Up @@ -536,7 +542,7 @@ public async IAsyncEnumerable<VectorSearchResult<TRecord>> HybridSearchAsync<TIn

var numCandidates = this._numCandidates ?? itemsAmount * MongoConstants.DefaultNumCandidatesRatio;

BsonDocument[] pipeline = MongoCollectionSearchMapping.GetHybridSearchPipeline(
List<BsonDocument> pipeline = [.. MongoCollectionSearchMapping.GetHybridSearchPipeline(
vectorArray,
keywords,
this.Name,
Expand All @@ -548,7 +554,13 @@ public async IAsyncEnumerable<VectorSearchResult<TRecord>> HybridSearchAsync<TIn
DocumentPropertyName,
itemsAmount,
numCandidates,
filter);
filter)];

// Add score threshold filter as a $match stage if specified
if (options.ScoreThreshold.HasValue)
{
pipeline.Add(MongoCollectionSearchMapping.GetScoreThresholdMatchQuery(ScorePropertyName, options.ScoreThreshold.Value));
}

var results = await this.RunOperationWithRetryAsync(
"KeywordVectorizedHybridSearch",
Expand Down
17 changes: 17 additions & 0 deletions dotnet/src/VectorData/MongoDB/MongoCollectionSearchMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public static BsonDocument GetSearchQuery<TVector>(
int numCandidates,
BsonDocument? filter)
{
// Docs: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage
var searchQuery = new BsonDocument
{
{ "index", indexName },
Expand Down Expand Up @@ -127,6 +128,22 @@ public static BsonDocument GetProjectionQuery(string scorePropertyName, string d
};
}

/// <summary>Returns a $match stage to filter results by score threshold.</summary>
/// <remarks>
/// MongoDB Atlas Vector Search returns a similarity score where higher values mean more similar,
/// so we filter with $gte to keep results at or above the threshold.
/// </remarks>
public static BsonDocument GetScoreThresholdMatchQuery(string scorePropertyName, double scoreThreshold)
=> new()
{
{
"$match", new BsonDocument
{
{ scorePropertyName, new BsonDocument { { "$gte", scoreThreshold } } }
}
}
};

/// <summary>Returns a pipeline for hybrid search using vector search and full text search.</summary>
public static BsonDocument[] GetHybridSearchPipeline<TVector>(
TVector vector,
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/VectorData/PgVector/PostgresCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, BinaryEm
#pragma warning disable CS0618 // VectorSearchFilter is obsolete
options.OldFilter,
#pragma warning restore CS0618 // VectorSearchFilter is obsolete
options.Filter, options.Skip, options.IncludeVectors, top);
options.Filter, options.Skip, options.IncludeVectors, top, options.ScoreThreshold);

using var reader = await connection.ExecuteWithErrorHandlingAsync(
this._collectionMetadata,
Expand Down
35 changes: 34 additions & 1 deletion dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,8 @@ internal static StringBuilder AppendIdentifier(this StringBuilder sb, string ide
/// <inheritdoc />
internal static void BuildGetNearestMatchCommand<TRecord>(
NpgsqlCommand command, string schema, string tableName, CollectionModel model, VectorPropertyModel vectorProperty, object vectorValue,
VectorSearchFilter? legacyFilter, Expression<Func<TRecord, bool>>? newFilter, int? skip, bool includeVectors, int limit)
VectorSearchFilter? legacyFilter, Expression<Func<TRecord, bool>>? newFilter, int? skip, bool includeVectors, int limit,
double? scoreThreshold = null)
{
// Build column list with proper escaping
StringBuilder columns = new();
Expand Down Expand Up @@ -501,6 +502,33 @@ internal static void BuildGetNearestMatchCommand<TRecord>(
commandText = outerSql.ToString();
}

// Apply score threshold filter if specified.
// For similarity functions (higher = more similar), filter out results below the threshold.
// For distance functions (lower = more similar), filter out results above the threshold.
if (scoreThreshold.HasValue)
{
var scoreThresholdParamIndex = parameters.Count + 2;
var comparisonOp = distanceFunction switch
{
DistanceFunction.CosineSimilarity or DistanceFunction.DotProductSimilarity
=> ">=",

DistanceFunction.EuclideanDistance
or DistanceFunction.CosineDistance
or DistanceFunction.ManhattanDistance
or DistanceFunction.HammingDistance
=> "<=",

_ => throw new UnreachableException($"Unexpected distance function: {distanceFunction}")
};

StringBuilder outerSql = new();
outerSql.Append("SELECT * FROM (").Append(commandText).Append(") AS scored WHERE ")
.AppendIdentifier(PostgresConstants.DistanceColumnName).Append(' ').Append(comparisonOp)
.Append(" $").Append(scoreThresholdParamIndex);
commandText = outerSql.ToString();
}

command.CommandText = commandText;

Debug.Assert(command.Parameters.Count == 0);
Expand All @@ -510,6 +538,11 @@ internal static void BuildGetNearestMatchCommand<TRecord>(
{
command.Parameters.Add(new NpgsqlParameter { Value = parameter });
}

if (scoreThreshold.HasValue)
{
command.Parameters.Add(new NpgsqlParameter { Value = scoreThreshold.Value });
}
}

internal static void BuildSelectWhereCommand<TRecord>(
Expand Down
7 changes: 7 additions & 0 deletions dotnet/src/VectorData/Pinecone/PineconeCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ public override async IAsyncEnumerable<VectorSearchResult<TRecord>> SearchAsync<
Verify.NotLessThan(top, 1);

options ??= s_defaultVectorSearchOptions;

if (options.IncludeVectors && this._model.EmbeddingGenerationRequired)
{
throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration);
Expand Down Expand Up @@ -500,6 +501,12 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embeddin

foreach (var record in records)
{
// Pinecone returns similarity scores where higher values indicate more similar results.
if (options.ScoreThreshold.HasValue && record.Score < options.ScoreThreshold.Value)
{
continue;
}

yield return record;
}
}
Expand Down
4 changes: 3 additions & 1 deletion dotnet/src/VectorData/Qdrant/QdrantCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ public override async IAsyncEnumerable<VectorSearchResult<TRecord>> SearchAsync<
query: query,
usingVector: this._hasNamedVectors ? vectorProperty.StorageName : null,
filter: filter,
scoreThreshold: (float?)options.ScoreThreshold,
limit: (ulong)top,
offset: (ulong)options.Skip,
vectorsSelector: vectorsSelector,
Expand Down Expand Up @@ -740,8 +741,9 @@ public async IAsyncEnumerable<VectorSearchResult<TRecord>> HybridSearchAsync<TIn
"Query",
() => this._qdrantClient.QueryAsync(
this.Name,
prefetch: new List<PrefetchQuery>() { vectorQuery, keywordQuery },
prefetch: [vectorQuery, keywordQuery],
query: fusionQuery,
scoreThreshold: (float?)options.ScoreThreshold,
limit: (ulong)top,
offset: (ulong)options.Skip,
vectorsSelector: vectorsSelector,
Expand Down
Loading
Loading