Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion Kerberos.NET/Client/Transport/ClientDomainService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Net.NetworkInformation;
using System.Threading;
using System.Threading.Tasks;
using Kerberos.NET.Configuration;
using Kerberos.NET.Dns;
Expand All @@ -13,6 +15,8 @@ namespace Kerberos.NET.Transport
{
public class ClientDomainService
{
private static readonly Random Random = new();

public ClientDomainService(ILoggerFactory logger)
{
this.logger = logger.CreateLoggerSafe<ClientDomainService>();
Expand Down Expand Up @@ -47,6 +51,13 @@ static ClientDomainService()

public Krb5Config Configuration { get; set; }

public TimeSpan ConnectTimeout { get; set; } = TimeSpan.FromSeconds(2);

public TimeSpan SendTimeout { get; set; } = TimeSpan.FromSeconds(10);

public TimeSpan ReceiveTimeout { get; set; } = TimeSpan.FromSeconds(10);


public void ResetConnections()
{
DomainCache.Clear();
Expand All @@ -59,7 +70,37 @@ public virtual async Task<IEnumerable<DnsRecord>> LocateKdc(string domain, strin
{
var results = await this.Query(domain, servicePrefix, DefaultKerberosPort);

return ParseQuerySrvReply(results);
results = ParseQuerySrvReply(results);

return await WeightResults(results);
}

private async Task<IEnumerable<DnsRecord>> WeightResults(IEnumerable<DnsRecord> results)
{
SortedList<int, DnsRecord> fastest = new();

if (this.Configuration.Defaults.PrioritizeKdcByPing)
{
try
{
using var cts = new CancellationTokenSource(this.ConnectTimeout);

fastest = await results.GetFastestAsync(PingAsync, cts.Token);
}
catch (Exception ex)
{
this.logger.LogWarning(ex, "Ping failed for all found services");
}
}

foreach (var r in results)
{
var speed = fastest.FirstOrDefault(f => string.Equals(f.Value.Target, r.Target, StringComparison.OrdinalIgnoreCase));

r.PingResponseTime = speed.Value != null ? speed.Key : Random.Next(fastest.Count, int.MaxValue);
}

return results;
}

public virtual async Task<IEnumerable<DnsRecord>> LocateKpasswd(string domain, string servicePrefix)
Expand Down Expand Up @@ -153,6 +194,43 @@ protected virtual async Task<IEnumerable<DnsRecord>> Query(string domain, string
return records;
}

protected virtual async Task<DnsRecord> PingAsync(DnsRecord record, CancellationToken cancellationToken)
{
using var ping = new Ping();

cancellationToken.Register(() => ping.SendAsyncCancel());

var reply = await ping.SendPingAsync(record.Target, Convert.ToInt32(this.ConnectTimeout.TotalMilliseconds));

return reply.Status == IPStatus.Success ? record : throw new PingException($"Ping {record.Target} returned {reply.Status}");
}

private class DnsRecordComparer : IEqualityComparer<DnsRecord>
{
public static readonly DnsRecordComparer Instance = new();

private DnsRecordComparer()
{
}

public bool Equals(DnsRecord x, DnsRecord y)
{
if (ReferenceEquals(x, y)) return true;
if (x is null) return false;
if (y is null) return false;
if (x.GetType() != y.GetType()) return false;
return x.Target == y.Target && x.Port == y.Port;
}

public int GetHashCode(DnsRecord obj)
{
unchecked
{
return ((obj.Target != null ? obj.Target.GetHashCode() : 0) * 397) ^ obj.Port;
}
}
}

private async Task QueryDns(string domain, string servicePrefix, List<DnsRecord> records)
{
var lookup = Invariant($"{servicePrefix}.{domain}");
Expand Down
2 changes: 0 additions & 2 deletions Kerberos.NET/Client/Transport/HttpsKerberosTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ namespace Kerberos.NET.Transport
{
public class HttpsKerberosTransport : KerberosTransportBase
{
private static readonly Random Random = new Random();

private readonly ILogger logger;

public HttpsKerberosTransport(ILoggerFactory logger = null)
Expand Down
76 changes: 26 additions & 50 deletions Kerberos.NET/Client/Transport/KerberosTransportBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.NetworkInformation;
using System.Threading;
using System.Threading.Tasks;
using Kerberos.NET.Asn1;
Expand All @@ -20,26 +19,41 @@ namespace Kerberos.NET.Transport
{
public abstract class KerberosTransportBase : IKerberosTransport2, IDisposable
{
protected static readonly Random Random = new();

private bool disposedValue;

protected KerberosTransportBase(ILoggerFactory logger)
{
this.ClientRealmService = new ClientDomainService(logger);
this.Logger = logger.CreateLoggerSafe<KerberosTransportBase>();
}

private bool disposedValue;

private DnsRecord fastest;
protected ILogger Logger { get; }

public virtual bool TransportFailed { get; set; }

public virtual KerberosTransportException LastError { get; set; }

public bool Enabled { get; set; }

public TimeSpan ConnectTimeout { get; set; } = TimeSpan.FromSeconds(2);
public TimeSpan ConnectTimeout
{
get => this.ClientRealmService.ConnectTimeout;
set => this.ClientRealmService.ConnectTimeout = value;
}

public TimeSpan SendTimeout { get; set; } = TimeSpan.FromSeconds(10);
public TimeSpan SendTimeout
{
get => this.ClientRealmService.SendTimeout;
set => this.ClientRealmService.SendTimeout = value;
}

public TimeSpan ReceiveTimeout { get; set; } = TimeSpan.FromSeconds(10);
public TimeSpan ReceiveTimeout
{
get => this.ClientRealmService.ReceiveTimeout;
set => this.ClientRealmService.ReceiveTimeout = value;
}

public int MaximumAttempts { get; set; } = 30;

Expand Down Expand Up @@ -166,58 +180,20 @@ public void Dispose()
protected virtual async Task<DnsRecord> LocatePreferredKdc(string domain, string servicePrefix)
{
var results = await this.LocateKdc(domain, servicePrefix);
return await SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKerberosPort);
return SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKerberosPort);
}

protected virtual async Task<DnsRecord> LocatePreferredKpasswd(string domain, string servicePrefix)
{
var results = await this.LocateKpasswd(domain, servicePrefix);
return await SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKpasswdPort);
}

protected virtual async Task<DnsRecord> SelectedPreferredInstance(string domain, string servicePrefix, IEnumerable<DnsRecord> results, int defaultPort)
{
if (results.Contains(fastest, DnsRecordComparer.Instance))
{
return fastest;
}

fastest = await results.Where(r => r.Name.StartsWith(servicePrefix)).GetFastestAsync(PingAsync);
return fastest ?? throw new KerberosTransportException($"Cannot locate SRV record for {domain}");
}

private async Task<DnsRecord> PingAsync(DnsRecord record, CancellationToken cancellationToken)
{
using var ping = new Ping();
cancellationToken.Register(() => ping.SendAsyncCancel());
var reply = await ping.SendPingAsync(record.Target, Convert.ToInt32(ConnectTimeout.TotalMilliseconds));
return reply.Status == IPStatus.Success ? record : throw new PingException($"Ping {record.Target} returned {reply.Status}");
return SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKpasswdPort);
}

private class DnsRecordComparer : IEqualityComparer<DnsRecord>
protected virtual DnsRecord SelectedPreferredInstance(string domain, string servicePrefix, IEnumerable<DnsRecord> results, int defaultPort)
{
public static readonly DnsRecordComparer Instance = new();

private DnsRecordComparer()
{
}
results = results.Where(r => r.Name.StartsWith(servicePrefix)).OrderBy(r => r.PingResponseTime);

public bool Equals(DnsRecord x, DnsRecord y)
{
if (ReferenceEquals(x, y)) return true;
if (x is null) return false;
if (y is null) return false;
if (x.GetType() != y.GetType()) return false;
return x.Target == y.Target && x.Port == y.Port;
}

public int GetHashCode(DnsRecord obj)
{
unchecked
{
return ((obj.Target != null ? obj.Target.GetHashCode() : 0) * 397) ^ obj.Port;
}
}
return results.FirstOrDefault() ?? throw new KerberosTransportException($"Cannot locate SRV record for {domain}");
}
}
}
7 changes: 7 additions & 0 deletions Kerberos.NET/Configuration/Krb5ConfigDefaults.cs
Original file line number Diff line number Diff line change
Expand Up @@ -346,5 +346,12 @@ public class Krb5ConfigDefaults : Krb5ConfigObject
[DefaultValue(PrincipalNameType.NT_ENTERPRISE)]
[DisplayName("default_name_type")]
public PrincipalNameType DefaultNameType { get; set; }

/// <summary>
/// Indicates whether the client should try to find and sort KDCs by how long it takes for them to respond by ping.
/// </summary>
[DefaultValue(true)]
[DisplayName("prioritize_by_response_time")]
public bool PrioritizeKdcByPing { get; set; }
}
}
2 changes: 2 additions & 0 deletions Kerberos.NET/Dns/DnsRecord.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,7 @@ public string Address
return this.Target;
}
}

