Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pylib/cqlshlib/copyutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,18 @@ def format_value(self, val, cqltype):
float_precision=cqltype.precision, nullval=self.nullval, quote=False,
decimal_sep=self.decimal_sep, thousands_sep=self.thousands_sep,
boolean_styles=self.boolean_styles)

# Python 3.10 fixed bpo-12178: csv.writer now properly escapes the escapechar ('\')
# in all fields, including unquoted ones in QUOTE_MINIMAL mode. Before 3.10,
# csv.writer would silently write bare backslashes, relying on the pre-doubling
# performed by format_value_text/format_value_default (val.replace('\\', '\\\\'))
# so that csv.reader could restore them. In Python 3.10+, csv.writer also escapes
# those backslashes, resulting in quadruple-backslash sequences that csv.reader only
# halves — causing backslash counts to double on every COPY TO/FROM round-trip.
# Undo the pre-doubling here so that csv.writer's own escaping is the sole layer.
if sys.version_info >= (3, 10):
formatted = formatted.replace('\\\\', '\\')

return formatted

def close(self):
Expand Down
113 changes: 112 additions & 1 deletion pylib/cqlshlib/test/test_copyutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
# and $CQL_TEST_PORT to the associated port.


import csv
import io
import sys
import unittest

from cassandra.metadata import MIN_LONG, Murmur3Token
from cassandra.policies import SimpleConvictionPolicy
from cassandra.pool import Host
from unittest.mock import Mock

from cqlshlib.copyutil import ExportTask
from cqlshlib.copyutil import ExportProcess, ExportTask, ImportConversion


class CopyTaskTest(unittest.TestCase):
Expand Down Expand Up @@ -114,3 +117,111 @@ def test_get_ranges_murmur3(self):
(None, MIN_LONG + 1): {'hosts': ('10.0.0.2', '10.0.0.3', '10.0.0.4'), 'attempts': 0, 'rows': 0, 'workerno': -1}
}
self._test_get_ranges_murmur3_base({'endtoken': MIN_LONG + 1}, expected_ranges)


class TestCopyBackslashRoundtrip(unittest.TestCase):
"""
Tests that COPY TO followed by COPY FROM is a lossless round-trip for
string values that contain backslashes.
"""

# Default COPY dialect (mirrors copyutil.py parse_options defaults)
DEFAULT_DIALECT = dict(quotechar='"', escapechar='\\', doublequote=False)

def _make_export_process(self, dialect=None):
"""
Build a minimal ExportProcess without spawning a child process.
Only the attributes consumed by write_rows_to_csv/format_value are set.
"""
proc = object.__new__(ExportProcess)
proc.formatters = {}
proc.float_precision = 5
proc.double_precision = 12
proc.nullval = ''
proc.encoding = 'utf-8'
proc.date_time_format = Mock()
proc.decimal_sep = '.'
proc.thousands_sep = ''
proc.boolean_styles = ('True', 'False')
proc.options = Mock()
proc.options.dialect = dialect or self.DEFAULT_DIALECT
proc.report_error = Mock()
return proc

def _text_cqltype(self):
"""
Minimal mock CQL text type sufficient for format_value.
formatter=None causes get_formatter to fall back to type(val).__name__
lookup, which resolves to format_value_text for str values.
"""
return Mock(formatter=None, type_name='text')

def _export(self, proc, val):
"""
Call the real write_rows_to_csv, capture and return the CSV content string.
self.send() normally writes to a multiprocessing pipe, we intercept it here.
"""
captured = []
proc.send = lambda data: captured.append(data)
proc.write_rows_to_csv(token_range=None, rows=[[val]], cql_types=[self._text_cqltype()])
self.assertFalse(proc.report_error.called, 'write_rows_to_csv raised an error')
csv_content, _ = captured[0][1]
return csv_content

def _import(self, csv_content, dialect=None):
"""
Read back a text column value using the real import-side functions:
csv.reader -> ImportConversion.unprotect -> convert_text
"""
rows = list(csv.reader(io.StringIO(csv_content), **(dialect or self.DEFAULT_DIALECT)))
raw = rows[0][0]
return str(ImportConversion.unprotect(raw)) # mirrors ImportConversion._get_converter

def _roundtrip(self, original, dialect=None):
"""Full COPY TO -> CSV file -> COPY FROM pipeline using real functions."""
proc = self._make_export_process(dialect)
csv_content = self._export(proc, original)
return self._import(csv_content, dialect)

def test_no_backslash(self):
self.assertEqual('hello world', self._roundtrip('hello world'))

def test_single_backslash(self):
self.assertEqual('a\\b', self._roundtrip('a\\b'))

def test_url_with_backslashes(self):
original = 'https:\\/\\/apache.org'
self.assertEqual(original, self._roundtrip(original))

def test_multiple_consecutive_backslashes(self):
self.assertEqual('a\\\\b', self._roundtrip('a\\\\b'))

def test_backslash_at_start(self):
self.assertEqual('\\hello', self._roundtrip('\\hello'))

def test_backslash_at_end(self):
self.assertEqual('hello\\', self._roundtrip('hello\\'))

def test_only_backslashes(self):
self.assertEqual('\\\\\\', self._roundtrip('\\\\\\'))

def test_empty_string(self):
self.assertEqual('', self._roundtrip(''))

def test_backslash_before_delimiter(self):
"""Backslash before comma: csv.writer must quote the field correctly."""
self.assertEqual('a\\,b', self._roundtrip('a\\,b'))

def test_backslash_before_quotechar(self):
"""Backslash immediately before the quotechar must not corrupt the field."""
self.assertEqual('say \\"hi\\"', self._roundtrip('say \\"hi\\"'))

def test_roundtrip_is_idempotent(self):
"""Two consecutive COPY TO/FROM cycles must produce identical results."""
original = 'https:\\/\\/example.com\\path'
after_first = self._roundtrip(original)
after_second = self._roundtrip(after_first)
self.assertEqual(original, after_first,
'First round-trip changed the value')
self.assertEqual(after_first, after_second,
'Second round-trip changed the value (non-idempotent)')