Skip to content

Commit 13987de

Browse files
authored
Remove Empty Chunks in Cloudfetch concatenation (#814)
* Removed the unncecessary chunk * Added tests
1 parent 0309e7c commit 13987de

2 files changed

Lines changed: 158 additions & 5 deletions

File tree

src/databricks/sql/result_set.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,21 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
306306
"""
307307
if size < 0:
308308
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
309+
310+
# Hold 0-row chunks aside instead of appending them to ``partial_result_chunks``.
311+
# CloudFetchQueue may return a placeholder empty table whose schema does not
312+
# match the real downloaded chunks; concatenating it would corrupt the result.
313+
partial_result_chunks: List["pyarrow.Table"] = []
314+
zero_row_table: Optional["pyarrow.Table"] = None
315+
n_remaining_rows = size
316+
309317
results = self.results.next_n_rows(size)
310-
partial_result_chunks = [results]
311-
n_remaining_rows = size - results.num_rows
312-
self._next_row_index += results.num_rows
318+
if results.num_rows == 0:
319+
zero_row_table = results
320+
else:
321+
partial_result_chunks.append(results)
322+
n_remaining_rows -= results.num_rows
323+
self._next_row_index += results.num_rows
313324

314325
while (
315326
n_remaining_rows > 0
@@ -318,10 +329,14 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
318329
):
319330
self._fill_results_buffer()
320331
partial_results = self.results.next_n_rows(n_remaining_rows)
332+
if partial_results.num_rows == 0:
333+
continue
321334
partial_result_chunks.append(partial_results)
322335
n_remaining_rows -= partial_results.num_rows
323336
self._next_row_index += partial_results.num_rows
324337

338+
if not partial_result_chunks:
339+
partial_result_chunks.append(zero_row_table)
325340
return concat_table_chunks(partial_result_chunks)
326341

327342
def fetchmany_columnar(self, size: int):
@@ -351,15 +366,30 @@ def fetchmany_columnar(self, size: int):
351366

352367
def fetchall_arrow(self) -> "pyarrow.Table":
353368
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
369+
# Hold 0-row chunks aside instead of appending them to ``partial_result_chunks``.
370+
# CloudFetchQueue may return a placeholder empty table whose schema does not
371+
# match the real downloaded chunks; concatenating it would corrupt the result.
372+
partial_result_chunks: List = []
373+
zero_row_table: Optional["pyarrow.Table"] = None
374+
354375
results = self.results.remaining_rows()
355-
self._next_row_index += results.num_rows
356-
partial_result_chunks = [results]
376+
if results.num_rows == 0:
377+
zero_row_table = results
378+
else:
379+
partial_result_chunks.append(results)
380+
self._next_row_index += results.num_rows
381+
357382
while not self.has_been_closed_server_side and self.has_more_rows:
358383
self._fill_results_buffer()
359384
partial_results = self.results.remaining_rows()
385+
if partial_results.num_rows == 0:
386+
continue
360387
partial_result_chunks.append(partial_results)
361388
self._next_row_index += partial_results.num_rows
362389

390+
if not partial_result_chunks:
391+
partial_result_chunks.append(zero_row_table)
392+
363393
result_table = concat_table_chunks(partial_result_chunks)
364394
# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
365395
# Valid only for metadata commands result set

tests/unit/test_fetches.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,34 @@
1414
from databricks.sql.result_set import ThriftResultSet
1515

1616

17+
class _StubArrowQueue:
18+
"""Minimal queue that hands back a pre-built pyarrow.Table once.
19+
20+
Used to inject a schemaless / wrong-schema placeholder that the real
21+
ArrowQueue would never produce — this is what CloudFetchQueue emits
22+
when ``self.table is None`` and ``schema_bytes`` is missing.
23+
"""
24+
25+
def __init__(self, table):
26+
self._table = table
27+
self._consumed = False
28+
29+
def _take(self):
30+
if self._consumed:
31+
return self._table.slice(0, 0)
32+
self._consumed = True
33+
return self._table
34+
35+
def next_n_rows(self, num_rows):
36+
return self._take()
37+
38+
def remaining_rows(self):
39+
return self._take()
40+
41+
def close(self):
42+
pass
43+
44+
1745
@pytest.mark.skipif(pa is None, reason="PyArrow is not installed")
1846
class FetchTests(unittest.TestCase):
1947
"""
@@ -110,6 +138,39 @@ def fetch_results(
110138
)
111139
return rs
112140

141+
@staticmethod
142+
def make_dummy_result_set_from_queue_list(queue_list, description=None):
143+
"""Like make_dummy_result_set_from_batch_list but yields pre-built queues.
144+
145+
Lets tests inject queues whose returned tables have an arbitrary schema
146+
(or no schema at all) — needed to reproduce the CloudFetch placeholder
147+
case that ``ArrowQueue`` would never produce.
148+
"""
149+
queue_index = 0
150+
151+
def fetch_results(**_):
152+
nonlocal queue_index
153+
q = queue_list[queue_index]
154+
queue_index += 1
155+
return q, queue_index < len(queue_list), 0
156+
157+
mock_thrift_backend = Mock(spec=ThriftDatabricksClient)
158+
mock_thrift_backend.fetch_results = fetch_results
159+
160+
rs = ThriftResultSet(
161+
connection=Mock(),
162+
execute_response=ExecuteResponse(
163+
command_id=None,
164+
status=None,
165+
has_been_closed_server_side=False,
166+
description=description or [],
167+
lz4_compressed=True,
168+
is_staging_operation=False,
169+
),
170+
thrift_client=mock_thrift_backend,
171+
)
172+
return rs
173+
113174
def assertEqualRowValues(self, actual, expected):
114175
self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0)
115176
for act, exp in zip(actual, expected):
@@ -267,6 +328,68 @@ def test_fetchone_without_initial_results(self):
267328
dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_2)
268329
self.assertEqual(dummy_result_set.fetchone(), None)
269330

