diff --git a/go/base/context.go b/go/base/context.go index 71bb3d0..23fe6f6 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -14,6 +14,8 @@ import ( "sync/atomic" "time" + "github.com/satori/go.uuid" + "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" @@ -71,6 +73,8 @@ func NewThrottleCheckResult(throttle bool, reason string, reasonHint ThrottleRea // MigrationContext has the general, global state of migration. It is used by // all components throughout the migration process. type MigrationContext struct { + Uuid string + DatabaseName string OriginalTableName string AlterStatement string @@ -195,8 +199,6 @@ type MigrationContext struct { ForceTmpTableName string recentBinlogCoordinates mysql.BinlogCoordinates - - CanStopStreaming func() bool } type ContextConfig struct { @@ -212,14 +214,9 @@ type ContextConfig struct { } } -var context *MigrationContext - -func init() { - context = newMigrationContext() -} - -func newMigrationContext() *MigrationContext { +func NewMigrationContext() *MigrationContext { return &MigrationContext{ + Uuid: uuid.NewV4().String(), defaultNumRetries: 60, ChunkSize: 1000, InspectorConnectionConfig: mysql.NewConnectionConfig(), @@ -239,11 +236,6 @@ func newMigrationContext() *MigrationContext { } } -// GetMigrationContext -func GetMigrationContext() *MigrationContext { - return context -} - func getSafeTableName(baseName string, suffix string) string { name := fmt.Sprintf("_%s_%s", baseName, suffix) if len(name) <= mysql.MaxTableNameLength { diff --git a/go/base/context_test.go b/go/base/context_test.go index b3a98eb..8a9c6a5 100644 --- a/go/base/context_test.go +++ b/go/base/context_test.go @@ -19,27 +19,27 @@ func init() { func TestGetTableNames(t *testing.T) { { - context = newMigrationContext() + context := NewMigrationContext() context.OriginalTableName = "some_table" test.S(t).ExpectEquals(context.GetOldTableName(), "_some_table_del") test.S(t).ExpectEquals(context.GetGhostTableName(), "_some_table_gho") test.S(t).ExpectEquals(context.GetChangelogTableName(), "_some_table_ghc") } { - context = newMigrationContext() + context := NewMigrationContext() context.OriginalTableName = "a123456789012345678901234567890123456789012345678901234567890" test.S(t).ExpectEquals(context.GetOldTableName(), "_a1234567890123456789012345678901234567890123456789012345678_del") test.S(t).ExpectEquals(context.GetGhostTableName(), "_a1234567890123456789012345678901234567890123456789012345678_gho") test.S(t).ExpectEquals(context.GetChangelogTableName(), "_a1234567890123456789012345678901234567890123456789012345678_ghc") } { - context = newMigrationContext() + context := NewMigrationContext() context.OriginalTableName = "a123456789012345678901234567890123456789012345678901234567890123" oldTableName := context.GetOldTableName() test.S(t).ExpectEquals(oldTableName, "_a1234567890123456789012345678901234567890123456789012345678_del") } { - context = newMigrationContext() + context := NewMigrationContext() context.OriginalTableName = "a123456789012345678901234567890123456789012345678901234567890123" context.TimestampOldTable = true longForm := "Jan 2, 2006 at 3:04pm (MST)" @@ -48,7 +48,7 @@ func TestGetTableNames(t *testing.T) { test.S(t).ExpectEquals(oldTableName, "_a1234567890123456789012345678901234567890123_20130203195400_del") } { - context = newMigrationContext() + context := NewMigrationContext() context.OriginalTableName = "foo_bar_baz" context.ForceTmpTableName = "tmp" test.S(t).ExpectEquals(context.GetOldTableName(), "_tmp_del") diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index e218823..6016f81 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -26,28 +26,26 @@ type GoMySQLReader struct { currentCoordinates mysql.BinlogCoordinates currentCoordinatesMutex *sync.Mutex LastAppliedRowsEventHint mysql.BinlogCoordinates - MigrationContext *base.MigrationContext } -func NewGoMySQLReader(connectionConfig *mysql.ConnectionConfig) (binlogReader *GoMySQLReader, err error) { +func NewGoMySQLReader(migrationContext *base.MigrationContext) (binlogReader *GoMySQLReader, err error) { binlogReader = &GoMySQLReader{ - connectionConfig: connectionConfig, + connectionConfig: migrationContext.InspectorConnectionConfig, currentCoordinates: mysql.BinlogCoordinates{}, currentCoordinatesMutex: &sync.Mutex{}, binlogSyncer: nil, binlogStreamer: nil, - MigrationContext: base.GetMigrationContext(), } - serverId := uint32(binlogReader.MigrationContext.ReplicaServerId) + serverId := uint32(migrationContext.ReplicaServerId) binlogSyncerConfig := &replication.BinlogSyncerConfig{ ServerID: serverId, Flavor: "mysql", - Host: connectionConfig.Key.Hostname, - Port: uint16(connectionConfig.Key.Port), - User: connectionConfig.User, - Password: connectionConfig.Password, + Host: binlogReader.connectionConfig.Key.Hostname, + Port: uint16(binlogReader.connectionConfig.Key.Port), + User: binlogReader.connectionConfig.User, + Password: binlogReader.connectionConfig.Password, } binlogReader.binlogSyncer = replication.NewBinlogSyncer(binlogSyncerConfig) @@ -160,10 +158,6 @@ func (this *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesCha } func (this *GoMySQLReader) Close() error { - // Historically there was a: - // this.binlogSyncer.Close() - // here. A new go-mysql version closes the binlog syncer connection independently. - // I will go against the sacred rules of comments and just leave this here. - // This is the year 2017. Let's see what year these comments get deleted. + this.binlogSyncer.Close() return nil } diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 04aa316..74d0f05 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -43,7 +43,7 @@ func acceptSignals(migrationContext *base.MigrationContext) { // main is the application's entry point. It will either spawn a CLI or HTTP interfaces. func main() { - migrationContext := base.GetMigrationContext() + migrationContext := base.NewMigrationContext() flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)") flag.StringVar(&migrationContext.AssumeMasterHostname, "assume-master-host", "", "(optional) explicitly tell gh-ost the identity of the master. Format: some.host.com[:port] This is useful in master-master setups where you wish to pick an explicit master, or in a tungsten-replicator where gh-ost is unable to determine the master") @@ -242,7 +242,7 @@ func main() { log.Infof("starting gh-ost %+v", AppVersion) acceptSignals(migrationContext) - migrator := logic.NewMigrator() + migrator := logic.NewMigrator(migrationContext) err := migrator.Migrate() if err != nil { migrator.ExecOnFailureHook() diff --git a/go/logic/applier.go b/go/logic/applier.go index 971f068..227b59e 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -52,26 +52,29 @@ func newDmlBuildResultError(err error) *dmlBuildResult { // Applier is the one to actually write row data and apply binlog events onto the ghost table. // It is where the ghost & changelog tables get created. It is where the cut-over phase happens. type Applier struct { - connectionConfig *mysql.ConnectionConfig - db *gosql.DB - singletonDB *gosql.DB - migrationContext *base.MigrationContext + connectionConfig *mysql.ConnectionConfig + db *gosql.DB + singletonDB *gosql.DB + migrationContext *base.MigrationContext + finishedMigrating int64 } -func NewApplier() *Applier { +func NewApplier(migrationContext *base.MigrationContext) *Applier { return &Applier{ - connectionConfig: base.GetMigrationContext().ApplierConnectionConfig, - migrationContext: base.GetMigrationContext(), + connectionConfig: migrationContext.ApplierConnectionConfig, + migrationContext: migrationContext, + finishedMigrating: 0, } } func (this *Applier) InitDBConnections() (err error) { + applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = sqlutils.GetDB(applierUri); err != nil { + if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil { return err } singletonApplierUri := fmt.Sprintf("%s?timeout=0", applierUri) - if this.singletonDB, _, err = sqlutils.GetDB(singletonApplierUri); err != nil { + if this.singletonDB, _, err = mysql.GetDB(this.migrationContext.Uuid, singletonApplierUri); err != nil { return err } this.singletonDB.SetMaxOpenConns(1) @@ -320,6 +323,9 @@ func (this *Applier) InitiateHeartbeat() { heartbeatTick := time.Tick(time.Duration(this.migrationContext.HeartbeatIntervalMilliseconds) * time.Millisecond) for range heartbeatTick { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } // Generally speaking, we would issue a goroutine, but I'd actually rather // have this block the loop rather than spam the master in the event something // goes wrong @@ -1074,3 +1080,10 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) log.Debugf("ApplyDMLEventQueries() applied %d events in one transaction", len(dmlEvents)) return nil } + +func (this *Applier) Teardown() { + log.Debugf("Tearing down...") + this.db.Close() + this.singletonDB.Close() + atomic.StoreInt64(&this.finishedMigrating, 1) +} diff --git a/go/logic/hooks.go b/go/logic/hooks.go index 58825ee..1fdfd5c 100644 --- a/go/logic/hooks.go +++ b/go/logic/hooks.go @@ -37,9 +37,9 @@ type HooksExecutor struct { migrationContext *base.MigrationContext } -func NewHooksExecutor() *HooksExecutor { +func NewHooksExecutor(migrationContext *base.MigrationContext) *HooksExecutor { return &HooksExecutor{ - migrationContext: base.GetMigrationContext(), + migrationContext: migrationContext, } } diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 3222481..31c81dc 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -26,23 +26,30 @@ const startSlavePostWaitMilliseconds = 500 * time.Millisecond // Inspector reads data from the read-MySQL-server (typically a replica, but can be the master) // It is used for gaining initial status and structure, and later also follow up on progress and changelog type Inspector struct { - connectionConfig *mysql.ConnectionConfig - db *gosql.DB - migrationContext *base.MigrationContext + connectionConfig *mysql.ConnectionConfig + db *gosql.DB + informationSchemaDb *gosql.DB + migrationContext *base.MigrationContext } -func NewInspector() *Inspector { +func NewInspector(migrationContext *base.MigrationContext) *Inspector { return &Inspector{ - connectionConfig: base.GetMigrationContext().InspectorConnectionConfig, - migrationContext: base.GetMigrationContext(), + connectionConfig: migrationContext.InspectorConnectionConfig, + migrationContext: migrationContext, } } func (this *Inspector) InitDBConnections() (err error) { inspectorUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = sqlutils.GetDB(inspectorUri); err != nil { + if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, inspectorUri); err != nil { return err } + + informationSchemaUri := this.connectionConfig.GetDBUri("information_schema") + if this.informationSchemaDb, _, err = mysql.GetDB(this.migrationContext.Uuid, informationSchemaUri); err != nil { + return err + } + if err := this.validateConnection(); err != nil { return err } @@ -749,7 +756,14 @@ func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.Connect func (this *Inspector) getReplicationLag() (replicationLag time.Duration, err error) { replicationLag, err = mysql.GetReplicationLag( + this.informationSchemaDb, this.migrationContext.InspectorConnectionConfig, ) return replicationLag, err } + +func (this *Inspector) Teardown() { + this.db.Close() + this.informationSchemaDb.Close() + return +} diff --git a/go/logic/migrator.go b/go/logic/migrator.go index e9f4cb8..3937a45 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -83,11 +83,13 @@ type Migrator struct { applyEventsQueue chan *applyEventStruct handledChangelogStates map[string]bool + + finishedMigrating int64 } -func NewMigrator() *Migrator { +func NewMigrator(context *base.MigrationContext) *Migrator { migrator := &Migrator{ - migrationContext: base.GetMigrationContext(), + migrationContext: context, parser: sql.NewParser(), ghostTableMigrated: make(chan bool), firstThrottlingCollected: make(chan bool, 3), @@ -97,13 +99,14 @@ func NewMigrator() *Migrator { copyRowsQueue: make(chan tableWriteFunc), applyEventsQueue: make(chan *applyEventStruct, base.MaxEventsBatchSize), handledChangelogStates: make(map[string]bool), + finishedMigrating: 0, } return migrator } // initiateHooksExecutor func (this *Migrator) initiateHooksExecutor() (err error) { - this.hooksExecutor = NewHooksExecutor() + this.hooksExecutor = NewHooksExecutor(this.migrationContext) if err := this.hooksExecutor.initHooks(); err != nil { return err } @@ -299,6 +302,11 @@ func (this *Migrator) Migrate() (err error) { if err := this.validateStatement(); err != nil { return err } + + // After this point, we'll need to teardown anything that's been started + // so we don't leave things hanging around + defer this.teardown() + if err := this.initiateInspector(); err != nil { return err } @@ -653,7 +661,7 @@ func (this *Migrator) initiateServer() (err error) { var f printStatusFunc = func(rule PrintStatusRule, writer io.Writer) { this.printStatus(rule, writer) } - this.server = NewServer(this.hooksExecutor, f) + this.server = NewServer(this.migrationContext, this.hooksExecutor, f) if err := this.server.BindSocketFile(); err != nil { return err } @@ -673,7 +681,7 @@ func (this *Migrator) initiateServer() (err error) { // - heartbeat // When `--allow-on-master` is supplied, the inspector is actually the master. func (this *Migrator) initiateInspector() (err error) { - this.inspector = NewInspector() + this.inspector = NewInspector(this.migrationContext) if err := this.inspector.InitDBConnections(); err != nil { return err } @@ -733,6 +741,9 @@ func (this *Migrator) initiateStatus() error { this.printStatus(ForcePrintStatusAndHintRule) statusTick := time.Tick(1 * time.Second) for range statusTick { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return nil + } go this.printStatus(HeuristicPrintStatusRule) } @@ -932,7 +943,7 @@ func (this *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { // initiateStreaming begins streaming of binary log events and registers listeners for such events func (this *Migrator) initiateStreaming() error { - this.eventsStreamer = NewEventsStreamer() + this.eventsStreamer = NewEventsStreamer(this.migrationContext) if err := this.eventsStreamer.InitDBConnections(); err != nil { return err } @@ -957,6 +968,9 @@ func (this *Migrator) initiateStreaming() error { go func() { ticker := time.Tick(1 * time.Second) for range ticker { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } this.migrationContext.SetRecentBinlogCoordinates(*this.eventsStreamer.GetCurrentBinlogCoordinates()) } }() @@ -980,7 +994,7 @@ func (this *Migrator) addDMLEventsListener() error { // initiateThrottler kicks in the throttling collection and the throttling checks. func (this *Migrator) initiateThrottler() error { - this.throttler = NewThrottler(this.applier, this.inspector) + this.throttler = NewThrottler(this.migrationContext, this.applier, this.inspector) go this.throttler.initiateThrottlerCollection(this.firstThrottlingCollected) log.Infof("Waiting for first throttle metrics to be collected") @@ -994,7 +1008,7 @@ func (this *Migrator) initiateThrottler() error { } func (this *Migrator) initiateApplier() error { - this.applier = NewApplier() + this.applier = NewApplier(this.migrationContext) if err := this.applier.InitDBConnections(); err != nil { return err } @@ -1147,6 +1161,10 @@ func (this *Migrator) executeWriteFuncs() error { return nil } for { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return nil + } + this.throttler.throttle(nil) // We give higher priority to event processing, then secondary priority to @@ -1226,3 +1244,27 @@ func (this *Migrator) finalCleanup() error { return nil } + +func (this *Migrator) teardown() { + atomic.StoreInt64(&this.finishedMigrating, 1) + + if this.inspector != nil { + log.Infof("Tearing down inspector") + this.inspector.Teardown() + } + + if this.applier != nil { + log.Infof("Tearing down applier") + this.applier.Teardown() + } + + if this.eventsStreamer != nil { + log.Infof("Tearing down streamer") + this.eventsStreamer.Teardown() + } + + if this.throttler != nil { + log.Infof("Tearing down throttler") + this.throttler.Teardown() + } +} diff --git a/go/logic/server.go b/go/logic/server.go index d16e582..b9903e2 100644 --- a/go/logic/server.go +++ b/go/logic/server.go @@ -30,9 +30,9 @@ type Server struct { printStatus printStatusFunc } -func NewServer(hooksExecutor *HooksExecutor, printStatus printStatusFunc) *Server { +func NewServer(migrationContext *base.MigrationContext, hooksExecutor *HooksExecutor, printStatus printStatusFunc) *Server { return &Server{ - migrationContext: base.GetMigrationContext(), + migrationContext: migrationContext, hooksExecutor: hooksExecutor, printStatus: printStatus, } diff --git a/go/logic/streamer.go b/go/logic/streamer.go index 14ac6ab..37e7195 100644 --- a/go/logic/streamer.go +++ b/go/logic/streamer.go @@ -45,10 +45,10 @@ type EventsStreamer struct { binlogReader *binlog.GoMySQLReader } -func NewEventsStreamer() *EventsStreamer { +func NewEventsStreamer(migrationContext *base.MigrationContext) *EventsStreamer { return &EventsStreamer{ - connectionConfig: base.GetMigrationContext().InspectorConnectionConfig, - migrationContext: base.GetMigrationContext(), + connectionConfig: migrationContext.InspectorConnectionConfig, + migrationContext: migrationContext, listeners: [](*BinlogEventListener){}, listenersMutex: &sync.Mutex{}, eventsChannel: make(chan *binlog.BinlogEntry, EventsChannelBufferSize), @@ -104,7 +104,7 @@ func (this *EventsStreamer) notifyListeners(binlogEvent *binlog.BinlogDMLEvent) func (this *EventsStreamer) InitDBConnections() (err error) { EventsStreamerUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = sqlutils.GetDB(EventsStreamerUri); err != nil { + if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, EventsStreamerUri); err != nil { return err } if _, err := base.ValidateConnection(this.db, this.connectionConfig); err != nil { @@ -122,7 +122,7 @@ func (this *EventsStreamer) InitDBConnections() (err error) { // initBinlogReader creates and connects the reader: we hook up to a MySQL server as a replica func (this *EventsStreamer) initBinlogReader(binlogCoordinates *mysql.BinlogCoordinates) error { - goMySQLReader, err := binlog.NewGoMySQLReader(this.migrationContext.InspectorConnectionConfig) + goMySQLReader, err := binlog.NewGoMySQLReader(this.migrationContext) if err != nil { return err } @@ -178,7 +178,14 @@ func (this *EventsStreamer) StreamEvents(canStopStreaming func() bool) error { var successiveFailures int64 var lastAppliedRowsEventHint mysql.BinlogCoordinates for { + if canStopStreaming() { + return nil + } if err := this.binlogReader.StreamEvents(canStopStreaming, this.eventsChannel); err != nil { + if canStopStreaming() { + return nil + } + log.Infof("StreamEvents encountered unexpected error: %+v", err) this.migrationContext.MarkPointOfInterest() time.Sleep(ReconnectStreamerSleepSeconds * time.Second) @@ -209,3 +216,8 @@ func (this *EventsStreamer) Close() (err error) { log.Infof("Closed streamer connection. err=%+v", err) return err } + +func (this *EventsStreamer) Teardown() { + this.db.Close() + return +} diff --git a/go/logic/throttler.go b/go/logic/throttler.go index f1cb898..624956a 100644 --- a/go/logic/throttler.go +++ b/go/logic/throttler.go @@ -16,7 +16,6 @@ import ( "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" "github.com/outbrain/golib/log" - "github.com/outbrain/golib/sqlutils" ) var ( @@ -42,16 +41,18 @@ const frenoMagicHint = "freno" // Throttler collects metrics related to throttling and makes informed decision // whether throttling should take place. type Throttler struct { - migrationContext *base.MigrationContext - applier *Applier - inspector *Inspector + migrationContext *base.MigrationContext + applier *Applier + inspector *Inspector + finishedMigrating int64 } -func NewThrottler(applier *Applier, inspector *Inspector) *Throttler { +func NewThrottler(migrationContext *base.MigrationContext, applier *Applier, inspector *Inspector) *Throttler { return &Throttler{ - migrationContext: base.GetMigrationContext(), - applier: applier, - inspector: inspector, + migrationContext: migrationContext, + applier: applier, + inspector: inspector, + finishedMigrating: 0, } } @@ -139,8 +140,8 @@ func (this *Throttler) collectReplicationLag(firstThrottlingCollected chan<- boo if this.migrationContext.TestOnReplica || this.migrationContext.MigrateOnReplica { // when running on replica, the heartbeat injection is also done on the replica. // This means we will always get a good heartbeat value. - // When running on replica, we should instead check the `SHOW SLAVE STATUS` output. - if lag, err := mysql.GetReplicationLag(this.inspector.connectionConfig); err != nil { + // When runnign on replica, we should instead check the `SHOW SLAVE STATUS` output. + if lag, err := mysql.GetReplicationLag(this.inspector.informationSchemaDb, this.inspector.connectionConfig); err != nil { return log.Errore(err) } else { atomic.StoreInt64(&this.migrationContext.CurrentLag, int64(lag)) @@ -160,6 +161,9 @@ func (this *Throttler) collectReplicationLag(firstThrottlingCollected chan<- boo ticker := time.Tick(time.Duration(this.migrationContext.HeartbeatIntervalMilliseconds) * time.Millisecond) for range ticker { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } go collectFunc() } } @@ -182,11 +186,12 @@ func (this *Throttler) collectControlReplicasLag() { dbUri := connectionConfig.GetDBUri("information_schema") var heartbeatValue string - if db, _, err := sqlutils.GetDB(dbUri); err != nil { + if db, _, err := mysql.GetDB(this.migrationContext.Uuid, dbUri); err != nil { return lag, err } else if err = db.QueryRow(replicationLagQuery).Scan(&heartbeatValue); err != nil { return lag, err } + lag, err = parseChangelogHeartbeat(heartbeatValue) return lag, err } @@ -233,6 +238,9 @@ func (this *Throttler) collectControlReplicasLag() { shouldReadLagAggressively := false for range aggressiveTicker { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } if counter%relaxedFactor == 0 { // we only check if we wish to be aggressive once per second. The parameters for being aggressive // do not typically change at all throughout the migration, but nonetheless we check them. @@ -285,6 +293,10 @@ func (this *Throttler) collectThrottleHTTPStatus(firstThrottlingCollected chan<- ticker := time.Tick(100 * time.Millisecond) for range ticker { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } + if sleep, _ := collectFunc(); sleep { time.Sleep(1 * time.Second) } @@ -393,6 +405,10 @@ func (this *Throttler) initiateThrottlerCollection(firstThrottlingCollected chan throttlerMetricsTick := time.Tick(1 * time.Second) for range throttlerMetricsTick { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return + } + this.collectGeneralThrottleMetrics() } }() @@ -419,6 +435,9 @@ func (this *Throttler) initiateThrottlerChecks() error { } throttlerFunction() for range throttlerTick { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return nil + } throttlerFunction() } @@ -440,3 +459,8 @@ func (this *Throttler) throttle(onThrottled func()) { time.Sleep(250 * time.Millisecond) } } + +func (this *Throttler) Teardown() { + log.Debugf("Tearing down...") + atomic.StoreInt64(&this.finishedMigrating, 1) +} diff --git a/go/mysql/utils.go b/go/mysql/utils.go index b670921..532cbb4 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -8,6 +8,7 @@ package mysql import ( gosql "database/sql" "fmt" + "sync" "time" "github.com/github/gh-ost/go/sql" @@ -33,16 +34,33 @@ func (this *ReplicationLagResult) HasLag() bool { return this.Lag > 0 } +// knownDBs is a DB cache by uri +var knownDBs map[string]*gosql.DB = make(map[string]*gosql.DB) +var knownDBsMutex = &sync.Mutex{} + +func GetDB(migrationUuid string, mysql_uri string) (*gosql.DB, bool, error) { + cacheKey := migrationUuid + ":" + mysql_uri + + knownDBsMutex.Lock() + defer func() { + knownDBsMutex.Unlock() + }() + + var exists bool + if _, exists = knownDBs[cacheKey]; !exists { + if db, err := gosql.Open("mysql", mysql_uri); err == nil { + knownDBs[cacheKey] = db + } else { + return db, exists, err + } + } + return knownDBs[cacheKey], exists, nil +} + // GetReplicationLag returns replication lag for a given connection config; either by explicit query // or via SHOW SLAVE STATUS -func GetReplicationLag(connectionConfig *ConnectionConfig) (replicationLag time.Duration, err error) { - dbUri := connectionConfig.GetDBUri("information_schema") - var db *gosql.DB - if db, _, err = sqlutils.GetDB(dbUri); err != nil { - return replicationLag, err - } - - err = sqlutils.QueryRowsMap(db, `show slave status`, func(m sqlutils.RowMap) error { +func GetReplicationLag(informationSchemaDb *gosql.DB, connectionConfig *ConnectionConfig) (replicationLag time.Duration, err error) { + err = sqlutils.QueryRowsMap(informationSchemaDb, `show slave status`, func(m sqlutils.RowMap) error { slaveIORunning := m.GetString("Slave_IO_Running") slaveSQLRunning := m.GetString("Slave_SQL_Running") secondsBehindMaster := m.GetNullInt64("Seconds_Behind_Master") @@ -52,12 +70,19 @@ func GetReplicationLag(connectionConfig *ConnectionConfig) (replicationLag time. replicationLag = time.Duration(secondsBehindMaster.Int64) * time.Second return nil }) + return replicationLag, err } func GetMasterKeyFromSlaveStatus(connectionConfig *ConnectionConfig) (masterKey *InstanceKey, err error) { currentUri := connectionConfig.GetDBUri("information_schema") - db, _, err := sqlutils.GetDB(currentUri) + // This function is only called once, okay to not have a cached connection pool + db, err := gosql.Open("mysql", currentUri) + if err != nil { + return nil, err + } + defer db.Close() + if err != nil { return nil, err }