Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 64 additions & 12 deletions batch_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

sql "github.com/Shopify/ghostferry/sqlwrapper"

"github.com/go-mysql-org/go-mysql/schema"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -56,14 +57,65 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
return nil
}

startPaginationKeypos, err := values[0].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
var startPaginationKeypos, endPaginationKeypos PaginationKey
var err error

paginationColumn := batch.TableSchema().GetPaginationColumn()

endPaginationKeypos, err := values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
switch paginationColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
var startValue, endValue uint64
startValue, err = values[0].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
startPaginationKeypos = NewUint64Key(startValue)
endPaginationKeypos = NewUint64Key(endValue)

case schema.TYPE_BINARY, schema.TYPE_STRING:
startValueInterface := values[0][batch.PaginationKeyIndex()]
endValueInterface := values[len(values)-1][batch.PaginationKeyIndex()]

getBytes := func(val interface{}) ([]byte, error) {
switch v := val.(type) {
case []byte:
return v, nil
case string:
return []byte(v), nil
default:
return nil, fmt.Errorf("expected binary/string pagination key, got %T", val)
}
}

startValue, err := getBytes(startValueInterface)
if err != nil {
return err
}

endValue, err := getBytes(endValueInterface)
if err != nil {
return err
}

startPaginationKeypos = NewBinaryKey(startValue)
endPaginationKeypos = NewBinaryKey(endValue)

default:
var startValue, endValue uint64
startValue, err = values[0].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
if err != nil {
return err
}
startPaginationKeypos = NewUint64Key(startValue)
endPaginationKeypos = NewUint64Key(endValue)
}