331+
# Regression tests for fetchmany_arrow / fetchall_arrow handling of
332+
# the schemaless CloudFetch placeholder.
333+
def test_fetchall_arrow_drops_mismatched_empty_placeholder(self):
334+
# First fetch_results() call hands back a 0-row placeholder whose
335+
# schema does not match the real chunks. The second call
336+
# hands back real data.
337+
placeholder = pa.Table.from_pydict(
338+
{"stale_col": []}, schema=pa.schema({"stale_col": pa.string()})
339+
)
340+
_, real_table = self.make_arrow_table([[1], [2], [3]])
341+
rs = self.make_dummy_result_set_from_queue_list(
342+
[_StubArrowQueue(placeholder), _StubArrowQueue(real_table)],
343+
description=[("col0", "integer", None, None, None, None, None)],
344+
)
345+
346+
result = rs.fetchall_arrow()
347+
348+
self.assertEqual(result.num_rows, 3)
349+
self.assertEqual(result.schema.names, ["col0"])
350+
self.assertEqual(result.column(0).to_pylist(), [1, 2, 3])
351+
352+
def test_fetchall_arrow_all_empty_returns_zero_row_table(self):
353+
# Every queue call returns the schemaless placeholder — the
354+
# call site should fall back to zero_row_table without crashing.
355+
placeholder = pa.Table.from_pydict({})
356+
rs = self.make_dummy_result_set_from_queue_list(
357+
[_StubArrowQueue(placeholder)],
358+
)
359+
360+
result = rs.fetchall_arrow()
361+
362+
self.assertIsInstance(result, pa.Table)
363+
self.assertEqual(result.num_rows, 0)
364+
365+
def test_fetchmany_arrow_drops_mismatched_empty_placeholder(self):
366+
# See ``test_fetchall_arrow_drops_mismatched_empty_placeholder``.
367+
placeholder = pa.Table.from_pydict(
368+
{"stale_col": []}, schema=pa.schema({"stale_col": pa.string()})
369+
)
370+
_, real_table = self.make_arrow_table([[1], [2], [3]])
371+
rs = self.make_dummy_result_set_from_queue_list(
372+
[_StubArrowQueue(placeholder), _StubArrowQueue(real_table)],
373+
description=[("col0", "integer", None, None, None, None, None)],
374+
)
375+
376+
result = rs.fetchmany_arrow(3)
377+
378+
self.assertEqual(result.num_rows, 3)
379+
self.assertEqual(result.schema.names, ["col0"])
380+
self.assertEqual(result.column(0).to_pylist(), [1, 2, 3])
381+
382+
def test_fetchmany_arrow_all_empty_returns_zero_row_table(self):
383+
placeholder = pa.Table.from_pydict({})
384+
rs = self.make_dummy_result_set_from_queue_list(
385+
[_StubArrowQueue(placeholder)],
386+
)
387+
388+
result = rs.fetchmany_arrow(10)
389+
390+
self.assertIsInstance(result, pa.Table)
391+
self.assertEqual(result.num_rows, 0)
392+
270393

271394
if __name__ == "__main__":
272395
unittest.main()

0 commit comments

Comments
 (0)