diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 8dc9910..865814d 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -7,6 +7,7 @@ package logic import ( "context" + "errors" "fmt" "io" "math" @@ -22,6 +23,10 @@ import ( "github.com/github/gh-ost/go/sql" ) +var ( + ErrMigratorUnsupportedRenameAlter = errors.New("ALTER statement seems to RENAME the table. This is not supported, and you should run your RENAME outside gh-ost.") +) + type ChangelogState string const ( @@ -223,28 +228,22 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er case Migrated, ReadMigrationRangeValues: // no-op event case GhostTableMigrated: - { - this.ghostTableMigrated <- true - } + this.ghostTableMigrated <- true case AllEventsUpToLockProcessed: - { - var applyEventFunc tableWriteFunc = func() error { - this.allEventsUpToLockProcessed <- changelogStateString - return nil - } - // at this point we know all events up to lock have been read from the streamer, - // because the streamer works sequentially. So those events are either already handled, - // or have event functions in applyEventsQueue. - // So as not to create a potential deadlock, we write this func to applyEventsQueue - // asynchronously, understanding it doesn't really matter. - go func() { - this.applyEventsQueue <- newApplyEventStructByFunc(&applyEventFunc) - }() + var applyEventFunc tableWriteFunc = func() error { + this.allEventsUpToLockProcessed <- changelogStateString + return nil } + // at this point we know all events up to lock have been read from the streamer, + // because the streamer works sequentially. So those events are either already handled, + // or have event functions in applyEventsQueue. + // So as not to create a potential deadlock, we write this func to applyEventsQueue + // asynchronously, understanding it doesn't really matter. + go func() { + this.applyEventsQueue <- newApplyEventStructByFunc(&applyEventFunc) + }() default: - { - return fmt.Errorf("Unknown changelog state: %+v", changelogState) - } + return fmt.Errorf("Unknown changelog state: %+v", changelogState) } this.migrationContext.Log.Infof("Handled changelog state %s", changelogState) return nil @@ -268,13 +267,13 @@ func (this *Migrator) listenOnPanicAbort() { this.migrationContext.Log.Fatale(err) } -// validateStatement validates the `alter` statement meets criteria. +// validateAlterStatement validates the `alter` statement meets criteria. // At this time this means: // - column renames are approved // - no table rename allowed -func (this *Migrator) validateStatement() (err error) { +func (this *Migrator) validateAlterStatement() (err error) { if this.parser.IsRenameTable() { - return fmt.Errorf("ALTER statement seems to RENAME the table. This is not supported, and you should run your RENAME outside gh-ost.") + return ErrMigratorUnsupportedRenameAlter } if this.parser.HasNonTrivialRenames() && !this.migrationContext.SkipRenamedColumns { this.migrationContext.ColumnRenameMap = this.parser.GetNonTrivialRenames() @@ -352,7 +351,7 @@ func (this *Migrator) Migrate() (err error) { if err := this.parser.ParseAlterStatement(this.migrationContext.AlterStatement); err != nil { return err } - if err := this.validateStatement(); err != nil { + if err := this.validateAlterStatement(); err != nil { return err } @@ -903,6 +902,94 @@ func (this *Migrator) printMigrationStatusHint(writers ...io.Writer) { } } +// getProgressPercent returns an estimate of migration progess as a percent. +func (this *Migrator) getProgressPercent(rowsEstimate int64) (progressPct float64) { + progressPct = 100.0 + if rowsEstimate > 0 { + progressPct *= float64(this.migrationContext.GetTotalRowsCopied()) / float64(rowsEstimate) + } + return progressPct +} + +// getMigrationETA returns the estimated duration of the migration +func (this *Migrator) getMigrationETA(rowsEstimate int64) (eta string, duration time.Duration) { + duration = time.Duration(base.ETAUnknown) + progressPct := this.getProgressPercent(rowsEstimate) + if progressPct >= 100.0 { + duration = 0 + } else if progressPct >= 0.1 { + totalRowsCopied := this.migrationContext.GetTotalRowsCopied() + elapsedRowCopySeconds := this.migrationContext.ElapsedRowCopyTime().Seconds() + totalExpectedSeconds := elapsedRowCopySeconds * float64(rowsEstimate) / float64(totalRowsCopied) + etaSeconds := totalExpectedSeconds - elapsedRowCopySeconds + if etaSeconds >= 0 { + duration = time.Duration(etaSeconds) * time.Second + } else { + duration = 0 + } + } + + switch duration { + case 0: + eta = "due" + case time.Duration(base.ETAUnknown): + eta = "N/A" + default: + eta = base.PrettifyDurationOutput(duration) + } + + return eta, duration +} + +// getMigrationStateAndETA returns the state and eta of the migration. +func (this *Migrator) getMigrationStateAndETA(rowsEstimate int64) (state, eta string, etaDuration time.Duration) { + eta, etaDuration = this.getMigrationETA(rowsEstimate) + state = "migrating" + if atomic.LoadInt64(&this.migrationContext.CountingRowsFlag) > 0 && !this.migrationContext.ConcurrentCountTableRows { + state = "counting rows" + } else if atomic.LoadInt64(&this.migrationContext.IsPostponingCutOver) > 0 { + eta = "due" + state = "postponing cut-over" + } else if isThrottled, throttleReason, _ := this.migrationContext.IsThrottled(); isThrottled { + state = fmt.Sprintf("throttled, %s", throttleReason) + } + return state, eta, etaDuration +} + +// shouldPrintStatus returns true when the migrator is due to print status info. +func (this *Migrator) shouldPrintStatus(rule PrintStatusRule, elapsedSeconds int64, etaDuration time.Duration) (shouldPrint bool) { + if rule != HeuristicPrintStatusRule { + return true + } + + etaSeconds := etaDuration.Seconds() + if elapsedSeconds <= 60 { + shouldPrint = true + } else if etaSeconds <= 60 { + shouldPrint = true + } else if etaSeconds <= 180 { + shouldPrint = (elapsedSeconds%5 == 0) + } else if elapsedSeconds <= 180 { + shouldPrint = (elapsedSeconds%5 == 0) + } else if this.migrationContext.TimeSincePointOfInterest().Seconds() <= 60 { + shouldPrint = (elapsedSeconds%5 == 0) + } else { + shouldPrint = (elapsedSeconds%30 == 0) + } + + return shouldPrint +} + +// shouldPrintMigrationStatus returns true when the migrator is due to print the migration status hint +func (this *Migrator) shouldPrintMigrationStatusHint(rule PrintStatusRule, elapsedSeconds int64) (shouldPrint bool) { + if elapsedSeconds%600 == 0 { + shouldPrint = true + } else if rule == ForcePrintStatusAndHintRule { + shouldPrint = true + } + return shouldPrint +} + // printStatus prints the progress status, and optionally additionally detailed // dump of configuration. // `rule` indicates the type of output expected. @@ -923,81 +1010,21 @@ func (this *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { // and there is no further need to keep updating the value. rowsEstimate = totalRowsCopied } - var progressPct float64 - if rowsEstimate == 0 { - progressPct = 100.0 - } else { - progressPct = 100.0 * float64(totalRowsCopied) / float64(rowsEstimate) - } + // we take the opportunity to update migration context with progressPct + progressPct := this.getProgressPercent(rowsEstimate) this.migrationContext.SetProgressPct(progressPct) + // Before status, let's see if we should print a nice reminder for what exactly we're doing here. - shouldPrintMigrationStatusHint := (elapsedSeconds%600 == 0) - if rule == ForcePrintStatusAndHintRule { - shouldPrintMigrationStatusHint = true - } - if rule == ForcePrintStatusOnlyRule { - shouldPrintMigrationStatusHint = false - } - if shouldPrintMigrationStatusHint { + if this.shouldPrintMigrationStatusHint(rule, elapsedSeconds) { this.printMigrationStatusHint(writers...) } - var etaSeconds float64 = math.MaxFloat64 - var etaDuration = time.Duration(base.ETAUnknown) - if progressPct >= 100.0 { - etaDuration = 0 - } else if progressPct >= 0.1 { - elapsedRowCopySeconds := this.migrationContext.ElapsedRowCopyTime().Seconds() - totalExpectedSeconds := elapsedRowCopySeconds * float64(rowsEstimate) / float64(totalRowsCopied) - etaSeconds = totalExpectedSeconds - elapsedRowCopySeconds - if etaSeconds >= 0 { - etaDuration = time.Duration(etaSeconds) * time.Second - } else { - etaDuration = 0 - } - } + // Get state + ETA + state, eta, etaDuration := this.getMigrationStateAndETA(rowsEstimate) this.migrationContext.SetETADuration(etaDuration) - var eta string - switch etaDuration { - case 0: - eta = "due" - case time.Duration(base.ETAUnknown): - eta = "N/A" - default: - eta = base.PrettifyDurationOutput(etaDuration) - } - state := "migrating" - if atomic.LoadInt64(&this.migrationContext.CountingRowsFlag) > 0 && !this.migrationContext.ConcurrentCountTableRows { - state = "counting rows" - } else if atomic.LoadInt64(&this.migrationContext.IsPostponingCutOver) > 0 { - eta = "due" - state = "postponing cut-over" - } else if isThrottled, throttleReason, _ := this.migrationContext.IsThrottled(); isThrottled { - state = fmt.Sprintf("throttled, %s", throttleReason) - } - - var shouldPrintStatus bool - if rule == HeuristicPrintStatusRule { - if elapsedSeconds <= 60 { - shouldPrintStatus = true - } else if etaSeconds <= 60 { - shouldPrintStatus = true - } else if etaSeconds <= 180 { - shouldPrintStatus = (elapsedSeconds%5 == 0) - } else if elapsedSeconds <= 180 { - shouldPrintStatus = (elapsedSeconds%5 == 0) - } else if this.migrationContext.TimeSincePointOfInterest().Seconds() <= 60 { - shouldPrintStatus = (elapsedSeconds%5 == 0) - } else { - shouldPrintStatus = (elapsedSeconds%30 == 0) - } - } else { - // Not heuristic - shouldPrintStatus = true - } - if !shouldPrintStatus { + if !this.shouldPrintStatus(rule, elapsedSeconds, etaDuration) { return } @@ -1016,7 +1043,7 @@ func (this *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { ) this.applier.WriteChangelog( fmt.Sprintf("copy iteration %d at %d", this.migrationContext.GetIteration(), time.Now().Unix()), - status, + state, ) w := io.MultiWriter(writers...) fmt.Fprintln(w, status) diff --git a/go/logic/migrator_test.go b/go/logic/migrator_test.go new file mode 100644 index 0000000..242a749 --- /dev/null +++ b/go/logic/migrator_test.go @@ -0,0 +1,256 @@ +/* + Copyright 2022 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package logic + +import ( + "errors" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/openark/golib/tests" + + "github.com/github/gh-ost/go/base" + "github.com/github/gh-ost/go/binlog" + "github.com/github/gh-ost/go/sql" +) + +func TestMigratorOnChangelogEvent(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.2.3") + + t.Run("heartbeat", func(t *testing.T) { + columnValues := sql.ToColumnValues([]interface{}{ + 123, + time.Now().Unix(), + "heartbeat", + "2022-08-16T00:45:10.52Z", + }) + tests.S(t).ExpectNil(migrator.onChangelogEvent(&binlog.BinlogDMLEvent{ + DatabaseName: "test", + DML: binlog.InsertDML, + NewColumnValues: columnValues, + })) + }) + + t.Run("state-AllEventsUpToLockProcessed", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(1) + go func(wg *sync.WaitGroup) { + defer wg.Done() + es := <-migrator.applyEventsQueue + tests.S(t).ExpectNotNil(es) + tests.S(t).ExpectNotNil(es.writeFunc) + }(&wg) + + columnValues := sql.ToColumnValues([]interface{}{ + 123, + time.Now().Unix(), + "state", + AllEventsUpToLockProcessed, + }) + tests.S(t).ExpectNil(migrator.onChangelogEvent(&binlog.BinlogDMLEvent{ + DatabaseName: "test", + DML: binlog.InsertDML, + NewColumnValues: columnValues, + })) + wg.Wait() + }) + + t.Run("state-GhostTableMigrated", func(t *testing.T) { + go func() { + tests.S(t).ExpectTrue(<-migrator.ghostTableMigrated) + }() + + columnValues := sql.ToColumnValues([]interface{}{ + 123, + time.Now().Unix(), + "state", + GhostTableMigrated, + }) + tests.S(t).ExpectNil(migrator.onChangelogEvent(&binlog.BinlogDMLEvent{ + DatabaseName: "test", + DML: binlog.InsertDML, + NewColumnValues: columnValues, + })) + }) + + t.Run("state-Migrated", func(t *testing.T) { + columnValues := sql.ToColumnValues([]interface{}{ + 123, + time.Now().Unix(), + "state", + Migrated, + }) + tests.S(t).ExpectNil(migrator.onChangelogEvent(&binlog.BinlogDMLEvent{ + DatabaseName: "test", + DML: binlog.InsertDML, + NewColumnValues: columnValues, + })) + }) + + t.Run("state-ReadMigrationRangeValues", func(t *testing.T) { + columnValues := sql.ToColumnValues([]interface{}{ + 123, + time.Now().Unix(), + "state", + ReadMigrationRangeValues, + }) + tests.S(t).ExpectNil(migrator.onChangelogEvent(&binlog.BinlogDMLEvent{ + DatabaseName: "test", + DML: binlog.InsertDML, + NewColumnValues: columnValues, + })) + }) +} + +func TestMigratorValidateStatement(t *testing.T) { + t.Run("add-column", func(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.2.3") + tests.S(t).ExpectNil(migrator.parser.ParseAlterStatement(`ALTER TABLE test ADD test_new VARCHAR(64) NOT NULL`)) + + tests.S(t).ExpectNil(migrator.validateAlterStatement()) + tests.S(t).ExpectEquals(len(migrator.migrationContext.DroppedColumnsMap), 0) + }) + + t.Run("drop-column", func(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.2.3") + tests.S(t).ExpectNil(migrator.parser.ParseAlterStatement(`ALTER TABLE test DROP abc`)) + + tests.S(t).ExpectNil(migrator.validateAlterStatement()) + tests.S(t).ExpectEquals(len(migrator.migrationContext.DroppedColumnsMap), 1) + _, exists := migrator.migrationContext.DroppedColumnsMap["abc"] + tests.S(t).ExpectTrue(exists) + }) + + t.Run("rename-column", func(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.2.3") + tests.S(t).ExpectNil(migrator.parser.ParseAlterStatement(`ALTER TABLE test CHANGE test123 test1234 bigint unsigned`)) + + err := migrator.validateAlterStatement() + tests.S(t).ExpectNotNil(err) + tests.S(t).ExpectTrue(strings.HasPrefix(err.Error(), "gh-ost believes the ALTER statement renames columns")) + tests.S(t).ExpectEquals(len(migrator.migrationContext.DroppedColumnsMap), 0) + }) + + t.Run("rename-column-approved", func(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.2.3") + migrator.migrationContext.ApproveRenamedColumns = true + tests.S(t).ExpectNil(migrator.parser.ParseAlterStatement(`ALTER TABLE test CHANGE test123 test1234 bigint unsigned`)) + + tests.S(t).ExpectNil(migrator.validateAlterStatement()) + tests.S(t).ExpectEquals(len(migrator.migrationContext.DroppedColumnsMap), 0) + }) + + t.Run("rename-table", func(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.2.3") + tests.S(t).ExpectNil(migrator.parser.ParseAlterStatement(`ALTER TABLE test RENAME TO test_new`)) + + err := migrator.validateAlterStatement() + tests.S(t).ExpectNotNil(err) + tests.S(t).ExpectTrue(errors.Is(err, ErrMigratorUnsupportedRenameAlter)) + tests.S(t).ExpectEquals(len(migrator.migrationContext.DroppedColumnsMap), 0) + }) +} + +func TestMigratorCreateFlagFiles(t *testing.T) { + tmpdir, err := os.MkdirTemp("", t.Name()) + if err != nil { + panic(err) + } + defer os.RemoveAll(tmpdir) + + migrationContext := base.NewMigrationContext() + migrationContext.PostponeCutOverFlagFile = filepath.Join(tmpdir, "cut-over.flag") + migrator := NewMigrator(migrationContext, "1.2.3") + tests.S(t).ExpectNil(migrator.createFlagFiles()) + tests.S(t).ExpectNil(migrator.createFlagFiles()) // twice to test already-exists + + _, err = os.Stat(migrationContext.PostponeCutOverFlagFile) + tests.S(t).ExpectNil(err) +} + +func TestMigratorGetProgressPercent(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.2.3") + + { + tests.S(t).ExpectEquals(migrator.getProgressPercent(0), float64(100.0)) + } + { + migrationContext.TotalRowsCopied = 250 + tests.S(t).ExpectEquals(migrator.getProgressPercent(1000), float64(25.0)) + } +} + +func TestMigratorGetMigrationStateAndETA(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.2.3") + now := time.Now() + migrationContext.RowCopyStartTime = now.Add(-time.Minute) + migrationContext.RowCopyEndTime = now + + { + migrationContext.TotalRowsCopied = 456 + state, eta, etaDuration := migrator.getMigrationStateAndETA(123456) + tests.S(t).ExpectEquals(state, "migrating") + tests.S(t).ExpectEquals(eta, "4h29m44s") + tests.S(t).ExpectEquals(etaDuration.String(), "4h29m44s") + } + { + migrationContext.TotalRowsCopied = 456 + state, eta, etaDuration := migrator.getMigrationStateAndETA(456) + tests.S(t).ExpectEquals(state, "migrating") + tests.S(t).ExpectEquals(eta, "due") + tests.S(t).ExpectEquals(etaDuration.String(), "0s") + } + { + migrationContext.TotalRowsCopied = 123456 + state, eta, etaDuration := migrator.getMigrationStateAndETA(456) + tests.S(t).ExpectEquals(state, "migrating") + tests.S(t).ExpectEquals(eta, "due") + tests.S(t).ExpectEquals(etaDuration.String(), "0s") + } + { + atomic.StoreInt64(&migrationContext.CountingRowsFlag, 1) + state, eta, etaDuration := migrator.getMigrationStateAndETA(123456) + tests.S(t).ExpectEquals(state, "counting rows") + tests.S(t).ExpectEquals(eta, "due") + tests.S(t).ExpectEquals(etaDuration.String(), "0s") + } + { + atomic.StoreInt64(&migrationContext.CountingRowsFlag, 0) + atomic.StoreInt64(&migrationContext.IsPostponingCutOver, 1) + state, eta, etaDuration := migrator.getMigrationStateAndETA(123456) + tests.S(t).ExpectEquals(state, "postponing cut-over") + tests.S(t).ExpectEquals(eta, "due") + tests.S(t).ExpectEquals(etaDuration.String(), "0s") + } +} + +func TestMigratorShouldPrintStatus(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.2.3") + + tests.S(t).ExpectTrue(migrator.shouldPrintStatus(NoPrintStatusRule, 10, time.Second)) // test 'rule != HeuristicPrintStatusRule' return + tests.S(t).ExpectTrue(migrator.shouldPrintStatus(HeuristicPrintStatusRule, 10, time.Second)) // test 'etaDuration.Seconds() <= 60' + tests.S(t).ExpectTrue(migrator.shouldPrintStatus(HeuristicPrintStatusRule, 90, time.Second)) // test 'etaDuration.Seconds() <= 60' again + tests.S(t).ExpectTrue(migrator.shouldPrintStatus(HeuristicPrintStatusRule, 90, time.Minute)) // test 'etaDuration.Seconds() <= 180' + tests.S(t).ExpectTrue(migrator.shouldPrintStatus(HeuristicPrintStatusRule, 60, 90*time.Second)) // test 'elapsedSeconds <= 180' + tests.S(t).ExpectFalse(migrator.shouldPrintStatus(HeuristicPrintStatusRule, 61, 90*time.Second)) // test 'elapsedSeconds <= 180' + tests.S(t).ExpectFalse(migrator.shouldPrintStatus(HeuristicPrintStatusRule, 99, 210*time.Second)) // test 'elapsedSeconds <= 180' + tests.S(t).ExpectFalse(migrator.shouldPrintStatus(HeuristicPrintStatusRule, 12345, 86400*time.Second)) // test 'else' + tests.S(t).ExpectTrue(migrator.shouldPrintStatus(HeuristicPrintStatusRule, 30030, 86400*time.Second)) // test 'else' again +}