db := batch.TableSchema().Schema
Expand All @@ -78,12 +130,12 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {

query, args, err := batch.AsSQLQuery(db, table)
if err != nil {
return fmt.Errorf("during generating sql query at paginationKey %v -> %v: %v", startPaginationKeypos, endPaginationKeypos, err)
return fmt.Errorf("during generating sql query at paginationKey %s -> %s: %v", startPaginationKeypos.String(), endPaginationKeypos.String(), err)
}

stmt, err := w.stmtCache.StmtFor(w.DB, query)
if err != nil {
return fmt.Errorf("during prepare query near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
return fmt.Errorf("during prepare query near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
}

tx, err := w.DB.Begin()
Expand All @@ -94,14 +146,14 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
_, err = tx.Stmt(stmt).Exec(args...)
if err != nil {
tx.Rollback()
return fmt.Errorf("during exec query near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
return fmt.Errorf("during exec query near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
}

if w.InlineVerifier != nil {
mismatches, err := w.InlineVerifier.CheckFingerprintInline(tx, db, table, batch, w.EnforceInlineVerification)
if err != nil {
tx.Rollback()
return fmt.Errorf("during fingerprint checking for paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
return fmt.Errorf("during fingerprint checking for paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
}

if w.EnforceInlineVerification {
Expand All @@ -119,7 +171,7 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
err = tx.Commit()
if err != nil {
tx.Rollback()
return fmt.Errorf("during commit near paginationKey %v -> %v (%s): %v", startPaginationKeypos, endPaginationKeypos, query, err)
return fmt.Errorf("during commit near paginationKey %s -> %s (%s): %v", startPaginationKeypos.String(), endPaginationKeypos.String(), query, err)
}

// Note that the state tracker expects us the track based on the original
Expand Down
52 changes: 37 additions & 15 deletions compression_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func (e UnsupportedCompressionError) Error() string {
type CompressionVerifier struct {
logger *logrus.Entry

TableSchemaCache TableSchemaCache
supportedAlgorithms map[string]struct{}
tableColumnCompressions TableColumnCompressionConfig
}
Expand All @@ -59,32 +60,52 @@ type CompressionVerifier struct {
// The GetCompressedHashes method checks if the existing table contains compressed data
// and will apply the decompression algorithm to the applicable columns if necessary.
// After the columns are decompressed, the hashes of the data are used to verify equality
func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (map[uint64][]byte, error) {
func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schemaName, tableName, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []interface{}) (map[string][]byte, error) {
c.logger.WithFields(logrus.Fields{
"tag": "compression_verifier",
"table": table,
"table": tableName,
}).Info("decompressing table data before verification")

tableCompression := c.tableColumnCompressions[table]
tableCompression := c.tableColumnCompressions[tableName]

// Extract the raw rows using SQL to be decompressed
rows, err := getRows(db, schema, table, paginationKeyColumn, columns, paginationKeys)
rows, err := getRows(db, schemaName, tableName, paginationKeyColumn, columns, paginationKeys)
if err != nil {
return nil, err
}
defer rows.Close()

// Decompress applicable columns and hash the resulting column values for comparison
resultSet := make(map[uint64][]byte)
table := c.TableSchemaCache.Get(schemaName, tableName)
if table == nil {
return nil, fmt.Errorf("table %s.%s not found in schema cache", schemaName, tableName)
}
paginationColumn := table.GetPaginationColumn()
resultSet := make(map[string][]byte)

for rows.Next() {
rowData, err := ScanByteRow(rows, len(columns)+1)
if err != nil {
return nil, err
}

paginationKey, err := strconv.ParseUint(string(rowData[0]), 10, 64)
if err != nil {
return nil, err
var paginationKeyStr string
switch paginationColumn.Type {
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
paginationKeyUint, err := strconv.ParseUint(string(rowData[0]), 10, 64)
if err != nil {
return nil, err
}
paginationKeyStr = NewUint64Key(paginationKeyUint).String()

case schema.TYPE_BINARY, schema.TYPE_STRING:
paginationKeyStr = NewBinaryKey(rowData[0]).String()

default:
paginationKeyUint, err := strconv.ParseUint(string(rowData[0]), 10, 64)
if err != nil {
return nil, err
}
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
}

// Decompress the applicable columns and then hash them together
Expand All @@ -95,7 +116,7 @@ func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, pag
for idx, column := range columns {
if algorithm, ok := tableCompression[column.Name]; ok {
// rowData contains the result of "SELECT paginationKeyColumn, * FROM ...", so idx+1 to get each column
decompressedColData, err := c.Decompress(table, column.Name, algorithm, rowData[idx+1])
decompressedColData, err := c.Decompress(tableName, column.Name, algorithm, rowData[idx+1])
if err != nil {
return nil, err
}
Expand All @@ -111,20 +132,20 @@ func (c *CompressionVerifier) GetCompressedHashes(db *sql.DB, schema, table, pag
return nil, err
}

resultSet[paginationKey] = decompressedRowHash
resultSet[paginationKeyStr] = decompressedRowHash
}

metrics.Gauge(
"compression_verifier_decompress_rows",
float64(len(resultSet)),
[]MetricTag{{"table", table}},
[]MetricTag{{"table", tableName}},
1.0,
)

logrus.WithFields(logrus.Fields{
"tag": "compression_verifier",
"rows": len(resultSet),
"table": table,
"table": tableName,
}).Debug("decompressed rows will be compared")

return resultSet, nil
Expand Down Expand Up @@ -192,12 +213,13 @@ func (c *CompressionVerifier) verifyConfiguredCompression(tableColumnCompression

// NewCompressionVerifier first checks the map for supported compression algorithms before
// initializing and returning the initialized instance.
func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig) (*CompressionVerifier, error) {
func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig, tableSchemaCache TableSchemaCache) (*CompressionVerifier, error) {
supportedAlgorithms := make(map[string]struct{})
supportedAlgorithms[CompressionSnappy] = struct{}{}

compressionVerifier := &CompressionVerifier{
logger: logrus.WithField("tag", "compression_verifier"),
TableSchemaCache: tableSchemaCache,
supportedAlgorithms: supportedAlgorithms,
tableColumnCompressions: tableColumnCompressions,
}
Expand All @@ -209,7 +231,7 @@ func NewCompressionVerifier(tableColumnCompressions TableColumnCompressionConfig
return compressionVerifier, nil
}

func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []uint64) (*sqlorig.Rows, error) {
func getRows(db *sql.DB, schema, table, paginationKeyColumn string, columns []schema.TableColumn, paginationKeys []interface{}) (*sqlorig.Rows, error) {
quotedPaginationKey := QuoteField(paginationKeyColumn)
sql, args, err := rowSelector(columns, paginationKeyColumn).
From(QuotedTableNameFromString(schema, table)).
Expand Down
16 changes: 13 additions & 3 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,12 +376,17 @@ func (c ForceIndexConfig) IndexFor(schemaName, tableName string) string {
// CascadingPaginationColumnConfig to configure pagination columns to be
// used. The term `Cascading` to denote that greater specificity takes
// precedence.
//
// IMPORTANT: All configured pagination columns must contain unique values.
// When specifying a FallbackColumn for tables with composite primary keys,
// ensure the column has a unique constraint to prevent data loss during migration.
type CascadingPaginationColumnConfig struct {
// PerTable has greatest specificity and takes precedence over the other options
PerTable map[string]map[string]string // SchemaName => TableName => ColumnName

// FallbackColumn is a global default to fallback to and is less specific than the
// default, which is the Primary Key
// default, which is the Primary Key.
// This column MUST have unique values (ideally a unique constraint) for data integrity.
FallbackColumn string
}

Expand Down Expand Up @@ -727,10 +732,15 @@ type Config struct {
//
ForceIndexForVerification ForceIndexConfig

// Ghostferry requires a single numeric column to paginate over tables. Inferring that column is done in the following exact order:
// Ghostferry requires a single numeric or binary column to paginate over tables. Inferring that column is done in the following exact order:
// 1. Use the PerTable pagination column, if configured for a table. Fail if we cannot find this column in the table.
// 2. Use the table's primary key column as the pagination column. Fail if the primary key is not numeric or is a composite key without a FallbackColumn specified.
// 2. Use the table's primary key column as the pagination column. Fail if the primary key is not numeric/binary or is a composite key without a FallbackColumn specified.
// 3. Use the FallbackColumn pagination column, if configured. Fail if we cannot find this column in the table.
//
// IMPORTANT: The pagination column MUST contain unique values for data integrity.
// When using a FallbackColumn (typically "id") for tables with composite primary keys, this column must have a unique constraint.
// The pagination algorithm uses WHERE pagination_key > last_key ORDER BY pagination_key LIMIT batch_size.
// If duplicate values exist, rows may be skipped during iteration, resulting in data loss during the migration.
CascadingPaginationColumnConfig *CascadingPaginationColumnConfig

// SkipTargetVerification is used to enable or disable target verification during moves.
Expand Down
Loading