Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ public void setCurrentProjection(
SchemaGetter schemaGetter,
ArrowCompressionInfo compressionInfo,
int[] selectedFieldPositions) {
// Empty projection is currently not supported on the server side: the Arrow metadata length
// estimate disagrees with the serialized length for a zero-field schema, which would
// cause the client to retry indefinitely. Fail fast with a clear, non-retriable error
// instead.
if (selectedFieldPositions != null && selectedFieldPositions.length == 0) {
throw new InvalidColumnProjectionException(
"Empty projection is not supported. "
+ "Please project at least one column. "
+ "For aggregations like COUNT(*) / COUNT(1), "
+ "the connector should push down a projection that selects "
+ "at least one column.");
}
this.tableId = tableId;
this.schemaGetter = schemaGetter;
this.compressionInfo = compressionInfo;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,40 @@ void testIllegalSetCurrentProjection() throws Exception {
"The projection indexes should not contain duplicated fields, but is [0, 0, 0]");
}

/**
* When the client pushes down an empty column projection (selectedFieldPositions = []), e.g.
* for {@code COUNT(*)} / {@code COUNT(1)} aggregations, the server should reject it eagerly
* with a clear {@link InvalidColumnProjectionException} instead of failing with an internal
* {@code IllegalStateException("Invalid metadata length")} that callers would retry
* indefinitely.
*/
@Test
void testEmptyProjectionRejectsWithClearError() throws Exception {
long tableId = 1L;
short schemaId = (short) 2;
FileLogRecords recordsOfData2RowType =
createFileLogRecords(
schemaId,
LOG_MAGIC_VALUE_V1,
TestData.DATA2_ROW_TYPE,
TestData.DATA2,
TestData.DATA2);
FileLogProjection projection = new FileLogProjection(new ProjectionPushdownCache());

// Empty projection - emulates Spark COUNT(*)/COUNT(1) optimisation.
assertThatThrownBy(
() ->
doProjection(
tableId,
schemaId,
projection,
recordsOfData2RowType,
new int[] {},
recordsOfData2RowType.sizeInBytes()))
.isInstanceOf(InvalidColumnProjectionException.class)
.hasMessageContaining("Empty projection is not supported");
}

@Test
void testProjectionOldDataWithNewSchema() throws Exception {
// Currently, we only support add column at last.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@ abstract class FlussBatch(

protected def projection: Array[Int] = {
val columnNameToIndex = tableInfo.getSchema.getColumnNames.asScala.zipWithIndex.toMap
readSchema.fields.map {
val projected = readSchema.fields.map {
field =>
columnNameToIndex.getOrElse(
field.name,
throw new IllegalArgumentException(s"Invalid field name: ${field.name}"))
}
// The Fluss server does not currently support empty Arrow projections. Fall back to projecting
// the first column so the row count is preserved while still avoiding fetching unnecessary columns.
if (projected.isEmpty) Array(0) else projected
}

protected def createUpsertPartitions(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.fluss.spark.read

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory}

/**
* A trivial single-row [[Batch]] that emits one [[InternalRow]] containing only the precomputed row
* count. Used as the fast path when Spark pushes down a `COUNT(*)` / `COUNT(1)` aggregate to a
* Fluss primary key (non-lake) table and the precomputed table statistics are available on the
* server side.
*/
class FlussCountBatch(rowCount: Long) extends Batch {

override def planInputPartitions(): Array[InputPartition] =
Array(FlussCountInputPartition(rowCount))

override def createReaderFactory(): PartitionReaderFactory = FlussCountPartitionReaderFactory
}

case class FlussCountInputPartition(rowCount: Long) extends InputPartition

object FlussCountPartitionReaderFactory extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] =
partition match {
case FlussCountInputPartition(rowCount) => new FlussCountPartitionReader(rowCount)
case other =>
throw new IllegalArgumentException(
s"Unexpected partition type for FlussCountBatch: ${other.getClass.getName}")
}
}

