Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
44c0598
refactor SQLTestUtils: set active db via sql.Catalog
Apr 15, 2026
ec178a1
refactor SQLTestUtils: use sql.Catalog to look up udf
Apr 9, 2026
95f5117
refactor ResolverGuardSuite: use equivalent, but sql-visible method
Apr 13, 2026
5321c1a
refactor sql.DatasetHolder: make concrete
Apr 9, 2026
704423d
Add placeholder ClassicSQLTestUtils
Apr 15, 2026
e50e586
Add placeholder SharedClassicSparkSession
Apr 13, 2026
22d6402
TODO: classic-only impl for SQLTestUtils::waitForTasksToFinish
Apr 15, 2026
b78d6f1
TODO: classic-only impl for SQLTestUtils::makeQualifiedPath
Apr 15, 2026
519482c
TODO: Add non-classic uncache table path
Apr 13, 2026
f8b4420
TODO: add cast to SharedClassicSparkSession get classic session from …
Apr 15, 2026
0c65acc
Add AlterTableDirHelper
Apr 15, 2026
2eea2f7
use SharedClassicSparkSession, ClassicSQLTestUtils where needed
Apr 13, 2026
71c8582
SQLTestUtils: move classic-only helpers to ClassicSQLTestUtils
Apr 15, 2026
503f6e3
SQLTestUtils: less classic in testImplicits
Apr 15, 2026
a2ad429
SharedSparkSession: move classic-only helper to SharedClassicSparkSes…
Apr 15, 2026
996cfe1
SharedSparkSession: provide sql.SparkSession
Apr 13, 2026
f164493
fixup
fwc Apr 16, 2026
a2ceaeb
SQLTestUtils: remove classic stuff from testImplicits
fwc Apr 16, 2026
ac8a83d
Add and use classic.SparkSessionProvider
fwc Apr 16, 2026
45a1a56
Revert "SQLTestUtils: remove classic stuff from testImplicits"
Apr 17, 2026
051cf6a
fixup classic.SparkSessionProvider: extend sql.SparkSessionProvider
fwc Apr 18, 2026
201a267
fix import ordering
fwc Apr 18, 2026
dd52430
SQLTestImplicits: fix imports
fwc Apr 18, 2026
283b301
fixup: Use SharedClassicSparkSession
fwc Apr 18, 2026
dfe4c95
remove unused 'import testImplicits'
fwc Apr 18, 2026
722ad67
Add missing casts
Apr 20, 2026
17ba6e7
some more fixes
May 19, 2026
fcf2e40
fixup: restore upstream changes accidentally reverted by --theirs con…
May 19, 2026
df7a613
Rename ClassicSQLTestUtils to ClassicQueryTest
May 20, 2026
7914ed9
Replace ClassicSQLTestUtils usage with ClassicQueryTest
May 20, 2026
011b1a5
fixes from LLM review
May 20, 2026
54d542d
Fix another compile error
May 21, 2026
a8df7ae
Move ClassicQueryTest to test.classic.QueryTest
May 21, 2026
f376df9
Replace ClassicQueryTest usage with test.classic.QueryTest
May 21, 2026
8bf1800
one more fix
May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ import org.apache.kafka.clients.producer.ProducerRecord
import org.apache.kafka.common.TopicPartition

import org.apache.spark.{SparkConf, TestUtils}
import org.apache.spark.sql.DataFrameReader
import org.apache.spark.sql.{DataFrameReader, QueryTest}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.util.Utils

abstract class KafkaRelationSuiteBase extends SharedSparkSession with KafkaTest {
abstract class KafkaRelationSuiteBase
extends QueryTest
with SharedClassicSparkSession
with KafkaTest {

import testImplicits._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ import org.apache.spark.annotation.Stable
* @since 1.6.0
*/
@Stable
abstract class DatasetHolder[T] {
class DatasetHolder[T](ds: Dataset[T]) {

// This is declared with parentheses to prevent the Scala compiler from treating
// `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset.
def toDS(): Dataset[T]
def toDS(): Dataset[T] = ds

// This is declared with parentheses to prevent the Scala compiler from treating
// `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
def toDF(): DataFrame
def toDF(): DataFrame = ds.toDF()

def toDF(colNames: String*): DataFrame
def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*)
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@ abstract class SQLImplicits extends EncoderImplicits with Serializable {
* Creates a [[Dataset]] from a local Seq.
* @since 1.6.0
*/
implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T]
implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T] =
new DatasetHolder[T](session.createDataset(s))

/**
* Creates a [[Dataset]] from an RDD.
*
* @since 1.6.0
*/
implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T]
implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T] =
new DatasetHolder[T](session.createDataset(rdd))

