Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,20 @@ import org.apache.spark.internal.LogKeys
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Deduplicate, DeduplicateWithinWatermark, Distinct, FlatMapGroupsInPandasWithState, FlatMapGroupsWithState, GlobalLimit, Join, LeafNode, LocalRelation, LogicalPlan, Project, StreamSourceAwareLogicalPlan, TransformWithState, TransformWithStateInPySpark}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Deduplicate, DeduplicateWithinWatermark, Distinct, Filter, FlatMapGroupsInPandasWithState, FlatMapGroupsWithState, GlobalLimit, Join, LeafNode, LocalRelation, LogicalPlan, Project, StreamSourceAwareLogicalPlan, TransformWithState, TransformWithStateInPySpark}
import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, Unassigned, WriteToStream}
import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.classic.{Dataset, SparkSession}
import org.apache.spark.sql.classic.ClassicConversions.castToImpl
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability, TransactionalCatalogPlugin}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsRealTimeMode, SupportsTriggerAvailableNow}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, RealTimeStreamScanExec, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, PushDownUtils, RealTimeStreamScanExec, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, Offset, OneTimeTrigger, ProcessingTimeTrigger, RealTimeModeAllowlist, RealTimeTrigger, Sink, Source, StreamingQueryPlanTraverseHelper}
import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, CommitMetadata, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataV2}
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, StateStoreWriter}
Expand All @@ -64,7 +65,8 @@ class MicroBatchExecution(
plan: WriteToStream)
extends StreamExecution(
sparkSession, plan.name, plan.resolvedCheckpointLocation, plan.inputQuery, plan.sink, trigger,
triggerClock, plan.outputMode, plan.deleteCheckpointOnStop) with AsyncLogPurge {
triggerClock, plan.outputMode, plan.deleteCheckpointOnStop) with AsyncLogPurge
with PredicateHelper{

/**
* Keeps track of the latest execution context
Expand Down Expand Up @@ -211,6 +213,65 @@ class MicroBatchExecution(
source, output, dataSourceV1.catalogTable, sourceIdentifyingName)(sparkSession)
})

case f @ Filter(condition,
s @ StreamingRelationV2(_, _, table,
options, output, catalog, identifier, _, _)) =>
val scanBuilder = table.asReadable.newScanBuilder(options)
val filters = splitConjunctivePredicates(condition)
val normalizedFilters =
DataSourceStrategy.normalizeExprs(filters, s.output)
val partitionPredicateFields = PushDownUtils.getPartitionPredicateSchema(table, s.output)
val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
normalizedFilters.partition(SubqueryExpression.hasSubquery)
val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters(
scanBuilder, normalizedFiltersWithoutSubquery, partitionPredicateFields)
var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate]
val pushedFiltersStr = if (pushedFilters.isLeft) {
pushedFilters.swap
.getOrElse(throw new NoSuchElementException("The left node doesn't have pushedFilters"))
.mkString(", ")
} else {
pushedPredicates = pushedFilters
.getOrElse(throw new NoSuchElementException(
"The right node doesn't have pushedFilters"))
pushedPredicates.mkString(", ")
}

val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery

logInfo(
log"""
|Pushing operators to ${MDC(RELATION_NAME, s.table.name())}
|Pushed Filters: ${MDC(PUSHED_FILTERS, pushedFiltersStr)}
|Post-Scan Filters: ${MDC(POST_SCAN_FILTERS, postScanFilters.mkString(","))}
""".stripMargin)

val filterCondition = postScanFilters.reduceLeftOption(And)
filterCondition.map(Filter(_, s)).getOrElse(s)

val scan = scanBuilder.build()
val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
nextSourceId.incrementAndGet()

val stream = scan.toMicroBatchStream(metadataPath)

val relation = StreamingDataSourceV2Relation(
table, output, catalog, identifier, options, metadataPath,
trigger match {
case RealTimeTrigger(duration) => Some(duration)
case _ => None
}
)

val scanRelation = StreamingDataSourceV2ScanRelation(relation, scan, output, stream)

// 5. wrap with post filters
if (filterCondition.nonEmpty) {
Filter(filterCondition.get, scanRelation)
} else {
scanRelation
}

case s @ StreamingRelationV2(src, srcName, table: SupportsRead, options, output,
catalog, identifier, v1, sourceIdentifyingName) =>
val dsStr = if (src.nonEmpty) s"[${src.get}]" else ""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.{streaming, InputPartition, PartitionReaderFactory, Scan, ScanBuilder, SupportsPushDownFilters}
import org.apache.spark.sql.connector.read.streaming.MicroBatchStream
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2ScanRelation
import org.apache.spark.sql.execution.streaming.runtime.LongOffset
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils

class PushDownPredicateInMicroBatchExecutionSuite extends StreamTest with SharedSparkSession {

test("Filter pushdown should propagate to MicroBatchStream in DSv2") {
val checkpointLoc = Utils.createTempDir(namePrefix = "streaming-checkpoint").getCanonicalPath
val providerClass = classOf[FakeStreamingProvider].getName

spark.conf.set(s"spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
spark.sql(s"""
CREATE TABLE testcat.default.streaming_test_table (id BIGINT, part STRING)
USING $providerClass
PARTITIONED BY (part)
""")

val df = spark.readStream
.format(s"$providerClass")
.load()
.filter("part = 'active'")

testStream(df)(
StartStream(checkpointLocation = checkpointLoc),
// Force a trigger and wait for it to complete
Execute { q => q.processAllAvailable() },
Execute { query =>
// If lastExecution is still null here, the engine hasn't reached the
// planning phase of the first batch yet.
assert(query.lastExecution != null,
"Execution should be initialized after processAllAvailable")

val logicalPlan = query.lastExecution.logical

// inspect the plan to confirm filter pushdown
val scanRelations = logicalPlan.collect {
case s: StreamingDataSourceV2ScanRelation => s
}

assert(scanRelations.nonEmpty,
"Logical plan should contain a StreamingDataSourceV2ScanRelation")

val scan = scanRelations.head.scan
val pushed = scan.asInstanceOf[FakeScan].pushedFilters
assert(pushed.exists(_.toString.contains("part")),
"The 'part' filter was missing from pushed predicates")
},
StopStream
)
}
}

/**
* Fake datasource
*/
class FakeStreamingProvider extends TableProvider {
override def getTable(
schema: StructType,
partitioning: Array[Transform],
properties: java.util.Map[String, String]): Table = {
new FakeTable()
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
StructType(Seq(
StructField("id", org.apache.spark.sql.types.LongType),
StructField("part", StringType)
))
}
}

class FakeTable extends org.apache.spark.sql.connector.catalog.SupportsRead {
override def name(): String = "fake_table"
override def capabilities():
java.util.Set[org.apache.spark.sql.connector.catalog.TableCapability] = {
java.util.EnumSet.of(org.apache.spark.sql.connector.catalog.TableCapability.MICRO_BATCH_READ)
}
override def newScanBuilder(options: CaseInsensitiveStringMap):
org.apache.spark.sql.connector.read.ScanBuilder = {
new FakeScanBuilder()
}
}

class FakeScan(val pushedFilters: Array[Filter]) extends Scan {
override def readSchema(): StructType = null
override def toMicroBatchStream(path: String): MicroBatchStream = new FakeMicroBatchStream()
}

class FakeScanBuilder extends ScanBuilder with SupportsPushDownFilters {
private var pushed: Array[Filter] = Array.empty
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
pushed = filters
Array.empty // all filters are pushed
}
override def pushedFilters(): Array[Filter] = pushed
override def build(): Scan = new FakeScan(pushed)
}

class FakeMicroBatchStream extends MicroBatchStream {
override def initialOffset(): Offset = LongOffset(0)
override def deserializeOffset(json: String): Offset = LongOffset(json.toLong)
override def latestOffset(): Offset = {
LongOffset(1)
}
override def createReaderFactory(): PartitionReaderFactory = null
override def stop(): Unit = {}
override def planInputPartitions(start: streaming.Offset,
end: streaming.Offset): Array[InputPartition] = Array.empty
override def commit(end: streaming.Offset): Unit = {}
}