Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions src/Microsoft.Data.Analysis/GroupBy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,59 @@

namespace Microsoft.Data.Analysis
{
/// <summary>
/// A record to identify the row that is being aggregated that can be used to decide whether or not to include it in the aggregation.
/// </summary>
public record GroupByPredicateInput
{
/// <summary>
/// The name of the column that is being aggregated
/// </summary>
public string ColumnName { get; set; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor — Design] Two small things on this record:

  1. Mutable set accessors on a record. Records are designed for value semantics and immutability. Using { get; init; } instead of { get; set; } would prevent accidental mutation after construction and make storing instances in collections (like the HashSet in CountDistinct) safe by design.

  2. object-typed properties cause boxing. Every value-type cell (Int32, Double, DateTime, etc.) gets boxed into object GroupKey / object RowValue. If perf matters, consider a generic variant like GroupByPredicateInput<TKey, TValue> — though that adds API complexity, so object may be an acceptable trade-off if documented.


/// <summary>
/// The value from the GroupBy column that this group is grouped on
/// </summary>
public object GroupKey { get; set; }

/// <summary>
/// The value of this row within the column that is being aggregated
/// </summary>
public object RowValue { get; set; }
}

/// <summary>
/// A GroupBy class that is typically the result of a DataFrame.GroupBy call.
/// It holds information to perform typical aggregation ops on it.
/// </summary>
public abstract class GroupBy
{
/// <summary>
/// Compute the number of non-null values in each group
/// Compute the number of non-null values in each group
/// </summary>
/// <param name="columnNames">The columns within which to compute the number of non-null values in each group. A default value includes all columns.</param>
/// <returns></returns>
public abstract DataFrame Count(params string[] columnNames);

/// <summary>
/// Compute the number of values in each group that match a custom predicate
/// </summary>
/// <param name="predicate">A function that takes in the column name, group key, and row value and returns true to include that row in the group count or false to exclude it.</param>
/// <param name="columnNames">The columns within which to compute the number of values in each group that match the predicate. A default value includes all columns.</param>
/// <returns></returns>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor — Docs] <returns></returns> is empty here and on CountDistinct. Consider filling these in, e.g.:

/// <returns>A <see cref="DataFrame"/> with one row per group and one column per aggregated column, containing the matching counts.</returns>

public abstract DataFrame CountIf(Func<GroupByPredicateInput, bool> predicate, params string[] columnNames);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Major — Breaking change] Adding abstract methods to a public non-sealed class is a binary-breaking change. Any downstream code that subclasses GroupBy will fail to compile (or fail at load time for pre-compiled assemblies) until it implements CountIf and CountDistinct.

Options to avoid the break:

  1. Make them virtual with a default implementation (e.g. throw new NotSupportedException()).
  2. Add them only on GroupBy<TKey> (which is the concrete class) rather than the abstract base.
  3. Seal GroupBy if external subclassing is not intended (coordinate with maintainers).


/// <summary>
/// Compute the number of distinct non-null values in each group
/// </summary>
/// <param name="columnNames">The columns within which to compute the number of distinct non-null values in each group. A default value includes all columns.</param>
/// <returns></returns>
public abstract DataFrame CountDistinct(params string[] columnNames);

/// <summary>
/// Return the first value in each group
/// </summary>
/// <param name="columnNames">Names of the columns to aggregate</param>
/// <returns></returns>
public abstract DataFrame First(params string[] columnNames);

Expand Down Expand Up @@ -140,6 +178,11 @@ private void EnumerateColumnsWithRows(GroupByColumnDelegate groupByColumnDelegat
}

public override DataFrame Count(params string[] columnNames)
{
return CountIf(input => input.RowValue != null, columnNames);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Major — Performance regression] Count() was previously allocation-free — a tight if (column[row] != null) count++ loop. By delegating through CountIf, every row now:

  1. Boxes the cell value into object RowValue (Int32, Double, DateTime, etc.).
  2. Allocates a GroupByPredicateInput per (group, column) pair.
  3. Invokes a delegate per row.

Count() is the most commonly-used GroupBy aggregation; on a 1M-row x 10-column DataFrame this adds millions of boxing allocations.

Suggested fix: Keep the original inlined null-check implementation for Count(). CountIf can coexist as a separate method without Count needing to delegate to it.

}

public override DataFrame CountIf(Func<GroupByPredicateInput, bool> predicate, params string[] columnNames)
{
DataFrame ret = new DataFrame();
PrimitiveDataFrameColumn<long> empty = new PrimitiveDataFrameColumn<long>("Empty");
Expand All @@ -156,10 +199,19 @@ public override DataFrame Count(params string[] columnNames)
return;
DataFrameColumn column = _dataFrame.Columns[columnIndex];
long count = 0;
var groupByPredicateInput = new GroupByPredicateInput
{
ColumnName = column.Name,
GroupKey = firstColumn[rowIndex]
};
foreach (long row in rowEnumerable)
{
if (column[row] != null)
groupByPredicateInput.RowValue = column[row];

if (predicate(groupByPredicateInput))
{
count++;
}
}
DataFrameColumn retColumn;
if (firstGroup)
Expand All @@ -182,6 +234,26 @@ public override DataFrame Count(params string[] columnNames)
return ret;
}

