Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion .github/workflows/velox_backend_x86.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
name: Velox Backend (x86)

on:
workflow_dispatch:
pull_request:
paths:
- '.github/workflows/velox_backend_x86.yml'
Expand Down Expand Up @@ -1188,13 +1189,34 @@ jobs:
cd ./cpp/build && ctest -V
- name: Run CPP benchmark test
run: |
$MVN_CMD clean test -Pspark-3.5 -Pbackends-velox -pl backends-velox -am \
$MVN_CMD clean install -Pspark-3.5 -Pbackends-velox -pl backends-velox -am \
-DtagsToInclude="org.apache.gluten.tags.GenerateExample" -Dtest=none -DfailIfNoTests=false -Dexec.skip
# This test depends on files generated by the above mvn test.
./cpp/build/velox/benchmarks/generic_benchmark --with-shuffle --partitioning hash --threads 1 --iterations 1 \
--conf $(realpath backends-velox/generated-native-benchmark/conf_12_0_*.ini) \
--plan $(realpath backends-velox/generated-native-benchmark/plan_12_0_*.json) \
--data $(realpath backends-velox/generated-native-benchmark/data_12_0_*_0.parquet),$(realpath backends-velox/generated-native-benchmark/data_12_0_*_1.parquet)
- name: Run table cache lazy deserialization benchmark
timeout-minutes: 30
run: |
set -o pipefail
mkdir -p benchmarks
export MAVEN_OPTS="-Xss128m -Xmx8g -XX:ReservedCodeCacheSize=2g \
-Dspark.test.home=/opt/shims/spark35/spark_home/ \
-Dspark.gluten.benchmark.rows=5000000 \
-Dspark.gluten.benchmark.partitions=32 \
-Dspark.gluten.benchmark.iterations=3 \
-Dspark.gluten.benchmark.phases=build,read1,read4,readAll,filter"
LD_LIBRARY_PATH=$GITHUB_WORKSPACE/cpp/build/releases \
$MVN_CMD test-compile exec:java -Pspark-3.5 -Pbackends-velox -pl backends-velox \
-Dexec.classpathScope=test \
-Dexec.mainClass=org.apache.spark.sql.execution.benchmark.ColumnarTableCacheLazyDeserBenchmark \
| tee benchmarks/ColumnarTableCacheLazyDeserBenchmark-results.txt
- name: Upload table cache lazy deserialization benchmark results
uses: actions/upload-artifact@v4
with:
name: table-cache-lazy-deserialization-benchmark
path: benchmarks/ColumnarTableCacheLazyDeserBenchmark-results.txt
- name: Run UDF test
run: |
yum install -y java-17-openjdk-devel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,50 @@ object CachedColumnarBatchKryoSerializer {
val STATS_FRAMED_MAGIC: Array[Byte] =
Array[Byte](0xfe.toByte, 0xca.toByte, 0x53.toByte, 0x02.toByte)

// V3 magic: same as V2 but last byte = 0x03.
val STATS_FRAMED_MAGIC_V3: Array[Byte] =
Array[Byte](0xfe.toByte, 0xca.toByte, 0x53.toByte, 0x03.toByte)

private def magicHex(bytes: Array[Byte]): String = {
if (bytes == null || bytes.length < 4) {
"<short>"
} else {
f"0x${bytes(0) & 0xff}%02X${bytes(1) & 0xff}%02X" +
f"${bytes(2) & 0xff}%02X${bytes(3) & 0xff}%02X"
}
}

private[execution] def hasFrameMagic(bytes: Array[Byte], magic: Array[Byte]): Boolean = {
bytes != null && bytes.length >= magic.length && {
var i = 0
while (i < magic.length) {
if (bytes(i) != magic(i)) {
return false
}
i += 1
}
true
}
}

private def requireFrameMagic(bytes: Array[Byte], magic: Array[Byte], version: String): Unit = {
require(
hasFrameMagic(bytes, magic),
s"$version framed bytes magic mismatch: expected ${magicHex(magic)}, got ${magicHex(bytes)}")
}

private def framedMagicVersion(framed: Array[Byte]): Int = {
if (hasFrameMagic(framed, STATS_FRAMED_MAGIC)) {
0x02
} else if (hasFrameMagic(framed, STATS_FRAMED_MAGIC_V3)) {
0x03
} else {
throw new IllegalArgumentException(
s"framed bytes magic mismatch: expected ${magicHex(STATS_FRAMED_MAGIC)}(V2) or " +
s"${magicHex(STATS_FRAMED_MAGIC_V3)}(V3), got ${magicHex(framed)}")
}
}

// Per-column statsBlob layout (LE throughout, matches the cpp emitter in
// VeloxColumnarBatchSerializer.cc):
//
Expand Down Expand Up @@ -605,45 +649,66 @@ object CachedColumnarBatchKryoSerializer {
}

/**
* Parse the JNI `serializeWithStats` framed return into (stats InternalRow, bytesBlob).
*
* Framed layout (matches cpp VeloxColumnarBatchSerializer.cc): `[ STATS_FRAMED_MAGIC: 4B ] [
* statsLen: u32 LE ] [ statsBlob ] [ bytesLen: u32 LE ] [ bytesBlob ]`.
* Parse the JNI `serializeWithStats` framed return into (stats InternalRow, bytesBlob). Routes on
* the full 4-byte magic: V2 -> 0xFECA5302, V3 -> 0xFECA5303.
*
* Eager guards catch corrupt magic / truncated framing before they propagate.
* V2 layout: `[ magic: 4B ] [ statsLen: u32 LE ] [ statsBlob ] [ bytesLen: u32 LE ] [ bytesBlob
* ]` V3 layout: `[ magic: 4B ] [ statsLen: u32 LE ] [ statsBlob ] [ numRows: u32 LE ] [ numCols:
* u32 LE ] [ per-col ]`
*/
private[execution] def parseFramedBytes(
framed: Array[Byte],
schema: StructType): (InternalRow, Array[Byte]) = {
// V2 minimum = 4+4+4=12B; V3 minimum = 4+4+4+4=16B; use 12 for dispatcher guard.
require(
framed != null && framed.length >= 4 + 4 + 4,
framed != null && framed.length >= 12,
s"framed bytes too short: len=${if (framed == null) -1 else framed.length}")
require(
framed(0) == STATS_FRAMED_MAGIC(0) && framed(1) == STATS_FRAMED_MAGIC(1) &&
framed(2) == STATS_FRAMED_MAGIC(2) && framed(3) == STATS_FRAMED_MAGIC(3),
f"framed bytes magic mismatch: expected " +
f"0x${STATS_FRAMED_MAGIC(0) & 0xff}%02X${STATS_FRAMED_MAGIC(1) & 0xff}%02X" +
f"${STATS_FRAMED_MAGIC(2) & 0xff}%02X${STATS_FRAMED_MAGIC(3) & 0xff}%02X, got " +
f"0x${framed(0) & 0xff}%02X${framed(1) & 0xff}%02X" +
f"${framed(2) & 0xff}%02X${framed(3) & 0xff}%02X"
)
framedMagicVersion(framed) match {
case 0x02 => parseV2Frame(framed, schema)
case 0x03 => parseV3Frame(framed, schema)
}
}

/** V2 parse: extract stats + pure Presto bytesBlob. */
private def parseV2Frame(framed: Array[Byte], schema: StructType): (InternalRow, Array[Byte]) = {
requireFrameMagic(framed, STATS_FRAMED_MAGIC, "V2")
val buf = ByteBuffer.wrap(framed).order(ByteOrder.LITTLE_ENDIAN)
buf.position(4) // skip magic
val statsLen = buf.getInt
require(
statsLen >= 0 && statsLen <= buf.remaining() - 4,
s"framed bytes statsLen=$statsLen exceeds remaining buffer ${buf.remaining() - 4}")
s"V2 framed bytes statsLen=$statsLen exceeds remaining buffer ${buf.remaining() - 4}")
val statsBlob = new Array[Byte](statsLen)
buf.get(statsBlob)
val stats = deserializeStats(statsBlob, schema)
val bytesLen = buf.getInt
require(
bytesLen >= 0 && bytesLen == buf.remaining(),
s"framed bytes bytesLen=$bytesLen != remaining ${buf.remaining()} (truncated or trailing)")
s"V2 framed bytes bytesLen=$bytesLen != remaining ${buf.remaining()} (truncated or trailing)")
val bytesBlob = new Array[Byte](bytesLen)
buf.get(bytesBlob)
(stats, bytesBlob)
}

