Skip to content

Commit 35ba1c4

Browse files
committed
Update RollbackAll func
1 parent eb49b7d commit 35ba1c4

File tree

1 file changed

+25
-51
lines changed

1 file changed

+25
-51
lines changed

pkg/dbmate/db.go

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -442,65 +442,31 @@ func (db *DB) MigrateNext() error {
442442

443443
// RollbackAll rolls back every applied migration (latest-first) until none remain
444444
func (db *DB) RollbackAll() error {
445-
drv, err := db.Driver()
446-
if err != nil {
447-
return err
448-
}
449-
450-
sqlDB, err := db.openDatabaseForMigration(drv)
451-
if err != nil {
452-
return err
453-
}
454-
defer dbutil.MustClose(sqlDB)
455-
456445
migrations, err := db.FindMigrations()
457446
if err != nil {
458447
return err
459448
}
460-
461-
applied := make([]*Migration, 0, len(migrations))
449+
var appliedVersions []string
462450
for i := len(migrations) - 1; i >= 0; i-- {
463451
if migrations[i].Applied {
464-
applied = append(applied, &migrations[i])
452+
appliedVersions = append(appliedVersions, migrations[i].Version)
465453
}
466454
}
467-
if len(applied) == 0 {
455+
if len(appliedVersions) == 0 {
468456
return ErrNoRollback
469457
}
458+
origAuto := db.AutoDumpSchema
459+
db.AutoDumpSchema = false
460+
defer func() { db.AutoDumpSchema = origAuto }()
470461

471-
for _, mig := range applied {
472-
fmt.Fprintf(db.Log, "Rolling back: %s\n", mig.FileName)
473-
start := time.Now()
474-
parsed, err := mig.Parse()
475-
if err != nil {
476-
return err
477-
}
478-
execMigration := func(tx dbutil.Transaction) error {
479-
result, err := tx.Exec(parsed.Down)
480-
if err != nil {
481-
return drv.QueryError(parsed.Down, err)
482-
} else if db.Verbose {
483-
db.printVerbose(result)
484-
}
485-
return drv.DeleteMigration(tx, mig.Version)
486-
}
487-
if parsed.DownOptions.Transaction() {
488-
err = doTransaction(sqlDB, execMigration)
489-
} else {
490-
err = execMigration(sqlDB)
491-
}
492-
elapsed := time.Since(start)
493-
fmt.Fprintf(db.Log, "Rolled back: %s in %s\n", mig.FileName, elapsed)
494-
495-
if err != nil {
462+
for _, v := range appliedVersions {
463+
if err := db.RollbackOnly(migrations, v); err != nil {
496464
return err
497465
}
498466
}
499-
500-
if db.AutoDumpSchema {
467+
if origAuto {
501468
_ = db.DumpSchema()
502469
}
503-
504470
return nil
505471
}
506472

@@ -656,7 +622,7 @@ func (db *DB) MigrateOnly(migrations []Migration, version string) error {
656622
if err != nil {
657623
return err
658624
}
659-
exec := func(tx dbutil.Transaction) error {
625+
execMigration := func(tx dbutil.Transaction) error {
660626
res, err := tx.Exec(parsed.Up)
661627
if err != nil {
662628
return drv.QueryError(parsed.Up, err)
@@ -667,15 +633,17 @@ func (db *DB) MigrateOnly(migrations []Migration, version string) error {
667633
}
668634

669635
if parsed.UpOptions.Transaction() {
670-
err = doTransaction(sqlDB, exec)
636+
err = doTransaction(sqlDB, execMigration)
671637
} else {
672-
err = exec(sqlDB)
638+
err = execMigration(sqlDB)
673639
}
674640

675641
fmt.Fprintf(db.Log, "Applied: %s in %s\n", target.FileName, time.Since(start))
676642
if err != nil {
677643
return err
678644
}
645+
646+
// automatically update schema file, silence errors
679647
if db.AutoDumpSchema {
680648
_ = db.DumpSchema()
681649
}
@@ -716,26 +684,32 @@ func (db *DB) RollbackOnly(migrations []Migration, version string) error {
716684
if err != nil {
717685
return err
718686
}
719-
exec := func(tx dbutil.Transaction) error {
720-
res, err := tx.Exec(parsed.Down)
687+
execMigration := func(tx dbutil.Transaction) error {
688+
result, err := tx.Exec(parsed.Down)
721689
if err != nil {
722690
return drv.QueryError(parsed.Down, err)
723691
} else if db.Verbose {
724-
db.printVerbose(res)
692+
db.printVerbose(result)
725693
}
694+
695+
// remove migration record
726696
return drv.DeleteMigration(tx, target.Version)
727697
}
728698

729699
if parsed.DownOptions.Transaction() {
730-
err = doTransaction(sqlDB, exec)
700+
// begin transaction
701+
err = doTransaction(sqlDB, execMigration)
731702
} else {
732-
err = exec(sqlDB)
703+
// run outside of transaction
704+
err = execMigration(sqlDB)
733705
}
734706

735707
fmt.Fprintf(db.Log, "Rolled back: %s in %s\n", target.FileName, time.Since(start))
736708
if err != nil {
737709
return err
738710
}
711+
712+
// automatically update schema file, silence errors
739713
if db.AutoDumpSchema {
740714
_ = db.DumpSchema()
741715
}

0 commit comments

Comments
 (0)