Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
================================================================================================
RTM stateless kafka-to-kafka
================================================================================================

OpenJDK 64-Bit Server VM 11.0.30+7-post-Ubuntu-1ubuntu120.04 on Linux 5.4.0-1157-aws-fips
Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
Kafka to kafka query e2e_latency in milliseconds is
p0: 53
p50: 79
p90: 88
p95: 90
p99: 96
p100: 497

Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.kafka010.benchmark

import java.nio.file.Files
import java.util.{Properties, Timer, TimerTask}
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}

import scala.concurrent.duration._

import org.apache.kafka.clients.producer.{Callback, KafkaProducer, Producer, ProducerRecord, RecordMetadata}

import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, SparkSession}
import org.apache.spark.sql.execution.streaming.RealTimeTrigger
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.kafka010.KafkaTestUtils
import org.apache.spark.sql.streaming.StreamingQueryListener

/**
* Stateless Kafka-to-Kafka RTM benchmark. Reads from an input Kafka topic, applies a
* stateless transformation, and writes results to an output Kafka topic using
* [[RealTimeTrigger]]. After the run it reports e2e latency percentiles.
*
* The benchmark spins up a real local-cluster Spark context and a live embedded Kafka
* broker, so a single run takes several minutes.
*
* To run this benchmark:
* {{{
* 1. without sbt:
* bin/spark-submit --class <this class>
* --jars <spark core test jar>,<spark sql test jar> <spark sql kafka 0-10 test jar>
* 2. build/sbt "sql-kafka-0-10/Test/runMain <this class>"
* 3. generate result:
* SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql-kafka-0-10/Test/runMain <this class>"
* Results will be written to:
* "connector/kafka-0-10-sql/benchmarks/RTMKafkaKafkaBenchmark-results.txt".
* }}}
*
* See `benchmarks/RTMKafkaKafkaBenchmark-results.txt` for a recorded run.
*/
object RTMKafkaKafkaBenchmark extends BenchmarkBase with Logging {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: justify "why not Benchmark.run()".

This is the only BenchmarkBase subclass in the repo that manually emits getJVMOSInfo() / getProcessorName() and doesn't use Benchmark.run() /
Benchmark.addCase(). The reason makes sense (you're measuring a latency distribution across a streaming pipeline, not "run this synchronous function N times"), but a one-line note in the class scaladoc would save future readers a head-scratch:

Unlike most Spark benchmarks, this one does not use Benchmark.run(): the metric of interest is end-to-end latency percentiles across a streaming pipeline, which doesn't fit the Best/Avg/Stdev table format. Environment header is emitted manually for consistency with other result files.


private val topicId = new AtomicInteger(0)
private var spark: SparkSession = _
private var testUtils: KafkaTestUtils = _

override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BenchmarkBase.main calls runBenchmarkSuite(args) and only calls afterAll() afterwards; it does not wrap runBenchmarkSuite in try/finally. This benchmark starts embedded Kafka and a local-cluster Spark session in runBenchmarkSuite, then relies on afterAll() for teardown. If benchmark(...) times out, the query fails, getLatencies throws, or setup partially fails after Kafka starts, afterAll() will not run, leaving Kafka/Spark resources behind. Since this benchmark intentionally runs heavyweight local resources, it should handle its own exception path, e.g. wrap setup/run in try/finally or call an idempotent cleanup method on failure.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure will add. Though the resources will not really be leaked as 1) Kafka is run in the same process and 2) workers will shutdown themselves down when the driver is not reachable.

// BenchmarkBase.main does not wrap this call in try/finally, so we must own
// teardown ourselves: partial setup, a timeout, or a getLatencies failure
// would otherwise leak the embedded Kafka broker and local-cluster workers.
testUtils = new KafkaTestUtils(Map.empty)
try {
testUtils.setup()
spark = SparkSession.builder()
.master("local-cluster[3, 5, 1024]")
.appName(this.getClass.getCanonicalName)
.getOrCreate()
runBenchmark("RTM stateless kafka-to-kafka") {
benchmark(60.seconds.toMillis, 4)
}
} finally {
cleanup()
}
}

