Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 63 additions & 22 deletions src/Abstractions/TaskFailureDetails.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ namespace Microsoft.DurableTask;
/// <param name="Properties">Additional properties associated with the exception.</param>
public record TaskFailureDetails(string ErrorType, string ErrorMessage, string? StackTrace, TaskFailureDetails? InnerFailure, IDictionary<string, object?>? Properties)
{
static readonly char[] TypeNameBoundaryChars = new[] { ',', '[' };

Type? loadedExceptionType;

/// <summary>
Expand All @@ -33,9 +35,22 @@ public override string ToString()
/// Returns <c>true</c> if the task failure was provided by the specified exception type.
/// </summary>
/// <remarks>
/// <para>
/// This method allows checking if a task failed due to an exception of a specific type by attempting
/// to load the type specified in <see cref="ErrorType"/>. If the exception type cannot be loaded
/// for any reason, this method will return <c>false</c>. Base types are supported.
/// to resolve the type specified in <see cref="ErrorType"/> against already-loaded assemblies. If the
/// exception type cannot be resolved for any reason, this method will return <c>false</c>. Base types
/// are supported.
/// </para>
/// <para>
/// <see cref="ErrorType"/> is rehydrated from the durable task hub backing store and may have been
/// written by an actor with storage write access. To prevent a malicious <see cref="ErrorType"/>
/// value from triggering an arbitrary assembly load (which would execute module initializers and
/// static constructors of the loaded assembly), this method intentionally never calls
/// <see cref="Type.GetType(string)"/> on <see cref="ErrorType"/> and only resolves names against
/// assemblies that are already loaded into the current <see cref="AppDomain"/>. Any assembly
/// qualifier or generic-argument payload embedded in <see cref="ErrorType"/> is stripped before
/// resolution; consequently, generic exception types are not recognized by this method.
/// </para>
/// </remarks>
/// <typeparam name="T">The type of exception to test against.</typeparam>
/// <exception cref="AmbiguousMatchException">If multiple exception types with the same name are found.</exception>
Expand Down Expand Up @@ -66,36 +81,52 @@ public bool IsCausedBy(Type targetBaseExceptionType)
return false;
}

// This check works for .NET exception types defined in System.Core.PrivateLib (aka mscorelib.dll)
this.loadedExceptionType ??= Type.GetType(this.ErrorType, throwOnError: false);

// For exception types defined in the same assembly as the target exception type.
this.loadedExceptionType ??= targetBaseExceptionType.Assembly.GetType(this.ErrorType, throwOnError: false);

// For custom exception types defined in the app's assembly
this.loadedExceptionType ??= Assembly.GetCallingAssembly().GetType(this.ErrorType);

if (this.loadedExceptionType is null)
{
// This last check works for exception types defined in any loaded assembly (e.g. NuGet packages, etc.).
// This is a fallback that should rarely be needed except in obscure cases.
List<Type> matchingExceptionTypes = AppDomain.CurrentDomain.GetAssemblies()
.Select(a => a.GetType(this.ErrorType, throwOnError: false))
.Where(t => t is not null)
.ToList();
if (matchingExceptionTypes.Count == 1)
// Defense-in-depth: ErrorType is rehydrated from durable storage and may be
// attacker-influenced. Strip any assembly-qualified suffix and any generic-argument
// payload so that the value we pass to Assembly.GetType is a plain dotted type name
// - which cannot trigger Assembly.Load on an attacker-supplied assembly. Resolution
// is performed only against assemblies that are already loaded.
string safeTypeName = ExtractPlainTypeName(this.ErrorType);
if (safeTypeName.Length == 0)
{
this.loadedExceptionType = matchingExceptionTypes[0];
return false;
}
else if (matchingExceptionTypes.Count > 1)

// For exception types defined in the same assembly as the target exception type
// (covers System.Private.CoreLib when the caller queries a built-in exception type).
Type? resolved = targetBaseExceptionType.Assembly.GetType(safeTypeName, throwOnError: false);

// For custom exception types defined in the app's assembly.
resolved ??= Assembly.GetCallingAssembly().GetType(safeTypeName, throwOnError: false);

if (resolved is null)
{
throw new AmbiguousMatchException($"Multiple exception types with the name '{this.ErrorType}' were found.");
// Fallback: scan all already-loaded assemblies (covers third-party NuGets, etc.).
// Assembly.GetType against a plain (unqualified) type name does not trigger
// Assembly.Load, so this scan does not create a code-execution primitive.
List<Type> matches = AppDomain.CurrentDomain.GetAssemblies()
.Select(a => a.GetType(safeTypeName, throwOnError: false))
.Where(t => t is not null)
.Distinct()
.ToList()!;
Comment on lines +109 to +113
if (matches.Count == 1)
{
resolved = matches[0];
}
else if (matches.Count > 1)
{
throw new AmbiguousMatchException($"Multiple exception types with the name '{safeTypeName}' were found.");
}
}

this.loadedExceptionType = resolved;
}