public override DataFrame CountDistinct(params string[] columnNames)
{
HashSet<GroupByPredicateInput> seenValues = [];
Copy link
Member

@rokonec rokonec Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Critical — Correctness] HashSet<GroupByPredicateInput> with a mutable, reused object produces wrong results.

CountDistinct captures a single HashSet<GroupByPredicateInput> in its lambda. Inside CountIf, a single GroupByPredicateInput instance is created per (group, column) and then reused across all rows in the inner loop:

// line ~202: created once per (group, column)
var groupByPredicateInput = new GroupByPredicateInput { ColumnName = column.Name, GroupKey = firstColumn[rowIndex] };

foreach (long row in rowEnumerable)
{
    groupByPredicateInput.RowValue = column[row];  // ← line ~209: mutated each iteration
    if (predicate(groupByPredicateInput))           // CountDistinct does .Contains() then .Add()
        count++;
}

CountDistinct's predicate calls seenValues.Add(input), storing the same reference that line 209 mutates on the next iteration. This violates the HashSet contract: after insertion, the object's hash code changes, so it lands in the wrong bucket.

How it breaks concretely: When the HashSet resizes (e.g. at the 3rd distinct value, capacity jumps from 3→7), it rehashes all entries. Since they're all the same reference with the same current RowValue, they all collapse into one bucket. After that, any previously-seen value that hashes to a different bucket won't be found — so it gets double-counted as "distinct."

The test data [1, 1, 2, 3, 4] for Group 2 avoids this because the duplicate 1 appears at index 1 (before the resize at index 2). If the data were [1, 2, 3, 4, 1] instead, the final 1 would be miscounted as distinct (returning 5 instead of 4).

Additionally, the HashSet is never cleared between groups/columns, so it accumulates stale entries — a memory leak proportional to total rows × columns.

Suggested fix: Use a HashSet<object> keyed on just the row value, scoped per (group, column). The simplest approach is to implement CountDistinct directly (like the original Count) rather than routing through CountIf:

public override DataFrame CountDistinct(params string[] columnNames)
{
    // Same structure as Count, but with a per-(group, column) HashSet<object>
    // to track distinct non-null values.
}


return CountIf(
input =>
{
if (input.RowValue == null || seenValues.Contains(input))
{
return false;
}

seenValues.Add(input);

return true;
},
columnNames
);
}

public override DataFrame First(params string[] columnNames)
{
DataFrame ret = new DataFrame();
Expand Down
44 changes: 44 additions & 0 deletions test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,50 @@ public void TestGroupBy()
Assert.Equal(2, firstDecimalColumn.Rows.Count);
Assert.Equal((decimal)0, firstDecimalColumn.Columns["Decimal"][0]);
Assert.Equal((decimal)1, firstDecimalColumn.Columns["Decimal"][1]);

var dfWithDuplicates = new DataFrame(
new Int32DataFrameColumn("Group", [1, 1, 1, 1, 1, 2, 2, 2, 2, 2]),
new Int32DataFrameColumn("Int", [1, 2, 3, 4, null, 1, 1, 2, 3, 4]),
new DoubleDataFrameColumn("Double", [1, 2, 3, 4, null, 1, 1, 2, 3, 4]),
new StringDataFrameColumn("String", ["1", "2", "3", "4", null, "1", "1", "2", "3", "4"]),
new DateTimeDataFrameColumn("DateTime", [
new DateTime(2026, 1, 1, 0, 0, 0),
new DateTime(2026, 1, 1, 0, 0, 1),
new DateTime(2026, 1, 1, 0, 0, 2),
new DateTime(2026, 1, 1, 0, 0, 3),
null,
new DateTime(2026, 1, 1, 0, 0, 0),
new DateTime(2026, 1, 1, 0, 0, 0),
new DateTime(2026, 1, 1, 0, 0, 1),
new DateTime(2026, 1, 1, 0, 0, 2),
new DateTime(2026, 1, 1, 0, 0, 3)
])
);

DataFrame countDistinct = dfWithDuplicates.GroupBy("Group").CountDistinct();
Assert.Equal(5, countDistinct.Columns.Count);
Assert.Equal(2, countDistinct.Rows.Count);

foreach (var columnName in countDistinct.Columns.Select(c => c.Name))
{
if (columnName == "Group")
{
continue;
}

var column = (PrimitiveDataFrameColumn<long>)countDistinct[columnName];

for (int row = 0; row < countDistinct.Rows.Count; row++)
{
Assert.Equal(4, column[row]);
}
}

DataFrame countIf = dfWithDuplicates.GroupBy("Group").CountIf((GroupByPredicateInput input) => input.RowValue is int and < 3, "Int");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor — Test coverage] CountIf is only tested with a single column ("Int"). Consider adding:

  • A call with no column names (default all-columns path) to verify it works across all column types.
  • An always-true predicate — result should match Count() output.
  • An always-false predicate — all counts should be 0.
  • Edge cases: empty groups, all-null columns, single-row groups.

Assert.Equal(2, countIf.Columns.Count);
Assert.Equal(2, countIf.Rows.Count);
Assert.Equal(2L, countIf["Int"][0]);
Assert.Equal(3L, countIf["Int"][1]);
}

[Fact]
Expand Down
Loading