diff --git a/go/base/context.go b/go/base/context.go index 1c8b17a..38c82a3 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -82,6 +82,8 @@ type MigrationContext struct { AlterStatement string AlterStatementOptions string // anything following the 'ALTER TABLE [schema.]table' from AlterStatement + countMutex sync.Mutex + countTableRowsCancelFunc func() CountTableRows bool ConcurrentCountTableRows bool AllowedRunningOnMaster bool @@ -428,6 +430,36 @@ func (this *MigrationContext) IsTransactionalTable() bool { return false } +// SetCountTableRowsCancelFunc sets the cancel function for the CountTableRows query context +func (this *MigrationContext) SetCountTableRowsCancelFunc(f func()) { + this.countMutex.Lock() + defer this.countMutex.Unlock() + + this.countTableRowsCancelFunc = f +} + +// IsCountingTableRows returns true if the migration has a table count query running +func (this *MigrationContext) IsCountingTableRows() bool { + this.countMutex.Lock() + defer this.countMutex.Unlock() + + return this.countTableRowsCancelFunc != nil +} + +// CancelTableRowsCount cancels the CountTableRows query context. It is safe to +// call function even when IsCountingTableRows is false. +func (this *MigrationContext) CancelTableRowsCount() { + this.countMutex.Lock() + defer this.countMutex.Unlock() + + if this.countTableRowsCancelFunc == nil { + return + } + + this.countTableRowsCancelFunc() + this.countTableRowsCancelFunc = nil +} + // ElapsedTime returns time since very beginning of the process func (this *MigrationContext) ElapsedTime() time.Duration { return time.Since(this.StartTime) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index e66d673..becda18 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -6,6 +6,7 @@ package logic import ( + "context" gosql "database/sql" "fmt" "reflect" @@ -532,18 +533,48 @@ func (this *Inspector) estimateTableRowsViaExplain() error { return nil } +// Kill kills a query for connectionID. +// - @amason: this should go somewhere _other_ than `logic`, but I couldn't decide +// between `base`, `sql`, or `mysql`. +func Kill(db *gosql.DB, connectionID string) error { + _, err := db.Exec(`KILL QUERY %s`, connectionID) + return err +} + // CountTableRows counts exact number of rows on the original table -func (this *Inspector) CountTableRows() error { +func (this *Inspector) CountTableRows(ctx context.Context) error { atomic.StoreInt64(&this.migrationContext.CountingRowsFlag, 1) defer atomic.StoreInt64(&this.migrationContext.CountingRowsFlag, 0) this.migrationContext.Log.Infof("As instructed, I'm issuing a SELECT COUNT(*) on the table. This may take a while") - query := fmt.Sprintf(`select /* gh-ost */ count(*) as count_rows from %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) - var rowsEstimate int64 - if err := this.db.QueryRow(query).Scan(&rowsEstimate); err != nil { + conn, err := this.db.Conn(ctx) + if err != nil { return err } + defer conn.Close() + + var connectionID string + if err := conn.QueryRowContext(ctx, `SELECT /* gh-ost */ CONNECTION_ID()`).Scan(&connectionID); err != nil { + return err + } + + query := fmt.Sprintf(`select /* gh-ost */ count(*) as count_rows from %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) + var rowsEstimate int64 + if err := conn.QueryRowContext(ctx, query).Scan(&rowsEstimate); err != nil { + switch err { + case context.Canceled, context.DeadlineExceeded: + this.migrationContext.Log.Infof("exact row count cancelled (%s), likely because I'm about to cut over. I'm going to kill that query.", ctx.Err()) + return Kill(this.db, connectionID) + default: + return err + } + } + + // row count query finished. nil out the cancel func, so the main migration thread + // doesn't bother calling it after row copy is done. + this.migrationContext.SetCountTableRowsCancelFunc(nil) + atomic.StoreInt64(&this.migrationContext.RowsEstimate, rowsEstimate) this.migrationContext.UsedRowsEstimateMethod = base.CountRowsEstimate diff --git a/go/logic/migrator.go b/go/logic/migrator.go index e1fe7d1..bc2a03f 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -6,6 +6,7 @@ package logic import ( + "context" "fmt" "io" "math" @@ -295,8 +296,8 @@ func (this *Migrator) countTableRows() (err error) { return nil } - countRowsFunc := func() error { - if err := this.inspector.CountTableRows(); err != nil { + countRowsFunc := func(ctx context.Context) error { + if err := this.inspector.CountTableRows(ctx); err != nil { return err } if err := this.hooksExecutor.onRowCountComplete(); err != nil { @@ -306,12 +307,17 @@ func (this *Migrator) countTableRows() (err error) { } if this.migrationContext.ConcurrentCountTableRows { + // store a cancel func so we can stop this query before a cut over + rowCountContext, rowCountCancel := context.WithCancel(context.Background()) + this.migrationContext.SetCountTableRowsCancelFunc(rowCountCancel) + this.migrationContext.Log.Infof("As instructed, counting rows in the background; meanwhile I will use an estimated count, and will update it later on") - go countRowsFunc() + go countRowsFunc(rowCountContext) + // and we ignore errors, because this turns to be a background job return nil } - return countRowsFunc() + return countRowsFunc(context.Background()) } func (this *Migrator) createFlagFiles() (err error) { @@ -415,6 +421,10 @@ func (this *Migrator) Migrate() (err error) { } this.printStatus(ForcePrintStatusRule) + if this.migrationContext.IsCountingTableRows() { + this.migrationContext.Log.Info("stopping query for exact row count, because that can accidentally lock out the cut over") + this.migrationContext.CancelTableRowsCount() + } if err := this.hooksExecutor.onBeforeCutOver(); err != nil { return err }