public int PingResponseTime { get; set; } = int.MaxValue;
}
}
29 changes: 24 additions & 5 deletions Kerberos.NET/TaskExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// -----------------------------------------------------------------------
// -----------------------------------------------------------------------
// Licensed to The .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// -----------------------------------------------------------------------
Expand All @@ -11,31 +11,50 @@

internal static class TaskExtensions
{
public static async Task<TResult> GetFastestAsync<TSource, TResult>(this IEnumerable<TSource> source, Func<TSource, CancellationToken, Task<TResult>> task, CancellationToken cancellationToken = default)
public static async Task<SortedList<int, TResult>> GetFastestAsync<TSource, TResult>(
this IEnumerable<TSource> source,
Func<TSource, CancellationToken, Task<TResult>> task,
CancellationToken cancellationToken = default
)
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
var tasks = new HashSet<Task<TResult>>(source.Select(e => task(e, cts.Token)));

if (tasks.Count == 0)
{
return default;
return new();
}

int next = 0;
SortedList<int, TResult> results = new();

var exceptions = new List<Exception>();

do
{
var completedTask = await Task.WhenAny(tasks);

if (completedTask.Status == TaskStatus.RanToCompletion)
{
cts.Cancel();
return completedTask.Result;

results.Add(++next, completedTask.Result);
}

if (completedTask.Exception != null)
{
exceptions.AddRange(completedTask.Exception.InnerExceptions);
}

tasks.Remove(completedTask);
} while (tasks.Count > 0);

}
while (tasks.Count > 0);

if (results.Count > 0)
{
return results;
}

throw new AggregateException(exceptions);
}
Expand Down
Loading
Loading