if (this.loadedExceptionType is null)
{
// The actual exception type could not be loaded, so we cannot determine if it matches the target type.
// The actual exception type could not be resolved, so we cannot determine if it matches the target type.
return false;
}

Expand Down Expand Up @@ -190,4 +221,14 @@ internal CoreFailureDetails ToCoreFailureDetails()
FromCoreFailureDetailsRecursive(coreFailureDetails.InnerFailure),
coreFailureDetails.Properties);
}

// Returns the leading plain type name from an ErrorType string, stripping any assembly
// qualifier and any generic-argument payload. This guarantees the returned value, when
// passed to Assembly.GetType, cannot trigger Assembly.Load.
static string ExtractPlainTypeName(string errorType)
{
int boundary = errorType.IndexOfAny(TypeNameBoundaryChars);
string trimmed = boundary < 0 ? errorType : errorType.Substring(0, boundary);
return trimmed.Trim();
}
}
233 changes: 233 additions & 0 deletions test/Abstractions.Tests/TaskFailureDetailsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Microsoft.DurableTask.Tests;

public class TaskFailureDetailsTests
{
[Fact]
public void IsCausedBy_NullTargetType_Throws()
{
TaskFailureDetails details = Create(typeof(InvalidOperationException).FullName!);

Action act = () => details.IsCausedBy(null!);

act.Should().Throw<ArgumentNullException>();
}

[Fact]
public void IsCausedBy_NonExceptionTargetType_Throws()
{
TaskFailureDetails details = Create(typeof(InvalidOperationException).FullName!);

Action act = () => details.IsCausedBy(typeof(string));

act.Should().Throw<ArgumentException>();
}

[Theory]
[InlineData(null)]
[InlineData("")]
public void IsCausedBy_NullOrEmptyErrorType_ReturnsFalse(string? errorType)
{
TaskFailureDetails details = Create(errorType!);

details.IsCausedBy<Exception>().Should().BeFalse();
}

[Fact]
public void IsCausedBy_ExactSystemExceptionMatch_ReturnsTrue()
{
TaskFailureDetails details = Create(typeof(InvalidOperationException).FullName!);

details.IsCausedBy<InvalidOperationException>().Should().BeTrue();
}

[Fact]
public void IsCausedBy_StoredDerived_TargetBase_ReturnsTrue()
{
// The documented and tested semantic: a stored derived exception type
// satisfies IsCausedBy against any of its base exception types.
TaskFailureDetails details = Create(typeof(InvalidOperationException).FullName!);

details.IsCausedBy<Exception>().Should().BeTrue();
details.IsCausedBy<SystemException>().Should().BeTrue();
}

[Fact]
public void IsCausedBy_StoredBase_TargetDerived_ReturnsFalse()
{
TaskFailureDetails details = Create(typeof(Exception).FullName!);

details.IsCausedBy<InvalidOperationException>().Should().BeFalse();
}

[Fact]
public void IsCausedBy_UnrelatedExceptionType_ReturnsFalse()
{
TaskFailureDetails details = Create(typeof(ArgumentException).FullName!);

details.IsCausedBy<InvalidOperationException>().Should().BeFalse();
}

[Fact]
public void IsCausedBy_UnresolvableErrorType_ReturnsFalse()
{
TaskFailureDetails details = Create("Some.Bogus.Type.That.Does.Not.Exist");

details.IsCausedBy<Exception>().Should().BeFalse();
}

[Fact]
public void IsCausedBy_CustomExceptionInCallingAssembly_Resolves()
{
// The custom exception is defined in this (calling) test assembly.
// The resolution path that picks it up is Assembly.GetCallingAssembly().GetType(...).
TaskFailureDetails details = Create(typeof(TestCustomException).FullName!);

details.IsCausedBy<TestCustomException>().Should().BeTrue();
details.IsCausedBy<Exception>().Should().BeTrue();
}

[Fact]
public void IsCausedBy_AssemblyQualifiedName_Resolves()
{
TaskFailureDetails details = Create(typeof(InvalidOperationException).AssemblyQualifiedName!);

details.IsCausedBy<InvalidOperationException>().Should().BeTrue();
}

[Fact]
public void IsCausedBy_CalledTwice_ReturnsSameResult()
{
// Exercises the internal type cache (loadedExceptionType).
TaskFailureDetails details = Create(typeof(InvalidOperationException).FullName!);

details.IsCausedBy<InvalidOperationException>().Should().BeTrue();
details.IsCausedBy<InvalidOperationException>().Should().BeTrue();
}

[Fact]
public void IsCausedByGeneric_DelegatesToTypedOverload()
{
TaskFailureDetails details = Create(typeof(InvalidOperationException).FullName!);

details.IsCausedBy<InvalidOperationException>().Should().BeTrue();
details.IsCausedBy<ArgumentException>().Should().BeFalse();
}

[Fact]
public void ToString_IncludesErrorTypeAndMessage()
{
TaskFailureDetails details = new(
ErrorType: "My.ErrorType",
ErrorMessage: "boom",
StackTrace: null,
InnerFailure: null,
Properties: null);

details.ToString().Should().Be("My.ErrorType: boom");
}

// --- Defense-in-depth tests: poisoned ErrorType must not trigger Assembly.Load ---

[Fact]
public void IsCausedBy_AssemblyQualifiedNameForNonLoadedAssembly_DoesNotAttemptAssemblyResolution()
{
// Simulates the MSRC PoC payload: an ErrorType naming an assembly that is not loaded.
// Hook AssemblyResolve - this is the same mechanism the MSRC PoC uses to turn
// Assembly.Load("Evil") into a code-execution primitive. The hardened implementation
// must never reach Assembly.Load on the ErrorType string, so this handler must never
// be invoked for the attacker-supplied name.
const string poisonedAssemblyName = "DT_TFD_Test_PoisonedAssembly_DoesNotExist";
List<string> resolveRequests = new();
ResolveEventHandler handler = (_, args) =>
{
resolveRequests.Add(args.Name);
return null;
Comment on lines +143 to +147
};
AppDomain.CurrentDomain.AssemblyResolve += handler;
try
{
TaskFailureDetails details = Create($"Some.Bogus.PwnedException, {poisonedAssemblyName}");

details.IsCausedBy<Exception>().Should().BeFalse();

resolveRequests.Should().NotContain(
n => n.StartsWith(poisonedAssemblyName, StringComparison.Ordinal),
"IsCausedBy must never trigger Assembly.Load on attacker-controlled ErrorType");
}
finally
{
AppDomain.CurrentDomain.AssemblyResolve -= handler;
}
}

[Fact]
public void IsCausedBy_QualifiedNameForKnownType_StripsQualifierAndResolves()
{
// After stripping the assembly qualifier, the plain type name must still resolve via the
// target's assembly (System.Private.CoreLib here) and the inheritance semantic must hold.
TaskFailureDetails details = Create("System.InvalidOperationException, Some.Wrong.AssemblyName");

details.IsCausedBy<InvalidOperationException>().Should().BeTrue();
details.IsCausedBy<Exception>().Should().BeTrue();
}

[Fact]
public void IsCausedBy_GenericTypePayloadWithNestedAssemblyName_DoesNotAttemptAssemblyResolution()
{
// Generic-argument payloads (anything inside '[' ... ']') can carry assembly-qualified
// inner type names that would otherwise trigger Assembly.Load. The hardened
// implementation must strip them - so the AssemblyResolve handler must not see the
// attacker name.
const string poisonedAssemblyName = "DT_TFD_Test_NestedPoisonedAssembly_DoesNotExist";
List<string> resolveRequests = new();
ResolveEventHandler handler = (_, args) =>
{
resolveRequests.Add(args.Name);
return null;
Comment on lines +185 to +189
};
AppDomain.CurrentDomain.AssemblyResolve += handler;
try
{
TaskFailureDetails details = Create(
$"System.Collections.Generic.List`1[[Some.Bogus.PwnedException, {poisonedAssemblyName}]]");

details.IsCausedBy<Exception>().Should().BeFalse();

resolveRequests.Should().NotContain(
n => n.StartsWith(poisonedAssemblyName, StringComparison.Ordinal),
"IsCausedBy must never trigger Assembly.Load on attacker-controlled ErrorType");
}
finally
{
AppDomain.CurrentDomain.AssemblyResolve -= handler;
}
}

[Fact]
public void IsCausedBy_ErrorTypeWithOnlyWhitespaceBeforeComma_ReturnsFalse()
{
// After stripping the qualifier, the remaining type name is empty - must return false
// rather than attempt to resolve "".
TaskFailureDetails details = Create(" , SomeAssembly");

details.IsCausedBy<Exception>().Should().BeFalse();
}

static TaskFailureDetails Create(string errorType) => new(
ErrorType: errorType,
ErrorMessage: "test failure",
StackTrace: null,
InnerFailure: null,
Properties: null);

sealed class TestCustomException : Exception
{
public TestCustomException()
: base("test")
{
}
}
}
Loading