From a1a34b81506bfed206ddb56cc5c3e8a4576121ba Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Fri, 8 Apr 2016 14:35:06 +0200 Subject: [PATCH 1/5] ongoing development: - accepts --max-load - accepts multiple conditions in --max-load - throttle includes reason - chunk-size sanity check - change log state writes both in appending (history) mode and in replacing (current) mode - more atomic checks - inspecting ghost table columns, unique key - comparing unique keys between tables; sanity - intersecting columns between tables - prettify status - refactored throttle() and retries() --- go/base/context.go | 63 ++++++++++++++++--- go/cmd/gh-osc/main.go | 35 ++++------- go/logic/applier.go | 44 +++++++------ go/logic/inspect.go | 73 +++++++++++++++++----- go/logic/migrator.go | 139 ++++++++++++++++++++++++++++++++++-------- go/sql/types.go | 11 ++++ 6 files changed, 277 insertions(+), 88 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index 6e28cb0..8185805 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -7,6 +7,7 @@ package base import ( "fmt" + "strconv" "strings" "sync/atomic" "time" @@ -44,20 +45,29 @@ type MigrationContext struct { AllowedRunningOnMaster bool InspectorConnectionConfig *mysql.ConnectionConfig MasterConnectionConfig *mysql.ConnectionConfig - MigrationRangeMinValues *sql.ColumnValues - MigrationRangeMaxValues *sql.ColumnValues - Iteration int64 - MigrationIterationRangeMinValues *sql.ColumnValues - MigrationIterationRangeMaxValues *sql.ColumnValues - UniqueKey *sql.UniqueKey StartTime time.Time RowCopyStartTime time.Time CurrentLag int64 MaxLagMillisecondsThrottleThreshold int64 ThrottleFlagFile string TotalRowsCopied int64 + isThrottled int64 + ThrottleReason string + MaxLoad map[string]int64 + + OriginalTableColumns sql.ColumnList + OriginalTableColumnsMap sql.ColumnsMap + OriginalTableUniqueKeys [](*sql.UniqueKey) + GhostTableColumns sql.ColumnList + GhostTableUniqueKeys [](*sql.UniqueKey) + UniqueKey *sql.UniqueKey + SharedColumns sql.ColumnList + MigrationRangeMinValues *sql.ColumnValues + MigrationRangeMaxValues *sql.ColumnValues + Iteration int64 + MigrationIterationRangeMinValues *sql.ColumnValues + MigrationIterationRangeMaxValues *sql.ColumnValues - IsThrottled func() bool CanStopStreaming func() bool } @@ -73,6 +83,7 @@ func newMigrationContext() *MigrationContext { InspectorConnectionConfig: mysql.NewConnectionConfig(), MasterConnectionConfig: mysql.NewConnectionConfig(), MaxLagMillisecondsThrottleThreshold: 1000, + MaxLoad: make(map[string]int64), } } @@ -141,3 +152,41 @@ func (this *MigrationContext) ElapsedRowCopyTime() time.Duration { func (this *MigrationContext) GetTotalRowsCopied() int64 { return atomic.LoadInt64(&this.TotalRowsCopied) } + +func (this *MigrationContext) GetIteration() int64 { + return atomic.LoadInt64(&this.Iteration) +} + +func (this *MigrationContext) SetThrottled(throttle bool) { + if throttle { + atomic.StoreInt64(&this.isThrottled, 1) + } else { + atomic.StoreInt64(&this.isThrottled, 0) + } +} + +func (this *MigrationContext) IsThrottled() bool { + return atomic.LoadInt64(&this.isThrottled) != 0 +} + +func (this *MigrationContext) ReadMaxLoad(maxLoadList string) error { + if maxLoadList == "" { + return nil + } + maxLoadConditions := strings.Split(maxLoadList, ",") + for _, maxLoadCondition := range maxLoadConditions { + maxLoadTokens := strings.Split(maxLoadCondition, "=") + if len(maxLoadTokens) != 2 { + return fmt.Errorf("Error parsing max-load condition: %s", maxLoadCondition) + } + if maxLoadTokens[0] == "" { + return fmt.Errorf("Error parsing status variable in max-load condition: %s", maxLoadCondition) + } + if n, err := strconv.ParseInt(maxLoadTokens[1], 10, 0); err != nil { + return fmt.Errorf("Error parsing numeric value in max-load condition: %s", maxLoadCondition) + } else { + this.MaxLoad[maxLoadTokens[0]] = n + } + } + return nil +} diff --git a/go/cmd/gh-osc/main.go b/go/cmd/gh-osc/main.go index de7f976..335d419 100644 --- a/go/cmd/gh-osc/main.go +++ b/go/cmd/gh-osc/main.go @@ -19,11 +19,6 @@ import ( func main() { migrationContext := base.GetMigrationContext() - // mysqlBasedir := flag.String("mysql-basedir", "", "the --basedir config for MySQL (auto-detected if not given)") - // mysqlDatadir := flag.String("mysql-datadir", "", "the --datadir config for MySQL (auto-detected if not given)") - internalExperiment := flag.Bool("internal-experiment", false, "issue an internal experiment") - binlogFile := flag.String("binlog-file", "", "Name of binary log file") - flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)") flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)") flag.StringVar(&migrationContext.InspectorConnectionConfig.User, "user", "root", "MySQL user") @@ -35,9 +30,16 @@ func main() { flag.BoolVar(&migrationContext.CountTableRows, "exact-rowcount", false, "actually count table rows as opposed to estimate them (results in more accurate progress estimation)") flag.BoolVar(&migrationContext.AllowedRunningOnMaster, "allow-on-master", false, "allow this migration to run directly on master. Preferably it would run on a replica") - flag.Int64Var(&migrationContext.ChunkSize, "chunk-size", 1000, "amount of rows to handle in each iteration") - flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists") - + flag.Int64Var(&migrationContext.ChunkSize, "chunk-size", 1000, "amount of rows to handle in each iteration (allowed range: 100-100,000)") + if migrationContext.ChunkSize < 100 { + migrationContext.ChunkSize = 100 + } + if migrationContext.ChunkSize > 100000 { + migrationContext.ChunkSize = 100000 + } + flag.Int64Var(&migrationContext.MaxLagMillisecondsThrottleThreshold, "max-lag-millis", 1000, "replication lag at which to throttle operation") + flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "/tmp/gh-osc.throttle", "operation pauses when this file exists") + maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'") quiet := flag.Bool("quiet", false, "quiet") verbose := flag.Bool("verbose", false, "verbose") debug := flag.Bool("debug", false, "debug mode (very verbose)") @@ -75,23 +77,12 @@ func main() { if migrationContext.AlterStatement == "" { log.Fatalf("--alter must be provided and statement must not be empty") } + if err := migrationContext.ReadMaxLoad(*maxLoad); err != nil { + log.Fatale(err) + } log.Info("starting gh-osc") - if *internalExperiment { - log.Debug("starting experiment with %+v", *binlogFile) - - //binlogReader = binlog.NewMySQLBinlogReader(*mysqlBasedir, *mysqlDatadir) - // binlogReader, err := binlog.NewGoMySQLReader(migrationContext.InspectorConnectionConfig) - // if err != nil { - // log.Fatale(err) - // } - // if err := binlogReader.ConnectBinlogStreamer(mysql.BinlogCoordinates{LogFile: *binlogFile, LogPos: 0}); err != nil { - // log.Fatale(err) - // } - // binlogReader.StreamEvents(func() bool { return false }) - // return - } migrator := logic.NewMigrator() err := migrator.Migrate() if err != nil { diff --git a/go/logic/applier.go b/go/logic/applier.go index b4f63a9..fb23440 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -104,10 +104,10 @@ func (this *Applier) AlterGhost() error { // CreateChangelogTable creates the changelog table on the master func (this *Applier) CreateChangelogTable() error { query := fmt.Sprintf(`create /* gh-osc */ table %s.%s ( - id int auto_increment, + id bigint auto_increment, last_update timestamp not null DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, hint varchar(64) charset ascii not null, - value varchar(64) charset ascii not null, + value varchar(255) charset ascii not null, primary key(id), unique key hint_uidx(hint) ) auto_increment=2 @@ -162,6 +162,12 @@ func (this *Applier) WriteChangelog(hint, value string) (string, error) { return hint, err } +func (this *Applier) WriteChangelogState(value string) (string, error) { + hint := "state" + this.WriteChangelog(hint, value) + return this.WriteChangelog(fmt.Sprintf("%s at %d", hint, time.Now().UnixNano()), value) +} + // InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table. // This is done asynchronously func (this *Applier) InitiateHeartbeat() { @@ -315,7 +321,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationRangeMaxValues.AbstractValues(), this.migrationContext.ChunkSize, - fmt.Sprintf("iteration:%d", this.migrationContext.Iteration), + fmt.Sprintf("iteration:%d", this.migrationContext.GetIteration()), ) if err != nil { return hasFurtherRange, err @@ -336,13 +342,13 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo return hasFurtherRange, nil } this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeMaxValues - log.Debugf( - "column values: [%s]..[%s]; iteration: %d; chunk-size: %d", - this.migrationContext.MigrationIterationRangeMinValues, - this.migrationContext.MigrationIterationRangeMaxValues, - this.migrationContext.Iteration, - this.migrationContext.ChunkSize, - ) + // log.Debugf( + // "column values: [%s]..[%s]; iteration: %d; chunk-size: %d", + // this.migrationContext.MigrationIterationRangeMinValues, + // this.migrationContext.MigrationIterationRangeMaxValues, + // this.migrationContext.GetIteration(), + // this.migrationContext.ChunkSize, + // ) return hasFurtherRange, nil } @@ -354,12 +360,12 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.GetGhostTableName(), - this.migrationContext.UniqueKey.Columns, + this.migrationContext.SharedColumns, this.migrationContext.UniqueKey.Name, this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(), - this.migrationContext.Iteration == 0, + this.migrationContext.GetIteration() == 0, this.migrationContext.IsTransactionalTable(), ) if err != nil { @@ -371,15 +377,11 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected } rowsAffected, _ = sqlResult.RowsAffected() duration = time.Now().Sub(startTime) - this.WriteChangelog( - fmt.Sprintf("copy iteration %d", this.migrationContext.Iteration), - fmt.Sprintf("chunk: %d; affected: %d; duration: %d", chunkSize, rowsAffected, duration), - ) log.Debugf( "Issued INSERT on range: [%s]..[%s]; iteration: %d; chunk-size: %d", this.migrationContext.MigrationIterationRangeMinValues, this.migrationContext.MigrationIterationRangeMaxValues, - this.migrationContext.Iteration, + this.migrationContext.GetIteration(), chunkSize) return chunkSize, rowsAffected, duration, nil } @@ -412,3 +414,11 @@ func (this *Applier) UnlockTables() error { log.Infof("Tables unlocked") return nil } + +func (this *Applier) ShowStatusVariable(variableName string) (result int64, err error) { + query := fmt.Sprintf(`show global status like '%s'`, variableName) + if err := this.db.QueryRow(query).Scan(&variableName, &result); err != nil { + return 0, err + } + return result, nil +} diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 31f74ec..42e9105 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -69,15 +69,50 @@ func (this *Inspector) ValidateOriginalTable() (err error) { return nil } -func (this *Inspector) InspectOriginalTable() (uniqueKeys [](*sql.UniqueKey), err error) { - uniqueKeys, err = this.getCandidateUniqueKeys(this.migrationContext.OriginalTableName) +func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) { + uniqueKeys, err = this.getCandidateUniqueKeys(tableName) if err != nil { - return uniqueKeys, err + return columns, uniqueKeys, err } if len(uniqueKeys) == 0 { - return uniqueKeys, fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out") + return columns, uniqueKeys, fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out") } - return uniqueKeys, err + columns, err = this.getTableColumns(this.migrationContext.DatabaseName, tableName) + if err != nil { + return columns, uniqueKeys, err + } + + return columns, uniqueKeys, nil +} + +func (this *Inspector) InspectOriginalTable() (err error) { + this.migrationContext.OriginalTableColumns, this.migrationContext.OriginalTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.OriginalTableName) + if err == nil { + return err + } + this.migrationContext.OriginalTableColumnsMap = sql.NewColumnsMap(this.migrationContext.OriginalTableColumns) + return nil +} + +func (this *Inspector) InspectOriginalAndGhostTables() (err error) { + this.migrationContext.GhostTableColumns, this.migrationContext.GhostTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.GetGhostTableName()) + if err != nil { + return err + } + sharedUniqueKeys, err := this.getSharedUniqueKeys(this.migrationContext.OriginalTableUniqueKeys, this.migrationContext.GhostTableUniqueKeys) + if err != nil { + return err + } + if len(sharedUniqueKeys) == 0 { + return fmt.Errorf("No shared unique key can be found after ALTER! Bailing out") + } + this.migrationContext.UniqueKey = sharedUniqueKeys[0] + log.Infof("Chosen shared unique key is %s", this.migrationContext.UniqueKey.Name) + + this.migrationContext.SharedColumns = this.getSharedColumns(this.migrationContext.OriginalTableColumns, this.migrationContext.GhostTableColumns) + log.Infof("Shared columns are %s", this.migrationContext.SharedColumns) + // By fact that a non-empty unique key exists we also know the shared columns are non-empty + return nil } // validateConnection issues a simple can-connect to MySQL @@ -361,17 +396,9 @@ func (this *Inspector) getCandidateUniqueKeys(tableName string) (uniqueKeys [](* return uniqueKeys, nil } -// getCandidateUniqueKeys investigates a table and returns the list of unique keys -// candidate for chunking -func (this *Inspector) getSharedUniqueKeys() (uniqueKeys [](*sql.UniqueKey), err error) { - originalUniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.OriginalTableName) - if err != nil { - return uniqueKeys, err - } - ghostUniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.GetGhostTableName()) - if err != nil { - return uniqueKeys, err - } +// getSharedUniqueKeys returns the intersection of two given unique keys, +// testing by list of columns +func (this *Inspector) getSharedUniqueKeys(originalUniqueKeys, ghostUniqueKeys [](*sql.UniqueKey)) (uniqueKeys [](*sql.UniqueKey), err error) { // We actually do NOT rely on key name, just on the set of columns. This is because maybe // the ALTER is on the name itself... for _, originalUniqueKey := range originalUniqueKeys { @@ -384,6 +411,20 @@ func (this *Inspector) getSharedUniqueKeys() (uniqueKeys [](*sql.UniqueKey), err return uniqueKeys, nil } +// getSharedColumns returns the intersection of two lists of columns in same order as the first list +func (this *Inspector) getSharedColumns(originalColumns, ghostColumns sql.ColumnList) (sharedColumns sql.ColumnList) { + columnsInGhost := make(map[string]bool) + for _, ghostColumn := range ghostColumns { + columnsInGhost[ghostColumn] = true + } + for _, originalColumn := range originalColumns { + if columnsInGhost[originalColumn] { + sharedColumns = append(sharedColumns, originalColumn) + } + } + return sharedColumns +} + func (this *Inspector) getMasterConnectionConfig() (masterConfig *mysql.ConnectionConfig, err error) { visitedKeys := mysql.NewInstanceKeyMap() return getMasterConnectionConfigSafe(this.connectionConfig, this.migrationContext.DatabaseName, visitedKeys) diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 451cb96..6bf6772 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -8,6 +8,7 @@ package logic import ( "fmt" "os" + "regexp" "sync/atomic" "time" @@ -30,6 +31,10 @@ const ( applyEventsQueueBuffer = 100 ) +var ( + prettifyDurationRegexp = regexp.MustCompile("([.][0-9]+)") +) + // Migrator is the main schema migration flow manager. type Migrator struct { inspector *Inspector @@ -57,25 +62,94 @@ func NewMigrator() *Migrator { copyRowsQueue: make(chan tableWriteFunc), applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), } - migrator.migrationContext.IsThrottled = func() bool { - return migrator.shouldThrottle() - } return migrator } -func (this *Migrator) shouldThrottle() bool { +func prettifyDurationOutput(d time.Duration) string { + if d < time.Second { + return "0s" + } + result := fmt.Sprintf("%s", d) + result = prettifyDurationRegexp.ReplaceAllString(result, "") + return result +} + +func (this *Migrator) shouldThrottle() (result bool, reason string) { lag := atomic.LoadInt64(&this.migrationContext.CurrentLag) - shouldThrottle := false if time.Duration(lag) > time.Duration(this.migrationContext.MaxLagMillisecondsThrottleThreshold)*time.Millisecond { - shouldThrottle = true - } else if this.migrationContext.ThrottleFlagFile != "" { + return true, fmt.Sprintf("lag=%fs", time.Duration(lag).Seconds()) + } + if this.migrationContext.ThrottleFlagFile != "" { if _, err := os.Stat(this.migrationContext.ThrottleFlagFile); err == nil { //Throttle file defined and exists! - shouldThrottle = true + return true, "flag-file" } } - return shouldThrottle + + for variableName, threshold := range this.migrationContext.MaxLoad { + value, err := this.applier.ShowStatusVariable(variableName) + if err != nil { + return true, fmt.Sprintf("%s %s", variableName, err) + } + if value > threshold { + return true, fmt.Sprintf("%s=%d", variableName, value) + } + } + + return false, "" +} + +// throttle initiates a throttling event, if need be, updates the Context and +// calls callback functions, if any +func (this *Migrator) throttle( + onStartThrottling func(), + onContinuousThrottling func(), + onEndThrottling func(), +) { + hasThrottledYet := false + for { + shouldThrottle, reason := this.shouldThrottle() + if !shouldThrottle { + break + } + this.migrationContext.ThrottleReason = reason + if !hasThrottledYet { + hasThrottledYet = true + if onStartThrottling != nil { + onStartThrottling() + } + this.migrationContext.SetThrottled(true) + } + time.Sleep(time.Second) + if onContinuousThrottling != nil { + onContinuousThrottling() + } + } + if hasThrottledYet { + if onEndThrottling != nil { + onEndThrottling() + } + this.migrationContext.SetThrottled(false) + } +} + +// retryOperation attempts up to `count` attempts at running given function, +// exiting as soon as it returns with non-error. +func (this *Migrator) retryOperation(operation func() error) (err error) { + maxRetries := this.migrationContext.MaxRetries() + for i := 0; i < maxRetries; i++ { + if i != 0 { + // sleep after previous iteration + time.Sleep(1 * time.Second) + } + err = operation() + if err == nil { + return nil + } + // there's an error. Let's try again. + } + return err } func (this *Migrator) canStopStreaming() bool { @@ -102,7 +176,7 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er return fmt.Errorf("Unknown changelog state: %+v", changelogState) } } - log.Debugf("---- - - - - - state %+v", changelogState) + log.Debugf("Received state %+v", changelogState) return nil } @@ -132,8 +206,7 @@ func (this *Migrator) Migrate() (err error) { if err := this.inspector.ValidateOriginalTable(); err != nil { return err } - uniqueKeys, err := this.inspector.InspectOriginalTable() - if err != nil { + if err := this.inspector.InspectOriginalTable(); err != nil { return err } // So far so good, table is accessible and valid. @@ -159,22 +232,24 @@ func (this *Migrator) Migrate() (err error) { // When running on replica, this means the replica has those tables. When running // on master this is always true, of course, and yet it also implies this knowledge // is in the binlogs. + if err := this.inspector.InspectOriginalAndGhostTables(); err != nil { + return err + } - this.migrationContext.UniqueKey = uniqueKeys[0] // TODO. Need to wait on replica till the ghost table exists and get shared keys if err := this.applier.ReadMigrationRangeValues(); err != nil { return err } - go this.initiateStatus() go this.executeWriteFuncs() go this.iterateChunks() + this.migrationContext.RowCopyStartTime = time.Now() + go this.initiateStatus() log.Debugf("Operating until row copy is complete") <-this.rowCopyComplete log.Debugf("Row copy complete") this.printStatus() - throttleMigration( - this.migrationContext, + this.throttle( func() { log.Debugf("throttling before LOCK TABLES") }, @@ -185,7 +260,7 @@ func (this *Migrator) Migrate() (err error) { ) // TODO retries!! this.applier.LockTables() - this.applier.WriteChangelog("state", string(AllEventsUpToLockProcessed)) + this.applier.WriteChangelogState(string(AllEventsUpToLockProcessed)) log.Debugf("Waiting for events up to lock") <-this.allEventsUpToLockProcessed log.Debugf("Done waiting for events up to lock") @@ -228,10 +303,20 @@ func (this *Migrator) printStatus() { return } - status := fmt.Sprintf("Copy: %d/%d %.1f%% Backlog: %d/%d Elapsed: %+v(copy), %+v(total) ETA: N/A", + eta := "N/A" + if this.migrationContext.IsThrottled() { + eta = fmt.Sprintf("throttled, %s", this.migrationContext.ThrottleReason) + } + status := fmt.Sprintf("Copy: %d/%d %.1f%%; Backlog: %d/%d; Elapsed: %+v(copy), %+v(total); ETA: %s", totalRowsCopied, rowsEstimate, progressPct, len(this.applyEventsQueue), cap(this.applyEventsQueue), - this.migrationContext.ElapsedRowCopyTime(), elapsedTime) + prettifyDurationOutput(this.migrationContext.ElapsedRowCopyTime()), prettifyDurationOutput(elapsedTime), + eta, + ) + this.applier.WriteChangelog( + fmt.Sprintf("copy iteration %d at %d", this.migrationContext.GetIteration(), time.Now().Unix()), + status, + ) fmt.Println(status) } @@ -281,13 +366,12 @@ func (this *Migrator) initiateApplier() error { return err } - this.applier.WriteChangelog("state", string(TablesInPlace)) + this.applier.WriteChangelogState(string(TablesInPlace)) this.applier.InitiateHeartbeat() return nil } func (this *Migrator) iterateChunks() error { - this.migrationContext.RowCopyStartTime = time.Now() terminateRowIteration := func(err error) error { this.rowCopyComplete <- true return log.Errore(err) @@ -306,7 +390,7 @@ func (this *Migrator) iterateChunks() error { return terminateRowIteration(err) } atomic.AddInt64(&this.migrationContext.TotalRowsCopied, rowsAffected) - this.migrationContext.Iteration++ + atomic.AddInt64(&this.migrationContext.Iteration, 1) return nil } this.copyRowsQueue <- copyRowsFunc @@ -316,8 +400,7 @@ func (this *Migrator) iterateChunks() error { func (this *Migrator) executeWriteFuncs() error { for { - throttleMigration( - this.migrationContext, + this.throttle( func() { log.Debugf("throttling writes") }, @@ -331,14 +414,18 @@ func (this *Migrator) executeWriteFuncs() error { select { case applyEventFunc := <-this.applyEventsQueue: { - retryOperation(applyEventFunc, this.migrationContext.MaxRetries()) + if err := this.retryOperation(applyEventFunc); err != nil { + return log.Errore(err) + } } default: { select { case copyRowsFunc := <-this.copyRowsQueue: { - retryOperation(copyRowsFunc, this.migrationContext.MaxRetries()) + if err := this.retryOperation(copyRowsFunc); err != nil { + return log.Errore(err) + } } default: { diff --git a/go/sql/types.go b/go/sql/types.go index 942bd37..cda02d3 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -28,6 +28,17 @@ func (this *ColumnList) Equals(other *ColumnList) bool { return reflect.DeepEqual(*this, *other) } +// ColumnsMap maps a column onto its ordinal position +type ColumnsMap map[string]int + +func NewColumnsMap(columnList ColumnList) ColumnsMap { + columnsMap := make(map[string]int) + for i, column := range columnList { + columnsMap[column] = i + } + return ColumnsMap(columnsMap) +} + // UniqueKey is the combination of a key's name and columns type UniqueKey struct { Name string From 80163b35b6faabd96a07704ba4c4c35a04a541d9 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Fri, 8 Apr 2016 14:44:36 +0200 Subject: [PATCH 2/5] extracted on-throttle functions outside loop --- go/logic/migrator.go | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 6bf6772..4df0e4a 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -399,16 +399,14 @@ func (this *Migrator) iterateChunks() error { } func (this *Migrator) executeWriteFuncs() error { + onStartThrottling := func() { + log.Debugf("throttling writes") + } + onEndThrottling := func() { + log.Debugf("done throttling writes") + } for { - this.throttle( - func() { - log.Debugf("throttling writes") - }, - nil, - func() { - log.Debugf("done throttling writes") - }, - ) + this.throttle(onStartThrottling, nil, onEndThrottling) // We give higher priority to event processing, then secondary priority to // rowcopy select { From 04525887f3a9c59deecaddb2617c8b8486c1a222 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 11 Apr 2016 17:27:16 +0200 Subject: [PATCH 3/5] - Throttling-check is now an async routine running once per second - Throttling variables protected by mutex - Added `--throttle-additional-flag-file`: `operation pauses when this file exists; hint: keep default, use for throttling multiple gh-osc operations` - ColumnList is not a `struct` which contains ordinal mapping - More implicit write changelog + audit changelog - builder now builds `DELETE` and `INSERT` queries from data it will eventually get from DML event - Sanity check for binlog_row_image - Restarting replication to be sure binlog settings apply - Prepare for accepting `SIGHUP` (reloading configuration) --- go/base/context.go | 39 ++++++++----- go/cmd/gh-osc/main.go | 3 +- go/logic/applier.go | 67 ++++++++++++++-------- go/logic/inspect.go | 57 ++++++++++++++----- go/logic/migrator.go | 106 ++++++++++++++++++++--------------- go/sql/builder.go | 97 ++++++++++++++++++++++++++++---- go/sql/builder_test.go | 123 +++++++++++++++++++++++++++++++++++++++-- go/sql/types.go | 72 +++++++++++++++++------- 8 files changed, 434 insertions(+), 130 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index 8185805..9611869 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -9,6 +9,7 @@ import ( "fmt" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -50,18 +51,19 @@ type MigrationContext struct { CurrentLag int64 MaxLagMillisecondsThrottleThreshold int64 ThrottleFlagFile string + ThrottleAdditionalFlagFile string TotalRowsCopied int64 - isThrottled int64 - ThrottleReason string + isThrottled bool + throttleReason string + throttleMutex *sync.Mutex MaxLoad map[string]int64 - OriginalTableColumns sql.ColumnList - OriginalTableColumnsMap sql.ColumnsMap + OriginalTableColumns *sql.ColumnList OriginalTableUniqueKeys [](*sql.UniqueKey) - GhostTableColumns sql.ColumnList + GhostTableColumns *sql.ColumnList GhostTableUniqueKeys [](*sql.UniqueKey) UniqueKey *sql.UniqueKey - SharedColumns sql.ColumnList + SharedColumns *sql.ColumnList MigrationRangeMinValues *sql.ColumnValues MigrationRangeMaxValues *sql.ColumnValues Iteration int64 @@ -83,7 +85,8 @@ func newMigrationContext() *MigrationContext { InspectorConnectionConfig: mysql.NewConnectionConfig(), MasterConnectionConfig: mysql.NewConnectionConfig(), MaxLagMillisecondsThrottleThreshold: 1000, - MaxLoad: make(map[string]int64), + MaxLoad: make(map[string]int64), + throttleMutex: &sync.Mutex{}, } } @@ -97,6 +100,11 @@ func (this *MigrationContext) GetGhostTableName() string { return fmt.Sprintf("_%s_New", this.OriginalTableName) } +// GetOldTableName generates the name of the "old" table, into which the original table is renamed. +func (this *MigrationContext) GetOldTableName() string { + return fmt.Sprintf("_%s_Old", this.OriginalTableName) +} + // GetChangelogTableName generates the name of changelog table, based on original table name func (this *MigrationContext) GetChangelogTableName() string { return fmt.Sprintf("_%s_OSC", this.OriginalTableName) @@ -157,16 +165,17 @@ func (this *MigrationContext) GetIteration() int64 { return atomic.LoadInt64(&this.Iteration) } -func (this *MigrationContext) SetThrottled(throttle bool) { - if throttle { - atomic.StoreInt64(&this.isThrottled, 1) - } else { - atomic.StoreInt64(&this.isThrottled, 0) - } +func (this *MigrationContext) SetThrottled(throttle bool, reason string) { + this.throttleMutex.Lock() + defer func() { this.throttleMutex.Unlock() }() + this.isThrottled = throttle + this.throttleReason = reason } -func (this *MigrationContext) IsThrottled() bool { - return atomic.LoadInt64(&this.isThrottled) != 0 +func (this *MigrationContext) IsThrottled() (bool, string) { + this.throttleMutex.Lock() + defer func() { this.throttleMutex.Unlock() }() + return this.isThrottled, this.throttleReason } func (this *MigrationContext) ReadMaxLoad(maxLoadList string) error { diff --git a/go/cmd/gh-osc/main.go b/go/cmd/gh-osc/main.go index 335d419..69709e6 100644 --- a/go/cmd/gh-osc/main.go +++ b/go/cmd/gh-osc/main.go @@ -38,7 +38,8 @@ func main() { migrationContext.ChunkSize = 100000 } flag.Int64Var(&migrationContext.MaxLagMillisecondsThrottleThreshold, "max-lag-millis", 1000, "replication lag at which to throttle operation") - flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "/tmp/gh-osc.throttle", "operation pauses when this file exists") + flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists; hint: use a file that is specific to the table being altered") + flag.StringVar(&migrationContext.ThrottleAdditionalFlagFile, "throttle-additional-flag-file", "/tmp/gh-osc.throttle", "operation pauses when this file exists; hint: keep default, use for throttling multiple gh-osc operations") maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'") quiet := flag.Bool("quiet", false, "quiet") verbose := flag.Bool("verbose", false, "verbose") diff --git a/go/logic/applier.go b/go/logic/applier.go index fb23440..60c83d7 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -12,6 +12,7 @@ import ( "time" "github.com/github/gh-osc/go/base" + "github.com/github/gh-osc/go/binlog" "github.com/github/gh-osc/go/mysql" "github.com/github/gh-osc/go/sql" @@ -63,7 +64,7 @@ func (this *Applier) validateConnection() error { return nil } -// CreateGhostTable creates the ghost table on the master +// CreateGhostTable creates the ghost table on the applier host func (this *Applier) CreateGhostTable() error { query := fmt.Sprintf(`create /* gh-osc */ table %s.%s like %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), @@ -82,7 +83,7 @@ func (this *Applier) CreateGhostTable() error { return nil } -// CreateGhostTable creates the ghost table on the master +// AlterGhost applies `alter` statement on ghost table func (this *Applier) AlterGhost() error { query := fmt.Sprintf(`alter /* gh-osc */ table %s.%s %s`, sql.EscapeName(this.migrationContext.DatabaseName), @@ -101,7 +102,7 @@ func (this *Applier) AlterGhost() error { return nil } -// CreateChangelogTable creates the changelog table on the master +// CreateChangelogTable creates the changelog table on the applier host func (this *Applier) CreateChangelogTable() error { query := fmt.Sprintf(`create /* gh-osc */ table %s.%s ( id bigint auto_increment, @@ -110,7 +111,7 @@ func (this *Applier) CreateChangelogTable() error { value varchar(255) charset ascii not null, primary key(id), unique key hint_uidx(hint) - ) auto_increment=2 + ) auto_increment=256 `, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetChangelogTableName()), @@ -126,23 +127,33 @@ func (this *Applier) CreateChangelogTable() error { return nil } -// DropChangelogTable drops the changelog table on the master -func (this *Applier) DropChangelogTable() error { +// dropTable drops a given table on the applied host +func (this *Applier) dropTable(tableName string) error { query := fmt.Sprintf(`drop /* gh-osc */ table if exists %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), - sql.EscapeName(this.migrationContext.GetChangelogTableName()), + sql.EscapeName(tableName), ) - log.Infof("Droppping changelog table %s.%s", + log.Infof("Droppping table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), - sql.EscapeName(this.migrationContext.GetChangelogTableName()), + sql.EscapeName(tableName), ) if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("Changelog table dropped") + log.Infof("Table dropped") return nil } +// DropChangelogTable drops the changelog table on the applier host +func (this *Applier) DropChangelogTable() error { + return this.dropTable(this.migrationContext.GetChangelogTableName()) +} + +// DropGhostTable drops the ghost table on the applier host +func (this *Applier) DropGhostTable() error { + return this.dropTable(this.migrationContext.GetGhostTableName()) +} + // WriteChangelog writes a value to the changelog table. // It returns the hint as given, for convenience func (this *Applier) WriteChangelog(hint, value string) (string, error) { @@ -162,12 +173,15 @@ func (this *Applier) WriteChangelog(hint, value string) (string, error) { return hint, err } -func (this *Applier) WriteChangelogState(value string) (string, error) { - hint := "state" +func (this *Applier) WriteAndLogChangelog(hint, value string) (string, error) { this.WriteChangelog(hint, value) return this.WriteChangelog(fmt.Sprintf("%s at %d", hint, time.Now().UnixNano()), value) } +func (this *Applier) WriteChangelogState(value string) (string, error) { + return this.WriteAndLogChangelog("state", value) +} + // InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table. // This is done asynchronously func (this *Applier) InitiateHeartbeat() { @@ -213,7 +227,7 @@ func (this *Applier) InitiateHeartbeat() { // ReadMigrationMinValues func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) - query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) + query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names) if err != nil { return err } @@ -222,7 +236,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { return err } for rows.Next() { - this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(len(uniqueKey.Columns)) + this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(uniqueKey.Len()) if err = rows.Scan(this.migrationContext.MigrationRangeMinValues.ValuesPointers...); err != nil { return err } @@ -234,7 +248,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { // ReadMigrationMinValues func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) - query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) + query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names) if err != nil { return err } @@ -243,7 +257,7 @@ func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { return err } for rows.Next() { - this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(len(uniqueKey.Columns)) + this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(uniqueKey.Len()) if err = rows.Scan(this.migrationContext.MigrationRangeMaxValues.ValuesPointers...); err != nil { return err } @@ -272,12 +286,12 @@ func (this *Applier) __unused_IterationIsComplete() (bool, error) { return false, nil } args := sqlutils.Args() - compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign) + compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign) if err != nil { return false, err } args = append(args, explodedArgs...) - compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign) + compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign) if err != nil { return false, err } @@ -317,7 +331,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo query, explodedArgs, err := sql.BuildUniqueKeyRangeEndPreparedQuery( this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, - this.migrationContext.UniqueKey.Columns, + this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationRangeMaxValues.AbstractValues(), this.migrationContext.ChunkSize, @@ -330,7 +344,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo if err != nil { return hasFurtherRange, err } - iterationRangeMaxValues := sql.NewColumnValues(len(this.migrationContext.UniqueKey.Columns)) + iterationRangeMaxValues := sql.NewColumnValues(this.migrationContext.UniqueKey.Len()) for rows.Next() { if err = rows.Scan(iterationRangeMaxValues.ValuesPointers...); err != nil { return hasFurtherRange, err @@ -360,9 +374,9 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.GetGhostTableName(), - this.migrationContext.SharedColumns, + this.migrationContext.SharedColumns.Names, this.migrationContext.UniqueKey.Name, - this.migrationContext.UniqueKey.Columns, + this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(), this.migrationContext.GetIteration() == 0, @@ -422,3 +436,12 @@ func (this *Applier) ShowStatusVariable(variableName string) (result int64, err } return result, nil } + +func (this *Applier) BuildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) (result string, err error) { + switch dmlEvent.DML { + case binlog.DeleteDML: + { + } + } + return result, err +} diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 42e9105..ddb53f3 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -44,6 +44,9 @@ func (this *Inspector) InitDBConnections() (err error) { if err := this.validateGrants(); err != nil { return err } + if err := this.restartReplication(); err != nil { + return err + } if err := this.validateBinlogs(); err != nil { return err } @@ -69,7 +72,7 @@ func (this *Inspector) ValidateOriginalTable() (err error) { return nil } -func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) { +func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns *sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) { uniqueKeys, err = this.getCandidateUniqueKeys(tableName) if err != nil { return columns, uniqueKeys, err @@ -90,7 +93,6 @@ func (this *Inspector) InspectOriginalTable() (err error) { if err == nil { return err } - this.migrationContext.OriginalTableColumnsMap = sql.NewColumnsMap(this.migrationContext.OriginalTableColumns) return nil } @@ -108,6 +110,11 @@ func (this *Inspector) InspectOriginalAndGhostTables() (err error) { } this.migrationContext.UniqueKey = sharedUniqueKeys[0] log.Infof("Chosen shared unique key is %s", this.migrationContext.UniqueKey.Name) + if !this.migrationContext.UniqueKey.IsPrimary() { + if this.migrationContext.OriginalBinlogRowImage != "full" { + return fmt.Errorf("binlog_row_image is '%s' and chosen key is %s, which is not the primary key. This operation cannot proceed. You may `set global binlog_row_image='full'` and try again") + } + } this.migrationContext.SharedColumns = this.getSharedColumns(this.migrationContext.OriginalTableColumns, this.migrationContext.GhostTableColumns) log.Infof("Shared columns are %s", this.migrationContext.SharedColumns) @@ -171,6 +178,26 @@ func (this *Inspector) validateGrants() error { return log.Errorf("User has insufficient privileges for migration.") } +// restartReplication is required so that we are _certain_ the binlog format and +// row image settings have actually been applied to the replication thread. +// It is entriely possible, for example, that the replication is using 'STATEMENT' +// binlog format even as the variable says 'ROW' +func (this *Inspector) restartReplication() error { + log.Infof("Restarting replication on %s:%d to make sure binlog settings apply to replication thread", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + + var stopError, startError error + _, stopError = sqlutils.ExecNoPrepare(this.db, `stop slave`) + _, startError = sqlutils.ExecNoPrepare(this.db, `start slave`) + if stopError != nil { + return stopError + } + if startError != nil { + return startError + } + log.Debugf("Replication restarted") + return nil +} + // validateBinlogs checks that binary log configuration is good to go func (this *Inspector) validateBinlogs() error { query := `select @@global.log_bin, @@global.log_slave_updates, @@global.binlog_format` @@ -299,27 +326,28 @@ func (this *Inspector) countTableRows() error { return nil } -func (this *Inspector) getTableColumns(databaseName, tableName string) (columns sql.ColumnList, err error) { +func (this *Inspector) getTableColumns(databaseName, tableName string) (*sql.ColumnList, error) { query := fmt.Sprintf(` show columns from %s.%s `, sql.EscapeName(databaseName), sql.EscapeName(tableName), ) - err = sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { - columns = append(columns, rowMap.GetString("Field")) + columnNames := []string{} + err := sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { + columnNames = append(columnNames, rowMap.GetString("Field")) return nil }) if err != nil { - return columns, err + return nil, err } - if len(columns) == 0 { - return columns, log.Errorf("Found 0 columns on %s.%s. Bailing out", + if len(columnNames) == 0 { + return nil, log.Errorf("Found 0 columns on %s.%s. Bailing out", sql.EscapeName(databaseName), sql.EscapeName(tableName), ) } - return columns, nil + return sql.NewColumnList(columnNames), nil } // getCandidateUniqueKeys investigates a table and returns the list of unique keys @@ -412,17 +440,18 @@ func (this *Inspector) getSharedUniqueKeys(originalUniqueKeys, ghostUniqueKeys [ } // getSharedColumns returns the intersection of two lists of columns in same order as the first list -func (this *Inspector) getSharedColumns(originalColumns, ghostColumns sql.ColumnList) (sharedColumns sql.ColumnList) { +func (this *Inspector) getSharedColumns(originalColumns, ghostColumns *sql.ColumnList) *sql.ColumnList { columnsInGhost := make(map[string]bool) - for _, ghostColumn := range ghostColumns { + for _, ghostColumn := range ghostColumns.Names { columnsInGhost[ghostColumn] = true } - for _, originalColumn := range originalColumns { + sharedColumnNames := []string{} + for _, originalColumn := range originalColumns.Names { if columnsInGhost[originalColumn] { - sharedColumns = append(sharedColumns, originalColumn) + sharedColumnNames = append(sharedColumnNames, originalColumn) } } - return sharedColumns + return sql.NewColumnList(sharedColumnNames) } func (this *Inspector) getMasterConnectionConfig() (masterConfig *mysql.ConnectionConfig, err error) { diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 4df0e4a..3dee61d 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -8,8 +8,10 @@ package logic import ( "fmt" "os" + "os/signal" "regexp" "sync/atomic" + "syscall" "time" "github.com/github/gh-osc/go/base" @@ -74,6 +76,21 @@ func prettifyDurationOutput(d time.Duration) string { return result } +// acceptSignals registers for OS signals +func (this *Migrator) acceptSignals() { + c := make(chan os.Signal, 1) + + signal.Notify(c, syscall.SIGHUP) + go func() { + for sig := range c { + switch sig { + case syscall.SIGHUP: + log.Debugf("Received SIGHUP. Reloading configuration") + } + } + }() +} + func (this *Migrator) shouldThrottle() (result bool, reason string) { lag := atomic.LoadInt64(&this.migrationContext.CurrentLag) @@ -82,7 +99,13 @@ func (this *Migrator) shouldThrottle() (result bool, reason string) { } if this.migrationContext.ThrottleFlagFile != "" { if _, err := os.Stat(this.migrationContext.ThrottleFlagFile); err == nil { - //Throttle file defined and exists! + // Throttle file defined and exists! + return true, "flag-file" + } + } + if this.migrationContext.ThrottleAdditionalFlagFile != "" { + if _, err := os.Stat(this.migrationContext.ThrottleAdditionalFlagFile); err == nil { + // 2nd Throttle file defined and exists! return true, "flag-file" } } @@ -100,37 +123,43 @@ func (this *Migrator) shouldThrottle() (result bool, reason string) { return false, "" } +func (this *Migrator) initiateThrottler() error { + throttlerTick := time.Tick(1 * time.Second) + + throttlerFunction := func() { + alreadyThrottling, currentReason := this.migrationContext.IsThrottled() + shouldThrottle, throttleReason := this.shouldThrottle() + if shouldThrottle && !alreadyThrottling { + // New throttling + this.applier.WriteAndLogChangelog("throttle", throttleReason) + } else if shouldThrottle && alreadyThrottling && (currentReason != throttleReason) { + // Change of reason + this.applier.WriteAndLogChangelog("throttle", throttleReason) + } else if alreadyThrottling && !shouldThrottle { + // End of throttling + this.applier.WriteAndLogChangelog("throttle", "done throttling") + } + this.migrationContext.SetThrottled(shouldThrottle, throttleReason) + } + throttlerFunction() + for range throttlerTick { + throttlerFunction() + } + + return nil +} + // throttle initiates a throttling event, if need be, updates the Context and // calls callback functions, if any -func (this *Migrator) throttle( - onStartThrottling func(), - onContinuousThrottling func(), - onEndThrottling func(), -) { - hasThrottledYet := false +func (this *Migrator) throttle(onThrottled func()) { for { - shouldThrottle, reason := this.shouldThrottle() - if !shouldThrottle { - break + if shouldThrottle, _ := this.migrationContext.IsThrottled(); !shouldThrottle { + return } - this.migrationContext.ThrottleReason = reason - if !hasThrottledYet { - hasThrottledYet = true - if onStartThrottling != nil { - onStartThrottling() - } - this.migrationContext.SetThrottled(true) + if onThrottled != nil { + onThrottled() } time.Sleep(time.Second) - if onContinuousThrottling != nil { - onContinuousThrottling() - } - } - if hasThrottledYet { - if onEndThrottling != nil { - onEndThrottling() - } - this.migrationContext.SetThrottled(false) } } @@ -239,6 +268,7 @@ func (this *Migrator) Migrate() (err error) { if err := this.applier.ReadMigrationRangeValues(); err != nil { return err } + go this.initiateThrottler() go this.executeWriteFuncs() go this.iterateChunks() this.migrationContext.RowCopyStartTime = time.Now() @@ -249,15 +279,9 @@ func (this *Migrator) Migrate() (err error) { log.Debugf("Row copy complete") this.printStatus() - this.throttle( - func() { - log.Debugf("throttling before LOCK TABLES") - }, - nil, - func() { - log.Debugf("done throttling") - }, - ) + this.throttle(func() { + log.Debugf("throttling on LOCK TABLES") + }) // TODO retries!! this.applier.LockTables() this.applier.WriteChangelogState(string(AllEventsUpToLockProcessed)) @@ -304,8 +328,8 @@ func (this *Migrator) printStatus() { } eta := "N/A" - if this.migrationContext.IsThrottled() { - eta = fmt.Sprintf("throttled, %s", this.migrationContext.ThrottleReason) + if isThrottled, throttleReason := this.migrationContext.IsThrottled(); isThrottled { + eta = fmt.Sprintf("throttled, %s", throttleReason) } status := fmt.Sprintf("Copy: %d/%d %.1f%%; Backlog: %d/%d; Elapsed: %+v(copy), %+v(total); ETA: %s", totalRowsCopied, rowsEstimate, progressPct, @@ -399,14 +423,8 @@ func (this *Migrator) iterateChunks() error { } func (this *Migrator) executeWriteFuncs() error { - onStartThrottling := func() { - log.Debugf("throttling writes") - } - onEndThrottling := func() { - log.Debugf("done throttling writes") - } for { - this.throttle(onStartThrottling, nil, onEndThrottling) + this.throttle(nil) // We give higher priority to event processing, then secondary priority to // rowcopy select { diff --git a/go/sql/builder.go b/go/sql/builder.go index b8f0358..784b1c9 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -32,6 +32,20 @@ func EscapeName(name string) string { return fmt.Sprintf("`%s`", name) } +func buildPreparedValues(length int) []string { + values := make([]string, length, length) + for i := 0; i < length; i++ { + values[i] = "?" + } + return values +} + +func duplicateNames(names []string) []string { + duplicate := make([]string, len(names), len(names)) + copy(duplicate, names) + return duplicate +} + func BuildValueComparison(column string, value string, comparisonSign ValueComparisonSign) (result string, err error) { if column == "" { return "", fmt.Errorf("Empty column in GetValueComparison") @@ -64,6 +78,11 @@ func BuildEqualsComparison(columns []string, values []string) (result string, er return result, nil } +func BuildEqualsPreparedComparison(columns []string) (result string, err error) { + values := buildPreparedValues(len(columns)) + return BuildEqualsComparison(columns, values) +} + func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { if len(columns) == 0 { return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison") @@ -121,10 +140,7 @@ func BuildRangeComparison(columns []string, values []string, args []interface{}, } func BuildRangePreparedComparison(columns []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { - values := make([]string, len(columns), len(columns)) - for i := range columns { - values[i] = "?" - } + values := buildPreparedValues(len(columns)) return BuildRangeComparison(columns, values, args, comparisonSign) } @@ -135,6 +151,7 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin databaseName = EscapeName(databaseName) originalTableName = EscapeName(originalTableName) ghostTableName = EscapeName(ghostTableName) + sharedColumns = duplicateNames(sharedColumns) for i := range sharedColumns { sharedColumns[i] = EscapeName(sharedColumns[i]) } @@ -171,12 +188,8 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin } func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { - rangeStartValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - rangeEndValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - for i := range uniqueKeyColumns { - rangeStartValues[i] = "?" - rangeEndValues[i] = "?" - } + rangeStartValues := buildPreparedValues(len(uniqueKeyColumns)) + rangeEndValues := buildPreparedValues(len(uniqueKeyColumns)) return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable) } @@ -198,6 +211,7 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK } explodedArgs = append(explodedArgs, rangeExplodedArgs...) + uniqueKeyColumns = duplicateNames(uniqueKeyColumns) uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) for i := range uniqueKeyColumns { @@ -244,6 +258,7 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni databaseName = EscapeName(databaseName) tableName = EscapeName(tableName) + uniqueKeyColumns = duplicateNames(uniqueKeyColumns) uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) for i := range uniqueKeyColumns { uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i]) @@ -262,3 +277,65 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni ) return query, nil } + +func BuildDMLDeleteQuery(databaseName, tableName string, originalTableColumns, uniqueKeyColumns *ColumnList, args []interface{}) (result string, uniqueKeyArgs []interface{}, err error) { + if len(args) != originalTableColumns.Len() { + return result, uniqueKeyArgs, fmt.Errorf("args count differs from table column count in BuildDMLDeleteQuery") + } + for _, column := range uniqueKeyColumns.Names { + tableOrdinal := originalTableColumns.Ordinals[column] + uniqueKeyArgs = append(uniqueKeyArgs, args[tableOrdinal]) + } + databaseName = EscapeName(databaseName) + tableName = EscapeName(tableName) + equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names) + result = fmt.Sprintf(` + delete /* gh-osc %s.%s */ + from + %s.%s + where + %s + `, databaseName, tableName, + databaseName, tableName, + equalsComparison, + ) + return result, uniqueKeyArgs, err +} + +func BuildDMLInsertQuery(databaseName, tableName string, originalTableColumns, sharedColumns *ColumnList, args []interface{}) (result string, sharedArgs []interface{}, err error) { + if len(args) != originalTableColumns.Len() { + return result, args, fmt.Errorf("args count differs from table column count in BuildDMLInsertQuery") + } + if !sharedColumns.IsSubsetOf(originalTableColumns) { + return result, args, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLInsertQuery") + } + if sharedColumns.Len() == 0 { + return result, args, fmt.Errorf("No shared columns found in BuildDMLInsertQuery") + } + databaseName = EscapeName(databaseName) + tableName = EscapeName(tableName) + + for _, column := range sharedColumns.Names { + tableOrdinal := originalTableColumns.Ordinals[column] + sharedArgs = append(sharedArgs, args[tableOrdinal]) + } + + sharedColumnNames := duplicateNames(sharedColumns.Names) + for i := range sharedColumnNames { + sharedColumnNames[i] = EscapeName(sharedColumnNames[i]) + } + preparedValues := buildPreparedValues(sharedColumns.Len()) + + result = fmt.Sprintf(` + replace /* gh-osc %s.%s */ into + %s.%s + (%s) + values + (%s) + `, databaseName, tableName, + databaseName, tableName, + strings.Join(sharedColumnNames, ", "), + strings.Join(preparedValues, ", "), + ) + return result, sharedArgs, err +} diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go index b101bbf..5a57a3d 100644 --- a/go/sql/builder_test.go +++ b/go/sql/builder_test.go @@ -68,6 +68,15 @@ func TestBuildEqualsComparison(t *testing.T) { } } +func TestBuildEqualsPreparedComparison(t *testing.T) { + { + columns := []string{"c1", "c2"} + comparison, err := BuildEqualsPreparedComparison(columns) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(comparison, "((`c1` = ?) and (`c2` = ?))") + } +} + func TestBuildRangeComparison(t *testing.T) { { columns := []string{"c1"} @@ -143,7 +152,7 @@ func TestBuildRangeInsertQuery(t *testing.T) { rangeStartArgs := []interface{}{3} rangeEndArgs := []interface{}{103} - query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) + query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true, false) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) @@ -162,7 +171,7 @@ func TestBuildRangeInsertQuery(t *testing.T) { rangeStartArgs := []interface{}{3, 17} rangeEndArgs := []interface{}{103, 117} - query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) + query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true, false) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) @@ -186,13 +195,13 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) { rangeStartArgs := []interface{}{3, 17} rangeEndArgs := []interface{}{103, 117} - query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true) + query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true, true) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) (select id, name, position from mydb.tbl force index (name_position_uidx) where (((name > ?) or (((name = ?)) AND (position > ?)) or ((name = ?) and (position = ?))) and ((name < ?) or (((name = ?)) AND (position < ?)) or ((name = ?) and (position = ?)))) - ) + lock in share mode ) ` test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 103, 103, 117, 103, 117})) @@ -202,7 +211,7 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) { func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) { databaseName := "mydb" originalTableName := "tbl" - chunkSize := 500 + var chunkSize int64 = 500 { uniqueKeyColumns := []string{"name", "position"} rangeStartArgs := []interface{}{3, 17} @@ -262,3 +271,107 @@ func TestBuildUniqueKeyMinValuesPreparedQuery(t *testing.T) { test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) } } + +func TestBuildDMLDeleteQuery(t *testing.T) { + databaseName := "mydb" + tableName := "tbl" + tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"}) + args := []interface{}{3, "testname", "first", 17, 23} + { + uniqueKeyColumns := NewColumnList([]string{"position"}) + + query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNil(err) + expected := ` + delete /* gh-osc mydb.tbl */ + from + mydb.tbl + where + ((position = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17})) + } + { + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) + + query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNil(err) + expected := ` + delete /* gh-osc mydb.tbl */ + from + mydb.tbl + where + ((name = ?) and (position = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{"testname", 17})) + } + { + uniqueKeyColumns := NewColumnList([]string{"position", "name"}) + + query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNil(err) + expected := ` + delete /* gh-osc mydb.tbl */ + from + mydb.tbl + where + ((position = ?) and (name = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17, "testname"})) + } + { + uniqueKeyColumns := NewColumnList([]string{"position", "name"}) + args := []interface{}{"first", 17} + + _, _, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNotNil(err) + } +} + +func TestBuildDMLInsertQuery(t *testing.T) { + databaseName := "mydb" + tableName := "tbl" + tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"}) + args := []interface{}{3, "testname", "first", 17, 23} + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + query, sharedArgs, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNil(err) + expected := ` + replace /* gh-osc mydb.tbl */ + into mydb.tbl + (id, name, position, age) + values + (?, ?, ?, ?) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + } + { + sharedColumns := NewColumnList([]string{"position", "name", "age", "id"}) + query, sharedArgs, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNil(err) + expected := ` + replace /* gh-osc mydb.tbl */ + into mydb.tbl + (position, name, age, id) + values + (?, ?, ?, ?) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{17, "testname", 23, 3})) + } + { + sharedColumns := NewColumnList([]string{"position", "name", "surprise", "id"}) + _, _, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNotNil(err) + } + { + sharedColumns := NewColumnList([]string{}) + _, _, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNotNil(err) + } +} diff --git a/go/sql/types.go b/go/sql/types.go index cda02d3..e82720c 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -11,34 +11,64 @@ import ( "strings" ) -// ColumnList makes for a named list of columns -type ColumnList []string - -// ParseColumnList parses a comma delimited list of column names -func ParseColumnList(columns string) *ColumnList { - result := ColumnList(strings.Split(columns, ",")) - return &result -} - -func (this *ColumnList) String() string { - return strings.Join(*this, ",") -} - -func (this *ColumnList) Equals(other *ColumnList) bool { - return reflect.DeepEqual(*this, *other) -} - // ColumnsMap maps a column onto its ordinal position type ColumnsMap map[string]int -func NewColumnsMap(columnList ColumnList) ColumnsMap { +func NewColumnsMap(orderedNames []string) ColumnsMap { columnsMap := make(map[string]int) - for i, column := range columnList { + for i, column := range orderedNames { columnsMap[column] = i } return ColumnsMap(columnsMap) } +// ColumnList makes for a named list of columns +type ColumnList struct { + Names []string + Ordinals ColumnsMap +} + +// NewColumnList creates an object given ordered list of column names +func NewColumnList(names []string) *ColumnList { + result := &ColumnList{ + Names: names, + } + result.Ordinals = NewColumnsMap(result.Names) + return result +} + +// ParseColumnList parses a comma delimited list of column names +func ParseColumnList(columns string) *ColumnList { + result := &ColumnList{ + Names: strings.Split(columns, ","), + } + result.Ordinals = NewColumnsMap(result.Names) + return result +} + +func (this *ColumnList) String() string { + return strings.Join(this.Names, ",") +} + +func (this *ColumnList) Equals(other *ColumnList) bool { + return reflect.DeepEqual(this.Names, other.Names) +} + +// IsSubsetOf returns 'true' when column names of this list are a subset of +// another list, in arbitrary order (order agnostic) +func (this *ColumnList) IsSubsetOf(other *ColumnList) bool { + for _, column := range this.Names { + if _, exists := other.Ordinals[column]; !exists { + return false + } + } + return true +} + +func (this *ColumnList) Len() int { + return len(this.Names) +} + // UniqueKey is the combination of a key's name and columns type UniqueKey struct { Name string @@ -51,6 +81,10 @@ func (this *UniqueKey) IsPrimary() bool { return this.Name == "PRIMARY" } +func (this *UniqueKey) Len() int { + return this.Columns.Len() +} + func (this *UniqueKey) String() string { return fmt.Sprintf("%s: %s; has nullable: %+v", this.Name, this.Columns, this.HasNullable) } From 0d25d11b40d9638a9f709294abc2a9c6b447040b Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Mon, 11 Apr 2016 19:06:47 +0200 Subject: [PATCH 4/5] added types_test --- go/sql/types_test.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 go/sql/types_test.go diff --git a/go/sql/types_test.go b/go/sql/types_test.go new file mode 100644 index 0000000..177b8cf --- /dev/null +++ b/go/sql/types_test.go @@ -0,0 +1,29 @@ +/* + Copyright 2016 GitHub Inc. + See https://github.com/github/gh-osc/blob/master/LICENSE +*/ + +package sql + +import ( + "testing" + + "github.com/outbrain/golib/log" + test "github.com/outbrain/golib/tests" + "reflect" +) + +func init() { + log.SetLevel(log.ERROR) +} + +func TestParseColumnList(t *testing.T) { + names := "id,category,max_len" + + columnList := ParseColumnList(names) + test.S(t).ExpectEquals(columnList.Len(), 3) + test.S(t).ExpectTrue(reflect.DeepEqual(columnList.Names, []string{"id", "category", "max_len"})) + test.S(t).ExpectEquals(columnList.Ordinals["id"], 0) + test.S(t).ExpectEquals(columnList.Ordinals["category"], 1) + test.S(t).ExpectEquals(columnList.Ordinals["max_len"], 2) +} From a4ee80df13801d7bf725701aef0b48072fbdca86 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 14 Apr 2016 13:37:56 +0200 Subject: [PATCH 5/5] - Building and applying queries from binlog event data! - `INSERT`, `DELETE`, `UPDATE` statements - support for `--noop` - initial support for `--test-on-replica`. Verifying against `--allow-on-master` - Changelog events no longer read from binlog stream, because reading it may be throttled, and we have to be able to keep reading the heartbeat and state events. They are now being read directly from table, mapping already-seen-events to avoid confusion Changlelog listener pools table in 2*frequency of heartbeat injection --- go/base/context.go | 23 +++++---- go/cmd/gh-osc/main.go | 9 +++- go/logic/applier.go | 115 +++++++++++++++++++++++------------------ go/logic/inspect.go | 69 +++++++++++++++++-------- go/logic/migrator.go | 108 ++++++++++++++++++++++++++++---------- go/sql/builder.go | 89 +++++++++++++++++++++++++++---- go/sql/builder_test.go | 104 +++++++++++++++++++++++++++++++++++++ 7 files changed, 397 insertions(+), 120 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index 9611869..5f1876b 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -33,9 +33,13 @@ const ( // MigrationContext has the general, global state of migration. It is used by // all components throughout the migration process. type MigrationContext struct { - DatabaseName string - OriginalTableName string - AlterStatement string + DatabaseName string + OriginalTableName string + AlterStatement string + + Noop bool + TestOnReplica bool + TableEngine string CountTableRows bool RowsEstimate int64 @@ -45,7 +49,7 @@ type MigrationContext struct { OriginalBinlogRowImage string AllowedRunningOnMaster bool InspectorConnectionConfig *mysql.ConnectionConfig - MasterConnectionConfig *mysql.ConnectionConfig + ApplierConnectionConfig *mysql.ConnectionConfig StartTime time.Time RowCopyStartTime time.Time CurrentLag int64 @@ -83,7 +87,7 @@ func newMigrationContext() *MigrationContext { return &MigrationContext{ ChunkSize: 1000, InspectorConnectionConfig: mysql.NewConnectionConfig(), - MasterConnectionConfig: mysql.NewConnectionConfig(), + ApplierConnectionConfig: mysql.NewConnectionConfig(), MaxLagMillisecondsThrottleThreshold: 1000, MaxLoad: make(map[string]int64), throttleMutex: &sync.Mutex{}, @@ -115,10 +119,11 @@ func (this *MigrationContext) RequiresBinlogFormatChange() bool { return this.OriginalBinlogFormat != "ROW" } -// IsRunningOnMaster is `true` when the app connects directly to the master (typically -// it should be executed on replica and infer the master) -func (this *MigrationContext) IsRunningOnMaster() bool { - return this.InspectorConnectionConfig.Equals(this.MasterConnectionConfig) +// InspectorIsAlsoApplier is `true` when the both inspector and applier are the +// same database instance. This would be true when running directly on master or when +// testing on replica. +func (this *MigrationContext) InspectorIsAlsoApplier() bool { + return this.InspectorConnectionConfig.Equals(this.ApplierConnectionConfig) } // HasMigrationRange tells us whether there's a range to iterate for copying rows. diff --git a/go/cmd/gh-osc/main.go b/go/cmd/gh-osc/main.go index 69709e6..09aa337 100644 --- a/go/cmd/gh-osc/main.go +++ b/go/cmd/gh-osc/main.go @@ -30,6 +30,9 @@ func main() { flag.BoolVar(&migrationContext.CountTableRows, "exact-rowcount", false, "actually count table rows as opposed to estimate them (results in more accurate progress estimation)") flag.BoolVar(&migrationContext.AllowedRunningOnMaster, "allow-on-master", false, "allow this migration to run directly on master. Preferably it would run on a replica") + executeFlag := flag.Bool("execute", false, "actually execute the alter & migrate the table. Default is noop: do some tests and exit") + flag.BoolVar(&migrationContext.TestOnReplica, "test-on-replica", false, "Have the migration run on a replica, not on the master. At the end of migration tables are not swapped; gh-osc issues `STOP SLAVE` and you can compare the two tables for building trust") + flag.Int64Var(&migrationContext.ChunkSize, "chunk-size", 1000, "amount of rows to handle in each iteration (allowed range: 100-100,000)") if migrationContext.ChunkSize < 100 { migrationContext.ChunkSize = 100 @@ -37,7 +40,7 @@ func main() { if migrationContext.ChunkSize > 100000 { migrationContext.ChunkSize = 100000 } - flag.Int64Var(&migrationContext.MaxLagMillisecondsThrottleThreshold, "max-lag-millis", 1000, "replication lag at which to throttle operation") + flag.Int64Var(&migrationContext.MaxLagMillisecondsThrottleThreshold, "max-lag-millis", 1500, "replication lag at which to throttle operation") flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists; hint: use a file that is specific to the table being altered") flag.StringVar(&migrationContext.ThrottleAdditionalFlagFile, "throttle-additional-flag-file", "/tmp/gh-osc.throttle", "operation pauses when this file exists; hint: keep default, use for throttling multiple gh-osc operations") maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'") @@ -78,6 +81,10 @@ func main() { if migrationContext.AlterStatement == "" { log.Fatalf("--alter must be provided and statement must not be empty") } + migrationContext.Noop = !(*executeFlag) + if migrationContext.AllowedRunningOnMaster && migrationContext.TestOnReplica { + log.Fatalf("--allow-on-master and --test-on-replica are mutually exclusive") + } if err := migrationContext.ReadMaxLoad(*maxLoad); err != nil { log.Fatale(err) } diff --git a/go/logic/applier.go b/go/logic/applier.go index 60c83d7..5e7c83f 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -20,10 +20,6 @@ import ( "github.com/outbrain/golib/sqlutils" ) -const ( - heartbeatIntervalSeconds = 1 -) - // Applier reads data from the read-MySQL-server (typically a replica, but can be the master) // It is used for gaining initial status and structure, and later also follow up on progress and changelog type Applier struct { @@ -34,7 +30,7 @@ type Applier struct { func NewApplier() *Applier { return &Applier{ - connectionConfig: base.GetMigrationContext().MasterConnectionConfig, + connectionConfig: base.GetMigrationContext().ApplierConnectionConfig, migrationContext: base.GetMigrationContext(), } } @@ -157,11 +153,20 @@ func (this *Applier) DropGhostTable() error { // WriteChangelog writes a value to the changelog table. // It returns the hint as given, for convenience func (this *Applier) WriteChangelog(hint, value string) (string, error) { + explicitId := 0 + switch hint { + case "heartbeat": + explicitId = 1 + case "state": + explicitId = 2 + case "throttle": + explicitId = 3 + } query := fmt.Sprintf(` insert /* gh-osc */ into %s.%s (id, hint, value) values - (NULL, ?, ?) + (NULLIF(?, 0), ?, ?) on duplicate key update last_update=NOW(), value=VALUES(value) @@ -169,7 +174,7 @@ func (this *Applier) WriteChangelog(hint, value string) (string, error) { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetChangelogTableName()), ) - _, err := sqlutils.Exec(this.db, query, hint, value) + _, err := sqlutils.Exec(this.db, query, explicitId, hint, value) return hint, err } @@ -184,44 +189,30 @@ func (this *Applier) WriteChangelogState(value string) (string, error) { // InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table. // This is done asynchronously -func (this *Applier) InitiateHeartbeat() { - go func() { - numSuccessiveFailures := 0 - query := fmt.Sprintf(` - insert /* gh-osc */ into %s.%s - (id, hint, value) - values - (1, 'heartbeat', ?) - on duplicate key update - last_update=NOW(), - value=VALUES(value) - `, - sql.EscapeName(this.migrationContext.DatabaseName), - sql.EscapeName(this.migrationContext.GetChangelogTableName()), - ) - injectHeartbeat := func() error { - if _, err := sqlutils.ExecNoPrepare(this.db, query, time.Now().Format(time.RFC3339)); err != nil { - numSuccessiveFailures++ - if numSuccessiveFailures > this.migrationContext.MaxRetries() { - return log.Errore(err) - } - } else { - numSuccessiveFailures = 0 +func (this *Applier) InitiateHeartbeat(heartbeatIntervalMilliseconds int64) { + numSuccessiveFailures := 0 + injectHeartbeat := func() error { + if _, err := this.WriteChangelog("heartbeat", time.Now().Format(time.RFC3339)); err != nil { + numSuccessiveFailures++ + if numSuccessiveFailures > this.migrationContext.MaxRetries() { + return log.Errore(err) } - return nil + } else { + numSuccessiveFailures = 0 } - injectHeartbeat() + return nil + } + injectHeartbeat() - heartbeatTick := time.Tick(time.Duration(heartbeatIntervalSeconds) * time.Second) - for range heartbeatTick { - // Generally speaking, we would issue a goroutine, but I'd actually rather - // have this blocked rather than spam the master in the event something - // goes wrong - if err := injectHeartbeat(); err != nil { - return - } + heartbeatTick := time.Tick(time.Duration(heartbeatIntervalMilliseconds) * time.Millisecond) + for range heartbeatTick { + // Generally speaking, we would issue a goroutine, but I'd actually rather + // have this blocked rather than spam the master in the event something + // goes wrong + if err := injectHeartbeat(); err != nil { + return } - }() + } } // ReadMigrationMinValues @@ -352,17 +343,10 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo hasFurtherRange = true } if !hasFurtherRange { - log.Debugf("Iteration complete: cannot find iteration end") + log.Debugf("Iteration complete: no further range to iterate") return hasFurtherRange, nil } this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeMaxValues - // log.Debugf( - // "column values: [%s]..[%s]; iteration: %d; chunk-size: %d", - // this.migrationContext.MigrationIterationRangeMinValues, - // this.migrationContext.MigrationIterationRangeMaxValues, - // this.migrationContext.GetIteration(), - // this.migrationContext.ChunkSize, - // ) return hasFurtherRange, nil } @@ -402,6 +386,11 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected // LockTables func (this *Applier) LockTables() error { + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really locking tables") + return nil + } + query := fmt.Sprintf(`lock /* gh-osc */ tables %s.%s write, %s.%s write, %s.%s write`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), @@ -437,11 +426,35 @@ func (this *Applier) ShowStatusVariable(variableName string) (result int64, err return result, nil } -func (this *Applier) BuildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) (result string, err error) { +func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) (query string, args []interface{}, err error) { switch dmlEvent.DML { case binlog.DeleteDML: { + query, uniqueKeyArgs, err := sql.BuildDMLDeleteQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, &this.migrationContext.UniqueKey.Columns, dmlEvent.WhereColumnValues.AbstractValues()) + return query, uniqueKeyArgs, err + } + case binlog.InsertDML: + { + query, sharedArgs, err := sql.BuildDMLInsertQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, dmlEvent.NewColumnValues.AbstractValues()) + return query, sharedArgs, err + } + case binlog.UpdateDML: + { + query, sharedArgs, uniqueKeyArgs, err := sql.BuildDMLUpdateQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, &this.migrationContext.UniqueKey.Columns, dmlEvent.NewColumnValues.AbstractValues(), dmlEvent.WhereColumnValues.AbstractValues()) + args = append(args, sharedArgs...) + args = append(args, uniqueKeyArgs...) + return query, args, err } } - return result, err + return "", args, fmt.Errorf("Unknown dml event type: %+v", dmlEvent.DML) +} + +func (this *Applier) ApplyDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) error { + query, args, err := this.buildDMLEventQuery(dmlEvent) + if err != nil { + return err + } + log.Errorf(query) + _, err = sqlutils.Exec(this.db, query, args...) + return err } diff --git a/go/logic/inspect.go b/go/logic/inspect.go index ddb53f3..902ccbe 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -185,6 +185,12 @@ func (this *Inspector) validateGrants() error { func (this *Inspector) restartReplication() error { log.Infof("Restarting replication on %s:%d to make sure binlog settings apply to replication thread", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + masterKey, _ := getMasterKeyFromSlaveStatus(this.connectionConfig) + if masterKey == nil { + // This is not a replica + return nil + } + var stopError, startError error _, stopError = sqlutils.ExecNoPrepare(this.db, `stop slave`) _, startError = sqlutils.ExecNoPrepare(this.db, `start slave`) @@ -454,43 +460,62 @@ func (this *Inspector) getSharedColumns(originalColumns, ghostColumns *sql.Colum return sql.NewColumnList(sharedColumnNames) } -func (this *Inspector) getMasterConnectionConfig() (masterConfig *mysql.ConnectionConfig, err error) { - visitedKeys := mysql.NewInstanceKeyMap() - return getMasterConnectionConfigSafe(this.connectionConfig, this.migrationContext.DatabaseName, visitedKeys) +func (this *Inspector) readChangelogState() (map[string]string, error) { + query := fmt.Sprintf(` + select hint, value from %s.%s where id <= 255 + `, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetChangelogTableName()), + ) + result := make(map[string]string) + err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error { + result[m.GetString("hint")] = m.GetString("value") + return nil + }) + return result, err } -func getMasterConnectionConfigSafe(connectionConfig *mysql.ConnectionConfig, databaseName string, visitedKeys *mysql.InstanceKeyMap) (masterConfig *mysql.ConnectionConfig, err error) { - log.Debugf("Looking for master on %+v", connectionConfig.Key) +func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.ConnectionConfig, err error) { + visitedKeys := mysql.NewInstanceKeyMap() + return getMasterConnectionConfigSafe(this.connectionConfig, visitedKeys) +} - currentUri := connectionConfig.GetDBUri(databaseName) +func getMasterKeyFromSlaveStatus(connectionConfig *mysql.ConnectionConfig) (masterKey *mysql.InstanceKey, err error) { + currentUri := connectionConfig.GetDBUri("information_schema") db, _, err := sqlutils.GetDB(currentUri) if err != nil { return nil, err } - - hasMaster := false - masterConfig = connectionConfig.Duplicate() err = sqlutils.QueryRowsMap(db, `show slave status`, func(rowMap sqlutils.RowMap) error { - masterKey := mysql.InstanceKey{ + masterKey = &mysql.InstanceKey{ Hostname: rowMap.GetString("Master_Host"), Port: rowMap.GetInt("Master_Port"), } - if masterKey.IsValid() { - masterConfig.Key = masterKey - hasMaster = true - } return nil }) + return masterKey, err +} + +func getMasterConnectionConfigSafe(connectionConfig *mysql.ConnectionConfig, visitedKeys *mysql.InstanceKeyMap) (masterConfig *mysql.ConnectionConfig, err error) { + log.Debugf("Looking for master on %+v", connectionConfig.Key) + + masterKey, err := getMasterKeyFromSlaveStatus(connectionConfig) if err != nil { return nil, err } - if hasMaster { - log.Debugf("Master of %+v is %+v", connectionConfig.Key, masterConfig.Key) - if visitedKeys.HasKey(masterConfig.Key) { - return nil, fmt.Errorf("There seems to be a master-master setup at %+v. This is unsupported. Bailing out", masterConfig.Key) - } - visitedKeys.AddKey(masterConfig.Key) - return getMasterConnectionConfigSafe(masterConfig, databaseName, visitedKeys) + if masterKey == nil { + return connectionConfig, nil } - return masterConfig, nil + if !masterKey.IsValid() { + return connectionConfig, nil + } + masterConfig = connectionConfig.Duplicate() + masterConfig.Key = *masterKey + + log.Debugf("Master of %+v is %+v", connectionConfig.Key, masterConfig.Key) + if visitedKeys.HasKey(masterConfig.Key) { + return nil, fmt.Errorf("There seems to be a master-master setup at %+v. This is unsupported. Bailing out", masterConfig.Key) + } + visitedKeys.AddKey(masterConfig.Key) + return getMasterConnectionConfigSafe(masterConfig, visitedKeys) } diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 3dee61d..62fc6e3 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -30,7 +30,8 @@ const ( type tableWriteFunc func() error const ( - applyEventsQueueBuffer = 100 + applyEventsQueueBuffer = 100 + heartbeatIntervalMilliseconds = 1000 ) var ( @@ -52,6 +53,8 @@ type Migrator struct { // excessive work happens at the end of the iteration as new copy-jobs arrive befroe realizing the copy is complete copyRowsQueue chan tableWriteFunc applyEventsQueue chan tableWriteFunc + + handledChangelogStates map[string]bool } func NewMigrator() *Migrator { @@ -61,8 +64,9 @@ func NewMigrator() *Migrator { rowCopyComplete: make(chan bool), allEventsUpToLockProcessed: make(chan bool), - copyRowsQueue: make(chan tableWriteFunc), - applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), + copyRowsQueue: make(chan tableWriteFunc), + applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), + handledChangelogStates: make(map[string]bool), } return migrator } @@ -185,12 +189,13 @@ func (this *Migrator) canStopStreaming() bool { return false } -func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { - // Hey, I created the changlog table, I know the type of columns it has! - if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "state" { +func (this *Migrator) onChangelogState(stateValue string) (err error) { + if this.handledChangelogStates[stateValue] { return } - changelogState := ChangelogState(dmlEvent.NewColumnValues.StringColumn(3)) + this.handledChangelogStates[stateValue] = true + + changelogState := ChangelogState(stateValue) switch changelogState { case TablesInPlace: { @@ -209,12 +214,8 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er return nil } -func (this *Migrator) onChangelogHeartbeatEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { - if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "heartbeat" { - return nil - } - value := dmlEvent.NewColumnValues.StringColumn(3) - heartbeatTime, err := time.Parse(time.RFC3339, value) +func (this *Migrator) onChangelogHeartbeat(heartbeatValue string) (err error) { + heartbeatTime, err := time.Parse(time.RFC3339, heartbeatValue) if err != nil { return log.Errore(err) } @@ -239,13 +240,25 @@ func (this *Migrator) Migrate() (err error) { return err } // So far so good, table is accessible and valid. - if this.migrationContext.MasterConnectionConfig, err = this.inspector.getMasterConnectionConfig(); err != nil { + // Let's get master connection config + if this.migrationContext.ApplierConnectionConfig, err = this.inspector.getMasterConnectionConfig(); err != nil { return err } - if this.migrationContext.IsRunningOnMaster() && !this.migrationContext.AllowedRunningOnMaster { + if this.migrationContext.TestOnReplica { + if this.migrationContext.InspectorIsAlsoApplier() { + return fmt.Errorf("Instructed to --test-on-replica, but the server we connect to doesn't seem to be a replica") + } + log.Infof("--test-on-replica given. Will not execute on master %+v but rather on replica %+v itself", + this.migrationContext.ApplierConnectionConfig.Key, this.migrationContext.InspectorConnectionConfig.Key, + ) + this.migrationContext.ApplierConnectionConfig = this.migrationContext.InspectorConnectionConfig.Duplicate() + } else if this.migrationContext.InspectorIsAlsoApplier() && !this.migrationContext.AllowedRunningOnMaster { return fmt.Errorf("It seems like this migration attempt to run directly on master. Preferably it would be executed on a replica (and this reduces load from the master). To proceed please provide --allow-on-master") } - log.Infof("Master found to be %+v", this.migrationContext.MasterConnectionConfig.Key) + + log.Infof("Master found to be %+v", this.migrationContext.ApplierConnectionConfig.Key) + + go this.initiateChangelogListener() if err := this.initiateStreaming(); err != nil { return err @@ -344,27 +357,54 @@ func (this *Migrator) printStatus() { fmt.Println(status) } +func (this *Migrator) initiateChangelogListener() { + ticker := time.Tick((heartbeatIntervalMilliseconds * time.Millisecond) / 2) + for range ticker { + go func() error { + changelogState, err := this.inspector.readChangelogState() + if err != nil { + return log.Errore(err) + } + for hint, value := range changelogState { + switch hint { + case "state": + { + this.onChangelogState(value) + } + case "heartbeat": + { + this.onChangelogHeartbeat(value) + } + } + } + return nil + }() + } +} + +// initiateStreaming begins treaming of binary log events and registers listeners for such events func (this *Migrator) initiateStreaming() error { this.eventsStreamer = NewEventsStreamer() if err := this.eventsStreamer.InitDBConnections(); err != nil { return err } + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really listening on binlog events") + return nil + } this.eventsStreamer.AddListener( - false, + true, this.migrationContext.DatabaseName, - this.migrationContext.GetChangelogTableName(), + this.migrationContext.OriginalTableName, func(dmlEvent *binlog.BinlogDMLEvent) error { - return this.onChangelogStateEvent(dmlEvent) - }, - ) - this.eventsStreamer.AddListener( - false, - this.migrationContext.DatabaseName, - this.migrationContext.GetChangelogTableName(), - func(dmlEvent *binlog.BinlogDMLEvent) error { - return this.onChangelogHeartbeatEvent(dmlEvent) + applyEventFunc := func() error { + return this.applier.ApplyDMLEventQuery(dmlEvent) + } + this.applyEventsQueue <- applyEventFunc + return nil }, ) + go func() { log.Debugf("Beginning streaming") this.eventsStreamer.StreamEvents(func() bool { return this.canStopStreaming() }) @@ -391,7 +431,7 @@ func (this *Migrator) initiateApplier() error { } this.applier.WriteChangelogState(string(TablesInPlace)) - this.applier.InitiateHeartbeat() + go this.applier.InitiateHeartbeat(heartbeatIntervalMilliseconds) return nil } @@ -400,6 +440,14 @@ func (this *Migrator) iterateChunks() error { this.rowCopyComplete <- true return log.Errore(err) } + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really copying data") + return terminateRowIteration(nil) + } + if this.migrationContext.MigrationRangeMinValues == nil { + log.Debugf("No rows found in table. Rowcopy will be implicitly empty") + return terminateRowIteration(nil) + } for { copyRowsFunc := func() error { hasFurtherRange, err := this.applier.CalculateNextIterationRangeEndValues() @@ -423,6 +471,10 @@ func (this *Migrator) iterateChunks() error { } func (this *Migrator) executeWriteFuncs() error { + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really doing writes") + return nil + } for { this.throttle(nil) // We give higher priority to event processing, then secondary priority to diff --git a/go/sql/builder.go b/go/sql/builder.go index 784b1c9..4b73a5a 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -83,6 +83,17 @@ func BuildEqualsPreparedComparison(columns []string) (result string, err error) return BuildEqualsComparison(columns, values) } +func BuildSetPreparedClause(columns []string) (result string, err error) { + if len(columns) == 0 { + return "", fmt.Errorf("Got 0 columns in BuildSetPreparedClause") + } + setTokens := []string{} + for _, column := range columns { + setTokens = append(setTokens, fmt.Sprintf("%s=?", EscapeName(column))) + } + return strings.Join(setTokens, ", "), nil +} + func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { if len(columns) == 0 { return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison") @@ -278,17 +289,23 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni return query, nil } -func BuildDMLDeleteQuery(databaseName, tableName string, originalTableColumns, uniqueKeyColumns *ColumnList, args []interface{}) (result string, uniqueKeyArgs []interface{}, err error) { - if len(args) != originalTableColumns.Len() { +func BuildDMLDeleteQuery(databaseName, tableName string, tableColumns, uniqueKeyColumns *ColumnList, args []interface{}) (result string, uniqueKeyArgs []interface{}, err error) { + if len(args) != tableColumns.Len() { return result, uniqueKeyArgs, fmt.Errorf("args count differs from table column count in BuildDMLDeleteQuery") } + if uniqueKeyColumns.Len() == 0 { + return result, uniqueKeyArgs, fmt.Errorf("No unique key columns found in BuildDMLDeleteQuery") + } for _, column := range uniqueKeyColumns.Names { - tableOrdinal := originalTableColumns.Ordinals[column] + tableOrdinal := tableColumns.Ordinals[column] uniqueKeyArgs = append(uniqueKeyArgs, args[tableOrdinal]) } databaseName = EscapeName(databaseName) tableName = EscapeName(tableName) equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names) + if err != nil { + return result, uniqueKeyArgs, err + } result = fmt.Sprintf(` delete /* gh-osc %s.%s */ from @@ -299,14 +316,14 @@ func BuildDMLDeleteQuery(databaseName, tableName string, originalTableColumns, u databaseName, tableName, equalsComparison, ) - return result, uniqueKeyArgs, err + return result, uniqueKeyArgs, nil } -func BuildDMLInsertQuery(databaseName, tableName string, originalTableColumns, sharedColumns *ColumnList, args []interface{}) (result string, sharedArgs []interface{}, err error) { - if len(args) != originalTableColumns.Len() { +func BuildDMLInsertQuery(databaseName, tableName string, tableColumns, sharedColumns *ColumnList, args []interface{}) (result string, sharedArgs []interface{}, err error) { + if len(args) != tableColumns.Len() { return result, args, fmt.Errorf("args count differs from table column count in BuildDMLInsertQuery") } - if !sharedColumns.IsSubsetOf(originalTableColumns) { + if !sharedColumns.IsSubsetOf(tableColumns) { return result, args, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLInsertQuery") } if sharedColumns.Len() == 0 { @@ -316,7 +333,7 @@ func BuildDMLInsertQuery(databaseName, tableName string, originalTableColumns, s tableName = EscapeName(tableName) for _, column := range sharedColumns.Names { - tableOrdinal := originalTableColumns.Ordinals[column] + tableOrdinal := tableColumns.Ordinals[column] sharedArgs = append(sharedArgs, args[tableOrdinal]) } @@ -337,5 +354,59 @@ func BuildDMLInsertQuery(databaseName, tableName string, originalTableColumns, s strings.Join(sharedColumnNames, ", "), strings.Join(preparedValues, ", "), ) - return result, sharedArgs, err + return result, sharedArgs, nil +} + +func BuildDMLUpdateQuery(databaseName, tableName string, tableColumns, sharedColumns, uniqueKeyColumns *ColumnList, valueArgs, whereArgs []interface{}) (result string, sharedArgs, uniqueKeyArgs []interface{}, err error) { + if len(valueArgs) != tableColumns.Len() { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("value args count differs from table column count in BuildDMLUpdateQuery") + } + if len(whereArgs) != tableColumns.Len() { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("where args count differs from table column count in BuildDMLUpdateQuery") + } + if !sharedColumns.IsSubsetOf(tableColumns) { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLUpdateQuery") + } + if !uniqueKeyColumns.IsSubsetOf(sharedColumns) { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("unique key columns is not a subset of shared columns in BuildDMLUpdateQuery") + } + if sharedColumns.Len() == 0 { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("No shared columns found in BuildDMLUpdateQuery") + } + if uniqueKeyColumns.Len() == 0 { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("No unique key columns found in BuildDMLUpdateQuery") + } + databaseName = EscapeName(databaseName) + tableName = EscapeName(tableName) + + for _, column := range sharedColumns.Names { + tableOrdinal := tableColumns.Ordinals[column] + sharedArgs = append(sharedArgs, valueArgs[tableOrdinal]) + } + + for _, column := range uniqueKeyColumns.Names { + tableOrdinal := tableColumns.Ordinals[column] + uniqueKeyArgs = append(uniqueKeyArgs, whereArgs[tableOrdinal]) + } + + sharedColumnNames := duplicateNames(sharedColumns.Names) + for i := range sharedColumnNames { + sharedColumnNames[i] = EscapeName(sharedColumnNames[i]) + } + setClause, err := BuildSetPreparedClause(sharedColumnNames) + + equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names) + result = fmt.Sprintf(` + update /* gh-osc %s.%s */ + %s.%s + set + %s + where + %s + `, databaseName, tableName, + databaseName, tableName, + setClause, + equalsComparison, + ) + return result, sharedArgs, uniqueKeyArgs, nil } diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go index 5a57a3d..3e0a8c6 100644 --- a/go/sql/builder_test.go +++ b/go/sql/builder_test.go @@ -77,6 +77,26 @@ func TestBuildEqualsPreparedComparison(t *testing.T) { } } +func TestBuildSetPreparedClause(t *testing.T) { + { + columns := []string{"c1"} + clause, err := BuildSetPreparedClause(columns) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(clause, "`c1`=?") + } + { + columns := []string{"c1", "c2"} + clause, err := BuildSetPreparedClause(columns) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(clause, "`c1`=?, `c2`=?") + } + { + columns := []string{} + _, err := BuildSetPreparedClause(columns) + test.S(t).ExpectNotNil(err) + } +} + func TestBuildRangeComparison(t *testing.T) { { columns := []string{"c1"} @@ -375,3 +395,87 @@ func TestBuildDMLInsertQuery(t *testing.T) { test.S(t).ExpectNotNil(err) } } + +func TestBuildDMLUpdateQuery(t *testing.T) { + databaseName := "mydb" + tableName := "tbl" + tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"}) + valueArgs := []interface{}{3, "testname", "newval", 17, 23} + whereArgs := []interface{}{3, "testname", "findme", 17, 56} + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"position"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((position = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"position", "name"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((position = ?) and (name = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17, "testname"})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"age"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((age = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{56})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"age", "position", "id", "name"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((age = ?) and (position = ?) and (id = ?) and (name = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{56, 17, 3, "testname"})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"age", "surprise"}) + _, _, _, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNotNil(err) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{}) + _, _, _, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNotNil(err) + } +}