class FlussCountPartitionReader(rowCount: Long) extends PartitionReader[InternalRow] {
private var consumed: Boolean = false

override def next(): Boolean = {
if (consumed) {
false
} else {
consumed = true
true
}
}

override def get(): InternalRow = InternalRow(rowCount)

override def close(): Unit = ()
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,15 @@ abstract class FlussMicroBatchStream(

protected def projection: Array[Int] = {
val columnNameToIndex = tableInfo.getSchema.getColumnNames.asScala.zipWithIndex.toMap
readSchema.fields.map {
val projected = readSchema.fields.map {
field =>
columnNameToIndex.getOrElse(
field.name,
throw new IllegalArgumentException(s"Invalid field name: ${field.name}"))
}
// See FlussBatch.projection: empty projection from Spark (e.g. COUNT(*)) is not
// supported by the Fluss server, so fall back to projecting the first column.
if (projected.isEmpty) Array(0) else projected
}

override def close(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,29 @@ case class FlussUpsertScan(
}
}

/** Fluss Count Scan: fast path that returns a precomputed row count from server-side table stats. */
case class FlussCountScan(
tablePath: TablePath,
tableInfo: TableInfo,
rowCount: Long,
requiredSchema: Option[StructType])
extends FlussScan {

override protected val scanType: String = "Count"

// A single LONG column matching the one-aggregate-no-group-by shape Spark expects.
override def readSchema(): StructType =
new StructType().add("count", org.apache.spark.sql.types.LongType, nullable = false)

override def toBatch: Batch = new FlussCountBatch(rowCount)

override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream =
throw new UnsupportedOperationException(
"FlussCountScan does not support micro-batch streaming reads.")

override def description(): String = s"${super.description()} [RowCount: $rowCount]"
}

/** Fluss Lake Upsert Scan for lake-enabled primary key tables. */
case class FlussLakeUpsertScan(
tablePath: TablePath,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

package org.apache.fluss.spark.read

import org.apache.fluss.client.ConnectionFactory
import org.apache.fluss.config.{Configuration => FlussConfiguration}
import org.apache.fluss.exception.UnsupportedVersionException
import org.apache.fluss.metadata.{LogFormat, TableInfo, TablePath}
import org.apache.fluss.predicate.{Predicate => FlussPredicate}
import org.apache.fluss.spark.read.lake.{FlussLakeBatch, FlussLakeUtils}
import org.apache.fluss.spark.utils.{SparkPartitionPredicate, SparkPredicateConverter}

import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand Down Expand Up @@ -116,6 +120,104 @@ trait FlussLakeSupportsPushDownV2Filters extends FlussSupportsPushDownV2Filters
}
}

/**
* Aggregation pushdown mixin: converts a single `COUNT(*)` / `COUNT(1)` / `COUNT(non_null_col)`
* without `GROUP BY` into a server-side row count. Only safe for PK (non-lake) tables, because KV
* upsert writes make the server-side row count match a full scan; log / lake-tiered tables can
* drift. Mirrors the Flink-side gating in `FlinkTableSource#applyAggregates`.
*/
trait FlussSupportsPushDownAggregates
extends FlussSupportsPushDownPartitionFilters
with SupportsPushDownAggregates {

def tablePath: TablePath

def flussConfig: FlussConfiguration

protected var pushedRowCount: Option[Long] = None

override def pushAggregation(aggregation: Aggregation): Boolean = {
if (!isPushable(aggregation)) {
return false
}
fetchRowCount() match {
case Some(count) =>
pushedRowCount = Some(count)
true
case None =>
false
}
}

override def supportCompletePushDown(aggregation: Aggregation): Boolean = isPushable(aggregation)

private def isPushable(aggregation: Aggregation): Boolean = {
if (aggregation.groupByExpressions().nonEmpty) {
return false
}
val aggs = aggregation.aggregateExpressions()
if (aggs.length != 1) {
return false
}
aggs.head match {
case _: CountStar => true
case c: Count if !c.isDistinct => isCountOnNonNullableColumn(c)
case _ => false
}
}

private def isCountOnNonNullableColumn(count: Count): Boolean = {
count.column() match {
case ref: NamedReference =>
val parts = ref.fieldNames()
// Only top-level column references are pushable.
if (parts.length != 1) {
return false
}
val colName = parts(0)
val rowType = tableInfo.getRowType
if (rowType.getFieldIndex(colName) < 0) {
return false
}
!rowType.getField(colName).getType.isNullable
case _ => false
}
}

// Probes server-side table stats; returns None on UnsupportedVersionException so Spark falls back.
private def fetchRowCount(): Option[Long] = {
val conn = ConnectionFactory.createConnection(flussConfig)
try {
val admin = conn.getAdmin
try {
Some(admin.getTableStats(tablePath).get().getRowCount)
} catch {
case e: Throwable =>
if (isUnsupportedVersion(e)) {
None
} else {
throw e
}
} finally {
admin.close()
}
} finally {
conn.close()
}
}

private def isUnsupportedVersion(t: Throwable): Boolean = {
var cur: Throwable = t
while (cur != null) {
if (cur.isInstanceOf[UnsupportedVersionException]) {
return true
}
cur = cur.getCause
}
false
}
}

/** Fluss Append Scan Builder. */
class FlussAppendScanBuilder(
tablePath: TablePath,
Expand Down Expand Up @@ -159,14 +261,24 @@ class FlussLakeAppendScanBuilder(

/** Fluss Upsert Scan Builder. */
class FlussUpsertScanBuilder(
tablePath: TablePath,
val tablePath: TablePath,
val tableInfo: TableInfo,
options: CaseInsensitiveStringMap,
val flussConfig: FlussConfiguration)
extends FlussSupportsPushDownPartitionFilters {
extends FlussSupportsPushDownAggregates {

override def build(): Scan = {
FlussUpsertScan(tablePath, tableInfo, requiredSchema, partitionPredicate, options, flussConfig)
pushedRowCount match {
case Some(count) => FlussCountScan(tablePath, tableInfo, count, requiredSchema)
case None =>
FlussUpsertScan(
tablePath,
tableInfo,
requiredSchema,
partitionPredicate,
options,
flussConfig)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,30 @@ class SparkLogTableReadTest extends FlussSparkTestBase {
Row(1000L, 25L) ::
Row(1100L, 260L) :: Nil
)

// Insert one row with a NULL nullable column to make the COUNT(nullable_col)
// assertion below actually discriminate between row count and non-null count.
sql(s"""
|INSERT INTO $DEFAULT_DATABASE.t VALUES
|(1200L, 270L, 607, NULL)
|""".stripMargin)

// COUNT(*) / COUNT(1) — empty projection case (regression test for #2724)
checkAnswer(sql(s"SELECT COUNT(*) FROM $DEFAULT_DATABASE.t"), Row(9L) :: Nil)
checkAnswer(sql(s"SELECT COUNT(1) FROM $DEFAULT_DATABASE.t"), Row(9L) :: Nil)

// COUNT(*) with filter — verifies the non-empty-projection path is unaffected
checkAnswer(
sql(s"SELECT COUNT(*) FROM $DEFAULT_DATABASE.t WHERE orderId >= 900"),
Row(5L) :: Nil
)

// COUNT on a nullable column — address is nullable STRING and one row has NULL,
// so COUNT(address) must be strictly less than COUNT(*).
checkAnswer(
sql(s"SELECT COUNT(address) FROM $DEFAULT_DATABASE.t"),
Row(8L) :: Nil
)
}
}

Expand Down Expand Up @@ -130,6 +154,15 @@ class SparkLogTableReadTest extends FlussSparkTestBase {
Row(800L, "addr3", "2026-01-02") ::
Row(900L, "addr4", "2026-01-02") :: Nil
)

// COUNT(*) on partitioned log table
checkAnswer(sql(s"SELECT COUNT(*) FROM $DEFAULT_DATABASE.t"), Row(5L) :: Nil)

// COUNT(*) with partition filter
checkAnswer(
sql(s"SELECT COUNT(*) FROM $DEFAULT_DATABASE.t WHERE dt = '2026-01-01'"),
Row(2L) :: Nil
)
}
}

Expand Down
Loading