Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public static IServiceCollection AddEntityFrameworkClickHouse(this IServiceColle
.TryAdd<IExecutionStrategyFactory, ClickHouseExecutionStrategyFactory>()
.TryAdd<IQueryableMethodTranslatingExpressionVisitorFactory, ClickHouseQueryableMethodTranslatingExpressionVisitorFactory>()
.TryAdd<IMethodCallTranslatorProvider, ClickHouseMethodCallTranslatorProvider>()
.TryAdd<IAggregateMethodCallTranslatorProvider, ClickHouseAggregateMethodCallTranslatorProvider>()
.TryAdd<IMemberTranslatorProvider, ClickHouseMemberTranslatorProvider>()
.TryAdd<IEvaluatableExpressionFilter, ClickHouseEvaluatableExpressionFilter>()
.TryAdd<IQuerySqlGeneratorFactory, ClickHouseQuerySqlGeneratorFactory>()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using Microsoft.EntityFrameworkCore.Query;

namespace ClickHouse.EntityFrameworkCore.Query.ExpressionTranslators.Internal;

public class ClickHouseAggregateMethodCallTranslatorProvider : RelationalAggregateMethodCallTranslatorProvider
{
public ClickHouseAggregateMethodCallTranslatorProvider(
RelationalAggregateMethodCallTranslatorProviderDependencies dependencies)
: base(dependencies)
{
var sqlExpressionFactory = dependencies.SqlExpressionFactory;

AddTranslators(
[
new ClickHouseQueryableAggregateMethodTranslator(sqlExpressionFactory),
]);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using System.Reflection;

namespace ClickHouse.EntityFrameworkCore.Query.ExpressionTranslators.Internal;

/// <summary>
/// Translates grouped LINQ aggregate methods (Count, Sum, Average, Min, Max)
/// into ClickHouse SQL aggregate function calls.
///
/// Scalar aggregates (without GROUP BY) are handled by the base EF Core classes;
/// this translator is needed for grouped aggregates produced by
/// <c>GroupBy().Select(g => g.Count())</c> and similar patterns.
/// </summary>
public class ClickHouseQueryableAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;

public ClickHouseQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
{
_sqlExpressionFactory = sqlExpressionFactory;
}

public SqlExpression? Translate(
MethodInfo method,
EnumerableExpression source,
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (method.DeclaringType != typeof(Queryable))
return null;

var methodInfo = method.IsGenericMethod
? method.GetGenericMethodDefinition()
: method;

switch (methodInfo.Name)
{
case nameof(Queryable.Average)
when (QueryableMethods.IsAverageWithoutSelector(methodInfo)
|| QueryableMethods.IsAverageWithSelector(methodInfo))
&& source.Selector is SqlExpression averageSqlExpression:
{
// ClickHouse avg() on integer columns returns 0 for empty groups;
// avgOrNull() returns NULL instead, matching LINQ/SQL Server semantics.
// Cast int/long to double first so avg doesn't do integer division.
var averageInputType = averageSqlExpression.Type;
if (averageInputType == typeof(int) || averageInputType == typeof(long))
{
averageSqlExpression = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(averageSqlExpression, typeof(double)));
}

averageSqlExpression = CombineTerms(source, averageSqlExpression);

return _sqlExpressionFactory.Function(
"avgOrNull",
[averageSqlExpression],
nullable: true,
argumentsPropagateNullability: [false],
typeof(double));
}

case nameof(Queryable.Count)
when methodInfo == QueryableMethods.CountWithoutPredicate
|| methodInfo == QueryableMethods.CountWithPredicate:
{
var countSqlExpression = (source.Selector as SqlExpression)
?? _sqlExpressionFactory.Fragment("*");
countSqlExpression = CombineTerms(source, countSqlExpression);

return _sqlExpressionFactory.Function(
"COUNT",
[countSqlExpression],
nullable: false,
argumentsPropagateNullability: [false],
typeof(int));
}

case nameof(Queryable.LongCount)
when methodInfo == QueryableMethods.LongCountWithoutPredicate
|| methodInfo == QueryableMethods.LongCountWithPredicate:
{
var longCountSqlExpression = (source.Selector as SqlExpression)
?? _sqlExpressionFactory.Fragment("*");
longCountSqlExpression = CombineTerms(source, longCountSqlExpression);

return _sqlExpressionFactory.Function(
"COUNT",
[longCountSqlExpression],
nullable: false,
argumentsPropagateNullability: [false],
typeof(long));
}

case nameof(Queryable.Max)
when (methodInfo == QueryableMethods.MaxWithoutSelector
|| methodInfo == QueryableMethods.MaxWithSelector)
&& source.Selector is SqlExpression maxSqlExpression:
{
maxSqlExpression = CombineTerms(source, maxSqlExpression);

return _sqlExpressionFactory.Function(
"MAX",
[maxSqlExpression],
nullable: true,
argumentsPropagateNullability: [false],
maxSqlExpression.Type,
maxSqlExpression.TypeMapping);
}

case nameof(Queryable.Min)
when (methodInfo == QueryableMethods.MinWithoutSelector
|| methodInfo == QueryableMethods.MinWithSelector)
&& source.Selector is SqlExpression minSqlExpression:
{
minSqlExpression = CombineTerms(source, minSqlExpression);

return _sqlExpressionFactory.Function(
"MIN",
[minSqlExpression],
nullable: true,
argumentsPropagateNullability: [false],
minSqlExpression.Type,
minSqlExpression.TypeMapping);
}

case nameof(Queryable.Sum)
when (QueryableMethods.IsSumWithoutSelector(methodInfo)
|| QueryableMethods.IsSumWithSelector(methodInfo))
&& source.Selector is SqlExpression sumSqlExpression:
{
sumSqlExpression = CombineTerms(source, sumSqlExpression);

return _sqlExpressionFactory.Function(
"SUM",
[sumSqlExpression],
nullable: true,
argumentsPropagateNullability: [false],
sumSqlExpression.Type,
sumSqlExpression.TypeMapping);
}
}

return null;
}

/// <summary>
/// Wraps the aggregate operand to handle predicate filtering and DISTINCT.
///
/// When a predicate is present (e.g. <c>g.Count(x => x.IsActive)</c>), the operand
/// is wrapped in <c>CASE WHEN predicate THEN expr ELSE NULL END</c> so that only
/// matching rows contribute to the aggregate. If the operand is <c>*</c> (a fragment),
/// it's replaced with the constant <c>1</c> since <c>CASE WHEN ... THEN * END</c>
/// isn't valid SQL.
///
/// When DISTINCT is requested, the operand is wrapped in a <see cref="DistinctExpression"/>
/// so the SQL generator emits <c>COUNT(DISTINCT expr)</c> etc.
/// </summary>
private SqlExpression CombineTerms(EnumerableExpression enumerableExpression, SqlExpression sqlExpression)
{
if (enumerableExpression.Predicate != null)
{
if (sqlExpression is SqlFragmentExpression)
{
sqlExpression = _sqlExpressionFactory.Constant(1);
}

sqlExpression = _sqlExpressionFactory.Case(
[new CaseWhenClause(enumerableExpression.Predicate, sqlExpression)],
elseResult: null);
}

if (enumerableExpression.IsDistinct)
{
sqlExpression = new DistinctExpression(sqlExpression);
}

return sqlExpression;
}
}
168 changes: 168 additions & 0 deletions test/EFCore.ClickHouse.Tests/GroupByAggregateTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
using Microsoft.EntityFrameworkCore;
using Xunit;

namespace EFCore.ClickHouse.Tests;

public class GroupByAggregateTests : IClassFixture<ClickHouseFixture>
{
private readonly ClickHouseFixture _fixture;

public GroupByAggregateTests(ClickHouseFixture fixture)
{
_fixture = fixture;
}

[Fact]
public async Task GroupBy_Count_ReturnsCorrectCounts()
{
await using var context = new TestDbContext(_fixture.ConnectionString);

var results = await context.TestEntities
.GroupBy(e => e.IsActive)
.Select(g => new { IsActive = g.Key, Count = g.Count() })
.OrderBy(x => x.IsActive)
.AsNoTracking()
.ToListAsync();

Assert.Equal(2, results.Count);
// false group: Charlie, Eve, Hank, Jack = 4
Assert.False(results[0].IsActive);
Assert.Equal(4, results[0].Count);
// true group: Alice, Bob, Diana, Frank, Grace, Ivy = 6
Assert.True(results[1].IsActive);
Assert.Equal(6, results[1].Count);
}

[Fact]
public async Task GroupBy_Sum_ReturnsCorrectSums()
{
await using var context = new TestDbContext(_fixture.ConnectionString);

var results = await context.TestEntities
.GroupBy(e => e.IsActive)
.Select(g => new { IsActive = g.Key, TotalAge = g.Sum(e => e.Age) })
.OrderBy(x => x.IsActive)
.AsNoTracking()
.ToListAsync();

Assert.Equal(2, results.Count);
// false: 35 + 22 + 27 + 29 = 113
Assert.Equal(113, results[0].TotalAge);
// true: 30 + 25 + 28 + 40 + 33 + 31 = 187
Assert.Equal(187, results[1].TotalAge);
}

[Fact]
public async Task GroupBy_Average_ReturnsCorrectAverages()
{
await using var context = new TestDbContext(_fixture.ConnectionString);

var results = await context.TestEntities
.GroupBy(e => e.IsActive)
.Select(g => new { IsActive = g.Key, AvgAge = g.Average(e => e.Age) })
.OrderBy(x => x.IsActive)
.AsNoTracking()
.ToListAsync();

Assert.Equal(2, results.Count);
// false: 113 / 4 = 28.25
Assert.Equal(28.25, results[0].AvgAge);
// true: 187 / 6 ≈ 31.1667
Assert.Equal(31.1667, results[1].AvgAge, 3);
}

[Fact]
public async Task GroupBy_MinMax_ReturnsCorrectValues()
{
await using var context = new TestDbContext(_fixture.ConnectionString);

var results = await context.TestEntities
.GroupBy(e => e.IsActive)
.Select(g => new { IsActive = g.Key, MinAge = g.Min(e => e.Age), MaxAge = g.Max(e => e.Age) })
.OrderBy(x => x.IsActive)
.AsNoTracking()
.ToListAsync();

Assert.Equal(2, results.Count);
// false: min=22(Eve), max=35(Charlie)
Assert.Equal(22, results[0].MinAge);
Assert.Equal(35, results[0].MaxAge);
// true: min=25(Bob), max=40(Frank)
Assert.Equal(25, results[1].MinAge);
Assert.Equal(40, results[1].MaxAge);
}

[Fact]
public async Task GroupBy_Having_FiltersGroups()
{
await using var context = new TestDbContext(_fixture.ConnectionString);

// Only groups where count > 4
var results = await context.TestEntities
.GroupBy(e => e.IsActive)
.Where(g => g.Count() > 4)
.Select(g => new { IsActive = g.Key, Count = g.Count() })
.AsNoTracking()
.ToListAsync();

// Only active group has 6, inactive has 4
Assert.Single(results);
Assert.True(results[0].IsActive);
Assert.Equal(6, results[0].Count);
}

[Fact]
public async Task GroupBy_MultipleAggregates_ReturnsAll()
{
await using var context = new TestDbContext(_fixture.ConnectionString);

var results = await context.TestEntities
.GroupBy(e => e.IsActive)
.Select(g => new
{
IsActive = g.Key,
Count = g.Count(),
TotalAge = g.Sum(e => e.Age),
MinAge = g.Min(e => e.Age),
MaxAge = g.Max(e => e.Age),
})
.OrderBy(x => x.IsActive)
.AsNoTracking()
.ToListAsync();

Assert.Equal(2, results.Count);

// Inactive group
Assert.Equal(4, results[0].Count);
Assert.Equal(113, results[0].TotalAge);
Assert.Equal(22, results[0].MinAge);
Assert.Equal(35, results[0].MaxAge);

// Active group
Assert.Equal(6, results[1].Count);
Assert.Equal(187, results[1].TotalAge);
Assert.Equal(25, results[1].MinAge);
Assert.Equal(40, results[1].MaxAge);
}

[Fact]
public async Task GroupBy_OrderByAggregate_Sorts()
{
await using var context = new TestDbContext(_fixture.ConnectionString);

var results = await context.TestEntities
.GroupBy(e => e.IsActive)
.Select(g => new { IsActive = g.Key, Count = g.Count() })
.OrderByDescending(x => x.Count)
.AsNoTracking()
.ToListAsync();

Assert.Equal(2, results.Count);
// Active group (6) should come first
Assert.True(results[0].IsActive);
Assert.Equal(6, results[0].Count);
// Inactive group (4) second
Assert.False(results[1].IsActive);
Assert.Equal(4, results[1].Count);
}
}
Loading