diff --git a/go/base/context.go b/go/base/context.go index 6e28cb0..5f1876b 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -7,7 +7,9 @@ package base import ( "fmt" + "strconv" "strings" + "sync" "sync/atomic" "time" @@ -31,9 +33,13 @@ const ( // MigrationContext has the general, global state of migration. It is used by // all components throughout the migration process. type MigrationContext struct { - DatabaseName string - OriginalTableName string - AlterStatement string + DatabaseName string + OriginalTableName string + AlterStatement string + + Noop bool + TestOnReplica bool + TableEngine string CountTableRows bool RowsEstimate int64 @@ -43,21 +49,31 @@ type MigrationContext struct { OriginalBinlogRowImage string AllowedRunningOnMaster bool InspectorConnectionConfig *mysql.ConnectionConfig - MasterConnectionConfig *mysql.ConnectionConfig - MigrationRangeMinValues *sql.ColumnValues - MigrationRangeMaxValues *sql.ColumnValues - Iteration int64 - MigrationIterationRangeMinValues *sql.ColumnValues - MigrationIterationRangeMaxValues *sql.ColumnValues - UniqueKey *sql.UniqueKey + ApplierConnectionConfig *mysql.ConnectionConfig StartTime time.Time RowCopyStartTime time.Time CurrentLag int64 MaxLagMillisecondsThrottleThreshold int64 ThrottleFlagFile string + ThrottleAdditionalFlagFile string TotalRowsCopied int64 + isThrottled bool + throttleReason string + throttleMutex *sync.Mutex + MaxLoad map[string]int64 + + OriginalTableColumns *sql.ColumnList + OriginalTableUniqueKeys [](*sql.UniqueKey) + GhostTableColumns *sql.ColumnList + GhostTableUniqueKeys [](*sql.UniqueKey) + UniqueKey *sql.UniqueKey + SharedColumns *sql.ColumnList + MigrationRangeMinValues *sql.ColumnValues + MigrationRangeMaxValues *sql.ColumnValues + Iteration int64 + MigrationIterationRangeMinValues *sql.ColumnValues + MigrationIterationRangeMaxValues *sql.ColumnValues - IsThrottled func() bool CanStopStreaming func() bool } @@ -71,8 +87,10 @@ func newMigrationContext() *MigrationContext { return &MigrationContext{ ChunkSize: 1000, InspectorConnectionConfig: mysql.NewConnectionConfig(), - MasterConnectionConfig: mysql.NewConnectionConfig(), + ApplierConnectionConfig: mysql.NewConnectionConfig(), MaxLagMillisecondsThrottleThreshold: 1000, + MaxLoad: make(map[string]int64), + throttleMutex: &sync.Mutex{}, } } @@ -86,6 +104,11 @@ func (this *MigrationContext) GetGhostTableName() string { return fmt.Sprintf("_%s_New", this.OriginalTableName) } +// GetOldTableName generates the name of the "old" table, into which the original table is renamed. +func (this *MigrationContext) GetOldTableName() string { + return fmt.Sprintf("_%s_Old", this.OriginalTableName) +} + // GetChangelogTableName generates the name of changelog table, based on original table name func (this *MigrationContext) GetChangelogTableName() string { return fmt.Sprintf("_%s_OSC", this.OriginalTableName) @@ -96,10 +119,11 @@ func (this *MigrationContext) RequiresBinlogFormatChange() bool { return this.OriginalBinlogFormat != "ROW" } -// IsRunningOnMaster is `true` when the app connects directly to the master (typically -// it should be executed on replica and infer the master) -func (this *MigrationContext) IsRunningOnMaster() bool { - return this.InspectorConnectionConfig.Equals(this.MasterConnectionConfig) +// InspectorIsAlsoApplier is `true` when the both inspector and applier are the +// same database instance. This would be true when running directly on master or when +// testing on replica. +func (this *MigrationContext) InspectorIsAlsoApplier() bool { + return this.InspectorConnectionConfig.Equals(this.ApplierConnectionConfig) } // HasMigrationRange tells us whether there's a range to iterate for copying rows. @@ -141,3 +165,42 @@ func (this *MigrationContext) ElapsedRowCopyTime() time.Duration { func (this *MigrationContext) GetTotalRowsCopied() int64 { return atomic.LoadInt64(&this.TotalRowsCopied) } + +func (this *MigrationContext) GetIteration() int64 { + return atomic.LoadInt64(&this.Iteration) +} + +func (this *MigrationContext) SetThrottled(throttle bool, reason string) { + this.throttleMutex.Lock() + defer func() { this.throttleMutex.Unlock() }() + this.isThrottled = throttle + this.throttleReason = reason +} + +func (this *MigrationContext) IsThrottled() (bool, string) { + this.throttleMutex.Lock() + defer func() { this.throttleMutex.Unlock() }() + return this.isThrottled, this.throttleReason +} + +func (this *MigrationContext) ReadMaxLoad(maxLoadList string) error { + if maxLoadList == "" { + return nil + } + maxLoadConditions := strings.Split(maxLoadList, ",") + for _, maxLoadCondition := range maxLoadConditions { + maxLoadTokens := strings.Split(maxLoadCondition, "=") + if len(maxLoadTokens) != 2 { + return fmt.Errorf("Error parsing max-load condition: %s", maxLoadCondition) + } + if maxLoadTokens[0] == "" { + return fmt.Errorf("Error parsing status variable in max-load condition: %s", maxLoadCondition) + } + if n, err := strconv.ParseInt(maxLoadTokens[1], 10, 0); err != nil { + return fmt.Errorf("Error parsing numeric value in max-load condition: %s", maxLoadCondition) + } else { + this.MaxLoad[maxLoadTokens[0]] = n + } + } + return nil +} diff --git a/go/cmd/gh-osc/main.go b/go/cmd/gh-osc/main.go index de7f976..09aa337 100644 --- a/go/cmd/gh-osc/main.go +++ b/go/cmd/gh-osc/main.go @@ -19,11 +19,6 @@ import ( func main() { migrationContext := base.GetMigrationContext() - // mysqlBasedir := flag.String("mysql-basedir", "", "the --basedir config for MySQL (auto-detected if not given)") - // mysqlDatadir := flag.String("mysql-datadir", "", "the --datadir config for MySQL (auto-detected if not given)") - internalExperiment := flag.Bool("internal-experiment", false, "issue an internal experiment") - binlogFile := flag.String("binlog-file", "", "Name of binary log file") - flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)") flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)") flag.StringVar(&migrationContext.InspectorConnectionConfig.User, "user", "root", "MySQL user") @@ -35,9 +30,20 @@ func main() { flag.BoolVar(&migrationContext.CountTableRows, "exact-rowcount", false, "actually count table rows as opposed to estimate them (results in more accurate progress estimation)") flag.BoolVar(&migrationContext.AllowedRunningOnMaster, "allow-on-master", false, "allow this migration to run directly on master. Preferably it would run on a replica") - flag.Int64Var(&migrationContext.ChunkSize, "chunk-size", 1000, "amount of rows to handle in each iteration") - flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists") + executeFlag := flag.Bool("execute", false, "actually execute the alter & migrate the table. Default is noop: do some tests and exit") + flag.BoolVar(&migrationContext.TestOnReplica, "test-on-replica", false, "Have the migration run on a replica, not on the master. At the end of migration tables are not swapped; gh-osc issues `STOP SLAVE` and you can compare the two tables for building trust") + flag.Int64Var(&migrationContext.ChunkSize, "chunk-size", 1000, "amount of rows to handle in each iteration (allowed range: 100-100,000)") + if migrationContext.ChunkSize < 100 { + migrationContext.ChunkSize = 100 + } + if migrationContext.ChunkSize > 100000 { + migrationContext.ChunkSize = 100000 + } + flag.Int64Var(&migrationContext.MaxLagMillisecondsThrottleThreshold, "max-lag-millis", 1500, "replication lag at which to throttle operation") + flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists; hint: use a file that is specific to the table being altered") + flag.StringVar(&migrationContext.ThrottleAdditionalFlagFile, "throttle-additional-flag-file", "/tmp/gh-osc.throttle", "operation pauses when this file exists; hint: keep default, use for throttling multiple gh-osc operations") + maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'") quiet := flag.Bool("quiet", false, "quiet") verbose := flag.Bool("verbose", false, "verbose") debug := flag.Bool("debug", false, "debug mode (very verbose)") @@ -75,23 +81,16 @@ func main() { if migrationContext.AlterStatement == "" { log.Fatalf("--alter must be provided and statement must not be empty") } + migrationContext.Noop = !(*executeFlag) + if migrationContext.AllowedRunningOnMaster && migrationContext.TestOnReplica { + log.Fatalf("--allow-on-master and --test-on-replica are mutually exclusive") + } + if err := migrationContext.ReadMaxLoad(*maxLoad); err != nil { + log.Fatale(err) + } log.Info("starting gh-osc") - if *internalExperiment { - log.Debug("starting experiment with %+v", *binlogFile) - - //binlogReader = binlog.NewMySQLBinlogReader(*mysqlBasedir, *mysqlDatadir) - // binlogReader, err := binlog.NewGoMySQLReader(migrationContext.InspectorConnectionConfig) - // if err != nil { - // log.Fatale(err) - // } - // if err := binlogReader.ConnectBinlogStreamer(mysql.BinlogCoordinates{LogFile: *binlogFile, LogPos: 0}); err != nil { - // log.Fatale(err) - // } - // binlogReader.StreamEvents(func() bool { return false }) - // return - } migrator := logic.NewMigrator() err := migrator.Migrate() if err != nil { diff --git a/go/logic/applier.go b/go/logic/applier.go index b4f63a9..5e7c83f 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -12,6 +12,7 @@ import ( "time" "github.com/github/gh-osc/go/base" + "github.com/github/gh-osc/go/binlog" "github.com/github/gh-osc/go/mysql" "github.com/github/gh-osc/go/sql" @@ -19,10 +20,6 @@ import ( "github.com/outbrain/golib/sqlutils" ) -const ( - heartbeatIntervalSeconds = 1 -) - // Applier reads data from the read-MySQL-server (typically a replica, but can be the master) // It is used for gaining initial status and structure, and later also follow up on progress and changelog type Applier struct { @@ -33,7 +30,7 @@ type Applier struct { func NewApplier() *Applier { return &Applier{ - connectionConfig: base.GetMigrationContext().MasterConnectionConfig, + connectionConfig: base.GetMigrationContext().ApplierConnectionConfig, migrationContext: base.GetMigrationContext(), } } @@ -63,7 +60,7 @@ func (this *Applier) validateConnection() error { return nil } -// CreateGhostTable creates the ghost table on the master +// CreateGhostTable creates the ghost table on the applier host func (this *Applier) CreateGhostTable() error { query := fmt.Sprintf(`create /* gh-osc */ table %s.%s like %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), @@ -82,7 +79,7 @@ func (this *Applier) CreateGhostTable() error { return nil } -// CreateGhostTable creates the ghost table on the master +// AlterGhost applies `alter` statement on ghost table func (this *Applier) AlterGhost() error { query := fmt.Sprintf(`alter /* gh-osc */ table %s.%s %s`, sql.EscapeName(this.migrationContext.DatabaseName), @@ -101,16 +98,16 @@ func (this *Applier) AlterGhost() error { return nil } -// CreateChangelogTable creates the changelog table on the master +// CreateChangelogTable creates the changelog table on the applier host func (this *Applier) CreateChangelogTable() error { query := fmt.Sprintf(`create /* gh-osc */ table %s.%s ( - id int auto_increment, + id bigint auto_increment, last_update timestamp not null DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, hint varchar(64) charset ascii not null, - value varchar(64) charset ascii not null, + value varchar(255) charset ascii not null, primary key(id), unique key hint_uidx(hint) - ) auto_increment=2 + ) auto_increment=256 `, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetChangelogTableName()), @@ -126,31 +123,50 @@ func (this *Applier) CreateChangelogTable() error { return nil } -// DropChangelogTable drops the changelog table on the master -func (this *Applier) DropChangelogTable() error { +// dropTable drops a given table on the applied host +func (this *Applier) dropTable(tableName string) error { query := fmt.Sprintf(`drop /* gh-osc */ table if exists %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), - sql.EscapeName(this.migrationContext.GetChangelogTableName()), + sql.EscapeName(tableName), ) - log.Infof("Droppping changelog table %s.%s", + log.Infof("Droppping table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), - sql.EscapeName(this.migrationContext.GetChangelogTableName()), + sql.EscapeName(tableName), ) if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("Changelog table dropped") + log.Infof("Table dropped") return nil } +// DropChangelogTable drops the changelog table on the applier host +func (this *Applier) DropChangelogTable() error { + return this.dropTable(this.migrationContext.GetChangelogTableName()) +} + +// DropGhostTable drops the ghost table on the applier host +func (this *Applier) DropGhostTable() error { + return this.dropTable(this.migrationContext.GetGhostTableName()) +} + // WriteChangelog writes a value to the changelog table. // It returns the hint as given, for convenience func (this *Applier) WriteChangelog(hint, value string) (string, error) { + explicitId := 0 + switch hint { + case "heartbeat": + explicitId = 1 + case "state": + explicitId = 2 + case "throttle": + explicitId = 3 + } query := fmt.Sprintf(` insert /* gh-osc */ into %s.%s (id, hint, value) values - (NULL, ?, ?) + (NULLIF(?, 0), ?, ?) on duplicate key update last_update=NOW(), value=VALUES(value) @@ -158,56 +174,51 @@ func (this *Applier) WriteChangelog(hint, value string) (string, error) { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetChangelogTableName()), ) - _, err := sqlutils.Exec(this.db, query, hint, value) + _, err := sqlutils.Exec(this.db, query, explicitId, hint, value) return hint, err } +func (this *Applier) WriteAndLogChangelog(hint, value string) (string, error) { + this.WriteChangelog(hint, value) + return this.WriteChangelog(fmt.Sprintf("%s at %d", hint, time.Now().UnixNano()), value) +} + +func (this *Applier) WriteChangelogState(value string) (string, error) { + return this.WriteAndLogChangelog("state", value) +} + // InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table. // This is done asynchronously -func (this *Applier) InitiateHeartbeat() { - go func() { - numSuccessiveFailures := 0 - query := fmt.Sprintf(` - insert /* gh-osc */ into %s.%s - (id, hint, value) - values - (1, 'heartbeat', ?) - on duplicate key update - last_update=NOW(), - value=VALUES(value) - `, - sql.EscapeName(this.migrationContext.DatabaseName), - sql.EscapeName(this.migrationContext.GetChangelogTableName()), - ) - injectHeartbeat := func() error { - if _, err := sqlutils.ExecNoPrepare(this.db, query, time.Now().Format(time.RFC3339)); err != nil { - numSuccessiveFailures++ - if numSuccessiveFailures > this.migrationContext.MaxRetries() { - return log.Errore(err) - } - } else { - numSuccessiveFailures = 0 +func (this *Applier) InitiateHeartbeat(heartbeatIntervalMilliseconds int64) { + numSuccessiveFailures := 0 + injectHeartbeat := func() error { + if _, err := this.WriteChangelog("heartbeat", time.Now().Format(time.RFC3339)); err != nil { + numSuccessiveFailures++ + if numSuccessiveFailures > this.migrationContext.MaxRetries() { + return log.Errore(err) } - return nil + } else { + numSuccessiveFailures = 0 } - injectHeartbeat() + return nil + } + injectHeartbeat() - heartbeatTick := time.Tick(time.Duration(heartbeatIntervalSeconds) * time.Second) - for range heartbeatTick { - // Generally speaking, we would issue a goroutine, but I'd actually rather - // have this blocked rather than spam the master in the event something - // goes wrong - if err := injectHeartbeat(); err != nil { - return - } + heartbeatTick := time.Tick(time.Duration(heartbeatIntervalMilliseconds) * time.Millisecond) + for range heartbeatTick { + // Generally speaking, we would issue a goroutine, but I'd actually rather + // have this blocked rather than spam the master in the event something + // goes wrong + if err := injectHeartbeat(); err != nil { + return } - }() + } } // ReadMigrationMinValues func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) - query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) + query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names) if err != nil { return err } @@ -216,7 +227,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { return err } for rows.Next() { - this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(len(uniqueKey.Columns)) + this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(uniqueKey.Len()) if err = rows.Scan(this.migrationContext.MigrationRangeMinValues.ValuesPointers...); err != nil { return err } @@ -228,7 +239,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { // ReadMigrationMinValues func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) - query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) + query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names) if err != nil { return err } @@ -237,7 +248,7 @@ func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { return err } for rows.Next() { - this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(len(uniqueKey.Columns)) + this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(uniqueKey.Len()) if err = rows.Scan(this.migrationContext.MigrationRangeMaxValues.ValuesPointers...); err != nil { return err } @@ -266,12 +277,12 @@ func (this *Applier) __unused_IterationIsComplete() (bool, error) { return false, nil } args := sqlutils.Args() - compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign) + compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign) if err != nil { return false, err } args = append(args, explodedArgs...) - compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign) + compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign) if err != nil { return false, err } @@ -311,11 +322,11 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo query, explodedArgs, err := sql.BuildUniqueKeyRangeEndPreparedQuery( this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, - this.migrationContext.UniqueKey.Columns, + this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationRangeMaxValues.AbstractValues(), this.migrationContext.ChunkSize, - fmt.Sprintf("iteration:%d", this.migrationContext.Iteration), + fmt.Sprintf("iteration:%d", this.migrationContext.GetIteration()), ) if err != nil { return hasFurtherRange, err @@ -324,7 +335,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo if err != nil { return hasFurtherRange, err } - iterationRangeMaxValues := sql.NewColumnValues(len(this.migrationContext.UniqueKey.Columns)) + iterationRangeMaxValues := sql.NewColumnValues(this.migrationContext.UniqueKey.Len()) for rows.Next() { if err = rows.Scan(iterationRangeMaxValues.ValuesPointers...); err != nil { return hasFurtherRange, err @@ -332,17 +343,10 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo hasFurtherRange = true } if !hasFurtherRange { - log.Debugf("Iteration complete: cannot find iteration end") + log.Debugf("Iteration complete: no further range to iterate") return hasFurtherRange, nil } this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeMaxValues - log.Debugf( - "column values: [%s]..[%s]; iteration: %d; chunk-size: %d", - this.migrationContext.MigrationIterationRangeMinValues, - this.migrationContext.MigrationIterationRangeMaxValues, - this.migrationContext.Iteration, - this.migrationContext.ChunkSize, - ) return hasFurtherRange, nil } @@ -354,12 +358,12 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.GetGhostTableName(), - this.migrationContext.UniqueKey.Columns, + this.migrationContext.SharedColumns.Names, this.migrationContext.UniqueKey.Name, - this.migrationContext.UniqueKey.Columns, + this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(), - this.migrationContext.Iteration == 0, + this.migrationContext.GetIteration() == 0, this.migrationContext.IsTransactionalTable(), ) if err != nil { @@ -371,21 +375,22 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected } rowsAffected, _ = sqlResult.RowsAffected() duration = time.Now().Sub(startTime) - this.WriteChangelog( - fmt.Sprintf("copy iteration %d", this.migrationContext.Iteration), - fmt.Sprintf("chunk: %d; affected: %d; duration: %d", chunkSize, rowsAffected, duration), - ) log.Debugf( "Issued INSERT on range: [%s]..[%s]; iteration: %d; chunk-size: %d", this.migrationContext.MigrationIterationRangeMinValues, this.migrationContext.MigrationIterationRangeMaxValues, - this.migrationContext.Iteration, + this.migrationContext.GetIteration(), chunkSize) return chunkSize, rowsAffected, duration, nil } // LockTables func (this *Applier) LockTables() error { + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really locking tables") + return nil + } + query := fmt.Sprintf(`lock /* gh-osc */ tables %s.%s write, %s.%s write, %s.%s write`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), @@ -412,3 +417,44 @@ func (this *Applier) UnlockTables() error { log.Infof("Tables unlocked") return nil } + +func (this *Applier) ShowStatusVariable(variableName string) (result int64, err error) { + query := fmt.Sprintf(`show global status like '%s'`, variableName) + if err := this.db.QueryRow(query).Scan(&variableName, &result); err != nil { + return 0, err + } + return result, nil +} + +func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) (query string, args []interface{}, err error) { + switch dmlEvent.DML { + case binlog.DeleteDML: + { + query, uniqueKeyArgs, err := sql.BuildDMLDeleteQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, &this.migrationContext.UniqueKey.Columns, dmlEvent.WhereColumnValues.AbstractValues()) + return query, uniqueKeyArgs, err + } + case binlog.InsertDML: + { + query, sharedArgs, err := sql.BuildDMLInsertQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, dmlEvent.NewColumnValues.AbstractValues()) + return query, sharedArgs, err + } + case binlog.UpdateDML: + { + query, sharedArgs, uniqueKeyArgs, err := sql.BuildDMLUpdateQuery(dmlEvent.DatabaseName, this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, &this.migrationContext.UniqueKey.Columns, dmlEvent.NewColumnValues.AbstractValues(), dmlEvent.WhereColumnValues.AbstractValues()) + args = append(args, sharedArgs...) + args = append(args, uniqueKeyArgs...) + return query, args, err + } + } + return "", args, fmt.Errorf("Unknown dml event type: %+v", dmlEvent.DML) +} + +func (this *Applier) ApplyDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) error { + query, args, err := this.buildDMLEventQuery(dmlEvent) + if err != nil { + return err + } + log.Errorf(query) + _, err = sqlutils.Exec(this.db, query, args...) + return err +} diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 31f74ec..902ccbe 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -44,6 +44,9 @@ func (this *Inspector) InitDBConnections() (err error) { if err := this.validateGrants(); err != nil { return err } + if err := this.restartReplication(); err != nil { + return err + } if err := this.validateBinlogs(); err != nil { return err } @@ -69,15 +72,54 @@ func (this *Inspector) ValidateOriginalTable() (err error) { return nil } -func (this *Inspector) InspectOriginalTable() (uniqueKeys [](*sql.UniqueKey), err error) { - uniqueKeys, err = this.getCandidateUniqueKeys(this.migrationContext.OriginalTableName) +func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns *sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) { + uniqueKeys, err = this.getCandidateUniqueKeys(tableName) if err != nil { - return uniqueKeys, err + return columns, uniqueKeys, err } if len(uniqueKeys) == 0 { - return uniqueKeys, fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out") + return columns, uniqueKeys, fmt.Errorf("No PRIMARY nor UNIQUE key found in table! Bailing out") } - return uniqueKeys, err + columns, err = this.getTableColumns(this.migrationContext.DatabaseName, tableName) + if err != nil { + return columns, uniqueKeys, err + } + + return columns, uniqueKeys, nil +} + +func (this *Inspector) InspectOriginalTable() (err error) { + this.migrationContext.OriginalTableColumns, this.migrationContext.OriginalTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.OriginalTableName) + if err == nil { + return err + } + return nil +} + +func (this *Inspector) InspectOriginalAndGhostTables() (err error) { + this.migrationContext.GhostTableColumns, this.migrationContext.GhostTableUniqueKeys, err = this.InspectTableColumnsAndUniqueKeys(this.migrationContext.GetGhostTableName()) + if err != nil { + return err + } + sharedUniqueKeys, err := this.getSharedUniqueKeys(this.migrationContext.OriginalTableUniqueKeys, this.migrationContext.GhostTableUniqueKeys) + if err != nil { + return err + } + if len(sharedUniqueKeys) == 0 { + return fmt.Errorf("No shared unique key can be found after ALTER! Bailing out") + } + this.migrationContext.UniqueKey = sharedUniqueKeys[0] + log.Infof("Chosen shared unique key is %s", this.migrationContext.UniqueKey.Name) + if !this.migrationContext.UniqueKey.IsPrimary() { + if this.migrationContext.OriginalBinlogRowImage != "full" { + return fmt.Errorf("binlog_row_image is '%s' and chosen key is %s, which is not the primary key. This operation cannot proceed. You may `set global binlog_row_image='full'` and try again") + } + } + + this.migrationContext.SharedColumns = this.getSharedColumns(this.migrationContext.OriginalTableColumns, this.migrationContext.GhostTableColumns) + log.Infof("Shared columns are %s", this.migrationContext.SharedColumns) + // By fact that a non-empty unique key exists we also know the shared columns are non-empty + return nil } // validateConnection issues a simple can-connect to MySQL @@ -136,6 +178,32 @@ func (this *Inspector) validateGrants() error { return log.Errorf("User has insufficient privileges for migration.") } +// restartReplication is required so that we are _certain_ the binlog format and +// row image settings have actually been applied to the replication thread. +// It is entriely possible, for example, that the replication is using 'STATEMENT' +// binlog format even as the variable says 'ROW' +func (this *Inspector) restartReplication() error { + log.Infof("Restarting replication on %s:%d to make sure binlog settings apply to replication thread", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + + masterKey, _ := getMasterKeyFromSlaveStatus(this.connectionConfig) + if masterKey == nil { + // This is not a replica + return nil + } + + var stopError, startError error + _, stopError = sqlutils.ExecNoPrepare(this.db, `stop slave`) + _, startError = sqlutils.ExecNoPrepare(this.db, `start slave`) + if stopError != nil { + return stopError + } + if startError != nil { + return startError + } + log.Debugf("Replication restarted") + return nil +} + // validateBinlogs checks that binary log configuration is good to go func (this *Inspector) validateBinlogs() error { query := `select @@global.log_bin, @@global.log_slave_updates, @@global.binlog_format` @@ -264,27 +332,28 @@ func (this *Inspector) countTableRows() error { return nil } -func (this *Inspector) getTableColumns(databaseName, tableName string) (columns sql.ColumnList, err error) { +func (this *Inspector) getTableColumns(databaseName, tableName string) (*sql.ColumnList, error) { query := fmt.Sprintf(` show columns from %s.%s `, sql.EscapeName(databaseName), sql.EscapeName(tableName), ) - err = sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { - columns = append(columns, rowMap.GetString("Field")) + columnNames := []string{} + err := sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { + columnNames = append(columnNames, rowMap.GetString("Field")) return nil }) if err != nil { - return columns, err + return nil, err } - if len(columns) == 0 { - return columns, log.Errorf("Found 0 columns on %s.%s. Bailing out", + if len(columnNames) == 0 { + return nil, log.Errorf("Found 0 columns on %s.%s. Bailing out", sql.EscapeName(databaseName), sql.EscapeName(tableName), ) } - return columns, nil + return sql.NewColumnList(columnNames), nil } // getCandidateUniqueKeys investigates a table and returns the list of unique keys @@ -361,17 +430,9 @@ func (this *Inspector) getCandidateUniqueKeys(tableName string) (uniqueKeys [](* return uniqueKeys, nil } -// getCandidateUniqueKeys investigates a table and returns the list of unique keys -// candidate for chunking -func (this *Inspector) getSharedUniqueKeys() (uniqueKeys [](*sql.UniqueKey), err error) { - originalUniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.OriginalTableName) - if err != nil { - return uniqueKeys, err - } - ghostUniqueKeys, err := this.getCandidateUniqueKeys(this.migrationContext.GetGhostTableName()) - if err != nil { - return uniqueKeys, err - } +// getSharedUniqueKeys returns the intersection of two given unique keys, +// testing by list of columns +func (this *Inspector) getSharedUniqueKeys(originalUniqueKeys, ghostUniqueKeys [](*sql.UniqueKey)) (uniqueKeys [](*sql.UniqueKey), err error) { // We actually do NOT rely on key name, just on the set of columns. This is because maybe // the ALTER is on the name itself... for _, originalUniqueKey := range originalUniqueKeys { @@ -384,43 +445,77 @@ func (this *Inspector) getSharedUniqueKeys() (uniqueKeys [](*sql.UniqueKey), err return uniqueKeys, nil } -func (this *Inspector) getMasterConnectionConfig() (masterConfig *mysql.ConnectionConfig, err error) { - visitedKeys := mysql.NewInstanceKeyMap() - return getMasterConnectionConfigSafe(this.connectionConfig, this.migrationContext.DatabaseName, visitedKeys) +// getSharedColumns returns the intersection of two lists of columns in same order as the first list +func (this *Inspector) getSharedColumns(originalColumns, ghostColumns *sql.ColumnList) *sql.ColumnList { + columnsInGhost := make(map[string]bool) + for _, ghostColumn := range ghostColumns.Names { + columnsInGhost[ghostColumn] = true + } + sharedColumnNames := []string{} + for _, originalColumn := range originalColumns.Names { + if columnsInGhost[originalColumn] { + sharedColumnNames = append(sharedColumnNames, originalColumn) + } + } + return sql.NewColumnList(sharedColumnNames) } -func getMasterConnectionConfigSafe(connectionConfig *mysql.ConnectionConfig, databaseName string, visitedKeys *mysql.InstanceKeyMap) (masterConfig *mysql.ConnectionConfig, err error) { - log.Debugf("Looking for master on %+v", connectionConfig.Key) +func (this *Inspector) readChangelogState() (map[string]string, error) { + query := fmt.Sprintf(` + select hint, value from %s.%s where id <= 255 + `, + sql.EscapeName(this.migrationContext.DatabaseName), + sql.EscapeName(this.migrationContext.GetChangelogTableName()), + ) + result := make(map[string]string) + err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error { + result[m.GetString("hint")] = m.GetString("value") + return nil + }) + return result, err +} - currentUri := connectionConfig.GetDBUri(databaseName) +func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.ConnectionConfig, err error) { + visitedKeys := mysql.NewInstanceKeyMap() + return getMasterConnectionConfigSafe(this.connectionConfig, visitedKeys) +} + +func getMasterKeyFromSlaveStatus(connectionConfig *mysql.ConnectionConfig) (masterKey *mysql.InstanceKey, err error) { + currentUri := connectionConfig.GetDBUri("information_schema") db, _, err := sqlutils.GetDB(currentUri) if err != nil { return nil, err } - - hasMaster := false - masterConfig = connectionConfig.Duplicate() err = sqlutils.QueryRowsMap(db, `show slave status`, func(rowMap sqlutils.RowMap) error { - masterKey := mysql.InstanceKey{ + masterKey = &mysql.InstanceKey{ Hostname: rowMap.GetString("Master_Host"), Port: rowMap.GetInt("Master_Port"), } - if masterKey.IsValid() { - masterConfig.Key = masterKey - hasMaster = true - } return nil }) + return masterKey, err +} + +func getMasterConnectionConfigSafe(connectionConfig *mysql.ConnectionConfig, visitedKeys *mysql.InstanceKeyMap) (masterConfig *mysql.ConnectionConfig, err error) { + log.Debugf("Looking for master on %+v", connectionConfig.Key) + + masterKey, err := getMasterKeyFromSlaveStatus(connectionConfig) if err != nil { return nil, err } - if hasMaster { - log.Debugf("Master of %+v is %+v", connectionConfig.Key, masterConfig.Key) - if visitedKeys.HasKey(masterConfig.Key) { - return nil, fmt.Errorf("There seems to be a master-master setup at %+v. This is unsupported. Bailing out", masterConfig.Key) - } - visitedKeys.AddKey(masterConfig.Key) - return getMasterConnectionConfigSafe(masterConfig, databaseName, visitedKeys) + if masterKey == nil { + return connectionConfig, nil } - return masterConfig, nil + if !masterKey.IsValid() { + return connectionConfig, nil + } + masterConfig = connectionConfig.Duplicate() + masterConfig.Key = *masterKey + + log.Debugf("Master of %+v is %+v", connectionConfig.Key, masterConfig.Key) + if visitedKeys.HasKey(masterConfig.Key) { + return nil, fmt.Errorf("There seems to be a master-master setup at %+v. This is unsupported. Bailing out", masterConfig.Key) + } + visitedKeys.AddKey(masterConfig.Key) + return getMasterConnectionConfigSafe(masterConfig, visitedKeys) } diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 451cb96..62fc6e3 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -8,7 +8,10 @@ package logic import ( "fmt" "os" + "os/signal" + "regexp" "sync/atomic" + "syscall" "time" "github.com/github/gh-osc/go/base" @@ -27,7 +30,12 @@ const ( type tableWriteFunc func() error const ( - applyEventsQueueBuffer = 100 + applyEventsQueueBuffer = 100 + heartbeatIntervalMilliseconds = 1000 +) + +var ( + prettifyDurationRegexp = regexp.MustCompile("([.][0-9]+)") ) // Migrator is the main schema migration flow manager. @@ -45,6 +53,8 @@ type Migrator struct { // excessive work happens at the end of the iteration as new copy-jobs arrive befroe realizing the copy is complete copyRowsQueue chan tableWriteFunc applyEventsQueue chan tableWriteFunc + + handledChangelogStates map[string]bool } func NewMigrator() *Migrator { @@ -54,40 +64,138 @@ func NewMigrator() *Migrator { rowCopyComplete: make(chan bool), allEventsUpToLockProcessed: make(chan bool), - copyRowsQueue: make(chan tableWriteFunc), - applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), - } - migrator.migrationContext.IsThrottled = func() bool { - return migrator.shouldThrottle() + copyRowsQueue: make(chan tableWriteFunc), + applyEventsQueue: make(chan tableWriteFunc, applyEventsQueueBuffer), + handledChangelogStates: make(map[string]bool), } return migrator } -func (this *Migrator) shouldThrottle() bool { +func prettifyDurationOutput(d time.Duration) string { + if d < time.Second { + return "0s" + } + result := fmt.Sprintf("%s", d) + result = prettifyDurationRegexp.ReplaceAllString(result, "") + return result +} + +// acceptSignals registers for OS signals +func (this *Migrator) acceptSignals() { + c := make(chan os.Signal, 1) + + signal.Notify(c, syscall.SIGHUP) + go func() { + for sig := range c { + switch sig { + case syscall.SIGHUP: + log.Debugf("Received SIGHUP. Reloading configuration") + } + } + }() +} + +func (this *Migrator) shouldThrottle() (result bool, reason string) { lag := atomic.LoadInt64(&this.migrationContext.CurrentLag) - shouldThrottle := false if time.Duration(lag) > time.Duration(this.migrationContext.MaxLagMillisecondsThrottleThreshold)*time.Millisecond { - shouldThrottle = true - } else if this.migrationContext.ThrottleFlagFile != "" { + return true, fmt.Sprintf("lag=%fs", time.Duration(lag).Seconds()) + } + if this.migrationContext.ThrottleFlagFile != "" { if _, err := os.Stat(this.migrationContext.ThrottleFlagFile); err == nil { - //Throttle file defined and exists! - shouldThrottle = true + // Throttle file defined and exists! + return true, "flag-file" } } - return shouldThrottle + if this.migrationContext.ThrottleAdditionalFlagFile != "" { + if _, err := os.Stat(this.migrationContext.ThrottleAdditionalFlagFile); err == nil { + // 2nd Throttle file defined and exists! + return true, "flag-file" + } + } + + for variableName, threshold := range this.migrationContext.MaxLoad { + value, err := this.applier.ShowStatusVariable(variableName) + if err != nil { + return true, fmt.Sprintf("%s %s", variableName, err) + } + if value > threshold { + return true, fmt.Sprintf("%s=%d", variableName, value) + } + } + + return false, "" +} + +func (this *Migrator) initiateThrottler() error { + throttlerTick := time.Tick(1 * time.Second) + + throttlerFunction := func() { + alreadyThrottling, currentReason := this.migrationContext.IsThrottled() + shouldThrottle, throttleReason := this.shouldThrottle() + if shouldThrottle && !alreadyThrottling { + // New throttling + this.applier.WriteAndLogChangelog("throttle", throttleReason) + } else if shouldThrottle && alreadyThrottling && (currentReason != throttleReason) { + // Change of reason + this.applier.WriteAndLogChangelog("throttle", throttleReason) + } else if alreadyThrottling && !shouldThrottle { + // End of throttling + this.applier.WriteAndLogChangelog("throttle", "done throttling") + } + this.migrationContext.SetThrottled(shouldThrottle, throttleReason) + } + throttlerFunction() + for range throttlerTick { + throttlerFunction() + } + + return nil +} + +// throttle initiates a throttling event, if need be, updates the Context and +// calls callback functions, if any +func (this *Migrator) throttle(onThrottled func()) { + for { + if shouldThrottle, _ := this.migrationContext.IsThrottled(); !shouldThrottle { + return + } + if onThrottled != nil { + onThrottled() + } + time.Sleep(time.Second) + } +} + +// retryOperation attempts up to `count` attempts at running given function, +// exiting as soon as it returns with non-error. +func (this *Migrator) retryOperation(operation func() error) (err error) { + maxRetries := this.migrationContext.MaxRetries() + for i := 0; i < maxRetries; i++ { + if i != 0 { + // sleep after previous iteration + time.Sleep(1 * time.Second) + } + err = operation() + if err == nil { + return nil + } + // there's an error. Let's try again. + } + return err } func (this *Migrator) canStopStreaming() bool { return false } -func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { - // Hey, I created the changlog table, I know the type of columns it has! - if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "state" { +func (this *Migrator) onChangelogState(stateValue string) (err error) { + if this.handledChangelogStates[stateValue] { return } - changelogState := ChangelogState(dmlEvent.NewColumnValues.StringColumn(3)) + this.handledChangelogStates[stateValue] = true + + changelogState := ChangelogState(stateValue) switch changelogState { case TablesInPlace: { @@ -102,16 +210,12 @@ func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (er return fmt.Errorf("Unknown changelog state: %+v", changelogState) } } - log.Debugf("---- - - - - - state %+v", changelogState) + log.Debugf("Received state %+v", changelogState) return nil } -func (this *Migrator) onChangelogHeartbeatEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { - if hint := dmlEvent.NewColumnValues.StringColumn(2); hint != "heartbeat" { - return nil - } - value := dmlEvent.NewColumnValues.StringColumn(3) - heartbeatTime, err := time.Parse(time.RFC3339, value) +func (this *Migrator) onChangelogHeartbeat(heartbeatValue string) (err error) { + heartbeatTime, err := time.Parse(time.RFC3339, heartbeatValue) if err != nil { return log.Errore(err) } @@ -132,18 +236,29 @@ func (this *Migrator) Migrate() (err error) { if err := this.inspector.ValidateOriginalTable(); err != nil { return err } - uniqueKeys, err := this.inspector.InspectOriginalTable() - if err != nil { + if err := this.inspector.InspectOriginalTable(); err != nil { return err } // So far so good, table is accessible and valid. - if this.migrationContext.MasterConnectionConfig, err = this.inspector.getMasterConnectionConfig(); err != nil { + // Let's get master connection config + if this.migrationContext.ApplierConnectionConfig, err = this.inspector.getMasterConnectionConfig(); err != nil { return err } - if this.migrationContext.IsRunningOnMaster() && !this.migrationContext.AllowedRunningOnMaster { + if this.migrationContext.TestOnReplica { + if this.migrationContext.InspectorIsAlsoApplier() { + return fmt.Errorf("Instructed to --test-on-replica, but the server we connect to doesn't seem to be a replica") + } + log.Infof("--test-on-replica given. Will not execute on master %+v but rather on replica %+v itself", + this.migrationContext.ApplierConnectionConfig.Key, this.migrationContext.InspectorConnectionConfig.Key, + ) + this.migrationContext.ApplierConnectionConfig = this.migrationContext.InspectorConnectionConfig.Duplicate() + } else if this.migrationContext.InspectorIsAlsoApplier() && !this.migrationContext.AllowedRunningOnMaster { return fmt.Errorf("It seems like this migration attempt to run directly on master. Preferably it would be executed on a replica (and this reduces load from the master). To proceed please provide --allow-on-master") } - log.Infof("Master found to be %+v", this.migrationContext.MasterConnectionConfig.Key) + + log.Infof("Master found to be %+v", this.migrationContext.ApplierConnectionConfig.Key) + + go this.initiateChangelogListener() if err := this.initiateStreaming(); err != nil { return err @@ -159,33 +274,30 @@ func (this *Migrator) Migrate() (err error) { // When running on replica, this means the replica has those tables. When running // on master this is always true, of course, and yet it also implies this knowledge // is in the binlogs. + if err := this.inspector.InspectOriginalAndGhostTables(); err != nil { + return err + } - this.migrationContext.UniqueKey = uniqueKeys[0] // TODO. Need to wait on replica till the ghost table exists and get shared keys if err := this.applier.ReadMigrationRangeValues(); err != nil { return err } - go this.initiateStatus() + go this.initiateThrottler() go this.executeWriteFuncs() go this.iterateChunks() + this.migrationContext.RowCopyStartTime = time.Now() + go this.initiateStatus() log.Debugf("Operating until row copy is complete") <-this.rowCopyComplete log.Debugf("Row copy complete") this.printStatus() - throttleMigration( - this.migrationContext, - func() { - log.Debugf("throttling before LOCK TABLES") - }, - nil, - func() { - log.Debugf("done throttling") - }, - ) + this.throttle(func() { + log.Debugf("throttling on LOCK TABLES") + }) // TODO retries!! this.applier.LockTables() - this.applier.WriteChangelog("state", string(AllEventsUpToLockProcessed)) + this.applier.WriteChangelogState(string(AllEventsUpToLockProcessed)) log.Debugf("Waiting for events up to lock") <-this.allEventsUpToLockProcessed log.Debugf("Done waiting for events up to lock") @@ -228,34 +340,71 @@ func (this *Migrator) printStatus() { return } - status := fmt.Sprintf("Copy: %d/%d %.1f%% Backlog: %d/%d Elapsed: %+v(copy), %+v(total) ETA: N/A", + eta := "N/A" + if isThrottled, throttleReason := this.migrationContext.IsThrottled(); isThrottled { + eta = fmt.Sprintf("throttled, %s", throttleReason) + } + status := fmt.Sprintf("Copy: %d/%d %.1f%%; Backlog: %d/%d; Elapsed: %+v(copy), %+v(total); ETA: %s", totalRowsCopied, rowsEstimate, progressPct, len(this.applyEventsQueue), cap(this.applyEventsQueue), - this.migrationContext.ElapsedRowCopyTime(), elapsedTime) + prettifyDurationOutput(this.migrationContext.ElapsedRowCopyTime()), prettifyDurationOutput(elapsedTime), + eta, + ) + this.applier.WriteChangelog( + fmt.Sprintf("copy iteration %d at %d", this.migrationContext.GetIteration(), time.Now().Unix()), + status, + ) fmt.Println(status) } +func (this *Migrator) initiateChangelogListener() { + ticker := time.Tick((heartbeatIntervalMilliseconds * time.Millisecond) / 2) + for range ticker { + go func() error { + changelogState, err := this.inspector.readChangelogState() + if err != nil { + return log.Errore(err) + } + for hint, value := range changelogState { + switch hint { + case "state": + { + this.onChangelogState(value) + } + case "heartbeat": + { + this.onChangelogHeartbeat(value) + } + } + } + return nil + }() + } +} + +// initiateStreaming begins treaming of binary log events and registers listeners for such events func (this *Migrator) initiateStreaming() error { this.eventsStreamer = NewEventsStreamer() if err := this.eventsStreamer.InitDBConnections(); err != nil { return err } + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really listening on binlog events") + return nil + } this.eventsStreamer.AddListener( - false, + true, this.migrationContext.DatabaseName, - this.migrationContext.GetChangelogTableName(), + this.migrationContext.OriginalTableName, func(dmlEvent *binlog.BinlogDMLEvent) error { - return this.onChangelogStateEvent(dmlEvent) - }, - ) - this.eventsStreamer.AddListener( - false, - this.migrationContext.DatabaseName, - this.migrationContext.GetChangelogTableName(), - func(dmlEvent *binlog.BinlogDMLEvent) error { - return this.onChangelogHeartbeatEvent(dmlEvent) + applyEventFunc := func() error { + return this.applier.ApplyDMLEventQuery(dmlEvent) + } + this.applyEventsQueue <- applyEventFunc + return nil }, ) + go func() { log.Debugf("Beginning streaming") this.eventsStreamer.StreamEvents(func() bool { return this.canStopStreaming() }) @@ -281,17 +430,24 @@ func (this *Migrator) initiateApplier() error { return err } - this.applier.WriteChangelog("state", string(TablesInPlace)) - this.applier.InitiateHeartbeat() + this.applier.WriteChangelogState(string(TablesInPlace)) + go this.applier.InitiateHeartbeat(heartbeatIntervalMilliseconds) return nil } func (this *Migrator) iterateChunks() error { - this.migrationContext.RowCopyStartTime = time.Now() terminateRowIteration := func(err error) error { this.rowCopyComplete <- true return log.Errore(err) } + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really copying data") + return terminateRowIteration(nil) + } + if this.migrationContext.MigrationRangeMinValues == nil { + log.Debugf("No rows found in table. Rowcopy will be implicitly empty") + return terminateRowIteration(nil) + } for { copyRowsFunc := func() error { hasFurtherRange, err := this.applier.CalculateNextIterationRangeEndValues() @@ -306,7 +462,7 @@ func (this *Migrator) iterateChunks() error { return terminateRowIteration(err) } atomic.AddInt64(&this.migrationContext.TotalRowsCopied, rowsAffected) - this.migrationContext.Iteration++ + atomic.AddInt64(&this.migrationContext.Iteration, 1) return nil } this.copyRowsQueue <- copyRowsFunc @@ -315,30 +471,29 @@ func (this *Migrator) iterateChunks() error { } func (this *Migrator) executeWriteFuncs() error { + if this.migrationContext.Noop { + log.Debugf("Noop operation; not really doing writes") + return nil + } for { - throttleMigration( - this.migrationContext, - func() { - log.Debugf("throttling writes") - }, - nil, - func() { - log.Debugf("done throttling writes") - }, - ) + this.throttle(nil) // We give higher priority to event processing, then secondary priority to // rowcopy select { case applyEventFunc := <-this.applyEventsQueue: { - retryOperation(applyEventFunc, this.migrationContext.MaxRetries()) + if err := this.retryOperation(applyEventFunc); err != nil { + return log.Errore(err) + } } default: { select { case copyRowsFunc := <-this.copyRowsQueue: { - retryOperation(copyRowsFunc, this.migrationContext.MaxRetries()) + if err := this.retryOperation(copyRowsFunc); err != nil { + return log.Errore(err) + } } default: { diff --git a/go/sql/builder.go b/go/sql/builder.go index b8f0358..4b73a5a 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -32,6 +32,20 @@ func EscapeName(name string) string { return fmt.Sprintf("`%s`", name) } +func buildPreparedValues(length int) []string { + values := make([]string, length, length) + for i := 0; i < length; i++ { + values[i] = "?" + } + return values +} + +func duplicateNames(names []string) []string { + duplicate := make([]string, len(names), len(names)) + copy(duplicate, names) + return duplicate +} + func BuildValueComparison(column string, value string, comparisonSign ValueComparisonSign) (result string, err error) { if column == "" { return "", fmt.Errorf("Empty column in GetValueComparison") @@ -64,6 +78,22 @@ func BuildEqualsComparison(columns []string, values []string) (result string, er return result, nil } +func BuildEqualsPreparedComparison(columns []string) (result string, err error) { + values := buildPreparedValues(len(columns)) + return BuildEqualsComparison(columns, values) +} + +func BuildSetPreparedClause(columns []string) (result string, err error) { + if len(columns) == 0 { + return "", fmt.Errorf("Got 0 columns in BuildSetPreparedClause") + } + setTokens := []string{} + for _, column := range columns { + setTokens = append(setTokens, fmt.Sprintf("%s=?", EscapeName(column))) + } + return strings.Join(setTokens, ", "), nil +} + func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { if len(columns) == 0 { return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison") @@ -121,10 +151,7 @@ func BuildRangeComparison(columns []string, values []string, args []interface{}, } func BuildRangePreparedComparison(columns []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { - values := make([]string, len(columns), len(columns)) - for i := range columns { - values[i] = "?" - } + values := buildPreparedValues(len(columns)) return BuildRangeComparison(columns, values, args, comparisonSign) } @@ -135,6 +162,7 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin databaseName = EscapeName(databaseName) originalTableName = EscapeName(originalTableName) ghostTableName = EscapeName(ghostTableName) + sharedColumns = duplicateNames(sharedColumns) for i := range sharedColumns { sharedColumns[i] = EscapeName(sharedColumns[i]) } @@ -171,12 +199,8 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin } func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { - rangeStartValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - rangeEndValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - for i := range uniqueKeyColumns { - rangeStartValues[i] = "?" - rangeEndValues[i] = "?" - } + rangeStartValues := buildPreparedValues(len(uniqueKeyColumns)) + rangeEndValues := buildPreparedValues(len(uniqueKeyColumns)) return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable) } @@ -198,6 +222,7 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK } explodedArgs = append(explodedArgs, rangeExplodedArgs...) + uniqueKeyColumns = duplicateNames(uniqueKeyColumns) uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) for i := range uniqueKeyColumns { @@ -244,6 +269,7 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni databaseName = EscapeName(databaseName) tableName = EscapeName(tableName) + uniqueKeyColumns = duplicateNames(uniqueKeyColumns) uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) for i := range uniqueKeyColumns { uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i]) @@ -262,3 +288,125 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni ) return query, nil } + +func BuildDMLDeleteQuery(databaseName, tableName string, tableColumns, uniqueKeyColumns *ColumnList, args []interface{}) (result string, uniqueKeyArgs []interface{}, err error) { + if len(args) != tableColumns.Len() { + return result, uniqueKeyArgs, fmt.Errorf("args count differs from table column count in BuildDMLDeleteQuery") + } + if uniqueKeyColumns.Len() == 0 { + return result, uniqueKeyArgs, fmt.Errorf("No unique key columns found in BuildDMLDeleteQuery") + } + for _, column := range uniqueKeyColumns.Names { + tableOrdinal := tableColumns.Ordinals[column] + uniqueKeyArgs = append(uniqueKeyArgs, args[tableOrdinal]) + } + databaseName = EscapeName(databaseName) + tableName = EscapeName(tableName) + equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names) + if err != nil { + return result, uniqueKeyArgs, err + } + result = fmt.Sprintf(` + delete /* gh-osc %s.%s */ + from + %s.%s + where + %s + `, databaseName, tableName, + databaseName, tableName, + equalsComparison, + ) + return result, uniqueKeyArgs, nil +} + +func BuildDMLInsertQuery(databaseName, tableName string, tableColumns, sharedColumns *ColumnList, args []interface{}) (result string, sharedArgs []interface{}, err error) { + if len(args) != tableColumns.Len() { + return result, args, fmt.Errorf("args count differs from table column count in BuildDMLInsertQuery") + } + if !sharedColumns.IsSubsetOf(tableColumns) { + return result, args, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLInsertQuery") + } + if sharedColumns.Len() == 0 { + return result, args, fmt.Errorf("No shared columns found in BuildDMLInsertQuery") + } + databaseName = EscapeName(databaseName) + tableName = EscapeName(tableName) + + for _, column := range sharedColumns.Names { + tableOrdinal := tableColumns.Ordinals[column] + sharedArgs = append(sharedArgs, args[tableOrdinal]) + } + + sharedColumnNames := duplicateNames(sharedColumns.Names) + for i := range sharedColumnNames { + sharedColumnNames[i] = EscapeName(sharedColumnNames[i]) + } + preparedValues := buildPreparedValues(sharedColumns.Len()) + + result = fmt.Sprintf(` + replace /* gh-osc %s.%s */ into + %s.%s + (%s) + values + (%s) + `, databaseName, tableName, + databaseName, tableName, + strings.Join(sharedColumnNames, ", "), + strings.Join(preparedValues, ", "), + ) + return result, sharedArgs, nil +} + +func BuildDMLUpdateQuery(databaseName, tableName string, tableColumns, sharedColumns, uniqueKeyColumns *ColumnList, valueArgs, whereArgs []interface{}) (result string, sharedArgs, uniqueKeyArgs []interface{}, err error) { + if len(valueArgs) != tableColumns.Len() { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("value args count differs from table column count in BuildDMLUpdateQuery") + } + if len(whereArgs) != tableColumns.Len() { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("where args count differs from table column count in BuildDMLUpdateQuery") + } + if !sharedColumns.IsSubsetOf(tableColumns) { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLUpdateQuery") + } + if !uniqueKeyColumns.IsSubsetOf(sharedColumns) { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("unique key columns is not a subset of shared columns in BuildDMLUpdateQuery") + } + if sharedColumns.Len() == 0 { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("No shared columns found in BuildDMLUpdateQuery") + } + if uniqueKeyColumns.Len() == 0 { + return result, sharedArgs, uniqueKeyArgs, fmt.Errorf("No unique key columns found in BuildDMLUpdateQuery") + } + databaseName = EscapeName(databaseName) + tableName = EscapeName(tableName) + + for _, column := range sharedColumns.Names { + tableOrdinal := tableColumns.Ordinals[column] + sharedArgs = append(sharedArgs, valueArgs[tableOrdinal]) + } + + for _, column := range uniqueKeyColumns.Names { + tableOrdinal := tableColumns.Ordinals[column] + uniqueKeyArgs = append(uniqueKeyArgs, whereArgs[tableOrdinal]) + } + + sharedColumnNames := duplicateNames(sharedColumns.Names) + for i := range sharedColumnNames { + sharedColumnNames[i] = EscapeName(sharedColumnNames[i]) + } + setClause, err := BuildSetPreparedClause(sharedColumnNames) + + equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names) + result = fmt.Sprintf(` + update /* gh-osc %s.%s */ + %s.%s + set + %s + where + %s + `, databaseName, tableName, + databaseName, tableName, + setClause, + equalsComparison, + ) + return result, sharedArgs, uniqueKeyArgs, nil +} diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go index b101bbf..3e0a8c6 100644 --- a/go/sql/builder_test.go +++ b/go/sql/builder_test.go @@ -68,6 +68,35 @@ func TestBuildEqualsComparison(t *testing.T) { } } +func TestBuildEqualsPreparedComparison(t *testing.T) { + { + columns := []string{"c1", "c2"} + comparison, err := BuildEqualsPreparedComparison(columns) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(comparison, "((`c1` = ?) and (`c2` = ?))") + } +} + +func TestBuildSetPreparedClause(t *testing.T) { + { + columns := []string{"c1"} + clause, err := BuildSetPreparedClause(columns) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(clause, "`c1`=?") + } + { + columns := []string{"c1", "c2"} + clause, err := BuildSetPreparedClause(columns) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(clause, "`c1`=?, `c2`=?") + } + { + columns := []string{} + _, err := BuildSetPreparedClause(columns) + test.S(t).ExpectNotNil(err) + } +} + func TestBuildRangeComparison(t *testing.T) { { columns := []string{"c1"} @@ -143,7 +172,7 @@ func TestBuildRangeInsertQuery(t *testing.T) { rangeStartArgs := []interface{}{3} rangeEndArgs := []interface{}{103} - query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) + query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true, false) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) @@ -162,7 +191,7 @@ func TestBuildRangeInsertQuery(t *testing.T) { rangeStartArgs := []interface{}{3, 17} rangeEndArgs := []interface{}{103, 117} - query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) + query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true, false) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) @@ -186,13 +215,13 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) { rangeStartArgs := []interface{}{3, 17} rangeEndArgs := []interface{}{103, 117} - query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true) + query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true, true) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) (select id, name, position from mydb.tbl force index (name_position_uidx) where (((name > ?) or (((name = ?)) AND (position > ?)) or ((name = ?) and (position = ?))) and ((name < ?) or (((name = ?)) AND (position < ?)) or ((name = ?) and (position = ?)))) - ) + lock in share mode ) ` test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 103, 103, 117, 103, 117})) @@ -202,7 +231,7 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) { func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) { databaseName := "mydb" originalTableName := "tbl" - chunkSize := 500 + var chunkSize int64 = 500 { uniqueKeyColumns := []string{"name", "position"} rangeStartArgs := []interface{}{3, 17} @@ -262,3 +291,191 @@ func TestBuildUniqueKeyMinValuesPreparedQuery(t *testing.T) { test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) } } + +func TestBuildDMLDeleteQuery(t *testing.T) { + databaseName := "mydb" + tableName := "tbl" + tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"}) + args := []interface{}{3, "testname", "first", 17, 23} + { + uniqueKeyColumns := NewColumnList([]string{"position"}) + + query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNil(err) + expected := ` + delete /* gh-osc mydb.tbl */ + from + mydb.tbl + where + ((position = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17})) + } + { + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) + + query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNil(err) + expected := ` + delete /* gh-osc mydb.tbl */ + from + mydb.tbl + where + ((name = ?) and (position = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{"testname", 17})) + } + { + uniqueKeyColumns := NewColumnList([]string{"position", "name"}) + + query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNil(err) + expected := ` + delete /* gh-osc mydb.tbl */ + from + mydb.tbl + where + ((position = ?) and (name = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17, "testname"})) + } + { + uniqueKeyColumns := NewColumnList([]string{"position", "name"}) + args := []interface{}{"first", 17} + + _, _, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNotNil(err) + } +} + +func TestBuildDMLInsertQuery(t *testing.T) { + databaseName := "mydb" + tableName := "tbl" + tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"}) + args := []interface{}{3, "testname", "first", 17, 23} + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + query, sharedArgs, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNil(err) + expected := ` + replace /* gh-osc mydb.tbl */ + into mydb.tbl + (id, name, position, age) + values + (?, ?, ?, ?) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + } + { + sharedColumns := NewColumnList([]string{"position", "name", "age", "id"}) + query, sharedArgs, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNil(err) + expected := ` + replace /* gh-osc mydb.tbl */ + into mydb.tbl + (position, name, age, id) + values + (?, ?, ?, ?) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{17, "testname", 23, 3})) + } + { + sharedColumns := NewColumnList([]string{"position", "name", "surprise", "id"}) + _, _, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNotNil(err) + } + { + sharedColumns := NewColumnList([]string{}) + _, _, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNotNil(err) + } +} + +func TestBuildDMLUpdateQuery(t *testing.T) { + databaseName := "mydb" + tableName := "tbl" + tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"}) + valueArgs := []interface{}{3, "testname", "newval", 17, 23} + whereArgs := []interface{}{3, "testname", "findme", 17, 56} + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"position"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((position = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"position", "name"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((position = ?) and (name = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17, "testname"})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"age"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((age = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{56})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"age", "position", "id", "name"}) + query, sharedArgs, uniqueKeyArgs, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNil(err) + expected := ` + update /* gh-osc mydb.tbl */ + mydb.tbl + set id=?, name=?, position=?, age=? + where + ((age = ?) and (position = ?) and (id = ?) and (name = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{56, 17, 3, "testname"})) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{"age", "surprise"}) + _, _, _, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNotNil(err) + } + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + uniqueKeyColumns := NewColumnList([]string{}) + _, _, _, err := BuildDMLUpdateQuery(databaseName, tableName, tableColumns, sharedColumns, uniqueKeyColumns, valueArgs, whereArgs) + test.S(t).ExpectNotNil(err) + } +} diff --git a/go/sql/types.go b/go/sql/types.go index 942bd37..e82720c 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -11,21 +11,62 @@ import ( "strings" ) +// ColumnsMap maps a column onto its ordinal position +type ColumnsMap map[string]int + +func NewColumnsMap(orderedNames []string) ColumnsMap { + columnsMap := make(map[string]int) + for i, column := range orderedNames { + columnsMap[column] = i + } + return ColumnsMap(columnsMap) +} + // ColumnList makes for a named list of columns -type ColumnList []string +type ColumnList struct { + Names []string + Ordinals ColumnsMap +} + +// NewColumnList creates an object given ordered list of column names +func NewColumnList(names []string) *ColumnList { + result := &ColumnList{ + Names: names, + } + result.Ordinals = NewColumnsMap(result.Names) + return result +} // ParseColumnList parses a comma delimited list of column names func ParseColumnList(columns string) *ColumnList { - result := ColumnList(strings.Split(columns, ",")) - return &result + result := &ColumnList{ + Names: strings.Split(columns, ","), + } + result.Ordinals = NewColumnsMap(result.Names) + return result } func (this *ColumnList) String() string { - return strings.Join(*this, ",") + return strings.Join(this.Names, ",") } func (this *ColumnList) Equals(other *ColumnList) bool { - return reflect.DeepEqual(*this, *other) + return reflect.DeepEqual(this.Names, other.Names) +} + +// IsSubsetOf returns 'true' when column names of this list are a subset of +// another list, in arbitrary order (order agnostic) +func (this *ColumnList) IsSubsetOf(other *ColumnList) bool { + for _, column := range this.Names { + if _, exists := other.Ordinals[column]; !exists { + return false + } + } + return true +} + +func (this *ColumnList) Len() int { + return len(this.Names) } // UniqueKey is the combination of a key's name and columns @@ -40,6 +81,10 @@ func (this *UniqueKey) IsPrimary() bool { return this.Name == "PRIMARY" } +func (this *UniqueKey) Len() int { + return this.Columns.Len() +} + func (this *UniqueKey) String() string { return fmt.Sprintf("%s: %s; has nullable: %+v", this.Name, this.Columns, this.HasNullable) } diff --git a/go/sql/types_test.go b/go/sql/types_test.go new file mode 100644 index 0000000..177b8cf --- /dev/null +++ b/go/sql/types_test.go @@ -0,0 +1,29 @@ +/* + Copyright 2016 GitHub Inc. + See https://github.com/github/gh-osc/blob/master/LICENSE +*/ + +package sql + +import ( + "testing" + + "github.com/outbrain/golib/log" + test "github.com/outbrain/golib/tests" + "reflect" +) + +func init() { + log.SetLevel(log.ERROR) +} + +func TestParseColumnList(t *testing.T) { + names := "id,category,max_len" + + columnList := ParseColumnList(names) + test.S(t).ExpectEquals(columnList.Len(), 3) + test.S(t).ExpectTrue(reflect.DeepEqual(columnList.Names, []string{"id", "category", "max_len"})) + test.S(t).ExpectEquals(columnList.Ordinals["id"], 0) + test.S(t).ExpectEquals(columnList.Ordinals["category"], 1) + test.S(t).ExpectEquals(columnList.Ordinals["max_len"], 2) +}