/**
* An implicit conversion that turns a Scala `Symbol` into a [[org.apache.spark.sql.Column]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ abstract class SQLImplicits private[sql] (override val session: SparkSession)
new DatasetHolder[T](session.createDataset(rdd))
}

class DatasetHolder[U](ds: Dataset[U]) extends sql.DatasetHolder[U] {
class DatasetHolder[U](ds: Dataset[U]) extends sql.DatasetHolder[U](ds) {
override def toDS(): Dataset[U] = ds
override def toDF(): DataFrame = ds.toDF()
override def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connector.catalog.{CatalogManager, Column, Identifier, InMemoryChangelogCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -70,7 +70,8 @@ import org.apache.spark.util.Utils
* compatibility.
*/
// scalastyle:on
class ProtoToParsedPlanTestSuite extends SharedSparkSession with ResourceHelper {
class ProtoToParsedPlanTestSuite
extends SharedClassicSparkSession with ResourceHelper {

private val cleanOrphanedGoldenFiles: Boolean =
System.getenv("SPARK_CLEAN_ORPHANED_GOLDEN_FILES") == "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.dsl.MockRemoteSession
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionKey, SparkConnectService}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.util.CloseableIterator

/**
* Base class and utilities for a test suite that starts and tests the real SparkConnectService
* with a real SparkConnectClient, communicating over RPC, but both in-process.
*/
trait SparkConnectServerTest extends SharedSparkSession {
trait SparkConnectServerTest extends SharedClassicSparkSession {

// Server port
val serverPort: Int =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, TimeType}
import org.apache.spark.unsafe.types.UTF8String

/**
* Testing trait for SparkConnect tests with some helper methods to make it easier to create new
* test cases.
*/
trait SparkConnectPlanTest extends SharedSparkSession {
trait SparkConnectPlanTest extends SharedClassicSparkSession {
def transform(rel: proto.Relation): logical.LogicalPlan = {
SparkConnectPlannerTestUtils.transform(spark, rel)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteKey, ExecuteStatus, SessionStatus, SparkConnectAnalyzeHandler, SparkConnectService, SparkListenerConnectOperationStarted}
import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand All @@ -61,7 +61,7 @@ import org.apache.spark.util.Utils
* Testing Connect Service implementation.
*/
class SparkConnectServiceSuite
extends SharedSparkSession
extends SharedClassicSparkSession
with MockitoSugar
with Logging
with SparkConnectPlanTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

class StreamingForeachBatchHelperSuite extends SharedSparkSession with MockitoSugar {
class StreamingForeachBatchHelperSuite extends SharedClassicSparkSession with MockitoSugar {

private def mockQuery(): StreamingQuery = {
val query = mock[StreamingQuery]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.{SparkConnectPlanner, SparkConnectPlanTest}
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

class DummyPlugin extends RelationPlugin {
override def transform(
Expand Down Expand Up @@ -119,7 +119,7 @@ class ExampleCommandPlugin extends CommandPlugin {
}
}

class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConnectPlanTest {
class SparkConnectPluginRegistrySuite extends SharedClassicSparkSession with SparkConnectPlanTest {

override def beforeEach(): Unit = {
if (SparkEnv.get.conf.contains(Connect.CONNECT_EXTENSIONS_EXPRESSION_CLASSES)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ import org.apache.spark.SparkRuntimeException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
import org.apache.spark.sql.connect.ResourceHelper
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.util.{ThreadUtils, Utils}

class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
class AddArtifactsHandlerSuite extends SharedClassicSparkSession with ResourceHelper {

private val CHUNK_SIZE: Int = 32 * 1024

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ArtifactStatusesResponse
import org.apache.spark.network.util.JavaUtils.sha256Hex
import org.apache.spark.sql.connect.ResourceHelper
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.util.ThreadUtils

private class DummyStreamObserver(p: Promise[ArtifactStatusesResponse])
Expand All @@ -38,7 +38,7 @@ private class DummyStreamObserver(p: Promise[ArtifactStatusesResponse])
override def onCompleted(): Unit = {}
}

class ArtifactStatusesHandlerSuite extends SharedSparkSession with ResourceHelper {
class ArtifactStatusesHandlerSuite extends SharedClassicSparkSession with ResourceHelper {

val sessionId = UUID.randomUUID().toString

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.GetStatusResponse
import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.connect.plugin.{GetStatusPlugin, SparkConnectPluginRegistry}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.util.ThreadUtils

/**
Expand Down Expand Up @@ -104,7 +104,7 @@ class FailingGetStatusPlugin extends GetStatusPlugin {
throw new RuntimeException("operation plugin failure")
}

class GetStatusHandlerSuite extends SharedSparkSession {
class GetStatusHandlerSuite extends SharedClassicSparkSession {

// Default userId matching SparkConnectTestUtils.createDummySessionHolder default
private val defaultUserId = "testUser"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.sql.connect.service
import java.util.UUID

import org.apache.spark.SparkSQLException
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

class SparkConnectCloneSessionSuite extends SharedSparkSession {
class SparkConnectCloneSessionSuite extends SharedClassicSparkSession {

override def beforeEach(): Unit = {
super.beforeEach()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package org.apache.spark.sql.connect.service

import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

/**
* Test suite for SparkConnectExecutionManager.
*/
class SparkConnectExecutionManagerSuite extends SharedSparkSession {
class SparkConnectExecutionManagerSuite extends SharedClassicSparkSession {

protected override def afterEach(): Unit = {
super.afterEach()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,19 @@ import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.scalatestplus.mockito.MockitoSugar

import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto.{Command, ExecutePlanResponse}
import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.connect.execution.ExecuteResponseObserver
import org.apache.spark.sql.connect.planner.SparkConnectStreamingQueryListenerHandler
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryListener}
import org.apache.spark.sql.streaming.Trigger.ProcessingTime
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

class SparkConnectListenerBusListenerSuite extends SharedSparkSession with MockitoSugar {
class SparkConnectListenerBusListenerSuite
extends SparkFunSuite
with SharedClassicSparkSession
with MockitoSugar {

override def afterEach(): Unit = {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
import org.apache.spark.sql.pipelines.logging.PipelineEvent
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.util.ArrayImplicits._

class SparkConnectSessionHolderSuite extends SharedSparkSession {
class SparkConnectSessionHolderSuite extends SharedClassicSparkSession {

test("DataFrame cache: Successful put and get") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ import org.apache.spark.SparkSQLException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
import org.apache.spark.sql.pipelines.logging.PipelineEvent
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession

class SparkConnectSessionManagerSuite extends SharedSparkSession {
class SparkConnectSessionManagerSuite extends SharedClassicSparkSession {

override def beforeEach(): Unit = {
super.beforeEach()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ abstract class SQLImplicits extends sql.SQLImplicits {
new DatasetHolder[T](session.createDataset(rdd))
}

class DatasetHolder[U](ds: Dataset[U]) extends sql.DatasetHolder[U] {
class DatasetHolder[U](ds: Dataset[U]) extends sql.DatasetHolder[U](ds) {
override def toDS(): Dataset[U] = ds
override def toDF(): DataFrame = ds.toDF()
override def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.storage.StorageLevel.{MEMORY_AND_DISK_2, MEMORY_ONLY}
Expand All @@ -64,9 +64,10 @@ import org.apache.spark.util.{AccumulatorContext, Utils}
private case class BigData(s: String)

@SlowSQLTest
class CachedTableSuite extends SharedSparkSession
class CachedTableSuite extends QueryTest
with SharedClassicSparkSession
with AdaptiveSparkPlanHelper {
import testImplicits._
import classicTestImplicits._

override def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.plans.AsOfJoinDirection
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.tags.SlowSQLTest

@SlowSQLTest
class DataFrameAsOfJoinSuite extends SharedSparkSession
class DataFrameAsOfJoinSuite extends QueryTest
with SharedClassicSparkSession
with AdaptiveSparkPlanHelper {

def prepareForAsOfJoin(): (classic.DataFrame, classic.DataFrame) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.tags.ExtendedSQLTest

/**
* Test suite for functions in [[org.apache.spark.sql.functions]].
*/
@ExtendedSQLTest
class DataFrameFunctionsSuite extends SharedSparkSession {
import testImplicits._
class DataFrameFunctionsSuite extends QueryTest with SharedClassicSparkSession {
import classicTestImplicits._

test("DataFrame function and SQL function parity") {
// This test compares the available list of DataFrame functions in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ import org.apache.spark.sql.catalyst.plans.{NearestByDirection, NearestByJoinMod
import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.tags.SlowSQLTest

@SlowSQLTest
class DataFrameNearestByJoinSuite extends QueryTest with SharedSparkSession {
class DataFrameNearestByJoinSuite extends QueryTest with SharedClassicSparkSession {

private def prepareForNearestByJoin(): (classic.DataFrame, classic.DataFrame) = {
val users = spark.createDataFrame(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ import org.apache.spark.sql.classic.{Dataset => DatasetImpl}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, count, explode, sum, year}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SharedClassicSparkSession
import org.apache.spark.sql.test.SQLTestData.TestData
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}

class DataFrameSelfJoinSuite extends SharedSparkSession {
import testImplicits._
class DataFrameSelfJoinSuite extends QueryTest with SharedClassicSparkSession {
import classicTestImplicits._

test("join - join using self join") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
Expand Down
Loading