diff --git a/go/base/context.go b/go/base/context.go index acfca5a..4dc09ad 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -14,9 +14,12 @@ import ( "sync/atomic" "time" + gosql "database/sql" "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" + "github.com/outbrain/golib/sqlutils" + "gopkg.in/gcfg.v1" gcfgscanner "gopkg.in/gcfg.v1/scanner" ) @@ -197,6 +200,9 @@ type MigrationContext struct { recentBinlogCoordinates mysql.BinlogCoordinates CanStopStreaming func() bool + + knownDBs map[string]*gosql.DB + knownDBsMutex *sync.Mutex } type ContextConfig struct { @@ -230,6 +236,8 @@ func NewMigrationContext() *MigrationContext { pointOfInterestTimeMutex: &sync.Mutex{}, ColumnRenameMap: make(map[string]string), PanicAbort: make(chan error), + knownDBsMutex: &sync.Mutex{}, + knownDBs: make(map[string]*gosql.DB), } } @@ -242,6 +250,23 @@ func getSafeTableName(baseName string, suffix string) string { return fmt.Sprintf("_%s_%s", baseName[0:len(baseName)-extraCharacters], suffix) } +// GetDB returns a DB instance based on uri. +// bool result indicates whether the DB was returned from cache; err +func (this *MigrationContext) GetDB(mysql_uri string) (*gosql.DB, bool, error) { + this.knownDBsMutex.Lock() + defer this.knownDBsMutex.Unlock() + + var exists bool + if _, exists = this.knownDBs[mysql_uri]; !exists { + if db, err := sqlutils.GetDB(mysql_uri); err == nil { + this.knownDBs[mysql_uri] = db + } else { + return db, exists, err + } + } + return this.knownDBs[mysql_uri], exists, nil +} + // GetGhostTableName generates the name of ghost table, based on original table name // or a given table name func (this *MigrationContext) GetGhostTableName() string { diff --git a/go/logic/applier.go b/go/logic/applier.go index 3791154..77e8381 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -47,11 +47,11 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier { func (this *Applier) InitDBConnections() (err error) { applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = sqlutils.GetDB(applierUri); err != nil { + if this.db, _, err = this.migrationContext.GetDB(applierUri); err != nil { return err } singletonApplierUri := fmt.Sprintf("%s?timeout=0", applierUri) - if this.singletonDB, _, err = sqlutils.GetDB(singletonApplierUri); err != nil { + if this.singletonDB, _, err = this.migrationContext.GetDB(singletonApplierUri); err != nil { return err } this.singletonDB.SetMaxOpenConns(1) diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 729800c..a4fda77 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -26,9 +26,10 @@ const startSlavePostWaitMilliseconds = 500 * time.Millisecond // Inspector reads data from the read-MySQL-server (typically a replica, but can be the master) // It is used for gaining initial status and structure, and later also follow up on progress and changelog type Inspector struct { - connectionConfig *mysql.ConnectionConfig - db *gosql.DB - migrationContext *base.MigrationContext + connectionConfig *mysql.ConnectionConfig + db *gosql.DB + informationSchemaDb *gosql.DB + migrationContext *base.MigrationContext } func NewInspector(migrationContext *base.MigrationContext) *Inspector { @@ -40,9 +41,15 @@ func NewInspector(migrationContext *base.MigrationContext) *Inspector { func (this *Inspector) InitDBConnections() (err error) { inspectorUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = sqlutils.GetDB(inspectorUri); err != nil { + if this.db, _, err = this.migrationContext.GetDB(inspectorUri); err != nil { return err } + + informationSchemaUri := this.connectionConfig.GetDBUri("information_schema") + if this.informationSchemaDb, _, err = this.migrationContext.GetDB(informationSchemaUri); err != nil { + return err + } + if err := this.validateConnection(); err != nil { return err } @@ -755,6 +762,7 @@ func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.Connect func (this *Inspector) getReplicationLag() (replicationLag time.Duration, err error) { replicationLag, err = mysql.GetReplicationLag( + this.informationSchemaDb, this.migrationContext.InspectorConnectionConfig, ) return replicationLag, err @@ -762,5 +770,6 @@ func (this *Inspector) getReplicationLag() (replicationLag time.Duration, err er func (this *Inspector) Teardown() { this.db.Close() + this.informationSchemaDb.Close() return } diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 21383e9..5a0e62d 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -20,7 +20,6 @@ import ( "github.com/github/gh-ost/go/sql" "github.com/outbrain/golib/log" - "github.com/outbrain/golib/sqlutils" ) type ChangelogState string @@ -1248,6 +1247,4 @@ func (this *Migrator) teardown() { log.Infof("Tearing down streamer") this.eventsStreamer.Teardown() } - - sqlutils.ResetDBCache() } diff --git a/go/logic/streamer.go b/go/logic/streamer.go index 921891d..6ba9c10 100644 --- a/go/logic/streamer.go +++ b/go/logic/streamer.go @@ -104,7 +104,7 @@ func (this *EventsStreamer) notifyListeners(binlogEvent *binlog.BinlogDMLEvent) func (this *EventsStreamer) InitDBConnections() (err error) { EventsStreamerUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = sqlutils.GetDB(EventsStreamerUri); err != nil { + if this.db, _, err = this.migrationContext.GetDB(EventsStreamerUri); err != nil { return err } if err := this.validateConnection(); err != nil { diff --git a/go/logic/throttler.go b/go/logic/throttler.go index 8a549bb..f3d4645 100644 --- a/go/logic/throttler.go +++ b/go/logic/throttler.go @@ -16,7 +16,6 @@ import ( "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" "github.com/outbrain/golib/log" - "github.com/outbrain/golib/sqlutils" ) var ( @@ -140,7 +139,7 @@ func (this *Throttler) collectReplicationLag(firstThrottlingCollected chan<- boo // when running on replica, the heartbeat injection is also done on the replica. // This means we will always get a good heartbeat value. // When runnign on replica, we should instead check the `SHOW SLAVE STATUS` output. - if lag, err := mysql.GetReplicationLag(this.inspector.connectionConfig); err != nil { + if lag, err := mysql.GetReplicationLag(this.inspector.informationSchemaDb, this.inspector.connectionConfig); err != nil { return log.Errore(err) } else { atomic.StoreInt64(&this.migrationContext.CurrentLag, int64(lag)) @@ -182,7 +181,7 @@ func (this *Throttler) collectControlReplicasLag() { dbUri := connectionConfig.GetDBUri("information_schema") var heartbeatValue string - if db, _, err := sqlutils.GetDB(dbUri); err != nil { + if db, _, err := this.migrationContext.GetDB(dbUri); err != nil { return lag, err } else if err = db.QueryRow(replicationLagQuery).Scan(&heartbeatValue); err != nil { return lag, err diff --git a/go/mysql/utils.go b/go/mysql/utils.go index 514ff84..96e7fbc 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -35,15 +35,8 @@ func (this *ReplicationLagResult) HasLag() bool { // GetReplicationLag returns replication lag for a given connection config; either by explicit query // or via SHOW SLAVE STATUS -func GetReplicationLag(connectionConfig *ConnectionConfig) (replicationLag time.Duration, err error) { - dbUri := connectionConfig.GetDBUri("information_schema") - var db *gosql.DB - if db, _, err = sqlutils.GetDB(dbUri); err != nil { - return replicationLag, err - } - defer db.Close() - - err = sqlutils.QueryRowsMap(db, `show slave status`, func(m sqlutils.RowMap) error { +func GetReplicationLag(informationSchemaDb *gosql.DB, connectionConfig *ConnectionConfig) (replicationLag time.Duration, err error) { + err = sqlutils.QueryRowsMap(informationSchemaDb, `show slave status`, func(m sqlutils.RowMap) error { slaveIORunning := m.GetString("Slave_IO_Running") slaveSQLRunning := m.GetString("Slave_SQL_Running") secondsBehindMaster := m.GetNullInt64("Seconds_Behind_Master") @@ -59,7 +52,10 @@ func GetReplicationLag(connectionConfig *ConnectionConfig) (replicationLag time. func GetMasterKeyFromSlaveStatus(connectionConfig *ConnectionConfig) (masterKey *InstanceKey, err error) { currentUri := connectionConfig.GetDBUri("information_schema") - db, _, err := sqlutils.GetDB(currentUri) + // This function is only called once, okay to not have a cached connection pool + db, err := sqlutils.GetDB(currentUri) + defer db.Close() + if err != nil { return nil, err } diff --git a/vendor/github.com/outbrain/golib/sqlutils/sqlutils.go b/vendor/github.com/outbrain/golib/sqlutils/sqlutils.go index 17b948f..77cc441 100644 --- a/vendor/github.com/outbrain/golib/sqlutils/sqlutils.go +++ b/vendor/github.com/outbrain/golib/sqlutils/sqlutils.go @@ -25,7 +25,6 @@ import ( "github.com/outbrain/golib/log" "strconv" "strings" - "sync" ) // RowMap represents one row in a result set. Its objective is to allow @@ -121,34 +120,13 @@ func (this *RowMap) GetBool(key string) bool { return this.GetInt(key) != 0 } -// knownDBs is a DB cache by uri -var knownDBs map[string]*sql.DB = make(map[string]*sql.DB) -var knownDBsMutex = &sync.Mutex{} - // GetDB returns a DB instance based on uri. -// bool result indicates whether the DB was returned from cache; err -func GetDB(mysql_uri string) (*sql.DB, bool, error) { - knownDBsMutex.Lock() - defer knownDBsMutex.Unlock() - - var exists bool - if _, exists = knownDBs[mysql_uri]; !exists { - if db, err := sql.Open("mysql", mysql_uri); err == nil { - knownDBs[mysql_uri] = db - } else { - return db, exists, err - } +func GetDB(mysql_uri string) (*sql.DB, error) { + if db, err := sql.Open("mysql", mysql_uri); err == nil { + return db, nil + } else { + return db, err } - return knownDBs[mysql_uri], exists, nil -} - -// Resets the knownDBs cache, used when the DB connections have been closed, -// and new connections are needed to access the DB -func ResetDBCache() { - knownDBsMutex.Lock() - defer knownDBsMutex.Unlock() - - knownDBs = make(map[string]*sql.DB) } // RowToArray is a convenience function, typically not called directly, which maps a