diff --git a/go/base/context.go b/go/base/context.go index 8185805..9611869 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -9,6 +9,7 @@ import ( "fmt" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -50,18 +51,19 @@ type MigrationContext struct { CurrentLag int64 MaxLagMillisecondsThrottleThreshold int64 ThrottleFlagFile string + ThrottleAdditionalFlagFile string TotalRowsCopied int64 - isThrottled int64 - ThrottleReason string + isThrottled bool + throttleReason string + throttleMutex *sync.Mutex MaxLoad map[string]int64 - OriginalTableColumns sql.ColumnList - OriginalTableColumnsMap sql.ColumnsMap + OriginalTableColumns *sql.ColumnList OriginalTableUniqueKeys [](*sql.UniqueKey) - GhostTableColumns sql.ColumnList + GhostTableColumns *sql.ColumnList GhostTableUniqueKeys [](*sql.UniqueKey) UniqueKey *sql.UniqueKey - SharedColumns sql.ColumnList + SharedColumns *sql.ColumnList MigrationRangeMinValues *sql.ColumnValues MigrationRangeMaxValues *sql.ColumnValues Iteration int64 @@ -83,7 +85,8 @@ func newMigrationContext() *MigrationContext { InspectorConnectionConfig: mysql.NewConnectionConfig(), MasterConnectionConfig: mysql.NewConnectionConfig(), MaxLagMillisecondsThrottleThreshold: 1000, - MaxLoad: make(map[string]int64), + MaxLoad: make(map[string]int64), + throttleMutex: &sync.Mutex{}, } } @@ -97,6 +100,11 @@ func (this *MigrationContext) GetGhostTableName() string { return fmt.Sprintf("_%s_New", this.OriginalTableName) } +// GetOldTableName generates the name of the "old" table, into which the original table is renamed. +func (this *MigrationContext) GetOldTableName() string { + return fmt.Sprintf("_%s_Old", this.OriginalTableName) +} + // GetChangelogTableName generates the name of changelog table, based on original table name func (this *MigrationContext) GetChangelogTableName() string { return fmt.Sprintf("_%s_OSC", this.OriginalTableName) @@ -157,16 +165,17 @@ func (this *MigrationContext) GetIteration() int64 { return atomic.LoadInt64(&this.Iteration) } -func (this *MigrationContext) SetThrottled(throttle bool) { - if throttle { - atomic.StoreInt64(&this.isThrottled, 1) - } else { - atomic.StoreInt64(&this.isThrottled, 0) - } +func (this *MigrationContext) SetThrottled(throttle bool, reason string) { + this.throttleMutex.Lock() + defer func() { this.throttleMutex.Unlock() }() + this.isThrottled = throttle + this.throttleReason = reason } -func (this *MigrationContext) IsThrottled() bool { - return atomic.LoadInt64(&this.isThrottled) != 0 +func (this *MigrationContext) IsThrottled() (bool, string) { + this.throttleMutex.Lock() + defer func() { this.throttleMutex.Unlock() }() + return this.isThrottled, this.throttleReason } func (this *MigrationContext) ReadMaxLoad(maxLoadList string) error { diff --git a/go/cmd/gh-osc/main.go b/go/cmd/gh-osc/main.go index 335d419..69709e6 100644 --- a/go/cmd/gh-osc/main.go +++ b/go/cmd/gh-osc/main.go @@ -38,7 +38,8 @@ func main() { migrationContext.ChunkSize = 100000 } flag.Int64Var(&migrationContext.MaxLagMillisecondsThrottleThreshold, "max-lag-millis", 1000, "replication lag at which to throttle operation") - flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "/tmp/gh-osc.throttle", "operation pauses when this file exists") + flag.StringVar(&migrationContext.ThrottleFlagFile, "throttle-flag-file", "", "operation pauses when this file exists; hint: use a file that is specific to the table being altered") + flag.StringVar(&migrationContext.ThrottleAdditionalFlagFile, "throttle-additional-flag-file", "/tmp/gh-osc.throttle", "operation pauses when this file exists; hint: keep default, use for throttling multiple gh-osc operations") maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'") quiet := flag.Bool("quiet", false, "quiet") verbose := flag.Bool("verbose", false, "verbose") diff --git a/go/logic/applier.go b/go/logic/applier.go index fb23440..60c83d7 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -12,6 +12,7 @@ import ( "time" "github.com/github/gh-osc/go/base" + "github.com/github/gh-osc/go/binlog" "github.com/github/gh-osc/go/mysql" "github.com/github/gh-osc/go/sql" @@ -63,7 +64,7 @@ func (this *Applier) validateConnection() error { return nil } -// CreateGhostTable creates the ghost table on the master +// CreateGhostTable creates the ghost table on the applier host func (this *Applier) CreateGhostTable() error { query := fmt.Sprintf(`create /* gh-osc */ table %s.%s like %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), @@ -82,7 +83,7 @@ func (this *Applier) CreateGhostTable() error { return nil } -// CreateGhostTable creates the ghost table on the master +// AlterGhost applies `alter` statement on ghost table func (this *Applier) AlterGhost() error { query := fmt.Sprintf(`alter /* gh-osc */ table %s.%s %s`, sql.EscapeName(this.migrationContext.DatabaseName), @@ -101,7 +102,7 @@ func (this *Applier) AlterGhost() error { return nil } -// CreateChangelogTable creates the changelog table on the master +// CreateChangelogTable creates the changelog table on the applier host func (this *Applier) CreateChangelogTable() error { query := fmt.Sprintf(`create /* gh-osc */ table %s.%s ( id bigint auto_increment, @@ -110,7 +111,7 @@ func (this *Applier) CreateChangelogTable() error { value varchar(255) charset ascii not null, primary key(id), unique key hint_uidx(hint) - ) auto_increment=2 + ) auto_increment=256 `, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetChangelogTableName()), @@ -126,23 +127,33 @@ func (this *Applier) CreateChangelogTable() error { return nil } -// DropChangelogTable drops the changelog table on the master -func (this *Applier) DropChangelogTable() error { +// dropTable drops a given table on the applied host +func (this *Applier) dropTable(tableName string) error { query := fmt.Sprintf(`drop /* gh-osc */ table if exists %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), - sql.EscapeName(this.migrationContext.GetChangelogTableName()), + sql.EscapeName(tableName), ) - log.Infof("Droppping changelog table %s.%s", + log.Infof("Droppping table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), - sql.EscapeName(this.migrationContext.GetChangelogTableName()), + sql.EscapeName(tableName), ) if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { return err } - log.Infof("Changelog table dropped") + log.Infof("Table dropped") return nil } +// DropChangelogTable drops the changelog table on the applier host +func (this *Applier) DropChangelogTable() error { + return this.dropTable(this.migrationContext.GetChangelogTableName()) +} + +// DropGhostTable drops the ghost table on the applier host +func (this *Applier) DropGhostTable() error { + return this.dropTable(this.migrationContext.GetGhostTableName()) +} + // WriteChangelog writes a value to the changelog table. // It returns the hint as given, for convenience func (this *Applier) WriteChangelog(hint, value string) (string, error) { @@ -162,12 +173,15 @@ func (this *Applier) WriteChangelog(hint, value string) (string, error) { return hint, err } -func (this *Applier) WriteChangelogState(value string) (string, error) { - hint := "state" +func (this *Applier) WriteAndLogChangelog(hint, value string) (string, error) { this.WriteChangelog(hint, value) return this.WriteChangelog(fmt.Sprintf("%s at %d", hint, time.Now().UnixNano()), value) } +func (this *Applier) WriteChangelogState(value string) (string, error) { + return this.WriteAndLogChangelog("state", value) +} + // InitiateHeartbeat creates a heartbeat cycle, writing to the changelog table. // This is done asynchronously func (this *Applier) InitiateHeartbeat() { @@ -213,7 +227,7 @@ func (this *Applier) InitiateHeartbeat() { // ReadMigrationMinValues func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) - query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) + query, err := sql.BuildUniqueKeyMinValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names) if err != nil { return err } @@ -222,7 +236,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { return err } for rows.Next() { - this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(len(uniqueKey.Columns)) + this.migrationContext.MigrationRangeMinValues = sql.NewColumnValues(uniqueKey.Len()) if err = rows.Scan(this.migrationContext.MigrationRangeMinValues.ValuesPointers...); err != nil { return err } @@ -234,7 +248,7 @@ func (this *Applier) ReadMigrationMinValues(uniqueKey *sql.UniqueKey) error { // ReadMigrationMinValues func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { log.Debugf("Reading migration range according to key: %s", uniqueKey.Name) - query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns) + query, err := sql.BuildUniqueKeyMaxValuesPreparedQuery(this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, uniqueKey.Columns.Names) if err != nil { return err } @@ -243,7 +257,7 @@ func (this *Applier) ReadMigrationMaxValues(uniqueKey *sql.UniqueKey) error { return err } for rows.Next() { - this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(len(uniqueKey.Columns)) + this.migrationContext.MigrationRangeMaxValues = sql.NewColumnValues(uniqueKey.Len()) if err = rows.Scan(this.migrationContext.MigrationRangeMaxValues.ValuesPointers...); err != nil { return err } @@ -272,12 +286,12 @@ func (this *Applier) __unused_IterationIsComplete() (bool, error) { return false, nil } args := sqlutils.Args() - compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign) + compareWithIterationRangeStart, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), sql.GreaterThanOrEqualsComparisonSign) if err != nil { return false, err } args = append(args, explodedArgs...) - compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign) + compareWithRangeEnd, explodedArgs, err := sql.BuildRangePreparedComparison(this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationRangeMaxValues.AbstractValues(), sql.LessThanComparisonSign) if err != nil { return false, err } @@ -317,7 +331,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo query, explodedArgs, err := sql.BuildUniqueKeyRangeEndPreparedQuery( this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, - this.migrationContext.UniqueKey.Columns, + this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationRangeMaxValues.AbstractValues(), this.migrationContext.ChunkSize, @@ -330,7 +344,7 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo if err != nil { return hasFurtherRange, err } - iterationRangeMaxValues := sql.NewColumnValues(len(this.migrationContext.UniqueKey.Columns)) + iterationRangeMaxValues := sql.NewColumnValues(this.migrationContext.UniqueKey.Len()) for rows.Next() { if err = rows.Scan(iterationRangeMaxValues.ValuesPointers...); err != nil { return hasFurtherRange, err @@ -360,9 +374,9 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, this.migrationContext.GetGhostTableName(), - this.migrationContext.SharedColumns, + this.migrationContext.SharedColumns.Names, this.migrationContext.UniqueKey.Name, - this.migrationContext.UniqueKey.Columns, + this.migrationContext.UniqueKey.Columns.Names, this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(), this.migrationContext.GetIteration() == 0, @@ -422,3 +436,12 @@ func (this *Applier) ShowStatusVariable(variableName string) (result int64, err } return result, nil } + +func (this *Applier) BuildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) (result string, err error) { + switch dmlEvent.DML { + case binlog.DeleteDML: + { + } + } + return result, err +} diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 42e9105..ddb53f3 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -44,6 +44,9 @@ func (this *Inspector) InitDBConnections() (err error) { if err := this.validateGrants(); err != nil { return err } + if err := this.restartReplication(); err != nil { + return err + } if err := this.validateBinlogs(); err != nil { return err } @@ -69,7 +72,7 @@ func (this *Inspector) ValidateOriginalTable() (err error) { return nil } -func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) { +func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns *sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) { uniqueKeys, err = this.getCandidateUniqueKeys(tableName) if err != nil { return columns, uniqueKeys, err @@ -90,7 +93,6 @@ func (this *Inspector) InspectOriginalTable() (err error) { if err == nil { return err } - this.migrationContext.OriginalTableColumnsMap = sql.NewColumnsMap(this.migrationContext.OriginalTableColumns) return nil } @@ -108,6 +110,11 @@ func (this *Inspector) InspectOriginalAndGhostTables() (err error) { } this.migrationContext.UniqueKey = sharedUniqueKeys[0] log.Infof("Chosen shared unique key is %s", this.migrationContext.UniqueKey.Name) + if !this.migrationContext.UniqueKey.IsPrimary() { + if this.migrationContext.OriginalBinlogRowImage != "full" { + return fmt.Errorf("binlog_row_image is '%s' and chosen key is %s, which is not the primary key. This operation cannot proceed. You may `set global binlog_row_image='full'` and try again") + } + } this.migrationContext.SharedColumns = this.getSharedColumns(this.migrationContext.OriginalTableColumns, this.migrationContext.GhostTableColumns) log.Infof("Shared columns are %s", this.migrationContext.SharedColumns) @@ -171,6 +178,26 @@ func (this *Inspector) validateGrants() error { return log.Errorf("User has insufficient privileges for migration.") } +// restartReplication is required so that we are _certain_ the binlog format and +// row image settings have actually been applied to the replication thread. +// It is entriely possible, for example, that the replication is using 'STATEMENT' +// binlog format even as the variable says 'ROW' +func (this *Inspector) restartReplication() error { + log.Infof("Restarting replication on %s:%d to make sure binlog settings apply to replication thread", this.connectionConfig.Key.Hostname, this.connectionConfig.Key.Port) + + var stopError, startError error + _, stopError = sqlutils.ExecNoPrepare(this.db, `stop slave`) + _, startError = sqlutils.ExecNoPrepare(this.db, `start slave`) + if stopError != nil { + return stopError + } + if startError != nil { + return startError + } + log.Debugf("Replication restarted") + return nil +} + // validateBinlogs checks that binary log configuration is good to go func (this *Inspector) validateBinlogs() error { query := `select @@global.log_bin, @@global.log_slave_updates, @@global.binlog_format` @@ -299,27 +326,28 @@ func (this *Inspector) countTableRows() error { return nil } -func (this *Inspector) getTableColumns(databaseName, tableName string) (columns sql.ColumnList, err error) { +func (this *Inspector) getTableColumns(databaseName, tableName string) (*sql.ColumnList, error) { query := fmt.Sprintf(` show columns from %s.%s `, sql.EscapeName(databaseName), sql.EscapeName(tableName), ) - err = sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { - columns = append(columns, rowMap.GetString("Field")) + columnNames := []string{} + err := sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { + columnNames = append(columnNames, rowMap.GetString("Field")) return nil }) if err != nil { - return columns, err + return nil, err } - if len(columns) == 0 { - return columns, log.Errorf("Found 0 columns on %s.%s. Bailing out", + if len(columnNames) == 0 { + return nil, log.Errorf("Found 0 columns on %s.%s. Bailing out", sql.EscapeName(databaseName), sql.EscapeName(tableName), ) } - return columns, nil + return sql.NewColumnList(columnNames), nil } // getCandidateUniqueKeys investigates a table and returns the list of unique keys @@ -412,17 +440,18 @@ func (this *Inspector) getSharedUniqueKeys(originalUniqueKeys, ghostUniqueKeys [ } // getSharedColumns returns the intersection of two lists of columns in same order as the first list -func (this *Inspector) getSharedColumns(originalColumns, ghostColumns sql.ColumnList) (sharedColumns sql.ColumnList) { +func (this *Inspector) getSharedColumns(originalColumns, ghostColumns *sql.ColumnList) *sql.ColumnList { columnsInGhost := make(map[string]bool) - for _, ghostColumn := range ghostColumns { + for _, ghostColumn := range ghostColumns.Names { columnsInGhost[ghostColumn] = true } - for _, originalColumn := range originalColumns { + sharedColumnNames := []string{} + for _, originalColumn := range originalColumns.Names { if columnsInGhost[originalColumn] { - sharedColumns = append(sharedColumns, originalColumn) + sharedColumnNames = append(sharedColumnNames, originalColumn) } } - return sharedColumns + return sql.NewColumnList(sharedColumnNames) } func (this *Inspector) getMasterConnectionConfig() (masterConfig *mysql.ConnectionConfig, err error) { diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 4df0e4a..3dee61d 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -8,8 +8,10 @@ package logic import ( "fmt" "os" + "os/signal" "regexp" "sync/atomic" + "syscall" "time" "github.com/github/gh-osc/go/base" @@ -74,6 +76,21 @@ func prettifyDurationOutput(d time.Duration) string { return result } +// acceptSignals registers for OS signals +func (this *Migrator) acceptSignals() { + c := make(chan os.Signal, 1) + + signal.Notify(c, syscall.SIGHUP) + go func() { + for sig := range c { + switch sig { + case syscall.SIGHUP: + log.Debugf("Received SIGHUP. Reloading configuration") + } + } + }() +} + func (this *Migrator) shouldThrottle() (result bool, reason string) { lag := atomic.LoadInt64(&this.migrationContext.CurrentLag) @@ -82,7 +99,13 @@ func (this *Migrator) shouldThrottle() (result bool, reason string) { } if this.migrationContext.ThrottleFlagFile != "" { if _, err := os.Stat(this.migrationContext.ThrottleFlagFile); err == nil { - //Throttle file defined and exists! + // Throttle file defined and exists! + return true, "flag-file" + } + } + if this.migrationContext.ThrottleAdditionalFlagFile != "" { + if _, err := os.Stat(this.migrationContext.ThrottleAdditionalFlagFile); err == nil { + // 2nd Throttle file defined and exists! return true, "flag-file" } } @@ -100,37 +123,43 @@ func (this *Migrator) shouldThrottle() (result bool, reason string) { return false, "" } +func (this *Migrator) initiateThrottler() error { + throttlerTick := time.Tick(1 * time.Second) + + throttlerFunction := func() { + alreadyThrottling, currentReason := this.migrationContext.IsThrottled() + shouldThrottle, throttleReason := this.shouldThrottle() + if shouldThrottle && !alreadyThrottling { + // New throttling + this.applier.WriteAndLogChangelog("throttle", throttleReason) + } else if shouldThrottle && alreadyThrottling && (currentReason != throttleReason) { + // Change of reason + this.applier.WriteAndLogChangelog("throttle", throttleReason) + } else if alreadyThrottling && !shouldThrottle { + // End of throttling + this.applier.WriteAndLogChangelog("throttle", "done throttling") + } + this.migrationContext.SetThrottled(shouldThrottle, throttleReason) + } + throttlerFunction() + for range throttlerTick { + throttlerFunction() + } + + return nil +} + // throttle initiates a throttling event, if need be, updates the Context and // calls callback functions, if any -func (this *Migrator) throttle( - onStartThrottling func(), - onContinuousThrottling func(), - onEndThrottling func(), -) { - hasThrottledYet := false +func (this *Migrator) throttle(onThrottled func()) { for { - shouldThrottle, reason := this.shouldThrottle() - if !shouldThrottle { - break + if shouldThrottle, _ := this.migrationContext.IsThrottled(); !shouldThrottle { + return } - this.migrationContext.ThrottleReason = reason - if !hasThrottledYet { - hasThrottledYet = true - if onStartThrottling != nil { - onStartThrottling() - } - this.migrationContext.SetThrottled(true) + if onThrottled != nil { + onThrottled() } time.Sleep(time.Second) - if onContinuousThrottling != nil { - onContinuousThrottling() - } - } - if hasThrottledYet { - if onEndThrottling != nil { - onEndThrottling() - } - this.migrationContext.SetThrottled(false) } } @@ -239,6 +268,7 @@ func (this *Migrator) Migrate() (err error) { if err := this.applier.ReadMigrationRangeValues(); err != nil { return err } + go this.initiateThrottler() go this.executeWriteFuncs() go this.iterateChunks() this.migrationContext.RowCopyStartTime = time.Now() @@ -249,15 +279,9 @@ func (this *Migrator) Migrate() (err error) { log.Debugf("Row copy complete") this.printStatus() - this.throttle( - func() { - log.Debugf("throttling before LOCK TABLES") - }, - nil, - func() { - log.Debugf("done throttling") - }, - ) + this.throttle(func() { + log.Debugf("throttling on LOCK TABLES") + }) // TODO retries!! this.applier.LockTables() this.applier.WriteChangelogState(string(AllEventsUpToLockProcessed)) @@ -304,8 +328,8 @@ func (this *Migrator) printStatus() { } eta := "N/A" - if this.migrationContext.IsThrottled() { - eta = fmt.Sprintf("throttled, %s", this.migrationContext.ThrottleReason) + if isThrottled, throttleReason := this.migrationContext.IsThrottled(); isThrottled { + eta = fmt.Sprintf("throttled, %s", throttleReason) } status := fmt.Sprintf("Copy: %d/%d %.1f%%; Backlog: %d/%d; Elapsed: %+v(copy), %+v(total); ETA: %s", totalRowsCopied, rowsEstimate, progressPct, @@ -399,14 +423,8 @@ func (this *Migrator) iterateChunks() error { } func (this *Migrator) executeWriteFuncs() error { - onStartThrottling := func() { - log.Debugf("throttling writes") - } - onEndThrottling := func() { - log.Debugf("done throttling writes") - } for { - this.throttle(onStartThrottling, nil, onEndThrottling) + this.throttle(nil) // We give higher priority to event processing, then secondary priority to // rowcopy select { diff --git a/go/sql/builder.go b/go/sql/builder.go index b8f0358..784b1c9 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -32,6 +32,20 @@ func EscapeName(name string) string { return fmt.Sprintf("`%s`", name) } +func buildPreparedValues(length int) []string { + values := make([]string, length, length) + for i := 0; i < length; i++ { + values[i] = "?" + } + return values +} + +func duplicateNames(names []string) []string { + duplicate := make([]string, len(names), len(names)) + copy(duplicate, names) + return duplicate +} + func BuildValueComparison(column string, value string, comparisonSign ValueComparisonSign) (result string, err error) { if column == "" { return "", fmt.Errorf("Empty column in GetValueComparison") @@ -64,6 +78,11 @@ func BuildEqualsComparison(columns []string, values []string) (result string, er return result, nil } +func BuildEqualsPreparedComparison(columns []string) (result string, err error) { + values := buildPreparedValues(len(columns)) + return BuildEqualsComparison(columns, values) +} + func BuildRangeComparison(columns []string, values []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { if len(columns) == 0 { return "", explodedArgs, fmt.Errorf("Got 0 columns in GetRangeComparison") @@ -121,10 +140,7 @@ func BuildRangeComparison(columns []string, values []string, args []interface{}, } func BuildRangePreparedComparison(columns []string, args []interface{}, comparisonSign ValueComparisonSign) (result string, explodedArgs []interface{}, err error) { - values := make([]string, len(columns), len(columns)) - for i := range columns { - values[i] = "?" - } + values := buildPreparedValues(len(columns)) return BuildRangeComparison(columns, values, args, comparisonSign) } @@ -135,6 +151,7 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin databaseName = EscapeName(databaseName) originalTableName = EscapeName(originalTableName) ghostTableName = EscapeName(ghostTableName) + sharedColumns = duplicateNames(sharedColumns) for i := range sharedColumns { sharedColumns[i] = EscapeName(sharedColumns[i]) } @@ -171,12 +188,8 @@ func BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName strin } func BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName string, sharedColumns []string, uniqueKey string, uniqueKeyColumns []string, rangeStartArgs, rangeEndArgs []interface{}, includeRangeStartValues bool, transactionalTable bool) (result string, explodedArgs []interface{}, err error) { - rangeStartValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - rangeEndValues := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) - for i := range uniqueKeyColumns { - rangeStartValues[i] = "?" - rangeEndValues[i] = "?" - } + rangeStartValues := buildPreparedValues(len(uniqueKeyColumns)) + rangeEndValues := buildPreparedValues(len(uniqueKeyColumns)) return BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, includeRangeStartValues, transactionalTable) } @@ -198,6 +211,7 @@ func BuildUniqueKeyRangeEndPreparedQuery(databaseName, tableName string, uniqueK } explodedArgs = append(explodedArgs, rangeExplodedArgs...) + uniqueKeyColumns = duplicateNames(uniqueKeyColumns) uniqueKeyColumnAscending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) uniqueKeyColumnDescending := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) for i := range uniqueKeyColumns { @@ -244,6 +258,7 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni databaseName = EscapeName(databaseName) tableName = EscapeName(tableName) + uniqueKeyColumns = duplicateNames(uniqueKeyColumns) uniqueKeyColumnOrder := make([]string, len(uniqueKeyColumns), len(uniqueKeyColumns)) for i := range uniqueKeyColumns { uniqueKeyColumns[i] = EscapeName(uniqueKeyColumns[i]) @@ -262,3 +277,65 @@ func buildUniqueKeyMinMaxValuesPreparedQuery(databaseName, tableName string, uni ) return query, nil } + +func BuildDMLDeleteQuery(databaseName, tableName string, originalTableColumns, uniqueKeyColumns *ColumnList, args []interface{}) (result string, uniqueKeyArgs []interface{}, err error) { + if len(args) != originalTableColumns.Len() { + return result, uniqueKeyArgs, fmt.Errorf("args count differs from table column count in BuildDMLDeleteQuery") + } + for _, column := range uniqueKeyColumns.Names { + tableOrdinal := originalTableColumns.Ordinals[column] + uniqueKeyArgs = append(uniqueKeyArgs, args[tableOrdinal]) + } + databaseName = EscapeName(databaseName) + tableName = EscapeName(tableName) + equalsComparison, err := BuildEqualsPreparedComparison(uniqueKeyColumns.Names) + result = fmt.Sprintf(` + delete /* gh-osc %s.%s */ + from + %s.%s + where + %s + `, databaseName, tableName, + databaseName, tableName, + equalsComparison, + ) + return result, uniqueKeyArgs, err +} + +func BuildDMLInsertQuery(databaseName, tableName string, originalTableColumns, sharedColumns *ColumnList, args []interface{}) (result string, sharedArgs []interface{}, err error) { + if len(args) != originalTableColumns.Len() { + return result, args, fmt.Errorf("args count differs from table column count in BuildDMLInsertQuery") + } + if !sharedColumns.IsSubsetOf(originalTableColumns) { + return result, args, fmt.Errorf("shared columns is not a subset of table columns in BuildDMLInsertQuery") + } + if sharedColumns.Len() == 0 { + return result, args, fmt.Errorf("No shared columns found in BuildDMLInsertQuery") + } + databaseName = EscapeName(databaseName) + tableName = EscapeName(tableName) + + for _, column := range sharedColumns.Names { + tableOrdinal := originalTableColumns.Ordinals[column] + sharedArgs = append(sharedArgs, args[tableOrdinal]) + } + + sharedColumnNames := duplicateNames(sharedColumns.Names) + for i := range sharedColumnNames { + sharedColumnNames[i] = EscapeName(sharedColumnNames[i]) + } + preparedValues := buildPreparedValues(sharedColumns.Len()) + + result = fmt.Sprintf(` + replace /* gh-osc %s.%s */ into + %s.%s + (%s) + values + (%s) + `, databaseName, tableName, + databaseName, tableName, + strings.Join(sharedColumnNames, ", "), + strings.Join(preparedValues, ", "), + ) + return result, sharedArgs, err +} diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go index b101bbf..5a57a3d 100644 --- a/go/sql/builder_test.go +++ b/go/sql/builder_test.go @@ -68,6 +68,15 @@ func TestBuildEqualsComparison(t *testing.T) { } } +func TestBuildEqualsPreparedComparison(t *testing.T) { + { + columns := []string{"c1", "c2"} + comparison, err := BuildEqualsPreparedComparison(columns) + test.S(t).ExpectNil(err) + test.S(t).ExpectEquals(comparison, "((`c1` = ?) and (`c2` = ?))") + } +} + func TestBuildRangeComparison(t *testing.T) { { columns := []string{"c1"} @@ -143,7 +152,7 @@ func TestBuildRangeInsertQuery(t *testing.T) { rangeStartArgs := []interface{}{3} rangeEndArgs := []interface{}{103} - query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) + query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true, false) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) @@ -162,7 +171,7 @@ func TestBuildRangeInsertQuery(t *testing.T) { rangeStartArgs := []interface{}{3, 17} rangeEndArgs := []interface{}{103, 117} - query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true) + query, explodedArgs, err := BuildRangeInsertQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartValues, rangeEndValues, rangeStartArgs, rangeEndArgs, true, false) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) @@ -186,13 +195,13 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) { rangeStartArgs := []interface{}{3, 17} rangeEndArgs := []interface{}{103, 117} - query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true) + query, explodedArgs, err := BuildRangeInsertPreparedQuery(databaseName, originalTableName, ghostTableName, sharedColumns, uniqueKey, uniqueKeyColumns, rangeStartArgs, rangeEndArgs, true, true) test.S(t).ExpectNil(err) expected := ` insert /* gh-osc mydb.tbl */ ignore into mydb.ghost (id, name, position) (select id, name, position from mydb.tbl force index (name_position_uidx) where (((name > ?) or (((name = ?)) AND (position > ?)) or ((name = ?) and (position = ?))) and ((name < ?) or (((name = ?)) AND (position < ?)) or ((name = ?) and (position = ?)))) - ) + lock in share mode ) ` test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) test.S(t).ExpectTrue(reflect.DeepEqual(explodedArgs, []interface{}{3, 3, 17, 3, 17, 103, 103, 117, 103, 117})) @@ -202,7 +211,7 @@ func TestBuildRangeInsertPreparedQuery(t *testing.T) { func TestBuildUniqueKeyRangeEndPreparedQuery(t *testing.T) { databaseName := "mydb" originalTableName := "tbl" - chunkSize := 500 + var chunkSize int64 = 500 { uniqueKeyColumns := []string{"name", "position"} rangeStartArgs := []interface{}{3, 17} @@ -262,3 +271,107 @@ func TestBuildUniqueKeyMinValuesPreparedQuery(t *testing.T) { test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) } } + +func TestBuildDMLDeleteQuery(t *testing.T) { + databaseName := "mydb" + tableName := "tbl" + tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"}) + args := []interface{}{3, "testname", "first", 17, 23} + { + uniqueKeyColumns := NewColumnList([]string{"position"}) + + query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNil(err) + expected := ` + delete /* gh-osc mydb.tbl */ + from + mydb.tbl + where + ((position = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17})) + } + { + uniqueKeyColumns := NewColumnList([]string{"name", "position"}) + + query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNil(err) + expected := ` + delete /* gh-osc mydb.tbl */ + from + mydb.tbl + where + ((name = ?) and (position = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{"testname", 17})) + } + { + uniqueKeyColumns := NewColumnList([]string{"position", "name"}) + + query, uniqueKeyArgs, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNil(err) + expected := ` + delete /* gh-osc mydb.tbl */ + from + mydb.tbl + where + ((position = ?) and (name = ?)) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(uniqueKeyArgs, []interface{}{17, "testname"})) + } + { + uniqueKeyColumns := NewColumnList([]string{"position", "name"}) + args := []interface{}{"first", 17} + + _, _, err := BuildDMLDeleteQuery(databaseName, tableName, tableColumns, uniqueKeyColumns, args) + test.S(t).ExpectNotNil(err) + } +} + +func TestBuildDMLInsertQuery(t *testing.T) { + databaseName := "mydb" + tableName := "tbl" + tableColumns := NewColumnList([]string{"id", "name", "rank", "position", "age"}) + args := []interface{}{3, "testname", "first", 17, 23} + { + sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) + query, sharedArgs, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNil(err) + expected := ` + replace /* gh-osc mydb.tbl */ + into mydb.tbl + (id, name, position, age) + values + (?, ?, ?, ?) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{3, "testname", 17, 23})) + } + { + sharedColumns := NewColumnList([]string{"position", "name", "age", "id"}) + query, sharedArgs, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNil(err) + expected := ` + replace /* gh-osc mydb.tbl */ + into mydb.tbl + (position, name, age, id) + values + (?, ?, ?, ?) + ` + test.S(t).ExpectEquals(normalizeQuery(query), normalizeQuery(expected)) + test.S(t).ExpectTrue(reflect.DeepEqual(sharedArgs, []interface{}{17, "testname", 23, 3})) + } + { + sharedColumns := NewColumnList([]string{"position", "name", "surprise", "id"}) + _, _, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNotNil(err) + } + { + sharedColumns := NewColumnList([]string{}) + _, _, err := BuildDMLInsertQuery(databaseName, tableName, tableColumns, sharedColumns, args) + test.S(t).ExpectNotNil(err) + } +} diff --git a/go/sql/types.go b/go/sql/types.go index cda02d3..e82720c 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -11,34 +11,64 @@ import ( "strings" ) -// ColumnList makes for a named list of columns -type ColumnList []string - -// ParseColumnList parses a comma delimited list of column names -func ParseColumnList(columns string) *ColumnList { - result := ColumnList(strings.Split(columns, ",")) - return &result -} - -func (this *ColumnList) String() string { - return strings.Join(*this, ",") -} - -func (this *ColumnList) Equals(other *ColumnList) bool { - return reflect.DeepEqual(*this, *other) -} - // ColumnsMap maps a column onto its ordinal position type ColumnsMap map[string]int -func NewColumnsMap(columnList ColumnList) ColumnsMap { +func NewColumnsMap(orderedNames []string) ColumnsMap { columnsMap := make(map[string]int) - for i, column := range columnList { + for i, column := range orderedNames { columnsMap[column] = i } return ColumnsMap(columnsMap) } +// ColumnList makes for a named list of columns +type ColumnList struct { + Names []string + Ordinals ColumnsMap +} + +// NewColumnList creates an object given ordered list of column names +func NewColumnList(names []string) *ColumnList { + result := &ColumnList{ + Names: names, + } + result.Ordinals = NewColumnsMap(result.Names) + return result +} + +// ParseColumnList parses a comma delimited list of column names +func ParseColumnList(columns string) *ColumnList { + result := &ColumnList{ + Names: strings.Split(columns, ","), + } + result.Ordinals = NewColumnsMap(result.Names) + return result +} + +func (this *ColumnList) String() string { + return strings.Join(this.Names, ",") +} + +func (this *ColumnList) Equals(other *ColumnList) bool { + return reflect.DeepEqual(this.Names, other.Names) +} + +// IsSubsetOf returns 'true' when column names of this list are a subset of +// another list, in arbitrary order (order agnostic) +func (this *ColumnList) IsSubsetOf(other *ColumnList) bool { + for _, column := range this.Names { + if _, exists := other.Ordinals[column]; !exists { + return false + } + } + return true +} + +func (this *ColumnList) Len() int { + return len(this.Names) +} + // UniqueKey is the combination of a key's name and columns type UniqueKey struct { Name string @@ -51,6 +81,10 @@ func (this *UniqueKey) IsPrimary() bool { return this.Name == "PRIMARY" } +func (this *UniqueKey) Len() int { + return this.Columns.Len() +} + func (this *UniqueKey) String() string { return fmt.Sprintf("%s: %s; has nullable: %+v", this.Name, this.Columns, this.HasNullable) }