/**
* Idempotent cleanup of the Spark session and embedded Kafka broker. Safe to call
* after any combination of partial setup, normal completion, or exception.
*/
private def cleanup(): Unit = {
if (spark != null) {
try {
spark.stop()
} catch {
case t: Throwable => logWarning("Failed to stop SparkSession during cleanup", t)
}
spark = null
}
if (testUtils != null) {
try {
testUtils.teardown()
} catch {
case t: Throwable => logWarning("Failed to teardown KafkaTestUtils during cleanup", t)
}
testUtils = null
}
}

private def newTopic(): String = s"topic-${topicId.getAndIncrement()}"

def benchmark(longRunningBatchDurationMs: Long, numBatches: Long): Unit = {
val inputTopic = newTopic()
testUtils.createTopic(inputTopic, partitions = 5)

val outputTopic = newTopic()
testUtils.createTopic(outputTopic, partitions = 5)

spark.conf.set(SQLConf.STREAMING_POLLING_DELAY.key, 10)

val kafkaStream = spark.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", inputTopic)
.option("kafka.fetch.max.wait.ms", "10")
.option("kafka.max.partition.fetch.bytes", "10485760") // 10MB
.load()

val currentTimestampUDF = udf(() => System.currentTimeMillis())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark's built-in current_timestamp() in a streaming context is evaluated once per batch for determinism — which is the exact opposite of what this benchmark wants (per-row wall-clock timestamp). This is a subtle correctness point: anyone seeing a UDF wrapping System.currentTimeMillis() will be tempted to "clean it up" to the built-in and silently change the semantics. Please add an inline comment, e.g.:

  // UDF instead of current_timestamp(): the built-in is evaluated once per batch
  // for streaming determinism, but we want per-row wall-clock to measure per-record latency.
  val currentTimestampUDF = udf(() => System.currentTimeMillis())


val streamWithObserved = kafkaStream
.withColumn("value", base64(col("value")))
.withColumn(
"headers",
array(
struct(
lit("source-timestamp") as "key",
unix_millis(col("timestamp")).cast("STRING").cast("BINARY") as "value")))
.withColumn("temp-timestamp", currentTimestampUDF())
.withColumn(
"latency",
col("temp-timestamp").cast("long") - unix_millis(col("timestamp")).cast("long"))
.observe(
name = "observedLatency",
avg(col("latency")).as("avg"),
max(col("latency")).as("max"),
percentile_approx(col("latency"), lit(0.99), lit(10000)).as("p99"),
percentile_approx(col("latency"), lit(0.5), lit(10000)).as("p50"))
.drop(col("latency"))
.drop(col("temp-timestamp"))
.drop(col("timestamp"))

Comment on lines +141 to +150
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The observe(...) + drop chain is effectively dead.

The observed metrics are computed and then dropped — they don't go into the result file. They do surface in the Spark UI / log output via observe, but a reader can't tell that from the code. Either remove this section (if it's not needed) or add a comment stating that observed metrics are emitted to UI/log on purpose and are not part of the recorded result file.

val query = streamWithObserved.writeStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("topic", outputTopic)
.option("checkpointLocation", Files.createTempDirectory("rtm-benchmark").toString)
Copy link
Copy Markdown
Member

@viirya viirya May 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: checkpoint directory isn't cleaned up.

The temp dir leaks across runs. Capture the path and Utils.deleteRecursively it in cleanup().

.option("kafka.buffer.memory", "67108864") // 64MB
.option("kafka.compression.type", "snappy")
.outputMode("update")
.queryName("rtm-kafka-kafka")
.trigger(RealTimeTrigger.apply(s"${longRunningBatchDurationMs} milliseconds"))
.start()

val dataGenThread = new Thread(() => {
genData(testUtils.brokerAddress, inputTopic, 1000)
})
dataGenThread.start()

val latch = new CountDownLatch(1)
val listener = new StreamingQueryListener {
override def onQueryStarted(
event: StreamingQueryListener.QueryStartedEvent): Unit = {}

override def onQueryTerminated(
event: StreamingQueryListener.QueryTerminatedEvent): Unit = {}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
if (event.progress.batchId == numBatches - 1) {
latch.countDown()
}
}
}
Comment on lines +176 to +181
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: latch logic depends on batchId numbering.

This assumes batchId starts at 0 and increases monotonically by 1. True today for a fresh query, but a counter would be more robust and more obviously correct:

  private val batchesCompleted = new AtomicLong(0)
  override def onQueryProgress(event: ...): Unit = {
    if (batchesCompleted.incrementAndGet() >= numBatches) latch.countDown()
  }

spark.streams.addListener(listener)

val timeoutMs = numBatches * longRunningBatchDurationMs * 2 + 60 * 1000
val completed = try {
latch.await(timeoutMs, TimeUnit.MILLISECONDS)
} finally {
spark.streams.removeListener(listener)
query.stop()
dataGenThread.interrupt()
dataGenThread.join(30 * 1000)
}
if (!completed) {
throw new RuntimeException(
s"Benchmark timed out waiting for $numBatches batches to complete after ${timeoutMs}ms.")
}

getLatencies(longRunningBatchDurationMs, numBatches, outputTopic)
}