/**
* V3 parse: extract stats; bytes = the full V3 framed array (C++ deserializeV3 starts at magic).
* Invariant: returned bytes[0..3] == V3 magic; C++ deserializeV3 re-validates.
*/
private def parseV3Frame(framed: Array[Byte], schema: StructType): (InternalRow, Array[Byte]) = {
require(framed.length >= 16, s"V3 framed bytes too short (min 16B): len=${framed.length}")
requireFrameMagic(framed, STATS_FRAMED_MAGIC_V3, "V3")
val buf = ByteBuffer.wrap(framed).order(ByteOrder.LITTLE_ENDIAN)
buf.position(4) // skip magic
val statsLen = buf.getInt
require(
statsLen >= 0 && statsLen <= buf.remaining() - 8, // 8 = numRows(4)+numCols(4)
s"V3 framed bytes statsLen=$statsLen invalid")
val statsBlob = new Array[Byte](statsLen)
buf.get(statsBlob)
val stats = if (statsLen == 0) null else deserializeStats(statsBlob, schema)
// Return full framed bytes; C++ deserializeV3 will skip magic+stats and per-col.
(stats, framed)
}
}

/**
Expand Down Expand Up @@ -750,6 +815,7 @@ class ColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer
val structSchema = StructType(
schema.map(a => StructField(a.name, a.dataType, a.nullable)))
val backendName = BackendsApiManager.getBackendName
// Hoist partition-level configs: GlutenConfig.get allocates a fresh object on each call.
val partitionStatsEnabled =
GlutenConfig.get.getConf(GlutenConfig.COLUMNAR_TABLE_CACHE_PARTITION_STATS_ENABLED)
val jni = ColumnarBatchSerializerJniWrapper.create(
Expand All @@ -772,13 +838,31 @@ class ColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer
stats = null,
schema = null)
}
// Route through serializeWithStats when the partition-stats conf is enabled and the
// JNI extension is linked in libgluten.so. Capability is detected lazily at the
// call site: a new Gluten jar paired with an older native library will throw
// UnsatisfiedLinkError on the first invocation; we catch it once, cache the
// result, and fall back to the legacy serialize() path emitting stats=null. The
// buildFilter wrapper directs such batches through without pruning.
if (partitionStatsEnabled && ColumnarCachedBatchSerializer.statsExtAvailable) {
def statsOrLegacySerializeInline(): CachedBatch = {
if (partitionStatsEnabled && ColumnarCachedBatchSerializer.statsExtAvailable) {
ColumnarCachedBatchSerializer.serializeOneBatchWithStats(
jni,
handle,
batch.numRows(),
structSchema,
() => legacySerializeInline())
} else {
legacySerializeInline()
}
}
// V3 is the default cache format for Velox table cache: it stores each column
// independently so reads can materialize only requested columns. Partition stats are
// an optional V3 payload used for pruning, not a prerequisite for lazy reads.
if (ColumnarCachedBatchSerializer.statsExtV3Available) {
ColumnarCachedBatchSerializer.serializeOneBatchV3(
jni,
handle,
batch.numRows(),
structSchema,
includeStats = partitionStatsEnabled,
fallbackToV2OrLegacy = () => statsOrLegacySerializeInline())
} else if (partitionStatsEnabled && ColumnarCachedBatchSerializer.statsExtAvailable) {
// V2 stats path.
ColumnarCachedBatchSerializer.serializeOneBatchWithStats(
jni,
handle,
Expand Down Expand Up @@ -835,21 +919,37 @@ class ColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer

override def next(): ColumnarBatch = {
val cachedBatch = it.next().asInstanceOf[CachedColumnarBatch]
val batchHandle =
jniWrapper
.deserialize(deserializerHandle, cachedBatch.bytes)
val batch = ColumnarBatches.create(batchHandle)
if (shouldSelectAttributes) {
try {
ColumnarBatches.select(
BackendsApiManager.getBackendName,
batch,
requestedColumnIndices.toArray)
} finally {
batch.close()
}
// V3 bytes are ALWAYS routed to deserializeWithProjection.
// V3 framed bytes must NOT go to jni.deserialize() (expects Presto format).
if (isV3Format(cachedBatch.bytes)) {
// C++ returns the requested M-column batch; LazyVector loads those columns
// on first access instead of eagerly decoding the full cached schema.
val reqIndices: Array[Int] =
if (cacheAttributes == selectedAttributes) null // all cols: C++ loadAll
else if (requestedColumnIndices.isEmpty) Array.empty[Int] // count(*): 0 cols
else requestedColumnIndices.toArray // projection: M cols
val batchHandle = jniWrapper.deserializeWithProjection(
deserializerHandle,
cachedBatch.bytes,
reqIndices)
ColumnarBatches.create(batchHandle)
// No ColumnarBatches.select(): C++ returns M-column batch.
} else {
batch
// V2 path (original logic).
val batchHandle = jniWrapper.deserialize(deserializerHandle, cachedBatch.bytes)
val batch = ColumnarBatches.create(batchHandle)
if (shouldSelectAttributes) {
try {
ColumnarBatches.select(
BackendsApiManager.getBackendName,
batch,
requestedColumnIndices.toArray)
} finally {
batch.close()
}
} else {
batch
}
}
}
})
Expand Down Expand Up @@ -898,6 +998,12 @@ class ColumnarCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer
}
}

/** True iff bytes starts with V3 magic (0xFE 0xCA 0x53 0x03). */
private def isV3Format(bytes: Array[Byte]): Boolean =
CachedColumnarBatchKryoSerializer.hasFrameMagic(
bytes,
CachedColumnarBatchKryoSerializer.STATS_FRAMED_MAGIC_V3)

