diff --git a/go/logic/applier.go b/go/logic/applier.go index bfa9807..3791154 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -34,14 +34,14 @@ type Applier struct { db *gosql.DB singletonDB *gosql.DB migrationContext *base.MigrationContext - finishedMigrating bool + finishedMigrating int64 } func NewApplier(migrationContext *base.MigrationContext) *Applier { return &Applier{ connectionConfig: migrationContext.ApplierConnectionConfig, migrationContext: migrationContext, - finishedMigrating: false, + finishedMigrating: 0, } } @@ -312,7 +312,7 @@ func (this *Applier) InitiateHeartbeat() { heartbeatTick := time.Tick(time.Duration(this.migrationContext.HeartbeatIntervalMilliseconds) * time.Millisecond) for range heartbeatTick { - if this.finishedMigrating { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { return } // Generally speaking, we would issue a goroutine, but I'd actually rather @@ -1049,5 +1049,5 @@ func (this *Applier) Teardown() { log.Debugf("Tearing down...") this.db.Close() this.singletonDB.Close() - this.finishedMigrating = true + atomic.StoreInt64(&this.finishedMigrating, 1) } diff --git a/go/logic/migrator.go b/go/logic/migrator.go index a3d17c5..21383e9 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -85,7 +85,7 @@ type Migrator struct { handledChangelogStates map[string]bool - finishedMigrating bool + finishedMigrating int64 } func NewMigrator(context *base.MigrationContext) *Migrator { @@ -100,7 +100,7 @@ func NewMigrator(context *base.MigrationContext) *Migrator { copyRowsQueue: make(chan tableWriteFunc), applyEventsQueue: make(chan *applyEventStruct, base.MaxEventsBatchSize), handledChangelogStates: make(map[string]bool), - finishedMigrating: false, + finishedMigrating: 0, } return migrator } @@ -727,7 +727,7 @@ func (this *Migrator) initiateStatus() error { this.printStatus(ForcePrintStatusAndHintRule) statusTick := time.Tick(1 * time.Second) for range statusTick { - if this.finishedMigrating { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { return nil } go this.printStatus(HeuristicPrintStatusRule) @@ -954,7 +954,7 @@ func (this *Migrator) initiateStreaming() error { go func() { ticker := time.Tick(1 * time.Second) for range ticker { - if this.finishedMigrating { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { return } this.migrationContext.SetRecentBinlogCoordinates(*this.eventsStreamer.GetCurrentBinlogCoordinates()) @@ -1147,7 +1147,7 @@ func (this *Migrator) executeWriteFuncs() error { return nil } for { - if this.finishedMigrating { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { return nil } @@ -1232,7 +1232,7 @@ func (this *Migrator) finalCleanup() error { } func (this *Migrator) teardown() { - this.finishedMigrating = true + atomic.StoreInt64(&this.finishedMigrating, 1) if this.inspector != nil { log.Infof("Tearing down inspector") diff --git a/vendor/github.com/outbrain/golib/sqlutils/sqlutils.go b/vendor/github.com/outbrain/golib/sqlutils/sqlutils.go index 46b60ce..17b948f 100644 --- a/vendor/github.com/outbrain/golib/sqlutils/sqlutils.go +++ b/vendor/github.com/outbrain/golib/sqlutils/sqlutils.go @@ -129,9 +129,7 @@ var knownDBsMutex = &sync.Mutex{} // bool result indicates whether the DB was returned from cache; err func GetDB(mysql_uri string) (*sql.DB, bool, error) { knownDBsMutex.Lock() - defer func() { - knownDBsMutex.Unlock() - }() + defer knownDBsMutex.Unlock() var exists bool if _, exists = knownDBs[mysql_uri]; !exists { @@ -148,9 +146,7 @@ func GetDB(mysql_uri string) (*sql.DB, bool, error) { // and new connections are needed to access the DB func ResetDBCache() { knownDBsMutex.Lock() - defer func() { - knownDBsMutex.Unlock() - }() + defer knownDBsMutex.Unlock() knownDBs = make(map[string]*sql.DB) }