private def genData(url: String, topicName: String, throughput: Long): Unit = {
Comment thread
jerrypeng marked this conversation as resolved.
logInfo(s"Producing to $url topic $topicName at $throughput records / sec")

val props: Properties = new Properties()
props.put("bootstrap.servers", url)
props.put("key.serializer", "org.apache.kafka.common.serialization.StringSerializer")
props.put("value.serializer", "org.apache.kafka.common.serialization.StringSerializer")

val producer: Producer[String, String] = new KafkaProducer[String, String](props)
Comment thread
jerrypeng marked this conversation as resolved.
val success = new AtomicLong(0)
val timer = new Timer()

try {
timer.scheduleAtFixedRate(
new TimerTask() {
override def run(): Unit = {
logInfo("Throughput: " + success.getAndSet(0) + " requests/sec")
}
},
1000,
1000
)

var i = 0L
val startTime = System.nanoTime
val delay = (Math.pow(10, 9) / throughput).asInstanceOf[Long]
var nextDeadline = startTime + delay
while (true) {
var currentTime = System.nanoTime
if (currentTime >= nextDeadline) {
i += 1
nextDeadline = startTime + (i * delay)
producer.send(
new ProducerRecord[String, String](
topicName,
java.lang.Long.toString(i),
java.lang.Long.toString(System.currentTimeMillis())
),
Copy link
Copy Markdown
Member

@viirya viirya May 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: genData writes a producer-side timestamp into the record value but it's never read.

The benchmark uses the Kafka record-level timestamp (broker-assigned) as the source timestamp instead, so this value is effectively unused. Two options:

  • Switch to using this value as source-timestamp — it's a more accurate "when did the producer hand this off" measurement and is robust to scenarios where
    input and output topics might live on different brokers in future variants of this benchmark.
  • Or, if Kafka's record timestamp is the intended source-of-truth, drop the System.currentTimeMillis() write and just send a counter, so the code doesn't
    mislead readers.

new Callback {
override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
if (e != null) {
logError("Got exception producing to kafka", e)
} else {
success.incrementAndGet()
}
}
}
)
currentTime = System.nanoTime

val sleepTimeNs =
if ((nextDeadline - currentTime) > 0) nextDeadline - currentTime
else 0
if (sleepTimeNs > 0) {
val sleepTimeMs = sleepTimeNs.nanoseconds.toMillis
val sleepTimeNano = (sleepTimeNs - sleepTimeMs.milliseconds.toNanos).toInt
Thread.sleep(sleepTimeMs, sleepTimeNano)
}
}
}
} catch {
case _: InterruptedException => // expected on shutdown
} finally {
timer.cancel()
producer.close()
}
}

private def printLatenciesTable(viewName: String, colName: String): Unit = {
val results = spark.sqlContext
.sql(s"""SELECT percentile_approx($colName, Array(0.0, 0.5, 0.9, 0.95, 0.99, 1.0), 10000)
| FROM $viewName""".stripMargin)
.collect()(0)(0)

if (results == null) {
throw new RuntimeException(
s"No results found in table $viewName when trying to print latency for $colName. " +
s"The benchmark may need more batches or a longer duration to produce enough data."
)
}

val latencies = results.asInstanceOf[scala.collection.Seq[_]]

val percentiles = Array("p0", "p50", "p90", "p95", "p99", "p100")
val latenciesTable = percentiles
.zip(latencies)
.map(pair => pair._1 + ": " + pair._2)
.mkString("\n")

// Include JVM/OS/processor info so result files are comparable across runs, matching
// the header that org.apache.spark.benchmark.Benchmark.run() emits.
val envHeader =
s"${Benchmark.getJVMOSInfo()}\n${Benchmark.getProcessorName()}\n"
val message =
envHeader + s"Kafka to kafka query ${colName} in milliseconds is\n" + latenciesTable + "\n"

output match {
case Some(out) => out.write(message.getBytes)
case None => logInfo("\n" + message)
}
}

