From a4ee80df13801d7bf725701aef0b48072fbdca86 Mon Sep 17 00:00:00 2001 From: Shlomi Noach Date: Thu, 14 Apr 2016 13:37:56 +0200 Subject: [PATCH] - Building and applying queries from binlog event data! - `INSERT`, `DELETE`, `UPDATE` statements - support for `--noop` - initial support for `--test-on-replica`. Verifying against `--allow-on-master` - Changelog events no longer read from binlog stream, because reading it may be throttled, and we have to be able to keep reading the heartbeat and state events. They are now being read directly from table, mapping already-seen-events to avoid confusion Changlelog listener pools table in 2*frequency of heartbeat injection --- go/base/context.go | 23 +++++---- go/cmd/gh-osc/main.go | 9 +++- go/logic/applier.go | 115 +++++++++++++++++++++++------------------ go/logic/inspect.go | 69 +++++++++++++++++-------- go/logic/migrator.go | 108 ++++++++++++++++++++++++++++---------- go/sql/builder.go | 89 +++++++++++++++++++++++++++---- go/sql/builder_test.go | 104 +++++++++++++++++++++++++++++++++++++ 7 files changed, 397 insertions(+), 120 deletions(-) diff --git a/go/base/context.go b/go/base/context.go index 9611869..5f1876b 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -33,9 +33,13 @@ const ( // MigrationContext has the general, global state of migration. It is used by // all components throughout the migration process. type MigrationContext struct { - DatabaseName string - OriginalTableName string - AlterStatement string + DatabaseName string + OriginalTableName string + AlterStatement string + + Noop bool + TestOnReplica bool + TableEngine string CountTableRows bool RowsEstimate int64 @@ -45,7 +49,7 @@ type MigrationContext struct { OriginalBinlogRowImage string AllowedRunningOnMaster bool InspectorConnectionConfig *mysql.ConnectionConfig - MasterConnectionConfig *mysql.ConnectionConfig + ApplierConnectionConfig *mysql.ConnectionConfig StartTime time.Time RowCopyStartTime time.Time CurrentLag int64 @@ -83,7 +87,7 @@ func newMigrationContext() *MigrationContext { return &MigrationContext{ ChunkSize: 1000, InspectorConnectionConfig: mysql.NewConnectionConfig(), - MasterConnectionConfig: mysql.NewConnectionConfig(), + ApplierConnectionConfig: mysql.NewConnectionConfig(), MaxLagMillisecondsThrottleThreshold: 1000, MaxLoad: make(map[string]int64), throttleMutex: &sync.Mutex{}, @@ -115,10 +119,11 @@ func (this *MigrationContext) RequiresBinlogFormatChange() bool { return this.OriginalBinlogFormat != "ROW" } -// IsRunningOnMaster is `true` when the app connects directly to the master (typically -// it should be executed on replica and infer the master) -func (this *MigrationContext) IsRunningOnMaster() bool { - return this.InspectorConnectionConfig.Equals(this.MasterConnectionConfig) +// InspectorIsAlsoApplier is `true` when the both inspector and applier are the +// same database instance. This would be true when running directly on master or when +// testing on replica. +func (this *MigrationContext) InspectorIsAlsoApplier() bool { + return this.InspectorConnectionConfig.Equals(this.ApplierConnectionConfig) } // HasMigrationRange tells us whether there's a range to iterate for copying rows. diff --git a/go/cmd/gh-osc/main.go b/go/cmd/gh-osc/main.go index 69709e6..09aa337 100644 --- a/go/cmd/gh-osc/main.go +++ b/go/cmd/gh-osc/main.go @@ -30,6 +30,9 @@ func main() { flag.BoolVar(&migrationContext.CountTableRows, "exact-rowcount", false, "actually count table rows as opposed to estimate them (results in more accurate progress estimation)") flag.BoolVar(&migrationContext.AllowedRunningOnMaster, "allow-on-master", false, "allow this migration to run directly on master. Preferably it would run on a replica") + executeFlag := flag.Bool("execute", false, "actually execute the alter & migrate the table. Default is noop: do some tests and exit") + flag.BoolVar(&migrationContext.TestOnReplica, "test-on-replica", false, "Have the migration run on a replica, not on the master. At the end of migration tables are not swapped; gh-osc issues `STOP SLAVE` and you can compare the two tables for building trust") + flag.Int64Var(&migrationContext.ChunkSize, "chunk-size", 1000, "amount of rows to handle in each iteration (allowed range: 100-100,000)") if migrationContext.ChunkSize < 100 { migrationContext.ChunkSize = 100 @@ -37,7 +40,7 @@ func main() { if migrationContext.ChunkSize > 100000 { migrationContext.ChunkSize = 100000 } - flag.Int64Var(&migrationContext.MaxLagMillisecondsThrottleThreshold, "max-lag-millis", 1000, "replication lag at which to throttle operation") + flag.Int64Var(&migrationContext.MaxLagMillisecondsThrottleThreshold, "max-lag-millis", 1500, "replication lag at which to throttle operation") flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists; hint: use a file that is specific to the table being altered") flag.StringVar(&migrationContext.ThrottleAdditionalFlagFile, "throttle-additional-flag-file", "/tmp/gh-osc.throttle", "operation pauses when this file exists; hint: keep default, use for throttling multiple gh-osc operations") maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'") @@ -78,6 +81,10 @@ func main() { if migrationContext.AlterStatement == "" { log.Fatalf("--alter must be provided and statement must not be empty") } + migrationContext.Noop = !(*executeFlag) + if migrationContext.AllowedRunningOnMaster && migrationContext.TestOnReplica { + log.Fatalf("--allow-on-master and --test-on-replica are mutually exclusive") + } if err := migrationContext.ReadMaxLoad(*maxLoad); err != nil { log.Fatale(err) } diff --git a/go/logic/applier.go b/go/logic/applier.go index 60c83d7..5e7c83f 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -20,10 +20,6 @@ import ( "github.com/outbrain/golib/sqlutils" ) -const ( - heartbeatIntervalSeconds = 1 -) - // Applier reads data from the read-MySQL-server (typically a replica, but can be the master) // It is used for gaining initial status and structure, and later also follow up on progress and changelog type Applier struct { @@ -34,7 +30,7 @@ type Applier struct { func NewApplier() *Applier { return &Applier{ - connectionConfig: base.GetMigrationContext().MasterConnectionConfig, + connectionConfig: base.GetMigrationContext().ApplierConnectionConfig, migrationContext: base.GetMigrationContext(), } } @@ -157,11 +153,20 @@ func (this *Applier) DropGhostTable() error { // WriteChangelog writes a value to the changelog table. // It returns the hint as given, for convenience func (this *Applier) WriteChangelog(hint, value string) (string, error) { + explicitId := 0 + switch hint { + case "heartbeat": + explicitId = 1 + case "state": + explicitId = 2 + case "throttle": + explicitId = 3 + } query := fmt.Sprintf(` insert /* gh-osc */ into %s.%s (id, hint, value) values - (NULL, ?, ?) + (NULLIF(?, 0), ?, ?) on duplicate key update last_update=NOW(), value=VALUES(value) @@ -169,7 +174,7 @@ func (this *Applier) WriteChangelog(hint, value string) (string, error) { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetChangelogTableName()), ) - _, err := sqlutils.Exec(this.db, query, hint, value) + _, err := sqlutils.Exec(this.db, query, explicitId, hint, value) return hint, err } @@ -184,44 +189,30 @@ func (this *Applier) WriteChangelogState(value string) (string, error) { // InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table. // This is done asynchronously -func (this *Applier) InitiateHeartbeat() { - go func() { - numSuccessiveFailures := 0 - query := fmt.Sprintf(` - insert /* gh-osc */ into %s.%s - (id, hint, value) - values - (1, 'heartbeat', ?) - on duplicate key update - last_update=NOW(), - value=VALUES(value) - `, - sql.EscapeName(this.migrationContext.DatabaseName), - sql.EscapeName(this.migrationContext.GetChangelogTableName()), - ) - injectHeartbeat := func() error { - if _, err := sqlutils.ExecNoPrepare(this.db, query, time.Now().Format(time.RFC3339)); err != nil { - numSuccessiveFailures++ - if numSuccessiveFailures > this.migrationContext.MaxRetries() { - return log.Errore(err) - } - } else { - numSuccessiveFailures = 0 +func (this *Applier) InitiateHeartbeat(heartbeatIntervalMilliseconds int64) { + numSuccessiveFailures := 0 + injectHeartbeat := func() error { + if _, err := this.WriteChangelog("heartbeat", time.Now().Format(time.RFC3339)); err != nil { + numSuccessiveFailures++ + if numSuccessiveFailures > this.migrationContext.MaxRetries() { + return log.Errore(err) } - return nil + } else { + numSuccessiveFailures = 0 } - injectHeartbeat() + return nil + } + injectHeartbeat() - heartbeatTick := time.Tick(time.Duration(heartbeatIntervalSeconds) * time.Second) - for range heartbeatTick { - // Generally speaking, we would issue a goroutine, but I'd actually rather - // have this blocked rather than spam the master in the event something - // goes wrong - if err := injectHeartbeat(); err != nil { - return - } + heartbeatTick := time.Tick(time.Duration(heartbeatIntervalMilliseconds) * time.Millisecond) + for range heartbeatTick { + // Generally speaking, we would issue a goroutine, but I'd actually rather + // have this blocked rather than spam the master in the event something + // goes wrong + if err := injectHeartbeat(); err != nil { + return } - }() + } } // ReadMigrationMinValues @@ -352,17 +343,10 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo hasFurtherRange = true } if !hasFurtherRange { - log.Debugf("Iteration complete: cannot find iteration end") + log.Debugf("Iteration complete: no further range to iterate") return hasFurtherRange, nil } this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeMaxValues - // log.Debugf( - // "column values: [%s]..[%s]; iteration: %d; chunk-size: %d", - // this.migrationContext.MigrationIterationRangeMinValues, - // this.migrationContext.MigrationIterationRangeMaxValues, - // this.migrationContext.GetIteration(), - // this.migrationContext.ChunkSize, - // ) return hasFurtherRange, nil } @@ -402,6 +386,11 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected // LockTables func (this *Applier) LockTables() error { + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really locking tables") + return nil + } + query := fmt.Sprintf(`lock /* gh-osc */ tables %s.%s write, %s.%s write, %s.%s write`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), @@ -437,11 +426,35 @@ func (this *Applier) ShowStatusVariable(variableName string) (result int64, err return result, nil } -func (this *Applier) BuildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) (result string, err error) { +func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) (query string, args []interface{}, err error) { switch dmlEvent.DML { case binlog.DeleteDML: { + query, uniqueKeyArgs, err := sql.BuildDMLDeleteQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, &this.migrationContext.UniqueKey.Columns, dmlEvent.WhereColumnValues.AbstractValues()) + return query, uniqueKeyArgs, err + } + case binlog.InsertDML: + { + query, sharedArgs, err := sql.BuildDMLInsertQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, dmlEvent.NewColumnValues.AbstractValues()) + return query, sharedArgs, err + } + case binlog.UpdateDML: + { + query, sharedArgs, uniqueKeyArgs, err := sql.BuildDMLUpdateQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, &this.migrationContext.UniqueKey.Columns, dmlEvent.NewColumnValues.AbstractValues(), dmlEvent.WhereColumnValues.AbstractValues()) + args = append(args, sharedArgs...) + args = append(args, uniqueKeyArgs...) + return query, args, err } } - return result, err + return "", args, fmt.Errorf("Unknown dml event type: %+v", dmlEvent.DML) +} + +func (this *Applier) ApplyDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) error { + query, args, err := this.buildDMLEventQuery(dmlEvent) + if err != nil { + return err + } + log.Errorf(query) + _, err = sqlutils.Exec(this.db, query, args...) + return err } diff --git a/go/logic/inspect.go b/go/logic/inspect.go index ddb53f3..902ccbe 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -185,6 +185,12 @@ func (this *Inspector) validateGrants() error { func (this *Inspector) restartReplication() error { log.Infof("Restarting replication on %s:%d to make sure binlog settings apply to replication thread", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + masterKey, _ := getMasterKeyFromSlaveStatus(this.connectionConfig) + if masterKey == nil { + // This is not a replica + return nil + } + var stopError, startError error _, stopError = sqlutils.ExecNoPrepare(this.db, `stop slave`) _, startError = sqlutils.ExecNoPrepare(this.db, `start slave`) @@ -454,43 +460,62 @@ func (this *Inspector) getSharedColumns(originalColumns, ghostColumns *sql.Colum return sql.NewColumnList(sharedColumnNames) } -func (this *Inspector) getMasterConnectionConfig() (masterConfig *mysql.ConnectionConfig, err error) { - visitedKeys := mysql.NewInstanceKeyMap() - return getMasterConnectionConfigSafe(this.connectionConfig, this.migrationContext.DatabaseName, visitedKeys) +func (this *Inspector) readChangelogState() (map[string]string, error) { + query := fmt.Sprintf(` + select hint, value from %s.%s where id <= 255 + `, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetChangelogTableName()), + ) + result := make(map[string]string) + err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error { + result[m.GetString("hint")] = m.GetString("value") + return nil + }) + return result, err } -func getMasterConnectionConfigSafe(connectionConfig *mysql.ConnectionConfig, databaseName string, visitedKeys *mysql.InstanceKeyMap) (masterConfig *mysql.ConnectionConfig, err error) { - log.Debugf("Looking for master on %+v", connectionConfig.Key) +func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.ConnectionConfig, err error) { + visitedKeys := mysql.NewInstanceKeyMap() + return getMasterConnectionConfigSafe(this.connectionConfig, visitedKeys) +} - currentUri := connectionConfig.GetDBUri(databaseName) +func getMasterKeyFromSlaveStatus(connectionConfig *mysql.ConnectionConfig) (masterKey *mysql.InstanceKey, err error) { + currentUri := connectionConfig.GetDBUri("information_schema") db, _, err := sqlutils.GetDB(currentUri) if err != nil { return nil, err } - - hasMaster := false - masterConfig = connectionConfig.Duplicate() err = sqlutils.QueryRowsMap(db, `show slave status`, func(rowMap sqlutils.RowMap) error { - masterKey := mysql.InstanceKey{ + masterKey = &mysql.InstanceKey{ Hostname: rowMap.GetString("Master_Host"), Port: rowMap.GetInt("Master_Port"), } - if masterKey.IsValid() { - masterConfig.Key = masterKey - hasMaster = true - } return nil }) + return masterKey, err +} + +func getMasterConnectionConfigSafe(connectionConfig *mysql.ConnectionConfig, visitedKeys *mysql.InstanceKeyMap) (masterConfig *mysql.ConnectionConfig, err error) { + log.Debugf("Looking for master on %+v", connectionConfig.Key) + + masterKey, err := getMasterKeyFromSlaveStatus(connectionConfig) if err != nil { return nil, err } - if hasMaster { - log.Debugf("Master of %+v is %+v", connectionConfig.Key, masterConfig.Key) - if visitedKeys.HasKey(masterConfig.Key) { - return nil, fmt.Errorf("There seems to be a master-master setup at %+v. This is unsupported. Bailing out", masterConfig.Key) - } - visitedKeys.AddKey(masterConfig.Key) - return getMasterConnectionConfigSafe(masterConfig, databaseName, visitedKeys) + if masterKey == nil { + return connectionConfig, nil } - return masterConfig, nil + if !masterKey.IsValid() { + return connectionConfig, nil + } + masterConfig = connectionConfig.Duplicate() + masterConfig.Key = *masterKey + + log.Debugf("Master of %+v is %+v", connectionConfig.Key, masterConfig.Key) + if visitedKeys.HasKey(masterConfig.Key) { + return nil, fmt.Errorf("There seems to be a master-master setup at %+v. This is unsupported. Bailing out", masterConfig.Key) + } + visitedKeys.AddKey(masterConfig.Key) + return getMasterConnectionConfigSafe(masterConfig, visitedKeys) } diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 3dee61d..62fc6e3 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -30,7 +30,8 @@ const ( type tableWriteFunc func() error const ( - applyEventsQueueBuffer = 100 + applyEventsQueueBuffer = 100 + heartbeatIntervalMilliseconds = 1000 ) var ( @@ -52,6 +53,8 @@ type Migrator struct { // excessive work happens at the end of the iteration as new copy-jobs arrive befroe realizing the copy is complete copyRowsQueue chan tableWriteFunc applyEventsQueue chan tableWriteFunc + + handledChangelogStates map[string]bool } func NewMigrator() *Migrator { @@ -61,8 +64,9 @@ func NewMigrator() *Migrator { rowCopyComplete: make(chan bool), allEventsUpToLockProcessed: make(chan bool), - copyRowsQueue: make(chan tableWriteFunc), - applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), + copyRowsQueue: make(chan tableWriteFunc), + applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), + handledChangelogStates: make(map[string]bool), } return migrator } @@ -185,12 +189,13 @@ func (this *Migrator) canStopStreaming() bool { return false } -func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { - // Hey, I created the changlog table, I know the type of columns it has! - if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "state" { +func (this *Migrator) onChangelogState(stateValue string) (err error) { + if this.handledChangelogStates[stateValue] { return } - changelogState := ChangelogState(dmlEvent.NewColumnValues.StringColumn(3)) + this.handledChangelogStates[stateValue] = true + + changelogState := ChangelogState(stateValue) switch changelogState { case TablesInPlace: { @@ -209,12 +214,8 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er return nil } -func (this *Migrator) onChangelogHeartbeatEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { - if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "heartbeat" { - return nil - } - value := dmlEvent.NewColumnValues.StringColumn(3) - heartbeatTime, err := time.Parse(time.RFC3339, value) +func (this *Migrator) onChangelogHeartbeat(heartbeatValue string) (err error) { + heartbeatTime, err := time.Parse(time.RFC3339, heartbeatValue) if err != nil { return log.Errore(err) } @@ -239,13 +240,25 @@ func (this *Migrator) Migrate() (err error) { return err } // So far so good, table is accessible and valid. - if this.migrationContext.MasterConnectionConfig, err = this.inspector.getMasterConnectionConfig(); err != nil { + // Let's get master connection config + if this.migrationContext.ApplierConnectionConfig, err = this.inspector.getMasterConnectionConfig(); err != nil { return err } - if this.migrationContext.IsRunningOnMaster() && !this.migrationContext.AllowedRunningOnMaster { + if this.migrationContext.TestOnReplica { + if this.migrationContext.InspectorIsAlsoApplier() { + return fmt.Errorf("Instructed to --test-on-replica, but the server we connect to doesn't seem to be a replica") + } + log.Infof("--test-on-replica given. Will not execute on master %+v but rather on replica %+v itself", + this.migrationContext.ApplierConnectionConfig.Key, this.migrationContext.InspectorConnectionConfig.Key, + ) + this.migrationContext.ApplierConnectionConfig = this.migrationContext.InspectorConnectionConfig.Duplicate() + } else if this.migrationContext.InspectorIsAlsoApplier() && !this.migrationContext.AllowedRunningOnMaster { return fmt.Errorf("It seems like this migration attempt to run directly on master. Preferably it would be executed on a replica (and this reduces load from the master). To proceed please provide --allow-on-master") } - log.Infof("Master found to be %+v", this.migrationContext.MasterConnectionConfig.Key) + + log.Infof("Master found to be %+v", this.migrationContext.ApplierConnectionConfig.Key) + + go this.initiateChangelogListener() if err := this.initiateStreaming(); err != nil { return err @@ -344,27 +357,54 @@ func (this *Migrator) printStatus() { fmt.Println(status) } +func (this *Migrator) initiateChangelogListener() { + ticker := time.Tick((heartbeatIntervalMilliseconds * time.Millisecond) / 2) + for range ticker { + go func() error { + changelogState, err := this.inspector.readChangelogState() + if err != nil { + return log.Errore(err) + } + for hint, value := range changelogState { + switch hint { + case "state": + { + this.onChangelogState(value) + } + case "heartbeat": + { + this.onChangelogHeartbeat(value) + } + } + } + return nil + }() + } +} + +// initiateStreaming begins treaming of binary log events and registers listeners for such events func (this *Migrator) initiateStreaming() error { this.eventsStreamer = NewEventsStreamer() if err := this.eventsStreamer.InitDBConnections(); err != nil { return err } + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really listening on binlog events") + return nil + } this.eventsStreamer.AddListener( - false, + true, this.migrationContext.DatabaseName, - this.migrationContext.GetChangelogTableName(), + this.migrationContext.OriginalTableName, func(dmlEvent *binlog.BinlogDMLEvent) error { - return this.onChangelogStateEvent(dmlEvent) - }, - ) - this.eventsStreamer.AddListener( - false, - this.migrationContext.DatabaseName, - this.migrationContext.GetChangelogTableName(), - func(dmlEvent *binlog.BinlogDMLEvent) error { - return this.onChangelogHeartbeatEvent(dmlEvent) + applyEventFunc := func() error { + return this.applier.ApplyDMLEventQuery(dmlEvent) + } + this.applyEventsQueue <- applyEventFunc + return nil }, ) + go func() { log.Debugf("Beginning streaming") this.eventsStreamer.StreamEvents(func() bool { return this.canStopStreaming() }) @@ -391,7 +431,7 @@ func (this *Migrator) initiateApplier() error { } this.applier.WriteChangelogState(string(TablesInPlace)) - this.applier.InitiateHeartbeat() + go this.applier.InitiateHeartbeat(heartbeatIntervalMilliseconds) return nil } @@ -400,6 +440,14 @@ func (this *Migrator) iterateChunks() error { this.rowCopyComplete <- true return log.Errore(err) } + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really copying data") + return terminateRowIteration(nil) + } + if this.migrationContext.MigrationRangeMinValues == nil { + log.Debugf("No rows found in table. Rowcopy will be implicitly empty") + return terminateRowIteration(nil) + } for { copyRowsFunc := func() error { hasFurtherRange, err := this.applier.CalculateNextIterationRangeEndValues() @@ -423,6 +471,10 @@ func (this *Migrator) iterateChunks() error { } func (this *Migrator) executeWriteFuncs() error { + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really doing writes") + return nil + } for { this.throttle(nil) // We give higher priority to event processing, then secondary priority to diff --git a/go/sql/builder.go b/go/sql/builder.go index 784b1c9..4b73a5a 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -83,6 +83,17 @@ func BuildEqualsPreparedComparison(columns []string) (result string, err error) return BuildEqualsComparison(columns, values) } +func BuildSetPreparedClause(columns []string) (result string, err error) { + if len(columns) == 0 { + return "", fmt.Errorf("Got 0 columns in BuildSetPreparedClause") + } + setTokens := []string{} + for _, column := range columns { + setTokens = append(setTokens, fmt.Sprintf("%s=?", EscapeName(column))) + } + return strings.Join(setTokens, ", "), nil +} + func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { if len(columns) == 0 { return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison") @@ -278,17 +289,23 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni return query, nil } -func BuildDMLDeleteQuery(databaseName, tableName string, originalTableColumns, uniqueKeyColumns *ColumnList, args []interface{}) (result string, uniqueKeyArgs []interface{}, err error) { - if len(args) != originalTableColumns.Len() { +func BuildDMLDeleteQuery(databaseName, tableName string, tableColumns, uniqueKeyColumns *ColumnList, args []interface{}) (result string, uniqueKeyArgs []interface{}, err error) { + if len(args) != tableColumns.Len() { return result, uniqueKeyArgs, fmt.Errorf("args count differs from table column count in BuildDMLDeleteQuery") } + if uniqueKeyColumns.Len() == 0 { + return result, uniqueKeyArgs, fmt.Errorf("No unique key columns found in BuildDMLDeleteQuery") + } for _, column := range uniqueKeyColumns.Names { - tableOrdinal := originalTableColumns.Ordinals[column] + tableOrdinal := tableColumns.Ordinals[column] uniqueKeyArgs = append(uniqueKeyArgs, args[tableOrdinal]) } databaseName = EscapeName(databaseName) tableName = EscapeName(tableName) equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names) + if err != nil { + return result, uniqueKeyArgs, err + } result = fmt.Sprintf(` delete /* gh-osc %s.%s */ from @@ -299,14 +316,14 @@ func BuildDMLDeleteQuery(databaseName, tableName string, originalTableColumns, u databaseName, tableName, equalsComparison, ) - return result, uniqueKeyArgs, err + return result, uniqueKeyArgs, nil } -func BuildDMLInsertQuery(databaseName, tableName string, originalTableColumns, sharedColumns *ColumnList, args []interface{}) (result string, sharedArgs []interface{}, err error) { - if len(args) != originalTableColumns.Len() { +func BuildDMLInsertQuery(databaseName, tableName string, tableColumns, sharedColumns *ColumnList, args []interface{}) (result string, sharedArgs []interface{}, err error) { + if len(args) != tableColumns.Len() { return result, args, fmt.Errorf("args count differs from table column count in BuildDMLInsertQuery") } - if !sharedColumns.IsSubsetOf(originalTableColumns) { + if !sharedColumns.IsSubsetOf(tableColumns) { return result, args, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLInsertQuery") } if sharedColumns.Len() == 0 { @@ -316,7 +333,7 @@ func BuildDMLInsertQuery(databaseName, tableName string, originalTableColumns, s tableName = EscapeName(tableName) for _, column := range sharedColumns.Names { - tableOrdinal := originalTableColumns.Ordinals[column] + tableOrdinal := tableColumns.Ordinals[column] sharedArgs = append(sharedArgs, args[tableOrdinal]) } @@ -337,5 +354,59 @@ func BuildDMLInsertQuery(databaseName, tableName string, originalTableColumns, s strings.Join(sharedColumnNames, ", "), strings.Join(preparedValues, ", "), ) - return result, sharedArgs, err + return result, sharedArgs, nil +} + +func BuildDMLUpdateQuery(databaseName, tableName string, tableColumns, sharedColumns, uniqueKeyColumns *ColumnList, valueArgs, whereArgs []interface{}) (result string, sharedArgs, uniqueKeyArgs []interface{}, err error) { + if len(valueArgs) != tableColumns.Len() { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("value args count differs from table column count in BuildDMLUpdateQuery") + } + if len(whereArgs) != tableColumns.Len() { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("where args count differs from table column count in BuildDMLUpdateQuery") + } + if !sharedColumns.IsSubsetOf(tableColumns) { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLUpdateQuery") + } + if !uniqueKeyColumns.IsSubsetOf(sharedColumns) { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("unique key columns is not a subset of shared columns in BuildDMLUpdateQuery") + } + if sharedColumns.Len() == 0 { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("No shared columns found in BuildDMLUpdateQuery") + } + if uniqueKeyColumns.Len() == 0 { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("No unique key columns found in BuildDMLUpdateQuery") + } + databaseName = EscapeName(databaseName) + tableName = EscapeName(tableName) + + for _, column := range sharedColumns.Names { + tableOrdinal := tableColumns.Ordinals[column] + sharedArgs = append(sharedArgs, valueArgs[tableOrdinal]) + } + + for _, column := range uniqueKeyColumns.Names { + tableOrdinal := tableColumns.Ordinals[column] + uniqueKeyArgs = append(uniqueKeyArgs, whereArgs[tableOrdinal]) + } + + sharedColumnNames := duplicateNames(sharedColumns.Names) + for i := range sharedColumnNames { + sharedColumnNames[i] = EscapeName(sharedColumnNames[i]) + } + setClause, err := BuildSetPreparedClause(sharedColumnNames) + + equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names) + result = fmt.Sprintf(` + update /* gh-osc %s.%s */ + %s.%s + set + %s + where + %s + `, databaseName, tableName, + databaseName, tableName, + setClause, + equalsComparison, + ) + return result, sharedArgs, uniqueKeyArgs, nil } diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go index 5a57a3d..3e0a8c6 100644 --- a/go/sql/builder_test.go +++ b/go/sql/builder_test.go @@ -77,6 +77,26 @@ func TestBuildEqualsPreparedComparison(t *testing.T) { } } +func TestBuildSetPreparedClause(t *testing.T) { + { + columns := []string{"c1"} + clause, err := BuildSetPreparedClause(columns) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(clause, "`c1`=?") + } + { + columns := []string{"c1", "c2"} + clause, err := BuildSetPreparedClause(columns) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(clause, "`c1`=?, `c2`=?") + } + { + columns := []string{} + _, err := BuildSetPreparedClause(columns) + test.S(t).ExpectNotNil(err) + } +} + func TestBuildRangeComparison(t *testing.T) { { columns := []string{"c1"} @@ -375,3 +395,87 @@ func TestBuildDMLInsertQuery(t *testing.T) { test.S(t).ExpectNotNil(err) } } + +func TestBuildDMLUpdateQuery(t *testing.T) { + databaseName := "mydb" + tableName := "tbl" + tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"}) + valueArgs := []interface{}{3, "testname", "newval", 17, 23} + whereArgs := []interface{}{3, "testname", "findme", 17, 56} + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"position"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((position = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"position", "name"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((position = ?) and (name = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17, "testname"})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"age"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((age = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{56})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"age", "position", "id", "name"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((age = ?) and (position = ?) and (id = ?) and (name = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{56, 17, 3, "testname"})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"age", "surprise"}) + _, _, _, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNotNil(err) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{}) + _, _, _, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNotNil(err) + } +}