Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dotnet/src/VectorData/PgVector/PostgresCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ private async Task InternalCreateCollectionAsync(bool ifNotExists, CancellationT

if (!extensionAlreadyExisted)
{
await connection.ReloadTypesAsync().ConfigureAwait(false);
await PostgresUtils.ReloadTypesAsyncCompat(connection, cancellationToken).ConfigureAwait(false);
}

batch.BatchCommands.Clear();
Expand Down
37 changes: 37 additions & 0 deletions dotnet/src/VectorData/PgVector/PostgresUtils.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.VectorData;
using Npgsql;
Expand All @@ -10,6 +15,8 @@ namespace Microsoft.SemanticKernel.Connectors.PgVector;

internal static class PostgresUtils
{
private static readonly ConcurrentDictionary<Type, MethodInfo?> s_reloadTypesAsyncWithTokenMethods = new();

/// <summary>
/// Wraps an <see cref="IAsyncEnumerable{T}"/> in an <see cref="IAsyncEnumerable{T}"/> that will throw a <see cref="VectorStoreException"/>
/// if an exception is thrown while iterating over the original enumerator.
Expand Down Expand Up @@ -72,4 +79,34 @@ internal static NpgsqlDataSource CreateDataSource(string connectionString)
sourceBuilder.UseVector();
return sourceBuilder.Build();
}

internal static Task ReloadTypesAsyncCompat(object connection, CancellationToken cancellationToken = default)
{
Verify.NotNull(connection);

MethodInfo? reloadTypesWithToken = GetReloadTypesAsyncMethod(connection.GetType());
if (reloadTypesWithToken is null)
{
throw new MissingMethodException("No compatible ReloadTypesAsync overload found.");
}

object?[] parameters = reloadTypesWithToken.GetParameters().Length == 0 ? [] : [cancellationToken];
return (Task)reloadTypesWithToken.Invoke(connection, parameters)!;
}

private static MethodInfo? GetReloadTypesAsyncMethod(Type connectionType)
=> s_reloadTypesAsyncWithTokenMethods.GetOrAdd(connectionType, static type => CreateReloadTypesAsyncMethod(type));

[UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Best-effort runtime compatibility probe for optional API shape.")]
private static MethodInfo? CreateReloadTypesAsyncMethod(Type connectionType)
{
MethodInfo? method = connectionType.GetMethod("ReloadTypesAsync", [typeof(CancellationToken)]);
if (method is not null && method.ReturnType == typeof(Task))
{
return method;
}

method = connectionType.GetMethod("ReloadTypesAsync", Type.EmptyTypes);
return method is not null && method.ReturnType == typeof(Task) ? method : null;
}
}
Loading