private def getLatencies(
longRunningBatchDurationMs: Long,
numBatches: Long,
outputTopic: String): Unit = {
val kafkaSinkData = spark.read
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("subscribe", outputTopic)
.option("includeHeaders", "true")
.load()
.withColumn("headers-map", map_from_entries(col("headers")))
.withColumn("source-timestamp",
col("headers-map.source-timestamp").cast("STRING").cast("BIGINT"))
.withColumn("sink-timestamp", unix_millis(col("timestamp")))
Copy link
Copy Markdown
Member

@viirya viirya May 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the "single broker" assumption behind the latency formula.

  // source side (input topic read):
  unix_millis(col("timestamp"))  // ← input topic's broker-assigned timestamp, stored in header

  // sink side (output topic read):
  .withColumn("sink-timestamp", unix_millis(col("timestamp")))  // ← output topic's broker-assigned timestamp

source-timestamp and sink-timestamp come from two different Kafka read contexts — they're the record timestamps stamped by the broker on the input and output topics respectively. The sink - source formula only measures true latency because KafkaTestUtils runs a single embedded broker, so both timestamps share the same wall clock.

If someone later adapts this benchmark to a multi-broker setup (e.g. to measure cross-cluster shuffle latency), the formula will silently start including clock skew between brokers and there's nothing in the code to flag it. Worth a comment near the sink-timestamp definition, and/or switching to the producer-side System.currentTimeMillis() that's already being written into the record value — that one is clock-skew-immune by construction.


val numRecordsInSink = kafkaSinkData.count()
val minimumSourceTimestamp =
kafkaSinkData.agg(min("source-timestamp")).collect()(0)(0).asInstanceOf[Long]

val numBatchesToFilter = 2
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: default numBatches = 4 with numBatchesToFilter = 2 only leaves 2 effective batches

That puts the default very close to the failure threshold (filteredSink.count() == 0 → RuntimeException) under any kind of jitter. Suggest bumping the default numBatches to 8 or 10 so the common case has more headroom; the error message already tells users to increase it but it would be friendlier not to hit it by default.

val timeFilterThresholdMs = longRunningBatchDurationMs * numBatchesToFilter
val filteredSink = kafkaSinkData
.withColumn("time", col("source-timestamp") - minimumSourceTimestamp)
.filter(col("time") > timeFilterThresholdMs)

if (filteredSink.count() == 0) {
if (numRecordsInSink > 0) {
throw new RuntimeException(
s"There were ${numRecordsInSink} records in the Kafka sink topic $outputTopic, " +
s"but none remained after filtering the first ${numBatchesToFilter} batch(es) " +
s"(${timeFilterThresholdMs} ms). Run more batches (current: ${numBatches})."
)
} else {
throw new RuntimeException(
s"No results were found in the Kafka sink topic $outputTopic. " +
s"The query may not have produced results or the sink topic was incorrect."
)
}
}

val sinkWithLatencies = filteredSink
.withColumn("e2e_latency", col("sink-timestamp") - col("source-timestamp"))
sinkWithLatencies.createOrReplaceTempView("sink_with_latencies")

printLatenciesTable("sink_with_latencies", "e2e_latency")
}

private def unix_millis(column: Column): Column = {
(column.cast("timestamp").cast("double") * 1000).cast("long")
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shadows the built-in SQL unix_millis function, which is confusing — a reader naturally assumes they should just use the built-in. Suggest:

  • Rename to toUnixMillis (or similar) to avoid the name collision.
  • Add a comment stating why a custom helper is needed (e.g. import conflict, or a specific cast-path requirement) and noting the precision implication of
    double * 1000 → long for timestamps with sub-millisecond precision.

}