Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/ExpressiveSharp.Generator/Emitter/ExpressionTreeEmitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -753,13 +753,25 @@ private bool TryEmitEnumMethodExpansion(IInvocationOperation invocation, out str
AppendLine($"var {defaultVar} = {Expr}.Default(typeof({returnTypeFqn}));");
}

var receiverTypeFqn = receiverType.ToDisplayString(_fqnFormat);

// Build the ternary chain in reverse so the first member ends up as the outermost (and first-tested) branch.
var currentVar = defaultVar;
foreach (var member in enumMembers.AsEnumerable().Reverse())
{
var enumValueVar = NextVar();
AppendLine($"var {enumValueVar} = {Expr}.Constant({enumTypeFqn}.{member.Name}, typeof({enumTypeFqn}));");

// The MethodInfo is bound on the original receiver type — for an instance method on
// Nullable<TEnum> or an extension whose first param is Nullable<TEnum>, the per-arm
// operand must also be Nullable<TEnum> or Expression.Call rejects the type mismatch.
if (isNullable)
{
var lifted = NextVar();
AppendLine($"var {lifted} = {Expr}.Convert({enumValueVar}, typeof({receiverTypeFqn}));");
enumValueVar = lifted;
}

// Static path passes the enum value as the first arg; instance path uses it as the receiver.
string callVar;
if (originalMethod.IsStatic)
Expand Down Expand Up @@ -800,7 +812,6 @@ private bool TryEmitEnumMethodExpansion(IInvocationOperation invocation, out str
if (isNullable)
{
var nullConst = NextVar();
var receiverTypeFqn = receiverType.ToDisplayString(_fqnFormat);
AppendLine($"var {nullConst} = {Expr}.Constant(null, typeof({receiverTypeFqn}));");

var nullCheck = NextVar();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// <auto-generated/>
#nullable disable

using Foo;

namespace ExpressiveSharp.Generated
{
static partial class Foo_Entity
{
// [Expressive]
// public string TrendLabel => Trend.Describe();
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo.Entity, string>> TrendLabel_Expression()
{
var p__this = global::System.Linq.Expressions.Expression.Parameter(typeof(global::Foo.Entity), "@this");
var expr_0 = global::System.Linq.Expressions.Expression.Property(p__this, typeof(global::Foo.Entity).GetProperty("Trend", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance)); // Trend
var expr_1 = global::System.Linq.Expressions.Expression.Constant(null, typeof(string));
var expr_2 = global::System.Linq.Expressions.Expression.Constant(global::Foo.Trend.NoChange, typeof(global::Foo.Trend));
var expr_3 = global::System.Linq.Expressions.Expression.Convert(expr_2, typeof(global::Foo.Trend?));
var expr_4 = global::System.Linq.Expressions.Expression.Call(typeof(global::Foo.TrendExtensions).GetMethod("Describe", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Static, null, new global::System.Type[] { typeof(global::Foo.Trend?) }, null), new global::System.Linq.Expressions.Expression[] { expr_3 });
var expr_5 = global::System.Linq.Expressions.Expression.Equal(expr_0, expr_3);
var expr_6 = global::System.Linq.Expressions.Expression.Condition(expr_5, expr_4, expr_1, typeof(string));
var expr_7 = global::System.Linq.Expressions.Expression.Constant(global::Foo.Trend.Decrease, typeof(global::Foo.Trend));
var expr_8 = global::System.Linq.Expressions.Expression.Convert(expr_7, typeof(global::Foo.Trend?));
var expr_9 = global::System.Linq.Expressions.Expression.Call(typeof(global::Foo.TrendExtensions).GetMethod("Describe", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Static, null, new global::System.Type[] { typeof(global::Foo.Trend?) }, null), new global::System.Linq.Expressions.Expression[] { expr_8 });
var expr_10 = global::System.Linq.Expressions.Expression.Equal(expr_0, expr_8);
var expr_11 = global::System.Linq.Expressions.Expression.Condition(expr_10, expr_9, expr_6, typeof(string));
var expr_12 = global::System.Linq.Expressions.Expression.Constant(global::Foo.Trend.Increase, typeof(global::Foo.Trend));
var expr_13 = global::System.Linq.Expressions.Expression.Convert(expr_12, typeof(global::Foo.Trend?));
var expr_14 = global::System.Linq.Expressions.Expression.Call(typeof(global::Foo.TrendExtensions).GetMethod("Describe", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Static, null, new global::System.Type[] { typeof(global::Foo.Trend?) }, null), new global::System.Linq.Expressions.Expression[] { expr_13 });
var expr_15 = global::System.Linq.Expressions.Expression.Equal(expr_0, expr_13);
var expr_16 = global::System.Linq.Expressions.Expression.Condition(expr_15, expr_14, expr_11, typeof(string));
var expr_17 = global::System.Linq.Expressions.Expression.Constant(null, typeof(global::Foo.Trend?));
var expr_18 = global::System.Linq.Expressions.Expression.Equal(expr_0, expr_17);
var expr_19 = global::System.Linq.Expressions.Expression.Condition(expr_18, expr_1, expr_16, typeof(string));
return global::System.Linq.Expressions.Expression.Lambda<global::System.Func<global::Foo.Entity, string>>(expr_19, p__this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// <auto-generated/>
#nullable disable

using Foo;

namespace ExpressiveSharp.Generated
{
static partial class Foo_Entity
{
// [Expressive]
// public string TrendLabel => Trend.ToString() ?? string.Empty;
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo.Entity, string>> TrendLabel_Expression()
{
var p__this = global::System.Linq.Expressions.Expression.Parameter(typeof(global::Foo.Entity), "@this");
var expr_1 = global::System.Linq.Expressions.Expression.Property(p__this, typeof(global::Foo.Entity).GetProperty("Trend", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance)); // Trend
var expr_2 = global::System.Linq.Expressions.Expression.Constant(null, typeof(string));
var expr_3 = global::System.Linq.Expressions.Expression.Constant(global::Foo.Trend.NoChange, typeof(global::Foo.Trend));
var expr_4 = global::System.Linq.Expressions.Expression.Convert(expr_3, typeof(global::Foo.Trend?));
var expr_5 = global::System.Linq.Expressions.Expression.Call(expr_4, typeof(global::Foo.Trend?).GetMethod("ToString", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance, null, new global::System.Type[] { }, null), global::System.Array.Empty<global::System.Linq.Expressions.Expression>());
var expr_6 = global::System.Linq.Expressions.Expression.Equal(expr_1, expr_4);
var expr_7 = global::System.Linq.Expressions.Expression.Condition(expr_6, expr_5, expr_2, typeof(string));
var expr_8 = global::System.Linq.Expressions.Expression.Constant(global::Foo.Trend.Decrease, typeof(global::Foo.Trend));
var expr_9 = global::System.Linq.Expressions.Expression.Convert(expr_8, typeof(global::Foo.Trend?));
var expr_10 = global::System.Linq.Expressions.Expression.Call(expr_9, typeof(global::Foo.Trend?).GetMethod("ToString", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance, null, new global::System.Type[] { }, null), global::System.Array.Empty<global::System.Linq.Expressions.Expression>());
var expr_11 = global::System.Linq.Expressions.Expression.Equal(expr_1, expr_9);
var expr_12 = global::System.Linq.Expressions.Expression.Condition(expr_11, expr_10, expr_7, typeof(string));
var expr_13 = global::System.Linq.Expressions.Expression.Constant(global::Foo.Trend.Increase, typeof(global::Foo.Trend));
var expr_14 = global::System.Linq.Expressions.Expression.Convert(expr_13, typeof(global::Foo.Trend?));
var expr_15 = global::System.Linq.Expressions.Expression.Call(expr_14, typeof(global::Foo.Trend?).GetMethod("ToString", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance, null, new global::System.Type[] { }, null), global::System.Array.Empty<global::System.Linq.Expressions.Expression>());
var expr_16 = global::System.Linq.Expressions.Expression.Equal(expr_1, expr_14);
var expr_17 = global::System.Linq.Expressions.Expression.Condition(expr_16, expr_15, expr_12, typeof(string));
var expr_18 = global::System.Linq.Expressions.Expression.Constant(null, typeof(global::Foo.Trend?));
var expr_19 = global::System.Linq.Expressions.Expression.Equal(expr_1, expr_18);
var expr_20 = global::System.Linq.Expressions.Expression.Condition(expr_19, expr_2, expr_17, typeof(string));
var expr_21 = global::System.Linq.Expressions.Expression.Field(null, typeof(string).GetField("Empty", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Static)); // string.Empty
var expr_0 = global::System.Linq.Expressions.Expression.Coalesce(expr_20, expr_21);
return global::System.Linq.Expressions.Expression.Lambda<global::System.Func<global::Foo.Entity, string>>(expr_0, p__this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -333,4 +333,60 @@ public record Entity

return Verifier.Verify(result.GeneratedTrees[0].ToString());
}

[TestMethod]
public Task ExpandToStringOnNullableEnum()
{
var compilation = CreateCompilation(
"""
namespace Foo {
public enum Trend { Increase, Decrease, NoChange }

public record Entity
{
public Trend? Trend { get; init; }

[Expressive]
public string TrendLabel => Trend.ToString() ?? string.Empty;
}
}
""");
var result = RunExpressiveGenerator(compilation);

Assert.AreEqual(0, result.Diagnostics.Length);
Assert.AreEqual(1, result.GeneratedTrees.Length);

return Verifier.Verify(result.GeneratedTrees[0].ToString());
}

[TestMethod]
public Task ExpandExtensionMethodOnNullableEnum()
{
var compilation = CreateCompilation(
"""
namespace Foo {
public enum Trend { Increase, Decrease, NoChange }

public static class TrendExtensions
{
public static string Describe(this Trend? value) =>
value.HasValue ? value.Value.ToString() : "n/a";
}

public record Entity
{
public Trend? Trend { get; init; }

[Expressive]
public string TrendLabel => Trend.Describe();
}
}
""");
var result = RunExpressiveGenerator(compilation);

Assert.AreEqual(0, result.Diagnostics.Length);
Assert.AreEqual(1, result.GeneratedTrees.Length);

return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
Comment on lines +362 to +391
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ public void ExpressivePropertyOnEnumComparison_RegistryResolves()

Assert.IsNotNull(registered);
}

[TestMethod]
public void ToStringOnNullableEnum_ExpandExpressives_MaterializesAndEvaluates()
{
var source = new List<EnumComparisonEntity>
{
new() { NullableValue = Bucket.Low },
new() { NullableValue = Bucket.Mid },
new() { NullableValue = Bucket.High },
new() { NullableValue = null },
}.AsQueryable();

Expression<Func<EnumComparisonEntity, string>> expr = e => e.NullableLabel;
var expanded = (Expression<Func<EnumComparisonEntity, string>>)expr.ExpandExpressives();

var labels = source.Select(expanded.Compile()).ToList();

CollectionAssert.AreEqual(new[] { "Low", "Mid", "High", "" }, labels);
}
}

public class EnumComparisonEntity
Expand All @@ -88,6 +107,9 @@ public class EnumComparisonEntity

[Expressive]
public bool IsLowOrMid => NullableValue <= Bucket.Mid;

[Expressive]
public string NullableLabel => NullableValue.ToString() ?? string.Empty;
}

public enum Bucket { Low, Mid, High }
Expand Down
Loading