Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions internal/mysqldump/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"regexp"
"slices"
"strings"
"sync"

Expand Down Expand Up @@ -38,6 +39,8 @@ type Dumper struct {
mapBins map[string][]string
mapExclusionColumns map[string][]string
mapMu sync.RWMutex
// schemaCache stores prefetched table schemas
schemaCache map[string]*TableSchema
}

const (
Expand Down Expand Up @@ -81,6 +84,17 @@ func (d *Dumper) Dump(ctx context.Context, w io.Writer) error {
d.filterMap[strings.ToLower(table)] = IgnoreMapPlacement
}

tablesToDump := make([]string, 0, len(tables))
for _, table := range tables {
if d.filterMap[strings.ToLower(table)] != IgnoreMapPlacement {
tablesToDump = append(tablesToDump, table)
}
}

if err := d.prefetchAllSchemas(ctx, tablesToDump); err != nil {
return err
}

if _, err = fmt.Fprintln(w, dump); err != nil {
return err
}
Expand Down Expand Up @@ -115,15 +129,12 @@ func (d *Dumper) dumpTablesSequential(ctx context.Context, w io.Writer, tables [
}

skipData := d.filterMap[strings.ToLower(table)] == NoDataMapPlacement
tmp, err := d.getCreateTableStatement(ctx, table)
createStmt, err := d.getCreateTableStatement(table)
if err != nil {
return err
}

tmp = d.excludeGeneratedColumns(table, tmp)
d.parseBinaryRelations(table, tmp)

if _, err = fmt.Fprintln(w, tmp); err != nil {
if _, err = fmt.Fprintln(w, createStmt); err != nil {
return err
}

Expand Down Expand Up @@ -169,7 +180,7 @@ func (d *Dumper) dumpTablesParallel(ctx context.Context, w io.Writer, tables []s
result.index = index

skipData := d.filterMap[strings.ToLower(table)] == NoDataMapPlacement
tmp, err := d.getCreateTableStatement(ctx, table)
createStmt, err := d.getCreateTableStatement(table)
if err != nil {
result.err = err
mu.Lock()
Expand All @@ -178,11 +189,8 @@ func (d *Dumper) dumpTablesParallel(ctx context.Context, w io.Writer, tables []s
return
}

tmp = d.excludeGeneratedColumns(table, tmp)
d.parseBinaryRelations(table, tmp)

var sb strings.Builder
sb.WriteString(tmp)
sb.WriteString(createStmt)

if !skipData {
if err := d.dumpTableDataToWriter(ctx, &sb, table); err != nil {
Expand Down Expand Up @@ -345,11 +353,7 @@ func (d *Dumper) isColumnBinary(table, columnName string) bool {
val, ok := d.mapBins[table]
d.mapMu.RUnlock()
if ok {
for _, b := range val {
if b == columnName {
return true
}
}
return slices.Contains(val, columnName)
}

return false
Expand All @@ -360,11 +364,7 @@ func (d *Dumper) isColumnExcluded(table, columnName string) bool {
val, ok := d.mapExclusionColumns[table]
d.mapMu.RUnlock()
if ok {
for _, b := range val {
if b == columnName {
return true
}
}
return slices.Contains(val, columnName)
}

return false
Expand Down Expand Up @@ -634,15 +634,22 @@ func (d *Dumper) rowCount(ctx context.Context, table string) (count uint64, err
return count, nil
}

func (d *Dumper) getCreateTableStatement(ctx context.Context, table string) (string, error) {
func (d *Dumper) getCreateTableStatement(table string) (string, error) {
s := fmt.Sprintf("\n--\n-- Structure for table `%s`\n--\n\n", table)
s += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;\n", table)
row := d.useTransactionOrDBQueryRow(ctx, fmt.Sprintf("SHOW CREATE TABLE `%s`", table))
var tname, ddl string
if err := row.Scan(&tname, &ddl); err != nil {
return "", err

schema, err := d.fetchTableSchema(table)
if err != nil {
return "", fmt.Errorf("fetch table schema: %w", err)
}
s += fmt.Sprintf("%s;\n", ddl)

// Populate binary and generated column maps for data export
d.mapMu.Lock()
d.mapBins[table] = schema.GetBinaryColumns()
d.mapExclusionColumns[table] = schema.GetGeneratedColumns()
d.mapMu.Unlock()

s += schema.BuildCreateTableSQL() + ";\n"
return s, nil
}

Expand Down
105 changes: 90 additions & 15 deletions internal/mysqldump/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,33 @@ func getInternalMySQLInstance(db *sql.DB) *Dumper {
return NewMySQLDumper(db)
}

func mockPrefetchSchemas(mock sqlmock.Sqlmock) {
mock.ExpectQuery("SELECT.*FROM INFORMATION_SCHEMA.TABLES.*TABLE_TYPE = 'BASE TABLE'").
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME", "ENGINE", "TABLE_COLLATION", "TABLE_COMMENT", "ROW_FORMAT", "AUTO_INCREMENT"}))

mock.ExpectQuery("SELECT.*FROM INFORMATION_SCHEMA.COLUMNS.*WHERE TABLE_SCHEMA = DATABASE()").
WillReturnRows(sqlmock.NewRows([]string{
"TABLE_NAME", "COLUMN_NAME", "COLUMN_TYPE", "CHARACTER_SET_NAME", "IS_NULLABLE", "COLUMN_DEFAULT",
"EXTRA", "COLLATION_NAME", "COLUMN_COMMENT", "GENERATION_EXPRESSION",
}))

mock.ExpectQuery("SELECT.*FROM INFORMATION_SCHEMA.STATISTICS.*WHERE TABLE_SCHEMA = DATABASE()").
WillReturnRows(sqlmock.NewRows([]string{
"TABLE_NAME", "INDEX_NAME", "COLUMN_NAME", "NON_UNIQUE", "INDEX_TYPE", "SUB_PART", "COLLATION", "INDEX_COMMENT", "SEQ_IN_INDEX",
}))

mock.ExpectQuery("SELECT.*FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE.*").
WillReturnRows(sqlmock.NewRows([]string{
"TABLE_NAME", "CONSTRAINT_NAME", "COLUMN_NAME", "REFERENCED_TABLE_NAME",
"REFERENCED_COLUMN_NAME", "UPDATE_RULE", "DELETE_RULE", "ORDINAL_POSITION",
}))

mock.ExpectQuery("SELECT.*FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS.*CHECK_CONSTRAINTS.*").
WillReturnRows(sqlmock.NewRows([]string{
"TABLE_NAME", "CONSTRAINT_NAME", "CHECK_CLAUSE",
}))
}

func TestMySQLFlushTable(t *testing.T) {
db, mock := getDB(t)
dumper := getInternalMySQLInstance(db)
Expand Down Expand Up @@ -74,32 +101,65 @@ func TestMySQLGetTablesHandlingErrorWhenScanningRow(t *testing.T) {
}

func TestMySQLDumpCreateTable(t *testing.T) {
var ddl = "CREATE TABLE `table` (" +
"`id` bigint(20) NOT NULL AUTO_INCREMENT, " +
"`name` varchar(255) NOT NULL, " +
"PRIMARY KEY (`id`), KEY `idx_name` (`name`) " +
") ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8"
db, mock := getDB(t)
dumper := getInternalMySQLInstance(db)
mock.ExpectQuery("SHOW CREATE TABLE `table`").WillReturnRows(
sqlmock.NewRows([]string{"Table", "Create Table"}).
AddRow("table", ddl),
)
str, err := dumper.getCreateTableStatement(t.Context(), "table")

// Mock batch prefetch queries
mock.ExpectQuery("SELECT.*FROM INFORMATION_SCHEMA.TABLES.*TABLE_TYPE = 'BASE TABLE'").
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME", "ENGINE", "TABLE_COLLATION", "TABLE_COMMENT", "ROW_FORMAT", "AUTO_INCREMENT"}).
AddRow("table", "InnoDB", "utf8mb4_unicode_ci", "", "Dynamic", nil))

mock.ExpectQuery("SELECT.*FROM INFORMATION_SCHEMA.COLUMNS.*WHERE TABLE_SCHEMA = DATABASE()").
WillReturnRows(sqlmock.NewRows([]string{
"TABLE_NAME", "COLUMN_NAME", "COLUMN_TYPE", "CHARACTER_SET_NAME", "IS_NULLABLE", "COLUMN_DEFAULT",
"EXTRA", "COLLATION_NAME", "COLUMN_COMMENT", "GENERATION_EXPRESSION",
}).
AddRow("table", "id", "bigint(20)", nil, "NO", nil, "AUTO_INCREMENT", nil, "", nil).
AddRow("table", "name", "varchar(255)", "utf8mb4", "NO", nil, "", "utf8mb4_unicode_ci", "", nil))

mock.ExpectQuery("SELECT.*FROM INFORMATION_SCHEMA.STATISTICS.*WHERE TABLE_SCHEMA = DATABASE()").
WillReturnRows(sqlmock.NewRows([]string{
"TABLE_NAME", "INDEX_NAME", "COLUMN_NAME", "NON_UNIQUE", "INDEX_TYPE", "SUB_PART", "COLLATION", "INDEX_COMMENT", "SEQ_IN_INDEX",
}).
AddRow("table", "PRIMARY", "id", 0, "BTREE", nil, "A", "", 1).
AddRow("table", "idx_name", "name", 1, "BTREE", nil, "A", "", 1))

mock.ExpectQuery("SELECT.*FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE.*").
WillReturnRows(sqlmock.NewRows([]string{
"TABLE_NAME", "CONSTRAINT_NAME", "COLUMN_NAME", "REFERENCED_TABLE_NAME",
"REFERENCED_COLUMN_NAME", "UPDATE_RULE", "DELETE_RULE", "ORDINAL_POSITION",
}))

mock.ExpectQuery("SELECT.*FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS.*CHECK_CONSTRAINTS.*").
WillReturnRows(sqlmock.NewRows([]string{
"TABLE_NAME", "CONSTRAINT_NAME", "CHECK_CLAUSE",
}))

// Prefetch schemas first
err := dumper.prefetchAllSchemas(t.Context(), []string{"table"})
assert.Nil(t, err)

str, err := dumper.getCreateTableStatement("table")

assert.Nil(t, err)
assert.Contains(t, str, "DROP TABLE IF EXISTS `table`")
assert.Contains(t, str, ddl)
assert.Contains(t, str, "CREATE TABLE `table`")
assert.Contains(t, str, "`id` bigint(20) NOT NULL AUTO_INCREMENT")
assert.Contains(t, str, "`name` varchar(255) NOT NULL")
assert.Contains(t, str, "PRIMARY KEY (`id`)")
assert.Contains(t, str, "KEY `idx_name` (`name`)")
assert.Contains(t, str, "ENGINE=InnoDB")
}

func TestMySQLDumpCreateTableHandlingErrorWhenScanningRows(t *testing.T) {
db, mock := getDB(t)
dumper := getInternalMySQLInstance(db)
mock.ExpectQuery("SHOW CREATE TABLE `table`").WillReturnRows(
sqlmock.NewRows([]string{"Table", "Create Table"}).AddRow("table", nil),
)

_, err := dumper.getCreateTableStatement(t.Context(), "table")
// Return an error from the first batch query (table metadata)
mock.ExpectQuery("SELECT.*FROM INFORMATION_SCHEMA.TABLES.*TABLE_TYPE = 'BASE TABLE'").
WillReturnError(errors.New("table not found"))

err := dumper.prefetchAllSchemas(t.Context(), []string{"table"})
assert.NotNil(t, err)
}

Expand Down Expand Up @@ -671,6 +731,9 @@ func Test_mySQL_ignoresTable(t *testing.T) {
AddRow("OLD_table", "BASE TABLE"),
)

// Mock batch prefetch queries (table is ignored, so no schemas to prefetch)
mockPrefetchSchemas(mock)

// Expect SHOW FULL TABLES query for views
mock.ExpectQuery("SHOW FULL TABLES").WillReturnRows(
sqlmock.NewRows([]string{"Tables_in_database", "Table_type"}),
Expand Down Expand Up @@ -706,6 +769,9 @@ func Test_mySQL_dumpsTriggers(t *testing.T) {
sqlmock.NewRows([]string{"Tables_in_database", "Table_type"}),
)

// Mock batch prefetch queries (no tables)
mockPrefetchSchemas(mock)

// Expect SHOW FULL TABLES query for views
mock.ExpectQuery("SHOW FULL TABLES").WillReturnRows(
sqlmock.NewRows([]string{"Tables_in_database", "Table_type"}),
Expand Down Expand Up @@ -755,6 +821,9 @@ func Test_mySQL_dumpsTriggersIgnoresDefiners(t *testing.T) {
sqlmock.NewRows([]string{"Tables_in_database", "Table_type"}),
)

// Mock batch prefetch queries (no tables)
mockPrefetchSchemas(mock)

mock.ExpectQuery("SHOW FULL TABLES").WillReturnRows(
sqlmock.NewRows([]string{"Tables_in_database", "Table_type"}),
)
Expand Down Expand Up @@ -855,6 +924,9 @@ func Test_mySQL_dumpsViews(t *testing.T) {
sqlmock.NewRows([]string{"Tables_in_database", "Table_type"}),
)

// Mock batch prefetch queries (no tables)
mockPrefetchSchemas(mock)

mock.ExpectQuery("SHOW FULL TABLES").WillReturnRows(
sqlmock.NewRows([]string{"Tables_in_database", "Table_type"}).
AddRow("user_view", "VIEW"),
Expand Down Expand Up @@ -900,6 +972,9 @@ func Test_mySQL_dumpsViewsIgnoresDefiners(t *testing.T) {
sqlmock.NewRows([]string{"Tables_in_database", "Table_type"}),
)

// Mock batch prefetch queries (no tables)
mockPrefetchSchemas(mock)

mock.ExpectQuery("SHOW FULL TABLES").WillReturnRows(
sqlmock.NewRows([]string{"Tables_in_database", "Table_type"}).
AddRow("user_view", "VIEW"),
Expand Down
Loading