Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/docs/source/tutorial/sql/python_data_source.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,13 @@ This is the same dummy streaming reader that generates 2 rows every batch implem
def read(self, start: dict) -> Tuple[Iterator[Tuple], dict]:
"""
Takes start offset as an input, return an iterator of tuples and
the start offset of next read.
the end offset (start offset for the next read). The end offset must
advance past the start offset when returning data; otherwise Spark
raises a validation exception.
For example, returning 2 records from start_idx 0 means end should
be {"offset": 2} (i.e. start + 2).
When there is no data to read, you may return the same offset as end and
start, but you must provide an empty iterator.
"""
start_idx = start["offset"]
it = iter([(i,) for i in range(start_idx, start_idx + 2)])
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,11 @@
"SparkContext or SparkSession should be created first."
]
},
"SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE": {
"message": [
"SimpleDataSourceStreamReader.read() returned a non-empty batch but the end offset: <end_offset> did not advance past the start offset: <start_offset>. The end offset must represent the position after the last record returned."
]
},
"SLICE_WITH_STEP": {
"message": [
"Slice with step is not supported."
Expand Down
26 changes: 25 additions & 1 deletion python/pyspark/sql/datasource_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,38 @@ def getDefaultReadLimit(self) -> ReadLimit:
# We do not consider providing different read limit on simple stream reader.
return ReadAllAvailable()

def add_result_to_cache(self, start: dict, end: dict, it: Iterator[Tuple]) -> None:
"""
Validates that read() did not return a non-empty batch with end equal to start,
which would cause the same batch to be processed repeatedly. When end != start,
appends the result to the cache; when end == start with empty iterator, does not
cache (avoids unbounded cache growth).
"""
start_str = json.dumps(start)
end_str = json.dumps(end)
if end_str != start_str:
self.cache.append(PrefetchedCacheEntry(start, end, it))
return
try:
next(it)
except StopIteration:
return
raise PySparkException(
errorClass="SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE",
messageParameters={
"start_offset": start_str,
"end_offset": end_str,
},
)

def latestOffset(self, start: dict, limit: ReadLimit) -> dict:
assert start is not None, "start offset should not be None"
assert isinstance(
limit, ReadAllAvailable
), "simple stream reader does not support read limit"

(iter, end) = self.simple_reader.read(start)
self.cache.append(PrefetchedCacheEntry(start, end, iter))
self.add_result_to_cache(start, end, iter)
return end

def commit(self, end: dict) -> None:
Expand Down
55 changes: 55 additions & 0 deletions python/pyspark/sql/tests/test_python_streaming_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
have_pyarrow,
pyarrow_requirement_message,
)
from pyspark.errors import PySparkException
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.utils import eventually
from pyspark.testing.sqlutils import ReusedSQLTestCase
Expand Down Expand Up @@ -509,6 +510,60 @@ def check_batch(df, batch_id):
q.awaitTermination(timeout=30)
self.assertIsNone(q.exception(), "No exception has to be propagated.")

def test_simple_stream_reader_offset_did_not_advance_raises(self):
"""Validate that returning end == start with non-empty data raises SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE."""
from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper

class BuggySimpleStreamReader(SimpleDataSourceStreamReader):
def initialOffset(self):
return {"offset": 0}

def read(self, start: dict):
# Bug: return same offset as end despite returning data
start_idx = start["offset"]
it = iter([(i,) for i in range(start_idx, start_idx + 3)])
return (it, start)

def readBetweenOffsets(self, start: dict, end: dict):
return iter([])

def commit(self, end: dict):
pass

reader = BuggySimpleStreamReader()
wrapper = _SimpleStreamReaderWrapper(reader)
with self.assertRaises(PySparkException) as cm:
wrapper.latestOffset({"offset": 0}, ReadAllAvailable())
self.assertEqual(
cm.exception.getCondition(),
"SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE",
)

def test_simple_stream_reader_empty_iterator_start_equals_end_allowed(self):
"""When read() returns end == start with an empty iterator, no exception and no cache entry."""
from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper

class EmptyBatchReader(SimpleDataSourceStreamReader):
def initialOffset(self):
return {"offset": 0}

def read(self, start: dict):
# Valid: same offset as end but empty iterator (no data)
return (iter([]), start)

def readBetweenOffsets(self, start: dict, end: dict):
return iter([])

def commit(self, end: dict):
pass

reader = EmptyBatchReader()
wrapper = _SimpleStreamReaderWrapper(reader)
start = {"offset": 0}
end = wrapper.latestOffset(start, ReadAllAvailable())
self.assertEqual(end, start)
self.assertEqual(len(wrapper.cache), 0)

def test_stream_writer(self):
input_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_input")
output_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_output")
Expand Down