override def buildFilter(
predicates: Seq[Expression],
cachedAttributes: Seq[Attribute])
Expand Down Expand Up @@ -1029,4 +1135,67 @@ object ColumnarCachedBatchSerializer extends Logging {
)
}
}

// Visible for testing: reset the capability flag so a unit test can re-exercise the
// probe-once semantics.
private[execution] def resetStatsExtAvailableForTesting(): Unit = {
statsExtAvailableFlag = true
}

// V3 lazy deserialization support

// Separate capability latch for the V3 JNI symbols
// (framedSerializeV3 / framedSerializeWithStatsV3).
@volatile private var statsExtV3AvailableFlag: Boolean = true

def statsExtV3Available: Boolean = statsExtV3AvailableFlag

def markStatsExtV3Unavailable(cause: Throwable): Unit = {
if (statsExtV3AvailableFlag) {
statsExtV3AvailableFlag = false
logWarning(
"serializeWithStatsV3 JNI returned null (backend not supported or symbol missing); " +
"disabling V3 per-column lazy deserialization for this JVM. " +
"This typically indicates a Gluten jar / native library version mismatch.",
cause
)
}
}

// V3 per-batch serialization: identical two-arm catch structure to serializeOneBatchWithStats.
// null return from JNI = non-Velox backend; treated as one-shot latch, not corrupt frame.
private[execution] def serializeOneBatchV3(
jni: ColumnarBatchSerializerJniWrapper,
handle: Long,
numRows: Int,
structSchema: StructType,
includeStats: Boolean,
fallbackToV2OrLegacy: () => CachedBatch): CachedBatch = {
try {
val framed =
if (includeStats) jni.serializeWithStatsV3(handle)
else jni.serializeV3(handle)
if (framed == null) {
// Non-Velox backend returns null; set latch and fall back.
markStatsExtV3Unavailable(
new RuntimeException("framedSerializeV3 returned null (backend not supported)"))
return fallbackToV2OrLegacy()
}
val (stats, _) = CachedColumnarBatchKryoSerializer.parseFramedBytes(framed, structSchema)
// bytes = full V3 frame (C++ deserializeV3 parses from byte 0 including magic).
CachedColumnarBatch(
numRows,
framed.length,
framed,
stats,
schema = if (stats == null) null else structSchema)
} catch {
case e: UnsatisfiedLinkError =>
markStatsExtV3Unavailable(e)
fallbackToV2OrLegacy()
case NonFatal(e) =>
warnCorruptStatsFrame(e) // count against shared corrupt-frame cap
fallbackToV2OrLegacy()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -509,4 +509,35 @@ class ColumnarCachedBatchE2ESuite
}
}
}

// V3 lazy deserialization smoke tests

test("V3 default: cache + equality filter produces correct result") {
val cached = cacheRange()
try {
cached.count()
val result = cached.filter(col("k") === pivot).count()
assert(result == 1L, s"V3: expected 1 row matching k=$pivot, got $result")
} finally {
cached.unpersist()
}
}

test("V3 default: multi-column cache, partial projection, no crash") {
val cached = spark
.range(N)
.selectExpr(
"cast(id as bigint) as a",
"cast(id*2 as bigint) as b",
"cast(id+1 as bigint) as c")
.repartitionByRange(P, col("a"))
.cache()
try {
cached.count()
val result = cached.filter(col("a") === pivot).select("a", "c").count()
assert(result == 1L, s"V3 projection: expected 1 row, got $result")
} finally {
cached.unpersist()
}
}
}
Loading
Loading